Coverage for excalidraw_mcp/monitoring/circuit_breaker.py: 98%

140 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-16 08:08 -0700

1"""Circuit breaker implementation for preventing cascading failures.""" 

2 

3import asyncio 

4import logging 

5import time 

6from collections.abc import Callable 

7from dataclasses import dataclass, field 

8from enum import Enum 

9from typing import Any 

10 

11from ..config import config 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16class CircuitState(Enum): 

17 """Circuit breaker states.""" 

18 

19 CLOSED = "closed" # Normal operation 

20 OPEN = "open" # Failing, blocking requests 

21 HALF_OPEN = "half_open" # Testing recovery 

22 

23 

24@dataclass 

25class CircuitStats: 

26 """Circuit breaker statistics.""" 

27 

28 total_calls: int = 0 

29 successful_calls: int = 0 

30 failed_calls: int = 0 

31 rejected_calls: int = 0 

32 last_failure_time: float | None = None 

33 last_success_time: float | None = None 

34 state_change_time: float = field(default_factory=time.time) 

35 

36 

37class CircuitBreakerError(Exception): 

38 """Exception raised when circuit breaker is open.""" 

39 

40 pass 

41 

42 

43class CircuitBreaker: 

44 """Circuit breaker for canvas server operations.""" 

45 

46 def __init__( 

47 self, 

48 failure_threshold: int | None = None, 

49 recovery_timeout: int | None = None, 

50 half_open_max_calls: int | None = None, 

51 ): 

52 # Use config values or provided overrides 

53 self._failure_threshold = ( 

54 failure_threshold or config.monitoring.circuit_failure_threshold 

55 ) 

56 self._recovery_timeout = ( 

57 recovery_timeout or config.monitoring.circuit_recovery_timeout_seconds 

58 ) 

59 self._half_open_max_calls = ( 

60 half_open_max_calls or config.monitoring.circuit_half_open_max_calls 

61 ) 

62 

63 self._state = CircuitState.CLOSED 

64 self._stats = CircuitStats() 

65 self._half_open_calls = 0 

66 self._lock = asyncio.Lock() 

67 

68 async def call(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: 

69 """Execute function with circuit breaker protection.""" 

70 async with self._lock: 

71 # Check if we should allow the call 

72 if not self._should_allow_call(): 

73 self._stats.rejected_calls += 1 

74 raise CircuitBreakerError( 

75 f"Circuit breaker is {self._state.value}, blocking call" 

76 ) 

77 

78 # Track the call 

79 self._stats.total_calls += 1 

80 

81 if self._state == CircuitState.HALF_OPEN: 

82 self._half_open_calls += 1 

83 

84 # Execute the function 

85 time.time() 

86 try: 

87 if asyncio.iscoroutinefunction(func): 

88 result = await func(*args, **kwargs) 

89 else: 

90 result = func(*args, **kwargs) 

91 

92 # Handle success 

93 await self._on_success() 

94 return result 

95 

96 except Exception as e: 

97 # Handle failure 

98 await self._on_failure(e) 

99 raise 

100 

101 def _should_allow_call(self) -> bool: 

102 """Determine if a call should be allowed.""" 

103 if self._state == CircuitState.CLOSED: 

104 return True 

105 

106 elif self._state == CircuitState.OPEN: 

107 # Check if enough time has passed to try recovery 

108 if self._should_attempt_recovery(): 

109 self._transition_to_half_open() 

110 return True 

111 return False 

112 

113 elif self._state == CircuitState.HALF_OPEN: 

114 # Allow limited calls to test recovery 

115 return self._half_open_calls < self._half_open_max_calls 

116 

117 return False # type: ignore 

118 

119 def _should_attempt_recovery(self) -> bool: 

120 """Check if we should attempt recovery from open state.""" 

121 if not self._stats.last_failure_time: 

122 return False 

123 

124 return time.time() - self._stats.last_failure_time >= self._recovery_timeout 

125 

126 async def _on_success(self) -> None: 

127 """Handle successful call.""" 

128 async with self._lock: 

129 self._stats.successful_calls += 1 

130 self._stats.last_success_time = time.time() 

131 

132 if self._state == CircuitState.HALF_OPEN: 

133 # If we've had enough successful calls, close the circuit 

134 if ( 

135 self._half_open_calls >= self._half_open_max_calls 

136 or self._stats.successful_calls >= self._half_open_max_calls 

137 ): 

138 self._transition_to_closed() 

139 

140 logger.debug("Circuit breaker: successful call") 

141 

142 async def _on_failure(self, error: Exception) -> None: 

143 """Handle failed call.""" 

144 async with self._lock: 

145 self._stats.failed_calls += 1 

146 self._stats.last_failure_time = time.time() 

147 

148 # Check if we should open the circuit 

149 if self._state == CircuitState.CLOSED: 

150 if self._stats.failed_calls >= self._failure_threshold: 

151 self._transition_to_open() 

152 

153 elif self._state == CircuitState.HALF_OPEN: 

154 # Any failure in half-open state should open the circuit 

155 self._transition_to_open() 

156 

157 logger.warning( 

158 f"Circuit breaker: call failed with {type(error).__name__}: {error}" 

159 ) 

160 

161 def _transition_to_open(self) -> None: 

162 """Transition circuit to open state.""" 

163 previous_state = self._state 

164 self._state = CircuitState.OPEN 

165 self._stats.state_change_time = time.time() 

166 self._half_open_calls = 0 

167 

168 logger.warning( 

169 f"Circuit breaker opened (was {previous_state.value}). " 

170 f"Failed calls: {self._stats.failed_calls}/{self._stats.total_calls}" 

171 ) 

172 

173 def _transition_to_half_open(self) -> None: 

174 """Transition circuit to half-open state.""" 

175 previous_state = self._state 

176 self._state = CircuitState.HALF_OPEN 

177 self._stats.state_change_time = time.time() 

178 self._half_open_calls = 0 

179 

180 logger.info( 

181 f"Circuit breaker transitioning to half-open (was {previous_state.value})" 

182 ) 

183 

184 def _transition_to_closed(self) -> None: 

185 """Transition circuit to closed state.""" 

186 previous_state = self._state 

187 self._state = CircuitState.CLOSED 

188 self._stats.state_change_time = time.time() 

189 self._half_open_calls = 0 

190 

191 # Reset failure count on successful recovery 

192 self._stats.failed_calls = 0 

193 

194 logger.info( 

195 f"Circuit breaker closed (was {previous_state.value}). Recovery successful." 

196 ) 

197 

198 async def force_open(self) -> None: 

199 """Manually force circuit to open state.""" 

200 async with self._lock: 

201 self._transition_to_open() 

202 logger.warning("Circuit breaker manually forced open") 

203 

204 async def force_close(self) -> None: 

205 """Manually force circuit to closed state.""" 

206 async with self._lock: 

207 self._transition_to_closed() 

208 logger.info("Circuit breaker manually forced closed") 

209 

210 async def reset(self) -> None: 

211 """Reset circuit breaker to initial state.""" 

212 async with self._lock: 

213 self._state = CircuitState.CLOSED 

214 self._stats = CircuitStats() 

215 self._half_open_calls = 0 

216 logger.info("Circuit breaker reset to initial state") 

217 

218 @property 

219 def state(self) -> CircuitState: 

220 """Get current circuit state.""" 

221 return self._state 

222 

223 @property 

224 def is_open(self) -> bool: 

225 """Check if circuit is open.""" 

226 return self._state == CircuitState.OPEN 

227 

228 @property 

229 def is_closed(self) -> bool: 

230 """Check if circuit is closed.""" 

231 return self._state == CircuitState.CLOSED 

232 

233 @property 

234 def is_half_open(self) -> bool: 

235 """Check if circuit is half-open.""" 

236 return self._state == CircuitState.HALF_OPEN 

237 

238 def get_stats(self) -> dict[str, Any]: 

239 """Get circuit breaker statistics.""" 

240 failure_rate = ( 

241 self._stats.failed_calls / max(self._stats.total_calls, 1) * 100 

242 if self._stats.total_calls > 0 

243 else 0 

244 ) 

245 

246 return { 

247 "state": self._state.value, 

248 "total_calls": self._stats.total_calls, 

249 "successful_calls": self._stats.successful_calls, 

250 "failed_calls": self._stats.failed_calls, 

251 "rejected_calls": self._stats.rejected_calls, 

252 "failure_rate_percent": round(failure_rate, 2), 

253 "last_failure_time": self._stats.last_failure_time, 

254 "last_success_time": self._stats.last_success_time, 

255 "state_change_time": self._stats.state_change_time, 

256 "half_open_calls": self._half_open_calls, 

257 "failure_threshold": self._failure_threshold, 

258 "recovery_timeout_seconds": self._recovery_timeout, 

259 } 

260 

261 def is_healthy(self) -> bool: 

262 """Check if circuit breaker indicates healthy state.""" 

263 return self._state == CircuitState.CLOSED 

264 

265 def get_time_until_recovery(self) -> float | None: 

266 """Get time in seconds until recovery attempt (if in open state).""" 

267 if self._state != CircuitState.OPEN or not self._stats.last_failure_time: 

268 return None 

269 

270 elapsed = time.time() - self._stats.last_failure_time 

271 remaining = self._recovery_timeout - elapsed 

272 

273 return max(0, remaining)