Coverage for fastblocks/actions/gather/middleware.py: 49%

177 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-21 04:50 -0700

1"""Middleware gathering and stack building functionality.""" 

2 

3import typing as t 

4from contextlib import suppress 

5from enum import Enum 

6 

7from acb.debug import debug 

8from starlette.middleware import Middleware 

9from starlette.middleware.errors import ServerErrorMiddleware 

10from starlette.middleware.exceptions import ExceptionMiddleware 

11 

12from .strategies import GatherStrategy, gather_with_strategy 

13 

14 

15class MiddlewarePosition(Enum): 

16 SECURITY = 0 

17 CORS = 1 

18 COMPRESSION = 2 

19 SESSIONS = 3 

20 AUTHENTICATION = 4 

21 CACHING = 5 

22 CUSTOM = 6 

23 

24 

25class MiddlewareGatherResult: 

26 def __init__( 

27 self, 

28 *, 

29 user_middleware: list[Middleware] | None = None, 

30 system_middleware: dict[MiddlewarePosition, t.Any] | None = None, 

31 middleware_stack: list[Middleware] | None = None, 

32 errors: list[Exception] | None = None, 

33 ) -> None: 

34 self.user_middleware = user_middleware if user_middleware is not None else [] 

35 self.system_middleware = ( 

36 system_middleware if system_middleware is not None else {} 

37 ) 

38 self.middleware_stack = middleware_stack if middleware_stack is not None else [] 

39 self.errors = errors if errors is not None else [] 

40 

41 @property 

42 def total_middleware(self) -> int: 

43 return len(self.user_middleware) + len(self.system_middleware) 

44 

45 @property 

46 def has_errors(self) -> bool: 

47 return len(self.errors) > 0 

48 

49 

50async def gather_middleware( 

51 *, 

52 user_middleware: list[Middleware] | None = None, 

53 system_overrides: dict[MiddlewarePosition, t.Any] | None = None, 

54 include_defaults: bool = True, 

55 debug_mode: bool = False, 

56 error_handler: t.Any | None = None, 

57 strategy: GatherStrategy | None = None, 

58) -> MiddlewareGatherResult: 

59 if strategy is None: 

60 strategy = GatherStrategy() 

61 

62 if user_middleware is None: 

63 user_middleware = [] 

64 

65 if system_overrides is None: 

66 system_overrides = {} 

67 

68 result = MiddlewareGatherResult( 

69 user_middleware=user_middleware, 

70 system_middleware=system_overrides, 

71 ) 

72 

73 tasks = [] 

74 

75 if include_defaults: 

76 tasks.append(_gather_default_middleware()) 

77 

78 tasks.extend( 

79 ( 

80 _gather_custom_middleware(), 

81 _build_middleware_stack( 

82 user_middleware, 

83 system_overrides, 

84 include_defaults, 

85 debug_mode, 

86 error_handler, 

87 ), 

88 ), 

89 ) 

90 

91 gather_result = await gather_with_strategy( 

92 tasks, 

93 strategy, 

94 cache_key=f"middleware:{include_defaults}:{debug_mode}", 

95 ) 

96 

97 for i, success in enumerate(gather_result.success): 

98 if i == 0 and include_defaults: 

99 result.system_middleware.update(success) 

100 elif i == 1: 

101 result.user_middleware.extend(success) 

102 elif i == 2: 

103 result.middleware_stack = success 

104 

105 result.errors.extend(gather_result.errors) 

106 

107 debug(f"Gathered {result.total_middleware} middleware components") 

108 

109 return result 

110 

111 

112async def _gather_default_middleware() -> dict[MiddlewarePosition, t.Any]: 

113 try: 

114 from fastblocks.middleware import middlewares 

115 

116 default_middleware_list = middlewares() 

117 middleware_map = {} 

118 for i, middleware in enumerate(default_middleware_list): 

119 if i < len(MiddlewarePosition): 

120 position = list(MiddlewarePosition)[i] 

121 middleware_map[position] = middleware 

122 debug(f"Gathered {len(middleware_map)} default middleware components") 

123 return middleware_map 

124 except Exception as e: 

125 debug(f"Error gathering default middleware: {e}") 

126 return {} 

127 

128 

129async def _gather_custom_middleware() -> list[Middleware]: 

130 custom_middleware = [] 

131 with suppress(Exception): 

132 from acb.depends import depends 

133 

134 config = depends.get("config") 

135 if hasattr(config, "middleware") and hasattr(config.middleware, "custom"): 

136 for middleware_path in config.middleware.custom: 

137 try: 

138 module_path, class_name = middleware_path.rsplit(".", 1) 

139 module = __import__(module_path, fromlist=[class_name]) 

140 middleware_class = getattr(module, class_name) 

141 middleware = Middleware(middleware_class) 

142 custom_middleware.append(middleware) 

143 debug(f"Added custom middleware: {class_name}") 

144 except Exception as e: 

145 debug(f"Error loading custom middleware {middleware_path}: {e}") 

146 

147 return custom_middleware 

148 

149 

150async def _build_middleware_stack( 

151 user_middleware: list[Middleware], 

152 system_overrides: dict[MiddlewarePosition, t.Any], 

153 include_defaults: bool, 

154 debug_mode: bool, 

155 error_handler: t.Any, 

156) -> list[Middleware]: 

157 stack = [] 

158 

159 stack.append(Middleware(ExceptionMiddleware, debug=debug_mode)) 

160 

161 stack.extend(user_middleware) 

162 

163 if include_defaults: 

164 _add_system_middleware(stack, system_overrides) 

165 

166 _add_error_handler_middleware(stack, error_handler, debug_mode) 

167 

168 debug(f"Built middleware stack with {len(stack)} components") 

169 return stack 

170 

171 

172def _add_system_middleware( 

173 stack: list[Middleware], 

174 system_overrides: dict[MiddlewarePosition, t.Any], 

175) -> None: 

176 try: 

177 from fastblocks.middleware import middlewares 

178 

179 system_middleware = middlewares() 

180 

181 for position, override in system_overrides.items(): 

182 position_index = position.value 

183 if 0 <= position_index < len(system_middleware): 

184 system_middleware[position_index] = override 

185 debug(f"Override middleware at position {position.name}") 

186 

187 for middleware_def in system_middleware: 

188 if isinstance(middleware_def, tuple): 

189 cls, kwargs = middleware_def 

190 stack.append(Middleware(cls, **kwargs)) 

191 else: 

192 stack.append(middleware_def) 

193 

194 except Exception as e: 

195 debug(f"Error building system middleware: {e}") 

196 

197 

198def _add_error_handler_middleware( 

199 stack: list[Middleware], 

200 error_handler: t.Any, 

201 debug_mode: bool, 

202) -> None: 

203 if error_handler: 

204 stack.append( 

205 Middleware( 

206 ServerErrorMiddleware, 

207 handler=error_handler, 

208 debug=debug_mode, 

209 ), 

210 ) 

211 else: 

212 stack.append( 

213 Middleware( 

214 ServerErrorMiddleware, 

215 debug=debug_mode, 

216 ), 

217 ) 

218 

219 

220def extract_middleware_info(middleware: t.Any) -> dict[str, t.Any]: 

221 if isinstance(middleware, Middleware): 

222 return { 

223 "class": getattr(middleware.cls, "__name__", str(middleware.cls)), 

224 "args": middleware.args, 

225 "kwargs": middleware.kwargs, 

226 } 

227 if isinstance(middleware, tuple) and len(middleware) >= 2: 

228 cls, kwargs = middleware[0], middleware[1] 

229 return { 

230 "class": cls.__name__ if hasattr(cls, "__name__") else str(cls), 

231 "kwargs": kwargs, 

232 } 

233 return { 

234 "class": middleware.__class__.__name__, 

235 "raw": str(middleware), 

236 } 

237 

238 

239def get_middleware_stack_info( 

240 middleware_stack: list[Middleware], 

241) -> dict[str, t.Any]: 

242 info: dict[str, t.Any] = { 

243 "total_middleware": len(middleware_stack), 

244 "middleware_list": [], 

245 "execution_order": [], 

246 } 

247 

248 for i, middleware in enumerate(middleware_stack): 

249 middleware_info = extract_middleware_info(middleware) 

250 middleware_info["position"] = i 

251 info["middleware_list"].append(middleware_info) 

252 info["execution_order"].append(middleware_info["class"]) 

253 

254 return info 

255 

256 

257def validate_middleware_stack( 

258 middleware_stack: list[Middleware], 

259) -> dict[str, t.Any]: 

260 validation: dict[str, t.Any] = { 

261 "valid": True, 

262 "warnings": [], 

263 "errors": [], 

264 "recommendations": [], 

265 } 

266 

267 middleware_classes = [extract_middleware_info(m)["class"] for m in middleware_stack] 

268 

269 if middleware_classes and middleware_classes[0] != "ExceptionMiddleware": 

270 validation["warnings"].append( 

271 "ExceptionMiddleware should be first in the stack", 

272 ) 

273 

274 if middleware_classes and middleware_classes[-1] != "ServerErrorMiddleware": 

275 validation["warnings"].append( 

276 "ServerErrorMiddleware should be last in the stack", 

277 ) 

278 

279 security_middleware = [ 

280 "CORSMiddleware", 

281 "TrustedHostMiddleware", 

282 "HTTPSRedirectMiddleware", 

283 ] 

284 

285 found_security = any( 

286 any(sec in cls for sec in security_middleware) for cls in middleware_classes 

287 ) 

288 

289 if not found_security: 

290 validation["recommendations"].append( 

291 "Consider adding security middleware (CORS, TrustedHost, etc.)", 

292 ) 

293 

294 session_index = -1 

295 auth_index = -1 

296 

297 for i, cls in enumerate(middleware_classes): 

298 if "Session" in cls: 

299 session_index = i 

300 if "Auth" in cls or "Login" in cls: 

301 auth_index = i 

302 

303 if session_index > -1 and auth_index > -1 and session_index > auth_index: 

304 validation["warnings"].append( 

305 "SessionMiddleware should come before authentication middleware", 

306 ) 

307 

308 validation["valid"] = len(validation["errors"]) == 0 

309 

310 return validation 

311 

312 

313async def create_middleware_manager( 

314 gather_result: MiddlewareGatherResult, 

315) -> t.Any: 

316 from fastblocks.applications import MiddlewareManager 

317 

318 manager = MiddlewareManager() 

319 

320 manager.user_middleware = gather_result.user_middleware 

321 

322 manager._system_middleware = gather_result.system_middleware # type: ignore[misc] 

323 

324 manager._middleware_stack_cache = gather_result.middleware_stack 

325 

326 debug( 

327 f"Created middleware manager with {gather_result.total_middleware} components", 

328 ) 

329 

330 return manager 

331 

332 

333def add_middleware_at_position( 

334 middleware_stack: list[Middleware], 

335 new_middleware: Middleware, 

336 position: MiddlewarePosition, 

337) -> list[Middleware]: 

338 stack = middleware_stack.copy() 

339 

340 insert_index = 1 

341 

342 if position == MiddlewarePosition.SECURITY: 

343 insert_index = 1 

344 elif position == MiddlewarePosition.CORS: 

345 insert_index = 2 

346 elif position == MiddlewarePosition.COMPRESSION: 

347 insert_index = 3 

348 elif position == MiddlewarePosition.SESSIONS: 

349 insert_index = 4 

350 elif position == MiddlewarePosition.AUTHENTICATION: 

351 insert_index = 5 

352 elif position == MiddlewarePosition.CACHING: 

353 insert_index = 6 

354 elif position == MiddlewarePosition.CUSTOM: 

355 insert_index = len(stack) - 1 

356 

357 insert_index = min(insert_index, len(stack) - 1) 

358 

359 stack.insert(insert_index, new_middleware) 

360 debug(f"Added middleware at position {position.name}") 

361 

362 return stack