Coverage for fastblocks/middleware.py: 72%
313 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-09 00:47 -0700
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-09 00:47 -0700
1import sys
2import typing as t
3from collections.abc import Mapping, Sequence
4from contextvars import ContextVar
5from enum import IntEnum
7from acb.debug import debug
8from acb.depends import depends
9from brotli_asgi import BrotliMiddleware
10from secure import Secure
11from starlette.datastructures import URL, Headers, MutableHeaders
12from starlette.middleware import Middleware
13from starlette.middleware.sessions import SessionMiddleware
14from starlette.requests import Request
15from starlette.types import ASGIApp, Message, Receive, Scope, Send
16from starlette_csrf.middleware import CSRFMiddleware
18from .caching import (
19 CacheControlResponder,
20 CacheDirectives,
21 CacheResponder,
22 Rule,
23 delete_from_cache,
24)
25from .htmx import HtmxDetails
27MiddlewareCallable = t.Callable[[ASGIApp], ASGIApp]
28MiddlewareClass = type[t.Any]
29MiddlewareOptions = dict[str, t.Any]
30from .exceptions import MissingCaching
33class MiddlewarePosition(IntEnum):
34 CSRF = 0
35 SESSION = 1
36 HTMX = 2
37 CURRENT_REQUEST = 3
38 COMPRESSION = 4
39 SECURITY_HEADERS = 5
42class HtmxMiddleware:
43 def __init__(self, app: ASGIApp) -> None:
44 self._app = app
45 debug("HtmxMiddleware: Initialized FastBlocks native HTMX middleware")
47 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
48 if scope["type"] in ("http", "websocket"):
49 await self._process_htmx_request(scope)
50 await self._app(scope, receive, send)
52 async def _process_htmx_request(self, scope: Scope) -> None:
53 """Process HTMX request and add HTMX details to scope."""
54 htmx_details = HtmxDetails(scope)
55 scope["htmx"] = htmx_details
56 if debug.enabled:
57 self._log_htmx_details(scope, htmx_details)
59 def _log_htmx_details(self, scope: Scope, htmx_details: HtmxDetails) -> None:
60 """Log HTMX details if debugging is enabled."""
61 method = scope.get("method", "UNKNOWN")
62 path = scope.get("path", "unknown")
63 is_htmx = bool(htmx_details)
64 debug(f"HtmxMiddleware: {method} {path} - HTMX: {is_htmx}")
65 if is_htmx:
66 headers = htmx_details.get_all_headers()
67 for header_name, header_value in headers.items():
68 debug(f"HtmxMiddleware: {header_name}: {header_value}")
71class HtmxResponseMiddleware:
72 def __init__(self, app: ASGIApp) -> None:
73 self._app = app
74 debug("HtmxResponseMiddleware: Initialized FastBlocks HTMX response middleware")
76 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
77 if scope["type"] != "http":
78 await self._app(scope, receive, send)
79 return
81 async def send_wrapper(message: Message) -> None:
82 await self._process_response_message(message, scope, send)
84 await self._app(scope, receive, send_wrapper)
86 async def _process_response_message(
87 self, message: Message, scope: Scope, send: Send
88 ) -> None:
89 """Process response message and handle HTMX responses."""
90 if message["type"] == "http.response.start":
91 htmx_details = scope.get("htmx")
92 if htmx_details and bool(htmx_details):
93 debug("HtmxResponseMiddleware: Processing HTMX response")
94 headers = list(message.get("headers", []))
95 message["headers"] = headers
96 await send(message)
99class MiddlewareUtils:
100 Cache = t.Any
102 secure_headers = Secure()
104 scope_name = "__starlette_caches__"
106 _request_ctx_var: ContextVar[Scope | None] = ContextVar("request", default=None)
108 HTTP = sys.intern("http")
109 WEBSOCKET = sys.intern("websocket")
110 TYPE = sys.intern("type")
111 METHOD = sys.intern("method")
112 PATH = sys.intern("path")
113 GET = sys.intern("GET")
114 HEAD = sys.intern("HEAD")
115 POST = sys.intern("POST")
116 PUT = sys.intern("PUT")
117 PATCH = sys.intern("PATCH")
118 DELETE = sys.intern("DELETE")
120 @classmethod
121 def get_request(cls) -> Scope | None:
122 return cls._request_ctx_var.get()
124 @classmethod
125 def set_request(cls, scope: Scope | None) -> None:
126 cls._request_ctx_var.set(scope)
129Cache = MiddlewareUtils.Cache
130secure_headers = MiddlewareUtils.secure_headers
131scope_name = MiddlewareUtils.scope_name
132_request_ctx_var = MiddlewareUtils._request_ctx_var
135def get_request() -> Scope | None:
136 return MiddlewareUtils.get_request()
139class CurrentRequestMiddleware:
140 def __init__(self, app: ASGIApp) -> None:
141 self.app = app
143 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> t.Any: # type: ignore[func-returns-value,no-any-return]
144 if scope[MiddlewareUtils.TYPE] not in (
145 MiddlewareUtils.HTTP,
146 MiddlewareUtils.WEBSOCKET,
147 ):
148 await self.app(scope, receive, send)
149 return None # type: ignore[func-returns-value]
150 local_scope = _request_ctx_var.set(scope)
151 response = await self.app(scope, receive, send) # type: ignore[func-returns-value]
152 _request_ctx_var.reset(local_scope)
153 return response # type: ignore[no-any-return]
156class SecureHeadersMiddleware:
157 def __init__(self, app: ASGIApp) -> None:
158 self.app = app
159 try:
160 self.logger = depends.get("logger")
161 except Exception:
162 self.logger = None
164 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
165 if scope["type"] != "http":
166 return await self.app(scope, receive, send)
168 async def send_with_secure_headers(message: Message) -> None:
169 if message["type"] == "http.response.start":
170 headers = MutableHeaders(scope=message)
171 for header_name, header_value in secure_headers.headers.items():
172 headers.append(header_name, header_value)
173 await send(message)
175 await self.app(scope, receive, send_with_secure_headers)
176 return None
179class CacheValidator:
180 def __init__(self, rules: Sequence[Rule] | None = None) -> None:
181 self.rules = rules or [Rule()]
183 def check_for_duplicate_middleware(self, app: ASGIApp) -> None:
184 if not hasattr(app, "middleware"):
185 return
187 middleware_attr = app.middleware # type: ignore[attr-defined]
188 if callable(middleware_attr):
189 return
191 middleware = middleware_attr
192 self._check_for_cache_middleware_duplicates(middleware)
194 def _check_for_cache_middleware_duplicates(self, middleware: t.Any) -> None:
195 """Check if CacheMiddleware is already in the middleware stack."""
196 for middleware_item in middleware:
197 if isinstance(middleware_item, CacheMiddleware):
198 from .exceptions import DuplicateCaching
200 msg = "CacheMiddleware detected in middleware stack"
201 raise DuplicateCaching(msg)
203 def is_duplicate_in_scope(self, scope: Scope) -> bool:
204 return scope_name in scope
207class CacheKeyManager:
208 def __init__(self, cache: t.Any | None = None) -> None:
209 self.cache = cache
210 self._cache_dict: dict[t.Any, t.Any] = {}
212 def get_cache_instance(self) -> t.Any:
213 if self.cache is None:
214 from .exceptions import safe_depends_get
216 self.cache = safe_depends_get("cache", self._cache_dict)
217 return self.cache
220class CacheMiddleware:
221 def __init__(
222 self,
223 app: ASGIApp,
224 *,
225 cache: t.Any | None = None,
226 rules: Sequence[Rule] | None = None,
227 ) -> None:
228 self.app = app
230 self.validator = CacheValidator(rules)
231 self.key_manager = CacheKeyManager(cache)
233 self.cache = cache
235 self.rules = self.validator.rules
237 self.validator.check_for_duplicate_middleware(app)
239 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
240 cache = self.key_manager.get_cache_instance() # type: ignore[no-untyped-call]
241 self.cache = cache
242 if scope["type"] != "http":
243 await self.app(scope, receive, send)
244 return
245 if self.validator.is_duplicate_in_scope(scope):
246 from .exceptions import DuplicateCaching
248 msg = (
249 "Another `CacheMiddleware` was detected in the middleware stack.\n"
250 "HINT: this exception probably occurred because:\n"
251 "- You wrapped an application around `CacheMiddleware` multiple times.\n"
252 "- You tried to apply `@cached()` onto an endpoint, but the application "
253 "is already wrapped around a `CacheMiddleware`."
254 )
255 raise DuplicateCaching(
256 msg,
257 )
258 scope[scope_name] = self
259 responder = CacheResponder(self.app, rules=self.rules)
260 await responder(scope, receive, send)
263class _BaseCacheMiddlewareHelper:
264 def __init__(self, request: Request) -> None:
265 self.request = request
266 if scope_name not in request.scope:
267 msg = "No CacheMiddleware instance found in the ASGI scope. Did you forget to wrap the ASGI application with `CacheMiddleware`?"
268 raise MissingCaching(
269 msg,
270 )
271 middleware = request.scope[scope_name]
272 if not isinstance(middleware, CacheMiddleware):
273 msg = f"A scope variable named {scope_name!r} was found, but it does not contain a `CacheMiddleware` instance. It is likely that an incompatible middleware was added to the middleware stack."
274 raise MissingCaching(
275 msg,
276 )
277 self.middleware = middleware
280class CacheHelper(_BaseCacheMiddlewareHelper):
281 async def invalidate_cache_for(
282 self,
283 url: str | URL,
284 *,
285 headers: Mapping[str, str] | None = None,
286 ) -> None:
287 if not isinstance(url, URL):
288 url = self.request.url_for(url)
289 if not isinstance(headers, Headers):
290 headers = Headers(headers)
291 await delete_from_cache(url, vary=headers, cache=self.middleware.cache)
294class CacheControlMiddleware:
295 app: ASGIApp
296 kwargs: CacheDirectives
297 max_age: int | None
298 s_maxage: int | None
299 no_cache: bool
300 no_store: bool
301 no_transform: bool
302 must_revalidate: bool
303 proxy_revalidate: bool
304 must_understand: bool
305 private: bool
306 public: bool
307 immutable: bool
308 stale_while_revalidate: int | None
309 stale_if_error: int | None
311 def __init__(self, app: ASGIApp, **kwargs: t.Unpack[CacheDirectives]) -> None:
312 self.app = app
313 self.kwargs = kwargs
314 self.max_age = None
315 self.s_maxage = None
316 self.no_cache = False
317 self.no_store = False
318 self.no_transform = False
319 self.must_revalidate = False
320 self.proxy_revalidate = False
321 self.must_understand = False
322 self.private = False
323 self.public = False
324 self.immutable = False
325 self.stale_while_revalidate = None
326 self.stale_if_error = None
327 for key, value in kwargs.items():
328 setattr(self, key, value)
330 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
331 if scope["type"] != "http":
332 await self.app(scope, receive, send)
333 return
334 responder = CacheControlResponder(self.app, **self.kwargs)
335 await responder(scope, receive, send)
337 def process_response(self, response: t.Any) -> None:
338 cache_control_parts: list[str] = []
339 if getattr(self, "public", False):
340 cache_control_parts.append("public")
341 elif getattr(self, "private", False):
342 cache_control_parts.append("private")
343 if getattr(self, "no_cache", False):
344 cache_control_parts.append("no-cache")
345 if getattr(self, "no_store", False):
346 cache_control_parts.append("no-store")
347 if getattr(self, "must_revalidate", False):
348 cache_control_parts.append("must-revalidate")
349 max_age = getattr(self, "max_age", None)
350 if max_age is not None:
351 cache_control_parts.append(f"max-age={max_age}")
352 if cache_control_parts:
353 response.headers["Cache-Control"] = ", ".join(cache_control_parts)
356def get_middleware_positions() -> dict[str, int]:
357 return {position.name: position.value for position in MiddlewarePosition}
360class MiddlewareStackManager:
361 def __init__(
362 self,
363 config: t.Any | None = None,
364 logger: t.Any | None = None,
365 ) -> None:
366 self.config = config
367 self.logger = logger
368 self._middleware_registry: dict[MiddlewarePosition, MiddlewareClass] = {}
369 self._middleware_options: dict[MiddlewarePosition, MiddlewareOptions] = {}
370 self._custom_middleware: dict[MiddlewarePosition, Middleware] = {}
371 self._initialized = False
373 def _ensure_dependencies(self) -> None:
374 if self.config is None or self.logger is None:
375 if self.config is None:
376 self.config = depends.get("config")
377 if self.logger is None:
378 try:
379 self.logger = depends.get("logger")
380 except Exception:
381 self.logger = None
383 def _register_default_middleware(self) -> None:
384 self._middleware_registry.update(
385 {
386 MiddlewarePosition.HTMX: HtmxMiddleware,
387 MiddlewarePosition.CURRENT_REQUEST: CurrentRequestMiddleware,
388 MiddlewarePosition.COMPRESSION: BrotliMiddleware,
389 },
390 )
391 self._middleware_options[MiddlewarePosition.COMPRESSION] = {"quality": 3}
393 def _register_conditional_middleware(self) -> None:
394 self._ensure_dependencies()
395 if not self.config:
396 return
397 from acb.adapters import get_adapter
399 self._middleware_registry[MiddlewarePosition.CSRF] = CSRFMiddleware
400 self._middleware_options[MiddlewarePosition.CSRF] = {
401 "secret": self.config.app.secret_key.get_secret_value(),
402 "cookie_name": f"{getattr(self.config.app, 'token_id', '_fb_')}_csrf",
403 "cookie_secure": self.config.deployed,
404 }
405 if get_adapter("auth"):
406 self._middleware_registry[MiddlewarePosition.SESSION] = SessionMiddleware
407 self._middleware_options[MiddlewarePosition.SESSION] = {
408 "secret_key": self.config.app.secret_key.get_secret_value(),
409 "session_cookie": f"{getattr(self.config.app, 'token_id', '_fb_')}_app",
410 "https_only": self.config.deployed,
411 }
412 if self.config.deployed or getattr(self.config.debug, "production", False):
413 self._middleware_registry[MiddlewarePosition.SECURITY_HEADERS] = (
414 SecureHeadersMiddleware
415 )
417 def initialize(self) -> None:
418 if self._initialized:
419 return
420 self._register_default_middleware()
421 self._register_conditional_middleware()
422 self._initialized = True
424 def register_middleware(
425 self,
426 middleware_class: MiddlewareClass,
427 position: MiddlewarePosition,
428 **options: t.Any,
429 ) -> None:
430 self._middleware_registry[position] = middleware_class
431 if options:
432 self._middleware_options[position] = options
434 def add_custom_middleware(
435 self,
436 middleware: Middleware,
437 position: MiddlewarePosition,
438 ) -> None:
439 self._custom_middleware[position] = middleware
441 def build_stack(self) -> list[Middleware]:
442 if not self._initialized:
443 self.initialize()
445 middleware_stack: dict[MiddlewarePosition, Middleware] = {}
446 self._build_middleware_stack(middleware_stack)
447 middleware_stack.update(self._custom_middleware)
449 return [
450 middleware_stack[position] for position in sorted(middleware_stack.keys())
451 ]
453 def _build_middleware_stack(
454 self, middleware_stack: dict[MiddlewarePosition, Middleware]
455 ) -> None:
456 """Build the middleware stack from registered middleware."""
457 for position, middleware_class in self._middleware_registry.items():
458 options = self._middleware_options.get(position, {})
459 middleware_stack[position] = Middleware(middleware_class, **options)
461 def get_middleware_info(self) -> dict[str, t.Any]:
462 if not self._initialized:
463 self.initialize()
465 return {
466 "registered": {
467 pos.name: cls.__name__ for pos, cls in self._middleware_registry.items()
468 },
469 "custom": {
470 pos.name: str(middleware)
471 for pos, middleware in self._custom_middleware.items()
472 },
473 "positions": get_middleware_positions(),
474 }
477def middlewares() -> list[Middleware]:
478 return MiddlewareStackManager().build_stack()