Coverage for excalidraw_mcp/monitoring/circuit_breaker.py: 98%
140 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-16 08:08 -0700
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-16 08:08 -0700
1"""Circuit breaker implementation for preventing cascading failures."""
3import asyncio
4import logging
5import time
6from collections.abc import Callable
7from dataclasses import dataclass, field
8from enum import Enum
9from typing import Any
11from ..config import config
13logger = logging.getLogger(__name__)
16class CircuitState(Enum):
17 """Circuit breaker states."""
19 CLOSED = "closed" # Normal operation
20 OPEN = "open" # Failing, blocking requests
21 HALF_OPEN = "half_open" # Testing recovery
24@dataclass
25class CircuitStats:
26 """Circuit breaker statistics."""
28 total_calls: int = 0
29 successful_calls: int = 0
30 failed_calls: int = 0
31 rejected_calls: int = 0
32 last_failure_time: float | None = None
33 last_success_time: float | None = None
34 state_change_time: float = field(default_factory=time.time)
37class CircuitBreakerError(Exception):
38 """Exception raised when circuit breaker is open."""
40 pass
43class CircuitBreaker:
44 """Circuit breaker for canvas server operations."""
46 def __init__(
47 self,
48 failure_threshold: int | None = None,
49 recovery_timeout: int | None = None,
50 half_open_max_calls: int | None = None,
51 ):
52 # Use config values or provided overrides
53 self._failure_threshold = (
54 failure_threshold or config.monitoring.circuit_failure_threshold
55 )
56 self._recovery_timeout = (
57 recovery_timeout or config.monitoring.circuit_recovery_timeout_seconds
58 )
59 self._half_open_max_calls = (
60 half_open_max_calls or config.monitoring.circuit_half_open_max_calls
61 )
63 self._state = CircuitState.CLOSED
64 self._stats = CircuitStats()
65 self._half_open_calls = 0
66 self._lock = asyncio.Lock()
68 async def call(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
69 """Execute function with circuit breaker protection."""
70 async with self._lock:
71 # Check if we should allow the call
72 if not self._should_allow_call():
73 self._stats.rejected_calls += 1
74 raise CircuitBreakerError(
75 f"Circuit breaker is {self._state.value}, blocking call"
76 )
78 # Track the call
79 self._stats.total_calls += 1
81 if self._state == CircuitState.HALF_OPEN:
82 self._half_open_calls += 1
84 # Execute the function
85 time.time()
86 try:
87 if asyncio.iscoroutinefunction(func):
88 result = await func(*args, **kwargs)
89 else:
90 result = func(*args, **kwargs)
92 # Handle success
93 await self._on_success()
94 return result
96 except Exception as e:
97 # Handle failure
98 await self._on_failure(e)
99 raise
101 def _should_allow_call(self) -> bool:
102 """Determine if a call should be allowed."""
103 if self._state == CircuitState.CLOSED:
104 return True
106 elif self._state == CircuitState.OPEN:
107 # Check if enough time has passed to try recovery
108 if self._should_attempt_recovery():
109 self._transition_to_half_open()
110 return True
111 return False
113 elif self._state == CircuitState.HALF_OPEN:
114 # Allow limited calls to test recovery
115 return self._half_open_calls < self._half_open_max_calls
117 return False # type: ignore
119 def _should_attempt_recovery(self) -> bool:
120 """Check if we should attempt recovery from open state."""
121 if not self._stats.last_failure_time:
122 return False
124 return time.time() - self._stats.last_failure_time >= self._recovery_timeout
126 async def _on_success(self) -> None:
127 """Handle successful call."""
128 async with self._lock:
129 self._stats.successful_calls += 1
130 self._stats.last_success_time = time.time()
132 if self._state == CircuitState.HALF_OPEN:
133 # If we've had enough successful calls, close the circuit
134 if (
135 self._half_open_calls >= self._half_open_max_calls
136 or self._stats.successful_calls >= self._half_open_max_calls
137 ):
138 self._transition_to_closed()
140 logger.debug("Circuit breaker: successful call")
142 async def _on_failure(self, error: Exception) -> None:
143 """Handle failed call."""
144 async with self._lock:
145 self._stats.failed_calls += 1
146 self._stats.last_failure_time = time.time()
148 # Check if we should open the circuit
149 if self._state == CircuitState.CLOSED:
150 if self._stats.failed_calls >= self._failure_threshold:
151 self._transition_to_open()
153 elif self._state == CircuitState.HALF_OPEN:
154 # Any failure in half-open state should open the circuit
155 self._transition_to_open()
157 logger.warning(
158 f"Circuit breaker: call failed with {type(error).__name__}: {error}"
159 )
161 def _transition_to_open(self) -> None:
162 """Transition circuit to open state."""
163 previous_state = self._state
164 self._state = CircuitState.OPEN
165 self._stats.state_change_time = time.time()
166 self._half_open_calls = 0
168 logger.warning(
169 f"Circuit breaker opened (was {previous_state.value}). "
170 f"Failed calls: {self._stats.failed_calls}/{self._stats.total_calls}"
171 )
173 def _transition_to_half_open(self) -> None:
174 """Transition circuit to half-open state."""
175 previous_state = self._state
176 self._state = CircuitState.HALF_OPEN
177 self._stats.state_change_time = time.time()
178 self._half_open_calls = 0
180 logger.info(
181 f"Circuit breaker transitioning to half-open (was {previous_state.value})"
182 )
184 def _transition_to_closed(self) -> None:
185 """Transition circuit to closed state."""
186 previous_state = self._state
187 self._state = CircuitState.CLOSED
188 self._stats.state_change_time = time.time()
189 self._half_open_calls = 0
191 # Reset failure count on successful recovery
192 self._stats.failed_calls = 0
194 logger.info(
195 f"Circuit breaker closed (was {previous_state.value}). Recovery successful."
196 )
198 async def force_open(self) -> None:
199 """Manually force circuit to open state."""
200 async with self._lock:
201 self._transition_to_open()
202 logger.warning("Circuit breaker manually forced open")
204 async def force_close(self) -> None:
205 """Manually force circuit to closed state."""
206 async with self._lock:
207 self._transition_to_closed()
208 logger.info("Circuit breaker manually forced closed")
210 async def reset(self) -> None:
211 """Reset circuit breaker to initial state."""
212 async with self._lock:
213 self._state = CircuitState.CLOSED
214 self._stats = CircuitStats()
215 self._half_open_calls = 0
216 logger.info("Circuit breaker reset to initial state")
218 @property
219 def state(self) -> CircuitState:
220 """Get current circuit state."""
221 return self._state
223 @property
224 def is_open(self) -> bool:
225 """Check if circuit is open."""
226 return self._state == CircuitState.OPEN
228 @property
229 def is_closed(self) -> bool:
230 """Check if circuit is closed."""
231 return self._state == CircuitState.CLOSED
233 @property
234 def is_half_open(self) -> bool:
235 """Check if circuit is half-open."""
236 return self._state == CircuitState.HALF_OPEN
238 def get_stats(self) -> dict[str, Any]:
239 """Get circuit breaker statistics."""
240 failure_rate = (
241 self._stats.failed_calls / max(self._stats.total_calls, 1) * 100
242 if self._stats.total_calls > 0
243 else 0
244 )
246 return {
247 "state": self._state.value,
248 "total_calls": self._stats.total_calls,
249 "successful_calls": self._stats.successful_calls,
250 "failed_calls": self._stats.failed_calls,
251 "rejected_calls": self._stats.rejected_calls,
252 "failure_rate_percent": round(failure_rate, 2),
253 "last_failure_time": self._stats.last_failure_time,
254 "last_success_time": self._stats.last_success_time,
255 "state_change_time": self._stats.state_change_time,
256 "half_open_calls": self._half_open_calls,
257 "failure_threshold": self._failure_threshold,
258 "recovery_timeout_seconds": self._recovery_timeout,
259 }
261 def is_healthy(self) -> bool:
262 """Check if circuit breaker indicates healthy state."""
263 return self._state == CircuitState.CLOSED
265 def get_time_until_recovery(self) -> float | None:
266 """Get time in seconds until recovery attempt (if in open state)."""
267 if self._state != CircuitState.OPEN or not self._stats.last_failure_time:
268 return None
270 elapsed = time.time() - self._stats.last_failure_time
271 remaining = self._recovery_timeout - elapsed
273 return max(0, remaining)