Coverage for fastblocks/actions/gather/middleware.py: 51%

192 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-09 00:47 -0700

1"""Middleware gathering and stack building functionality.""" 

2 

3import typing as t 

4from contextlib import suppress 

5from enum import Enum 

6 

7from acb.debug import debug 

8from starlette.middleware import Middleware 

9from starlette.middleware.errors import ServerErrorMiddleware 

10from starlette.middleware.exceptions import ExceptionMiddleware 

11 

12from .strategies import GatherStrategy, gather_with_strategy 

13 

14 

15class MiddlewarePosition(Enum): 

16 SECURITY = 0 

17 CORS = 1 

18 COMPRESSION = 2 

19 SESSIONS = 3 

20 AUTHENTICATION = 4 

21 CACHING = 5 

22 CUSTOM = 6 

23 

24 

25class MiddlewareGatherResult: 

26 def __init__( 

27 self, 

28 *, 

29 user_middleware: list[Middleware] | None = None, 

30 system_middleware: dict[MiddlewarePosition, t.Any] | None = None, 

31 middleware_stack: list[Middleware] | None = None, 

32 errors: list[Exception] | None = None, 

33 ) -> None: 

34 self.user_middleware = user_middleware if user_middleware is not None else [] 

35 self.system_middleware = ( 

36 system_middleware if system_middleware is not None else {} 

37 ) 

38 self.middleware_stack = middleware_stack if middleware_stack is not None else [] 

39 self.errors = errors if errors is not None else [] 

40 

41 @property 

42 def total_middleware(self) -> int: 

43 return len(self.user_middleware) + len(self.system_middleware) 

44 

45 @property 

46 def has_errors(self) -> bool: 

47 return len(self.errors) > 0 

48 

49 

50async def gather_middleware( 

51 *, 

52 user_middleware: list[Middleware] | None = None, 

53 system_overrides: dict[MiddlewarePosition, t.Any] | None = None, 

54 include_defaults: bool = True, 

55 debug_mode: bool = False, 

56 error_handler: t.Any | None = None, 

57 strategy: GatherStrategy | None = None, 

58) -> MiddlewareGatherResult: 

59 if strategy is None: 

60 strategy = GatherStrategy() 

61 

62 if user_middleware is None: 

63 user_middleware = [] 

64 

65 if system_overrides is None: 

66 system_overrides = {} 

67 

68 result = MiddlewareGatherResult( 

69 user_middleware=user_middleware, 

70 system_middleware=system_overrides, 

71 ) 

72 

73 tasks: list[t.Coroutine[t.Any, t.Any, t.Any]] = [] 

74 

75 if include_defaults: 

76 tasks.append(_gather_default_middleware()) 

77 

78 tasks.extend( 

79 ( 

80 _gather_custom_middleware(), 

81 _build_middleware_stack( 

82 user_middleware, 

83 system_overrides, 

84 include_defaults, 

85 debug_mode, 

86 error_handler, 

87 ), 

88 ), 

89 ) 

90 

91 gather_result = await gather_with_strategy( 

92 tasks, 

93 strategy, 

94 cache_key=f"middleware:{include_defaults}:{debug_mode}", 

95 ) 

96 

97 for i, success in enumerate(gather_result.success): 

98 if i == 0 and include_defaults: 

99 result.system_middleware.update(success) 

100 elif i == 1: 

101 result.user_middleware.extend(success) 

102 elif i == 2: 

103 result.middleware_stack = success 

104 

105 result.errors.extend(gather_result.errors) 

106 

107 debug(f"Gathered {result.total_middleware} middleware components") 

108 

109 return result 

110 

111 

112async def _gather_default_middleware() -> dict[MiddlewarePosition, t.Any]: 

113 try: 

114 from fastblocks.middleware import middlewares 

115 

116 default_middleware_list = middlewares() 

117 middleware_map = {} 

118 for i, middleware in enumerate(default_middleware_list): 

119 if i < len(MiddlewarePosition): 

120 position = list(MiddlewarePosition)[i] 

121 middleware_map[position] = middleware 

122 debug(f"Gathered {len(middleware_map)} default middleware components") 

123 return middleware_map 

124 except Exception as e: 

125 debug(f"Error gathering default middleware: {e}") 

126 return {} 

127 

128 

129async def _gather_custom_middleware() -> list[Middleware]: 

130 custom_middleware = [] 

131 with suppress(Exception): 

132 from acb.depends import depends 

133 

134 config = depends.get("config") 

135 if hasattr(config, "middleware") and hasattr(config.middleware, "custom"): 

136 for middleware_path in config.middleware.custom: 

137 try: 

138 module_path, class_name = middleware_path.rsplit(".", 1) 

139 module = __import__(module_path, fromlist=[class_name]) 

140 middleware_class = getattr(module, class_name) 

141 middleware = Middleware(middleware_class) 

142 custom_middleware.append(middleware) 

143 debug(f"Added custom middleware: {class_name}") 

144 except Exception as e: 

145 debug(f"Error loading custom middleware {middleware_path}: {e}") 

146 

147 return custom_middleware 

148 

149 

150async def _build_middleware_stack( 

151 user_middleware: list[Middleware], 

152 system_overrides: dict[MiddlewarePosition, t.Any], 

153 include_defaults: bool, 

154 debug_mode: bool, 

155 error_handler: t.Any, 

156) -> list[Middleware]: 

157 stack = [] 

158 

159 stack.append(Middleware(ExceptionMiddleware, debug=debug_mode)) 

160 

161 stack.extend(user_middleware) 

162 

163 if include_defaults: 

164 _add_system_middleware(stack, system_overrides) 

165 

166 _add_error_handler_middleware(stack, error_handler, debug_mode) 

167 

168 debug(f"Built middleware stack with {len(stack)} components") 

169 return stack 

170 

171 

172def _add_system_middleware( 

173 stack: list[Middleware], 

174 system_overrides: dict[MiddlewarePosition, t.Any], 

175) -> None: 

176 try: 

177 _apply_system_middleware(stack, system_overrides) 

178 except Exception as e: 

179 debug(f"Error building system middleware: {e}") 

180 

181 

182def _apply_system_middleware( 

183 stack: list[Middleware], 

184 system_overrides: dict[MiddlewarePosition, t.Any], 

185) -> None: 

186 """Apply system middleware to the stack.""" 

187 from fastblocks.middleware import middlewares 

188 

189 system_middleware = middlewares() 

190 

191 for position, override in system_overrides.items(): 

192 position_index = position.value 

193 if 0 <= position_index < len(system_middleware): 

194 system_middleware[position_index] = override 

195 debug(f"Override middleware at position {position.name}") 

196 

197 for middleware_def in system_middleware: 

198 if isinstance(middleware_def, tuple): 

199 cls, kwargs = middleware_def 

200 stack.append(Middleware(cls, **kwargs)) 

201 else: 

202 stack.append(middleware_def) 

203 

204 

205def _add_error_handler_middleware( 

206 stack: list[Middleware], 

207 error_handler: t.Any, 

208 debug_mode: bool, 

209) -> None: 

210 error_middleware = _create_error_middleware(error_handler, debug_mode) 

211 stack.append(error_middleware) 

212 

213 

214def _create_error_middleware(error_handler: t.Any, debug_mode: bool) -> Middleware: 

215 """Create error handler middleware.""" 

216 if error_handler: 

217 return Middleware( 

218 ServerErrorMiddleware, 

219 handler=error_handler, 

220 debug=debug_mode, 

221 ) 

222 return Middleware( 

223 ServerErrorMiddleware, 

224 debug=debug_mode, 

225 ) 

226 

227 

228def extract_middleware_info(middleware: t.Any) -> dict[str, t.Any]: 

229 if isinstance(middleware, Middleware): 

230 return { 

231 "class": getattr(middleware.cls, "__name__", str(middleware.cls)), 

232 "args": middleware.args, 

233 "kwargs": middleware.kwargs, 

234 } 

235 if isinstance(middleware, tuple) and len(middleware) >= 2: 

236 cls, kwargs = middleware[0], middleware[1] 

237 return { 

238 "class": cls.__name__ if hasattr(cls, "__name__") else str(cls), 

239 "kwargs": kwargs, 

240 } 

241 return { 

242 "class": middleware.__class__.__name__, 

243 "raw": str(middleware), 

244 } 

245 

246 

247def get_middleware_stack_info( 

248 middleware_stack: list[Middleware], 

249) -> dict[str, t.Any]: 

250 info: dict[str, t.Any] = { 

251 "total_middleware": len(middleware_stack), 

252 "middleware_list": [], 

253 "execution_order": [], 

254 } 

255 

256 return _populate_middleware_info(middleware_stack, info) 

257 

258 

259def _populate_middleware_info( 

260 middleware_stack: list[Middleware], info: dict[str, t.Any] 

261) -> dict[str, t.Any]: 

262 """Populate middleware information.""" 

263 for i, middleware in enumerate(middleware_stack): 

264 middleware_info = extract_middleware_info(middleware) 

265 middleware_info["position"] = i 

266 info["middleware_list"].append(middleware_info) 

267 info["execution_order"].append(middleware_info["class"]) 

268 

269 return info 

270 

271 

272def validate_middleware_stack( 

273 middleware_stack: list[Middleware], 

274) -> dict[str, t.Any]: 

275 validation: dict[str, t.Any] = { 

276 "valid": True, 

277 "warnings": [], 

278 "errors": [], 

279 "recommendations": [], 

280 } 

281 

282 middleware_classes = [extract_middleware_info(m)["class"] for m in middleware_stack] 

283 

284 # Check middleware ordering 

285 _check_middleware_ordering(middleware_classes, validation) 

286 

287 # Check for security middleware 

288 _check_security_middleware(middleware_classes, validation) 

289 

290 # Check session and auth middleware ordering 

291 _check_session_auth_ordering(middleware_classes, validation) 

292 

293 validation["valid"] = len(validation["errors"]) == 0 

294 

295 return validation 

296 

297 

298def _check_middleware_ordering( 

299 middleware_classes: list[str], validation: dict[str, t.Any] 

300) -> None: 

301 """Check if middleware is in the correct order.""" 

302 if middleware_classes and middleware_classes[0] != "ExceptionMiddleware": 

303 validation["warnings"].append( 

304 "ExceptionMiddleware should be first in the stack", 

305 ) 

306 

307 if middleware_classes and middleware_classes[-1] != "ServerErrorMiddleware": 

308 validation["warnings"].append( 

309 "ServerErrorMiddleware should be last in the stack", 

310 ) 

311 

312 

313def _check_security_middleware( 

314 middleware_classes: list[str], validation: dict[str, t.Any] 

315) -> None: 

316 """Check if security middleware is present.""" 

317 security_middleware = [ 

318 "CORSMiddleware", 

319 "TrustedHostMiddleware", 

320 "HTTPSRedirectMiddleware", 

321 ] 

322 

323 found_security = any( 

324 any(sec in cls for sec in security_middleware) for cls in middleware_classes 

325 ) 

326 

327 if not found_security: 

328 validation["recommendations"].append( 

329 "Consider adding security middleware (CORS, TrustedHost, etc.)", 

330 ) 

331 

332 

333def _check_session_auth_ordering( 

334 middleware_classes: list[str], validation: dict[str, t.Any] 

335) -> None: 

336 """Check if session and auth middleware are in the correct order.""" 

337 session_index = -1 

338 auth_index = -1 

339 

340 for i, cls in enumerate(middleware_classes): 

341 if "Session" in cls: 

342 session_index = i 

343 if "Auth" in cls or "Login" in cls: 

344 auth_index = i 

345 

346 if session_index > -1 and auth_index > -1 and session_index > auth_index: 

347 validation["warnings"].append( 

348 "SessionMiddleware should come before authentication middleware", 

349 ) 

350 

351 

352async def create_middleware_manager( 

353 gather_result: MiddlewareGatherResult, 

354) -> t.Any: 

355 from fastblocks.applications import MiddlewareManager 

356 

357 manager = MiddlewareManager() 

358 

359 manager.user_middleware = gather_result.user_middleware 

360 

361 manager._system_middleware = gather_result.system_middleware # type: ignore[assignment] 

362 

363 manager._middleware_stack_cache = gather_result.middleware_stack 

364 

365 debug( 

366 f"Created middleware manager with {gather_result.total_middleware} components", 

367 ) 

368 

369 return manager 

370 

371 

372def add_middleware_at_position( 

373 middleware_stack: list[Middleware], 

374 new_middleware: Middleware, 

375 position: MiddlewarePosition, 

376) -> list[Middleware]: 

377 stack = middleware_stack.copy() 

378 

379 insert_index = _calculate_insert_index(position, stack) 

380 

381 stack.insert(insert_index, new_middleware) 

382 debug(f"Added middleware at position {position.name}") 

383 

384 return stack 

385 

386 

387def _calculate_insert_index( 

388 position: MiddlewarePosition, stack: list[Middleware] 

389) -> int: 

390 """Calculate the insert index based on the middleware position.""" 

391 insert_index = 1 

392 

393 if position == MiddlewarePosition.SECURITY: 

394 insert_index = 1 

395 elif position == MiddlewarePosition.CORS: 

396 insert_index = 2 

397 elif position == MiddlewarePosition.COMPRESSION: 

398 insert_index = 3 

399 elif position == MiddlewarePosition.SESSIONS: 

400 insert_index = 4 

401 elif position == MiddlewarePosition.AUTHENTICATION: 

402 insert_index = 5 

403 elif position == MiddlewarePosition.CACHING: 

404 insert_index = 6 

405 elif position == MiddlewarePosition.CUSTOM: 

406 insert_index = len(stack) - 1 

407 

408 return min(insert_index, len(stack) - 1)