Coverage for fastblocks/actions/gather/middleware.py: 49%
177 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-21 04:50 -0700
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-21 04:50 -0700
1"""Middleware gathering and stack building functionality."""
3import typing as t
4from contextlib import suppress
5from enum import Enum
7from acb.debug import debug
8from starlette.middleware import Middleware
9from starlette.middleware.errors import ServerErrorMiddleware
10from starlette.middleware.exceptions import ExceptionMiddleware
12from .strategies import GatherStrategy, gather_with_strategy
15class MiddlewarePosition(Enum):
16 SECURITY = 0
17 CORS = 1
18 COMPRESSION = 2
19 SESSIONS = 3
20 AUTHENTICATION = 4
21 CACHING = 5
22 CUSTOM = 6
25class MiddlewareGatherResult:
26 def __init__(
27 self,
28 *,
29 user_middleware: list[Middleware] | None = None,
30 system_middleware: dict[MiddlewarePosition, t.Any] | None = None,
31 middleware_stack: list[Middleware] | None = None,
32 errors: list[Exception] | None = None,
33 ) -> None:
34 self.user_middleware = user_middleware if user_middleware is not None else []
35 self.system_middleware = (
36 system_middleware if system_middleware is not None else {}
37 )
38 self.middleware_stack = middleware_stack if middleware_stack is not None else []
39 self.errors = errors if errors is not None else []
41 @property
42 def total_middleware(self) -> int:
43 return len(self.user_middleware) + len(self.system_middleware)
45 @property
46 def has_errors(self) -> bool:
47 return len(self.errors) > 0
50async def gather_middleware(
51 *,
52 user_middleware: list[Middleware] | None = None,
53 system_overrides: dict[MiddlewarePosition, t.Any] | None = None,
54 include_defaults: bool = True,
55 debug_mode: bool = False,
56 error_handler: t.Any | None = None,
57 strategy: GatherStrategy | None = None,
58) -> MiddlewareGatherResult:
59 if strategy is None:
60 strategy = GatherStrategy()
62 if user_middleware is None:
63 user_middleware = []
65 if system_overrides is None:
66 system_overrides = {}
68 result = MiddlewareGatherResult(
69 user_middleware=user_middleware,
70 system_middleware=system_overrides,
71 )
73 tasks = []
75 if include_defaults:
76 tasks.append(_gather_default_middleware())
78 tasks.extend(
79 (
80 _gather_custom_middleware(),
81 _build_middleware_stack(
82 user_middleware,
83 system_overrides,
84 include_defaults,
85 debug_mode,
86 error_handler,
87 ),
88 ),
89 )
91 gather_result = await gather_with_strategy(
92 tasks,
93 strategy,
94 cache_key=f"middleware:{include_defaults}:{debug_mode}",
95 )
97 for i, success in enumerate(gather_result.success):
98 if i == 0 and include_defaults:
99 result.system_middleware.update(success)
100 elif i == 1:
101 result.user_middleware.extend(success)
102 elif i == 2:
103 result.middleware_stack = success
105 result.errors.extend(gather_result.errors)
107 debug(f"Gathered {result.total_middleware} middleware components")
109 return result
112async def _gather_default_middleware() -> dict[MiddlewarePosition, t.Any]:
113 try:
114 from fastblocks.middleware import middlewares
116 default_middleware_list = middlewares()
117 middleware_map = {}
118 for i, middleware in enumerate(default_middleware_list):
119 if i < len(MiddlewarePosition):
120 position = list(MiddlewarePosition)[i]
121 middleware_map[position] = middleware
122 debug(f"Gathered {len(middleware_map)} default middleware components")
123 return middleware_map
124 except Exception as e:
125 debug(f"Error gathering default middleware: {e}")
126 return {}
129async def _gather_custom_middleware() -> list[Middleware]:
130 custom_middleware = []
131 with suppress(Exception):
132 from acb.depends import depends
134 config = depends.get("config")
135 if hasattr(config, "middleware") and hasattr(config.middleware, "custom"):
136 for middleware_path in config.middleware.custom:
137 try:
138 module_path, class_name = middleware_path.rsplit(".", 1)
139 module = __import__(module_path, fromlist=[class_name])
140 middleware_class = getattr(module, class_name)
141 middleware = Middleware(middleware_class)
142 custom_middleware.append(middleware)
143 debug(f"Added custom middleware: {class_name}")
144 except Exception as e:
145 debug(f"Error loading custom middleware {middleware_path}: {e}")
147 return custom_middleware
150async def _build_middleware_stack(
151 user_middleware: list[Middleware],
152 system_overrides: dict[MiddlewarePosition, t.Any],
153 include_defaults: bool,
154 debug_mode: bool,
155 error_handler: t.Any,
156) -> list[Middleware]:
157 stack = []
159 stack.append(Middleware(ExceptionMiddleware, debug=debug_mode))
161 stack.extend(user_middleware)
163 if include_defaults:
164 _add_system_middleware(stack, system_overrides)
166 _add_error_handler_middleware(stack, error_handler, debug_mode)
168 debug(f"Built middleware stack with {len(stack)} components")
169 return stack
172def _add_system_middleware(
173 stack: list[Middleware],
174 system_overrides: dict[MiddlewarePosition, t.Any],
175) -> None:
176 try:
177 from fastblocks.middleware import middlewares
179 system_middleware = middlewares()
181 for position, override in system_overrides.items():
182 position_index = position.value
183 if 0 <= position_index < len(system_middleware):
184 system_middleware[position_index] = override
185 debug(f"Override middleware at position {position.name}")
187 for middleware_def in system_middleware:
188 if isinstance(middleware_def, tuple):
189 cls, kwargs = middleware_def
190 stack.append(Middleware(cls, **kwargs))
191 else:
192 stack.append(middleware_def)
194 except Exception as e:
195 debug(f"Error building system middleware: {e}")
198def _add_error_handler_middleware(
199 stack: list[Middleware],
200 error_handler: t.Any,
201 debug_mode: bool,
202) -> None:
203 if error_handler:
204 stack.append(
205 Middleware(
206 ServerErrorMiddleware,
207 handler=error_handler,
208 debug=debug_mode,
209 ),
210 )
211 else:
212 stack.append(
213 Middleware(
214 ServerErrorMiddleware,
215 debug=debug_mode,
216 ),
217 )
220def extract_middleware_info(middleware: t.Any) -> dict[str, t.Any]:
221 if isinstance(middleware, Middleware):
222 return {
223 "class": getattr(middleware.cls, "__name__", str(middleware.cls)),
224 "args": middleware.args,
225 "kwargs": middleware.kwargs,
226 }
227 if isinstance(middleware, tuple) and len(middleware) >= 2:
228 cls, kwargs = middleware[0], middleware[1]
229 return {
230 "class": cls.__name__ if hasattr(cls, "__name__") else str(cls),
231 "kwargs": kwargs,
232 }
233 return {
234 "class": middleware.__class__.__name__,
235 "raw": str(middleware),
236 }
239def get_middleware_stack_info(
240 middleware_stack: list[Middleware],
241) -> dict[str, t.Any]:
242 info: dict[str, t.Any] = {
243 "total_middleware": len(middleware_stack),
244 "middleware_list": [],
245 "execution_order": [],
246 }
248 for i, middleware in enumerate(middleware_stack):
249 middleware_info = extract_middleware_info(middleware)
250 middleware_info["position"] = i
251 info["middleware_list"].append(middleware_info)
252 info["execution_order"].append(middleware_info["class"])
254 return info
257def validate_middleware_stack(
258 middleware_stack: list[Middleware],
259) -> dict[str, t.Any]:
260 validation: dict[str, t.Any] = {
261 "valid": True,
262 "warnings": [],
263 "errors": [],
264 "recommendations": [],
265 }
267 middleware_classes = [extract_middleware_info(m)["class"] for m in middleware_stack]
269 if middleware_classes and middleware_classes[0] != "ExceptionMiddleware":
270 validation["warnings"].append(
271 "ExceptionMiddleware should be first in the stack",
272 )
274 if middleware_classes and middleware_classes[-1] != "ServerErrorMiddleware":
275 validation["warnings"].append(
276 "ServerErrorMiddleware should be last in the stack",
277 )
279 security_middleware = [
280 "CORSMiddleware",
281 "TrustedHostMiddleware",
282 "HTTPSRedirectMiddleware",
283 ]
285 found_security = any(
286 any(sec in cls for sec in security_middleware) for cls in middleware_classes
287 )
289 if not found_security:
290 validation["recommendations"].append(
291 "Consider adding security middleware (CORS, TrustedHost, etc.)",
292 )
294 session_index = -1
295 auth_index = -1
297 for i, cls in enumerate(middleware_classes):
298 if "Session" in cls:
299 session_index = i
300 if "Auth" in cls or "Login" in cls:
301 auth_index = i
303 if session_index > -1 and auth_index > -1 and session_index > auth_index:
304 validation["warnings"].append(
305 "SessionMiddleware should come before authentication middleware",
306 )
308 validation["valid"] = len(validation["errors"]) == 0
310 return validation
313async def create_middleware_manager(
314 gather_result: MiddlewareGatherResult,
315) -> t.Any:
316 from fastblocks.applications import MiddlewareManager
318 manager = MiddlewareManager()
320 manager.user_middleware = gather_result.user_middleware
322 manager._system_middleware = gather_result.system_middleware # type: ignore[misc]
324 manager._middleware_stack_cache = gather_result.middleware_stack
326 debug(
327 f"Created middleware manager with {gather_result.total_middleware} components",
328 )
330 return manager
333def add_middleware_at_position(
334 middleware_stack: list[Middleware],
335 new_middleware: Middleware,
336 position: MiddlewarePosition,
337) -> list[Middleware]:
338 stack = middleware_stack.copy()
340 insert_index = 1
342 if position == MiddlewarePosition.SECURITY:
343 insert_index = 1
344 elif position == MiddlewarePosition.CORS:
345 insert_index = 2
346 elif position == MiddlewarePosition.COMPRESSION:
347 insert_index = 3
348 elif position == MiddlewarePosition.SESSIONS:
349 insert_index = 4
350 elif position == MiddlewarePosition.AUTHENTICATION:
351 insert_index = 5
352 elif position == MiddlewarePosition.CACHING:
353 insert_index = 6
354 elif position == MiddlewarePosition.CUSTOM:
355 insert_index = len(stack) - 1
357 insert_index = min(insert_index, len(stack) - 1)
359 stack.insert(insert_index, new_middleware)
360 debug(f"Added middleware at position {position.name}")
362 return stack