Coverage for excalidraw_mcp/config.py: 88%

260 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-16 08:08 -0700

1"""Configuration management for Excalidraw MCP server.""" 

2 

3import os 

4from dataclasses import dataclass, field 

5from pathlib import Path 

6from typing import Any 

7from urllib.parse import urlparse 

8 

9try: 

10 import tomli 

11 

12 _tomli: Any = tomli 

13except ImportError: 

14 _tomli = None 

15 

16 

17@dataclass 

18class SecurityConfig: 

19 """Security-related configuration.""" 

20 

21 # Authentication 

22 auth_enabled: bool = False # Disabled by default for development 

23 jwt_secret: str = "" 

24 jwt_algorithm: str = "HS256" 

25 token_expiration_hours: int = 24 

26 

27 # CORS 

28 allowed_origins: list[str] = field(default_factory=list) 

29 cors_credentials: bool = True 

30 cors_methods: list[str] = field(default_factory=list) 

31 cors_headers: list[str] = field(default_factory=list) 

32 

33 # Rate limiting 

34 rate_limit_window_minutes: int = 15 

35 rate_limit_max_requests: int = 100 

36 

37 def __post_init__(self) -> None: 

38 pass 

39 

40 

41@dataclass 

42class ServerConfig: 

43 """Server configuration settings.""" 

44 

45 # Express server 

46 express_url: str = "http://localhost:3031" 

47 express_host: str = "localhost" 

48 express_port: int = 3031 

49 

50 # Health checks 

51 health_check_timeout_seconds: float = 5.0 

52 health_check_interval_seconds: int = 30 

53 health_check_max_failures: int = 3 

54 

55 # Sync operations 

56 sync_operation_timeout_seconds: float = 10.0 

57 sync_retry_attempts: int = 3 

58 sync_retry_delay_seconds: float = 1.0 

59 sync_retry_max_delay_seconds: float = 30.0 

60 sync_retry_exponential_base: float = 2.0 

61 sync_retry_jitter: bool = True 

62 

63 # Process management 

64 canvas_auto_start: bool = True 

65 startup_timeout_seconds: int = 30 

66 startup_retry_delay_seconds: float = 1.0 

67 graceful_shutdown_timeout_seconds: float = 10.0 

68 

69 def __post_init__(self) -> None: 

70 """Validate and parse configuration.""" 

71 try: 

72 parsed = urlparse(self.express_url) 

73 if parsed.hostname: 

74 self.express_host = parsed.hostname 

75 if parsed.port: 

76 self.express_port = parsed.port 

77 except ValueError as e: 

78 # Re-raise with our custom message 

79 if "Port out of range" in str(e): 

80 raise ValueError("Express port must be between 1 and 65535") 

81 raise 

82 

83 

84@dataclass 

85class PerformanceConfig: 

86 """Performance-related configuration.""" 

87 

88 # Connection pooling 

89 http_pool_connections: int = 10 

90 http_pool_maxsize: int = 20 

91 http_keep_alive: bool = True 

92 

93 # WebSocket 

94 websocket_ping_interval: int = 30 

95 websocket_ping_timeout: int = 10 

96 websocket_close_timeout: int = 10 

97 

98 # Memory management 

99 max_elements_per_canvas: int = 10000 

100 element_cache_ttl_hours: int = 24 

101 memory_cleanup_interval_minutes: int = 60 

102 

103 # Message batching 

104 websocket_batch_size: int = 50 

105 websocket_batch_timeout_ms: int = 100 

106 

107 # Query optimization 

108 enable_spatial_indexing: bool = True 

109 query_result_limit: int = 1000 

110 

111 

112@dataclass 

113class MonitoringConfig: 

114 """Monitoring and observability configuration.""" 

115 

116 # Core monitoring 

117 enabled: bool = True 

118 health_check_interval_seconds: int = 10 

119 health_check_timeout_seconds: float = 3.0 

120 consecutive_failure_threshold: int = 3 

121 

122 # Metrics collection 

123 metrics_enabled: bool = True 

124 metrics_collection_interval_seconds: int = 30 

125 memory_monitoring_enabled: bool = True 

126 performance_metrics_enabled: bool = True 

127 

128 # Circuit breaker 

129 circuit_breaker_enabled: bool = True 

130 circuit_failure_threshold: int = 5 

131 circuit_recovery_timeout_seconds: int = 60 

132 circuit_half_open_max_calls: int = 3 

133 

134 # Alerting 

135 alerting_enabled: bool = True 

136 alert_channels: list[str] = field(default_factory=list) 

137 alert_deduplication_window_seconds: int = 300 

138 alert_throttle_max_per_hour: int = 10 

139 

140 # Resource monitoring 

141 resource_monitoring_enabled: bool = True 

142 cpu_threshold_percent: float = 80.0 

143 memory_threshold_percent: float = 85.0 

144 memory_leak_detection_enabled: bool = True 

145 

146 # Request tracing 

147 request_tracing_enabled: bool = True 

148 trace_sampling_rate: float = 1.0 

149 trace_headers_enabled: bool = True 

150 

151 def __post_init__(self) -> None: 

152 pass 

153 

154 # Validate thresholds 

155 if not (0 < self.cpu_threshold_percent <= 100): 

156 raise ValueError("CPU threshold must be between 0 and 100") 

157 if not (0 < self.memory_threshold_percent <= 100): 

158 raise ValueError("Memory threshold must be between 0 and 100") 

159 if not (0 <= self.trace_sampling_rate <= 1.0): 

160 raise ValueError("Trace sampling rate must be between 0 and 1") 

161 

162 

163@dataclass 

164class LoggingConfig: 

165 """Logging configuration.""" 

166 

167 level: str = "INFO" 

168 format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 

169 structured_logging: bool = True 

170 json_format: bool = False 

171 file_path: str | None = None 

172 max_file_size_mb: int = 100 

173 backup_count: int = 5 

174 

175 # Security logging 

176 audit_enabled: bool = True 

177 audit_file_path: str | None = None 

178 sensitive_fields: list[str] = field(default_factory=list) 

179 

180 # Correlation tracking 

181 correlation_id_enabled: bool = True 

182 correlation_header: str = "X-Correlation-ID" 

183 

184 def __post_init__(self) -> None: 

185 pass 

186 

187 

188@dataclass 

189class MCPConfig: 

190 """MCP server configuration.""" 

191 

192 http_enabled: bool = False 

193 http_host: str = "127.0.0.1" 

194 http_port: int = 3030 

195 canvas_server_url: str = "http://localhost:3031" 

196 

197 

198class Config: 

199 """Main configuration class.""" 

200 

201 def __init__(self) -> None: 

202 self.security = SecurityConfig() 

203 self.server = ServerConfig() 

204 self.performance = PerformanceConfig() 

205 self.logging = LoggingConfig() 

206 self.monitoring = MonitoringConfig() 

207 self.mcp = MCPConfig() 

208 self._load_from_pyproject() 

209 self._load_from_environment() 

210 self._validate() 

211 

212 def _load_from_pyproject(self) -> None: 

213 """Load MCP configuration from pyproject.toml.""" 

214 project_path = Path.cwd() 

215 pyproject_path = project_path / "pyproject.toml" 

216 

217 if not pyproject_path.exists() or not _tomli: 

218 return 

219 

220 from contextlib import suppress 

221 

222 with suppress(Exception): 

223 with pyproject_path.open("rb") as f: 

224 pyproject_data = _tomli.load(f) 

225 

226 mcp_config = pyproject_data.get("tool", {}).get("excalidraw-mcp", {}) 

227 

228 if mcp_config: 

229 self.mcp.http_enabled = mcp_config.get( 

230 "http_enabled", self.mcp.http_enabled 

231 ) 

232 self.mcp.http_host = mcp_config.get("mcp_http_host", self.mcp.http_host) 

233 self.mcp.http_port = mcp_config.get("mcp_http_port", self.mcp.http_port) 

234 self.mcp.canvas_server_url = mcp_config.get( 

235 "canvas_server_url", self.mcp.canvas_server_url 

236 ) 

237 

238 def _load_security_config_from_environment(self) -> None: 

239 """Load security configuration from environment variables.""" 

240 self.security.auth_enabled = ( 

241 os.getenv("AUTH_ENABLED", "false").lower() == "true" 

242 ) 

243 self.security.jwt_secret = os.getenv("JWT_SECRET", "") 

244 

245 origins_env = os.getenv("ALLOWED_ORIGINS") 

246 if origins_env: 

247 self.security.allowed_origins = [o.strip() for o in origins_env.split(",")] 

248 

249 def _load_server_config_from_environment(self) -> None: 

250 """Load server configuration from environment variables.""" 

251 self.server.express_url = os.getenv( 

252 "EXPRESS_SERVER_URL", self.server.express_url 

253 ) 

254 self.server.canvas_auto_start = ( 

255 os.getenv("CANVAS_AUTO_START", "true").lower() != "false" 

256 ) 

257 

258 # Retry configuration from environment 

259 from contextlib import suppress 

260 

261 sync_retry_attempts = os.getenv("SYNC_RETRY_ATTEMPTS") 

262 if sync_retry_attempts: 

263 with suppress(ValueError): 

264 self.server.sync_retry_attempts = int(sync_retry_attempts) 

265 

266 sync_retry_delay = os.getenv("SYNC_RETRY_DELAY_SECONDS") 

267 if sync_retry_delay: 

268 with suppress(ValueError): 

269 self.server.sync_retry_delay_seconds = float(sync_retry_delay) 

270 

271 sync_retry_max_delay = os.getenv("SYNC_RETRY_MAX_DELAY_SECONDS") 

272 if sync_retry_max_delay: 

273 with suppress(ValueError): 

274 self.server.sync_retry_max_delay_seconds = float(sync_retry_max_delay) 

275 

276 sync_retry_base = os.getenv("SYNC_RETRY_EXPONENTIAL_BASE") 

277 if sync_retry_base: 

278 with suppress(ValueError): 

279 self.server.sync_retry_exponential_base = float(sync_retry_base) 

280 

281 sync_retry_jitter = os.getenv("SYNC_RETRY_JITTER") 

282 if sync_retry_jitter: 

283 self.server.sync_retry_jitter = sync_retry_jitter.lower() == "true" 

284 

285 # Parse the updated URL 

286 self.server.__post_init__() 

287 

288 def _load_performance_config_from_environment(self) -> None: 

289 """Load performance configuration from environment variables.""" 

290 from contextlib import suppress 

291 

292 # Performance config 

293 max_elements = os.getenv("MAX_ELEMENTS") 

294 if max_elements: 

295 with suppress(ValueError): 

296 self.performance.max_elements_per_canvas = int(max_elements) 

297 

298 def _load_logging_config_from_environment(self) -> None: 

299 """Load logging configuration from environment variables.""" 

300 self.logging.level = os.getenv("LOG_LEVEL", self.logging.level) 

301 self.logging.structured_logging = ( 

302 os.getenv("STRUCTURED_LOGGING", "true").lower() == "true" 

303 ) 

304 self.logging.json_format = os.getenv("JSON_LOGGING", "false").lower() == "true" 

305 self.logging.file_path = os.getenv("LOG_FILE") 

306 self.logging.audit_file_path = os.getenv("AUDIT_LOG_FILE") 

307 

308 def _load_monitoring_config_from_environment(self) -> None: 

309 """Load monitoring configuration from environment variables.""" 

310 from contextlib import suppress 

311 

312 # Monitoring config 

313 self.monitoring.enabled = ( 

314 os.getenv("MONITORING_ENABLED", "true").lower() == "true" 

315 ) 

316 self.monitoring.metrics_enabled = ( 

317 os.getenv("METRICS_ENABLED", "true").lower() == "true" 

318 ) 

319 self.monitoring.alerting_enabled = ( 

320 os.getenv("ALERTING_ENABLED", "true").lower() == "true" 

321 ) 

322 self.monitoring.circuit_breaker_enabled = ( 

323 os.getenv("CIRCUIT_BREAKER_ENABLED", "true").lower() == "true" 

324 ) 

325 

326 health_check_interval = os.getenv("HEALTH_CHECK_INTERVAL") 

327 if health_check_interval: 

328 with suppress(ValueError): 

329 self.monitoring.health_check_interval_seconds = int( 

330 health_check_interval 

331 ) 

332 

333 cpu_threshold = os.getenv("CPU_THRESHOLD") 

334 if cpu_threshold: 

335 with suppress(ValueError): 

336 self.monitoring.cpu_threshold_percent = float(cpu_threshold) 

337 

338 memory_threshold = os.getenv("MEMORY_THRESHOLD") 

339 if memory_threshold: 

340 with suppress(ValueError): 

341 self.monitoring.memory_threshold_percent = float(memory_threshold) 

342 

343 def _load_from_environment(self) -> None: 

344 """Load configuration from environment variables.""" 

345 self._load_security_config_from_environment() 

346 self._load_server_config_from_environment() 

347 self._load_performance_config_from_environment() 

348 self._load_logging_config_from_environment() 

349 self._load_monitoring_config_from_environment() 

350 

351 def _validate_security_config(self, errors: list[str]) -> None: 

352 """Validate security configuration values.""" 

353 if self.security.auth_enabled and not self.security.jwt_secret: 

354 errors.append("JWT_SECRET is required when authentication is enabled") 

355 

356 if self.security.token_expiration_hours <= 0: 

357 errors.append("Token expiration must be positive") 

358 

359 def _validate_server_config(self, errors: list[str]) -> None: 

360 """Validate server configuration values.""" 

361 if self.server.express_port <= 0 or self.server.express_port > 65535: 

362 errors.append("Express port must be between 1 and 65535") 

363 

364 if self.server.health_check_timeout_seconds <= 0: 

365 errors.append("Health check timeout must be positive") 

366 

367 if self.server.sync_retry_attempts < 0: 

368 errors.append("Sync retry attempts must be non-negative") 

369 

370 if self.server.sync_retry_delay_seconds <= 0: 

371 errors.append("Sync retry delay must be positive") 

372 

373 if self.server.sync_retry_max_delay_seconds <= 0: 

374 errors.append("Sync retry max delay must be positive") 

375 

376 if self.server.sync_retry_exponential_base <= 1.0: 

377 errors.append("Sync retry exponential base must be greater than 1.0") 

378 

379 def _validate_performance_config(self, errors: list[str]) -> None: 

380 """Validate performance configuration values.""" 

381 if self.performance.max_elements_per_canvas <= 0: 

382 errors.append("Max elements per canvas must be positive") 

383 

384 if self.performance.websocket_batch_size <= 0: 

385 errors.append("WebSocket batch size must be positive") 

386 

387 def _validate_monitoring_config(self, errors: list[str]) -> None: 

388 """Validate monitoring configuration values.""" 

389 if self.monitoring.health_check_interval_seconds <= 0: 

390 errors.append("Health check interval must be positive") 

391 

392 if self.monitoring.consecutive_failure_threshold <= 0: 

393 errors.append("Consecutive failure threshold must be positive") 

394 

395 if self.monitoring.circuit_failure_threshold <= 0: 

396 errors.append("Circuit breaker failure threshold must be positive") 

397 

398 def _validate(self) -> None: 

399 """Validate configuration values.""" 

400 errors: list[str] = [] 

401 self._validate_security_config(errors) 

402 self._validate_server_config(errors) 

403 self._validate_performance_config(errors) 

404 self._validate_monitoring_config(errors) 

405 

406 if errors: 

407 raise ValueError(f"Configuration validation failed: {'; '.join(errors)}") 

408 

409 @property 

410 def is_development(self) -> bool: 

411 """Check if running in development mode.""" 

412 return os.getenv("ENVIRONMENT", "development").lower() == "development" 

413 

414 @property 

415 def is_production(self) -> bool: 

416 """Check if running in production mode.""" 

417 return os.getenv("ENVIRONMENT", "development").lower() == "production" 

418 

419 

420# Global configuration instance 

421config = Config()