Coverage for fastblocks/actions/gather/middleware.py: 51%
192 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-09 00:47 -0700
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-09 00:47 -0700
1"""Middleware gathering and stack building functionality."""
3import typing as t
4from contextlib import suppress
5from enum import Enum
7from acb.debug import debug
8from starlette.middleware import Middleware
9from starlette.middleware.errors import ServerErrorMiddleware
10from starlette.middleware.exceptions import ExceptionMiddleware
12from .strategies import GatherStrategy, gather_with_strategy
15class MiddlewarePosition(Enum):
16 SECURITY = 0
17 CORS = 1
18 COMPRESSION = 2
19 SESSIONS = 3
20 AUTHENTICATION = 4
21 CACHING = 5
22 CUSTOM = 6
25class MiddlewareGatherResult:
26 def __init__(
27 self,
28 *,
29 user_middleware: list[Middleware] | None = None,
30 system_middleware: dict[MiddlewarePosition, t.Any] | None = None,
31 middleware_stack: list[Middleware] | None = None,
32 errors: list[Exception] | None = None,
33 ) -> None:
34 self.user_middleware = user_middleware if user_middleware is not None else []
35 self.system_middleware = (
36 system_middleware if system_middleware is not None else {}
37 )
38 self.middleware_stack = middleware_stack if middleware_stack is not None else []
39 self.errors = errors if errors is not None else []
41 @property
42 def total_middleware(self) -> int:
43 return len(self.user_middleware) + len(self.system_middleware)
45 @property
46 def has_errors(self) -> bool:
47 return len(self.errors) > 0
50async def gather_middleware(
51 *,
52 user_middleware: list[Middleware] | None = None,
53 system_overrides: dict[MiddlewarePosition, t.Any] | None = None,
54 include_defaults: bool = True,
55 debug_mode: bool = False,
56 error_handler: t.Any | None = None,
57 strategy: GatherStrategy | None = None,
58) -> MiddlewareGatherResult:
59 if strategy is None:
60 strategy = GatherStrategy()
62 if user_middleware is None:
63 user_middleware = []
65 if system_overrides is None:
66 system_overrides = {}
68 result = MiddlewareGatherResult(
69 user_middleware=user_middleware,
70 system_middleware=system_overrides,
71 )
73 tasks: list[t.Coroutine[t.Any, t.Any, t.Any]] = []
75 if include_defaults:
76 tasks.append(_gather_default_middleware())
78 tasks.extend(
79 (
80 _gather_custom_middleware(),
81 _build_middleware_stack(
82 user_middleware,
83 system_overrides,
84 include_defaults,
85 debug_mode,
86 error_handler,
87 ),
88 ),
89 )
91 gather_result = await gather_with_strategy(
92 tasks,
93 strategy,
94 cache_key=f"middleware:{include_defaults}:{debug_mode}",
95 )
97 for i, success in enumerate(gather_result.success):
98 if i == 0 and include_defaults:
99 result.system_middleware.update(success)
100 elif i == 1:
101 result.user_middleware.extend(success)
102 elif i == 2:
103 result.middleware_stack = success
105 result.errors.extend(gather_result.errors)
107 debug(f"Gathered {result.total_middleware} middleware components")
109 return result
112async def _gather_default_middleware() -> dict[MiddlewarePosition, t.Any]:
113 try:
114 from fastblocks.middleware import middlewares
116 default_middleware_list = middlewares()
117 middleware_map = {}
118 for i, middleware in enumerate(default_middleware_list):
119 if i < len(MiddlewarePosition):
120 position = list(MiddlewarePosition)[i]
121 middleware_map[position] = middleware
122 debug(f"Gathered {len(middleware_map)} default middleware components")
123 return middleware_map
124 except Exception as e:
125 debug(f"Error gathering default middleware: {e}")
126 return {}
129async def _gather_custom_middleware() -> list[Middleware]:
130 custom_middleware = []
131 with suppress(Exception):
132 from acb.depends import depends
134 config = depends.get("config")
135 if hasattr(config, "middleware") and hasattr(config.middleware, "custom"):
136 for middleware_path in config.middleware.custom:
137 try:
138 module_path, class_name = middleware_path.rsplit(".", 1)
139 module = __import__(module_path, fromlist=[class_name])
140 middleware_class = getattr(module, class_name)
141 middleware = Middleware(middleware_class)
142 custom_middleware.append(middleware)
143 debug(f"Added custom middleware: {class_name}")
144 except Exception as e:
145 debug(f"Error loading custom middleware {middleware_path}: {e}")
147 return custom_middleware
150async def _build_middleware_stack(
151 user_middleware: list[Middleware],
152 system_overrides: dict[MiddlewarePosition, t.Any],
153 include_defaults: bool,
154 debug_mode: bool,
155 error_handler: t.Any,
156) -> list[Middleware]:
157 stack = []
159 stack.append(Middleware(ExceptionMiddleware, debug=debug_mode))
161 stack.extend(user_middleware)
163 if include_defaults:
164 _add_system_middleware(stack, system_overrides)
166 _add_error_handler_middleware(stack, error_handler, debug_mode)
168 debug(f"Built middleware stack with {len(stack)} components")
169 return stack
172def _add_system_middleware(
173 stack: list[Middleware],
174 system_overrides: dict[MiddlewarePosition, t.Any],
175) -> None:
176 try:
177 _apply_system_middleware(stack, system_overrides)
178 except Exception as e:
179 debug(f"Error building system middleware: {e}")
182def _apply_system_middleware(
183 stack: list[Middleware],
184 system_overrides: dict[MiddlewarePosition, t.Any],
185) -> None:
186 """Apply system middleware to the stack."""
187 from fastblocks.middleware import middlewares
189 system_middleware = middlewares()
191 for position, override in system_overrides.items():
192 position_index = position.value
193 if 0 <= position_index < len(system_middleware):
194 system_middleware[position_index] = override
195 debug(f"Override middleware at position {position.name}")
197 for middleware_def in system_middleware:
198 if isinstance(middleware_def, tuple):
199 cls, kwargs = middleware_def
200 stack.append(Middleware(cls, **kwargs))
201 else:
202 stack.append(middleware_def)
205def _add_error_handler_middleware(
206 stack: list[Middleware],
207 error_handler: t.Any,
208 debug_mode: bool,
209) -> None:
210 error_middleware = _create_error_middleware(error_handler, debug_mode)
211 stack.append(error_middleware)
214def _create_error_middleware(error_handler: t.Any, debug_mode: bool) -> Middleware:
215 """Create error handler middleware."""
216 if error_handler:
217 return Middleware(
218 ServerErrorMiddleware,
219 handler=error_handler,
220 debug=debug_mode,
221 )
222 return Middleware(
223 ServerErrorMiddleware,
224 debug=debug_mode,
225 )
228def extract_middleware_info(middleware: t.Any) -> dict[str, t.Any]:
229 if isinstance(middleware, Middleware):
230 return {
231 "class": getattr(middleware.cls, "__name__", str(middleware.cls)),
232 "args": middleware.args,
233 "kwargs": middleware.kwargs,
234 }
235 if isinstance(middleware, tuple) and len(middleware) >= 2:
236 cls, kwargs = middleware[0], middleware[1]
237 return {
238 "class": cls.__name__ if hasattr(cls, "__name__") else str(cls),
239 "kwargs": kwargs,
240 }
241 return {
242 "class": middleware.__class__.__name__,
243 "raw": str(middleware),
244 }
247def get_middleware_stack_info(
248 middleware_stack: list[Middleware],
249) -> dict[str, t.Any]:
250 info: dict[str, t.Any] = {
251 "total_middleware": len(middleware_stack),
252 "middleware_list": [],
253 "execution_order": [],
254 }
256 return _populate_middleware_info(middleware_stack, info)
259def _populate_middleware_info(
260 middleware_stack: list[Middleware], info: dict[str, t.Any]
261) -> dict[str, t.Any]:
262 """Populate middleware information."""
263 for i, middleware in enumerate(middleware_stack):
264 middleware_info = extract_middleware_info(middleware)
265 middleware_info["position"] = i
266 info["middleware_list"].append(middleware_info)
267 info["execution_order"].append(middleware_info["class"])
269 return info
272def validate_middleware_stack(
273 middleware_stack: list[Middleware],
274) -> dict[str, t.Any]:
275 validation: dict[str, t.Any] = {
276 "valid": True,
277 "warnings": [],
278 "errors": [],
279 "recommendations": [],
280 }
282 middleware_classes = [extract_middleware_info(m)["class"] for m in middleware_stack]
284 # Check middleware ordering
285 _check_middleware_ordering(middleware_classes, validation)
287 # Check for security middleware
288 _check_security_middleware(middleware_classes, validation)
290 # Check session and auth middleware ordering
291 _check_session_auth_ordering(middleware_classes, validation)
293 validation["valid"] = len(validation["errors"]) == 0
295 return validation
298def _check_middleware_ordering(
299 middleware_classes: list[str], validation: dict[str, t.Any]
300) -> None:
301 """Check if middleware is in the correct order."""
302 if middleware_classes and middleware_classes[0] != "ExceptionMiddleware":
303 validation["warnings"].append(
304 "ExceptionMiddleware should be first in the stack",
305 )
307 if middleware_classes and middleware_classes[-1] != "ServerErrorMiddleware":
308 validation["warnings"].append(
309 "ServerErrorMiddleware should be last in the stack",
310 )
313def _check_security_middleware(
314 middleware_classes: list[str], validation: dict[str, t.Any]
315) -> None:
316 """Check if security middleware is present."""
317 security_middleware = [
318 "CORSMiddleware",
319 "TrustedHostMiddleware",
320 "HTTPSRedirectMiddleware",
321 ]
323 found_security = any(
324 any(sec in cls for sec in security_middleware) for cls in middleware_classes
325 )
327 if not found_security:
328 validation["recommendations"].append(
329 "Consider adding security middleware (CORS, TrustedHost, etc.)",
330 )
333def _check_session_auth_ordering(
334 middleware_classes: list[str], validation: dict[str, t.Any]
335) -> None:
336 """Check if session and auth middleware are in the correct order."""
337 session_index = -1
338 auth_index = -1
340 for i, cls in enumerate(middleware_classes):
341 if "Session" in cls:
342 session_index = i
343 if "Auth" in cls or "Login" in cls:
344 auth_index = i
346 if session_index > -1 and auth_index > -1 and session_index > auth_index:
347 validation["warnings"].append(
348 "SessionMiddleware should come before authentication middleware",
349 )
352async def create_middleware_manager(
353 gather_result: MiddlewareGatherResult,
354) -> t.Any:
355 from fastblocks.applications import MiddlewareManager
357 manager = MiddlewareManager()
359 manager.user_middleware = gather_result.user_middleware
361 manager._system_middleware = gather_result.system_middleware # type: ignore[assignment]
363 manager._middleware_stack_cache = gather_result.middleware_stack
365 debug(
366 f"Created middleware manager with {gather_result.total_middleware} components",
367 )
369 return manager
372def add_middleware_at_position(
373 middleware_stack: list[Middleware],
374 new_middleware: Middleware,
375 position: MiddlewarePosition,
376) -> list[Middleware]:
377 stack = middleware_stack.copy()
379 insert_index = _calculate_insert_index(position, stack)
381 stack.insert(insert_index, new_middleware)
382 debug(f"Added middleware at position {position.name}")
384 return stack
387def _calculate_insert_index(
388 position: MiddlewarePosition, stack: list[Middleware]
389) -> int:
390 """Calculate the insert index based on the middleware position."""
391 insert_index = 1
393 if position == MiddlewarePosition.SECURITY:
394 insert_index = 1
395 elif position == MiddlewarePosition.CORS:
396 insert_index = 2
397 elif position == MiddlewarePosition.COMPRESSION:
398 insert_index = 3
399 elif position == MiddlewarePosition.SESSIONS:
400 insert_index = 4
401 elif position == MiddlewarePosition.AUTHENTICATION:
402 insert_index = 5
403 elif position == MiddlewarePosition.CACHING:
404 insert_index = 6
405 elif position == MiddlewarePosition.CUSTOM:
406 insert_index = len(stack) - 1
408 return min(insert_index, len(stack) - 1)