Coverage for fastblocks/middleware.py: 73%
302 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-21 04:50 -0700
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-21 04:50 -0700
1import sys
2import typing as t
3from collections.abc import Mapping, Sequence
4from contextvars import ContextVar
5from enum import IntEnum
7from acb.debug import debug
8from acb.depends import depends
9from brotli_asgi import BrotliMiddleware
10from secure import Secure
11from starlette.datastructures import URL, Headers, MutableHeaders
12from starlette.middleware import Middleware
13from starlette.middleware.sessions import SessionMiddleware
14from starlette.requests import Request
15from starlette.types import ASGIApp, Message, Receive, Scope, Send
16from starlette_csrf.middleware import CSRFMiddleware
18from .caching import (
19 CacheControlResponder,
20 CacheDirectives,
21 CacheResponder,
22 Rule,
23 delete_from_cache,
24)
25from .htmx import HtmxDetails
27MiddlewareCallable = t.Callable[[ASGIApp], ASGIApp]
28MiddlewareClass = type[t.Any]
29MiddlewareOptions = dict[str, t.Any]
30from .exceptions import MissingCaching
33class MiddlewarePosition(IntEnum):
34 CSRF = 0
35 SESSION = 1
36 HTMX = 2
37 CURRENT_REQUEST = 3
38 COMPRESSION = 4
39 SECURITY_HEADERS = 5
42class HtmxMiddleware:
43 def __init__(self, app: ASGIApp) -> None:
44 self._app = app
45 debug("HtmxMiddleware: Initialized FastBlocks native HTMX middleware")
47 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
48 if scope["type"] in ("http", "websocket"):
49 htmx_details = HtmxDetails(scope)
50 scope["htmx"] = htmx_details
51 if debug.enabled:
52 method = scope.get("method", "UNKNOWN")
53 path = scope.get("path", "unknown")
54 is_htmx = bool(htmx_details)
55 debug(f"HtmxMiddleware: {method} {path} - HTMX: {is_htmx}")
56 if is_htmx:
57 headers = htmx_details.get_all_headers()
58 for header_name, header_value in headers.items():
59 debug(f"HtmxMiddleware: {header_name}: {header_value}")
60 await self._app(scope, receive, send)
63class HtmxResponseMiddleware:
64 def __init__(self, app: ASGIApp) -> None:
65 self._app = app
66 debug("HtmxResponseMiddleware: Initialized FastBlocks HTMX response middleware")
68 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
69 if scope["type"] != "http":
70 await self._app(scope, receive, send)
71 return
73 async def send_wrapper(message: Message) -> None:
74 if message["type"] == "http.response.start":
75 htmx_details = scope.get("htmx")
76 if htmx_details and bool(htmx_details):
77 debug("HtmxResponseMiddleware: Processing HTMX response")
78 headers = list(message.get("headers", []))
79 message["headers"] = headers
80 await send(message)
82 await self._app(scope, receive, send_wrapper)
85class MiddlewareUtils:
86 Cache = t.Any
88 secure_headers = Secure()
90 scope_name = "__starlette_caches__"
92 _request_ctx_var: ContextVar[Scope | None] = ContextVar("request", default=None)
94 HTTP = sys.intern("http")
95 WEBSOCKET = sys.intern("websocket")
96 TYPE = sys.intern("type")
97 METHOD = sys.intern("method")
98 PATH = sys.intern("path")
99 GET = sys.intern("GET")
100 HEAD = sys.intern("HEAD")
101 POST = sys.intern("POST")
102 PUT = sys.intern("PUT")
103 PATCH = sys.intern("PATCH")
104 DELETE = sys.intern("DELETE")
106 @classmethod
107 def get_request(cls) -> Scope | None:
108 return cls._request_ctx_var.get()
110 @classmethod
111 def set_request(cls, scope: Scope | None) -> None:
112 cls._request_ctx_var.set(scope)
115Cache = MiddlewareUtils.Cache
116secure_headers = MiddlewareUtils.secure_headers
117scope_name = MiddlewareUtils.scope_name
118_request_ctx_var = MiddlewareUtils._request_ctx_var
121def get_request() -> Scope | None:
122 return MiddlewareUtils.get_request()
125class CurrentRequestMiddleware:
126 def __init__(self, app: ASGIApp) -> None:
127 self.app = app
129 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
130 if scope[MiddlewareUtils.TYPE] not in (
131 MiddlewareUtils.HTTP,
132 MiddlewareUtils.WEBSOCKET,
133 ):
134 await self.app(scope, receive, send)
135 return None
136 local_scope = _request_ctx_var.set(scope)
137 response = await self.app(scope, receive, send)
138 _request_ctx_var.reset(local_scope)
139 return response
142class SecureHeadersMiddleware:
143 def __init__(self, app: ASGIApp) -> None:
144 self.app = app
145 try:
146 self.logger = depends.get("logger")
147 except Exception:
148 self.logger = None
150 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
151 if scope["type"] != "http":
152 return await self.app(scope, receive, send)
154 async def send_with_secure_headers(message: Message) -> None:
155 if message["type"] == "http.response.start":
156 headers = MutableHeaders(scope=message)
157 for header_name, header_value in secure_headers.headers.items():
158 headers.append(header_name, header_value)
159 await send(message)
161 await self.app(scope, receive, send_with_secure_headers)
162 return None
165class CacheValidator:
166 def __init__(self, rules: Sequence[Rule] | None = None) -> None:
167 self.rules = rules or [Rule()]
169 def check_for_duplicate_middleware(self, app: ASGIApp) -> None:
170 if hasattr(app, "middleware"):
171 middleware_attr = app.middleware # type: ignore[attr-defined]
172 if callable(middleware_attr):
173 return
174 middleware = middleware_attr
175 for middleware_item in middleware:
176 if isinstance(middleware_item, CacheMiddleware):
177 from .exceptions import DuplicateCaching
179 msg = "CacheMiddleware detected in middleware stack"
180 raise DuplicateCaching(
181 msg,
182 )
184 def is_duplicate_in_scope(self, scope: Scope) -> bool:
185 return scope_name in scope
188class CacheKeyManager:
189 def __init__(self, cache: t.Any | None = None) -> None:
190 self.cache = cache
191 self._cache_dict = {}
193 def get_cache_instance(self):
194 if self.cache is None:
195 from .exceptions import safe_depends_get
197 self.cache = safe_depends_get("cache", self._cache_dict)
198 return self.cache
201class CacheMiddleware:
202 def __init__(
203 self,
204 app: ASGIApp,
205 *,
206 cache: t.Any | None = None,
207 rules: Sequence[Rule] | None = None,
208 ) -> None:
209 self.app = app
211 self.validator = CacheValidator(rules)
212 self.key_manager = CacheKeyManager(cache)
214 self.cache = cache
216 self.rules = self.validator.rules
218 self.validator.check_for_duplicate_middleware(app)
220 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
221 cache = self.key_manager.get_cache_instance()
222 self.cache = cache
223 if scope["type"] != "http":
224 await self.app(scope, receive, send)
225 return
226 if self.validator.is_duplicate_in_scope(scope):
227 from .exceptions import DuplicateCaching
229 msg = (
230 "Another `CacheMiddleware` was detected in the middleware stack.\n"
231 "HINT: this exception probably occurred because:\n"
232 "- You wrapped an application around `CacheMiddleware` multiple times.\n"
233 "- You tried to apply `@cached()` onto an endpoint, but the application "
234 "is already wrapped around a `CacheMiddleware`."
235 )
236 raise DuplicateCaching(
237 msg,
238 )
239 scope[scope_name] = self
240 responder = CacheResponder(self.app, rules=self.rules)
241 await responder(scope, receive, send)
244class _BaseCacheMiddlewareHelper:
245 def __init__(self, request: Request) -> None:
246 self.request = request
247 if scope_name not in request.scope:
248 msg = "No CacheMiddleware instance found in the ASGI scope. Did you forget to wrap the ASGI application with `CacheMiddleware`?"
249 raise MissingCaching(
250 msg,
251 )
252 middleware = request.scope[scope_name]
253 if not isinstance(middleware, CacheMiddleware):
254 msg = f"A scope variable named {scope_name!r} was found, but it does not contain a `CacheMiddleware` instance. It is likely that an incompatible middleware was added to the middleware stack."
255 raise MissingCaching(
256 msg,
257 )
258 self.middleware = middleware
261class CacheHelper(_BaseCacheMiddlewareHelper):
262 async def invalidate_cache_for(
263 self,
264 url: str | URL,
265 *,
266 headers: Mapping[str, str] | None = None,
267 ) -> None:
268 if not isinstance(url, URL):
269 url = self.request.url_for(url)
270 if not isinstance(headers, Headers):
271 headers = Headers(headers)
272 await delete_from_cache(url, vary=headers, cache=self.middleware.cache)
275class CacheControlMiddleware:
276 app: ASGIApp
277 kwargs: CacheDirectives
278 max_age: int | None
279 s_maxage: int | None
280 no_cache: bool
281 no_store: bool
282 no_transform: bool
283 must_revalidate: bool
284 proxy_revalidate: bool
285 must_understand: bool
286 private: bool
287 public: bool
288 immutable: bool
289 stale_while_revalidate: int | None
290 stale_if_error: int | None
292 def __init__(self, app: ASGIApp, **kwargs: t.Unpack[CacheDirectives]) -> None:
293 self.app = app
294 self.kwargs = kwargs
295 self.max_age = None
296 self.s_maxage = None
297 self.no_cache = False
298 self.no_store = False
299 self.no_transform = False
300 self.must_revalidate = False
301 self.proxy_revalidate = False
302 self.must_understand = False
303 self.private = False
304 self.public = False
305 self.immutable = False
306 self.stale_while_revalidate = None
307 self.stale_if_error = None
308 for key, value in kwargs.items():
309 setattr(self, key, value)
311 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
312 if scope["type"] != "http":
313 await self.app(scope, receive, send)
314 return
315 responder = CacheControlResponder(self.app, **self.kwargs)
316 await responder(scope, receive, send)
318 def process_response(self, response: t.Any) -> None:
319 cache_control_parts: list[str] = []
320 if getattr(self, "public", False):
321 cache_control_parts.append("public")
322 elif getattr(self, "private", False):
323 cache_control_parts.append("private")
324 if getattr(self, "no_cache", False):
325 cache_control_parts.append("no-cache")
326 if getattr(self, "no_store", False):
327 cache_control_parts.append("no-store")
328 if getattr(self, "must_revalidate", False):
329 cache_control_parts.append("must-revalidate")
330 max_age = getattr(self, "max_age", None)
331 if max_age is not None:
332 cache_control_parts.append(f"max-age={max_age}")
333 if cache_control_parts:
334 response.headers["Cache-Control"] = ", ".join(cache_control_parts)
337def get_middleware_positions() -> dict[str, int]:
338 return {position.name: position.value for position in MiddlewarePosition}
341class MiddlewareStackManager:
342 def __init__(
343 self,
344 config: t.Any | None = None,
345 logger: t.Any | None = None,
346 ) -> None:
347 self.config = config
348 self.logger = logger
349 self._middleware_registry: dict[MiddlewarePosition, MiddlewareClass] = {}
350 self._middleware_options: dict[MiddlewarePosition, MiddlewareOptions] = {}
351 self._custom_middleware: dict[MiddlewarePosition, Middleware] = {}
352 self._initialized = False
354 def _ensure_dependencies(self) -> None:
355 if self.config is None or self.logger is None:
356 if self.config is None:
357 self.config = depends.get("config")
358 if self.logger is None:
359 try:
360 self.logger = depends.get("logger")
361 except Exception:
362 self.logger = None
364 def _register_default_middleware(self) -> None:
365 self._middleware_registry.update(
366 {
367 MiddlewarePosition.HTMX: HtmxMiddleware,
368 MiddlewarePosition.CURRENT_REQUEST: CurrentRequestMiddleware,
369 MiddlewarePosition.COMPRESSION: BrotliMiddleware,
370 },
371 )
372 self._middleware_options[MiddlewarePosition.COMPRESSION] = {"quality": 3}
374 def _register_conditional_middleware(self) -> None:
375 self._ensure_dependencies()
376 if not self.config:
377 return
378 from acb.adapters import get_adapter
380 self._middleware_registry[MiddlewarePosition.CSRF] = CSRFMiddleware
381 self._middleware_options[MiddlewarePosition.CSRF] = {
382 "secret": self.config.app.secret_key.get_secret_value(),
383 "cookie_name": f"{getattr(self.config.app, 'token_id', '_fb_')}_csrf",
384 "cookie_secure": self.config.deployed,
385 }
386 if get_adapter("auth"):
387 self._middleware_registry[MiddlewarePosition.SESSION] = SessionMiddleware
388 self._middleware_options[MiddlewarePosition.SESSION] = {
389 "secret_key": self.config.app.secret_key.get_secret_value(),
390 "session_cookie": f"{getattr(self.config.app, 'token_id', '_fb_')}_app",
391 "https_only": self.config.deployed,
392 }
393 if self.config.deployed or getattr(self.config.debug, "production", False):
394 self._middleware_registry[MiddlewarePosition.SECURITY_HEADERS] = (
395 SecureHeadersMiddleware
396 )
398 def initialize(self) -> None:
399 if self._initialized:
400 return
401 self._register_default_middleware()
402 self._register_conditional_middleware()
403 self._initialized = True
405 def register_middleware(
406 self,
407 middleware_class: MiddlewareClass,
408 position: MiddlewarePosition,
409 **options: t.Any,
410 ) -> None:
411 self._middleware_registry[position] = middleware_class
412 if options:
413 self._middleware_options[position] = options
415 def add_custom_middleware(
416 self,
417 middleware: Middleware,
418 position: MiddlewarePosition,
419 ) -> None:
420 self._custom_middleware[position] = middleware
422 def build_stack(self) -> list[Middleware]:
423 if not self._initialized:
424 self.initialize()
425 middleware_stack: dict[MiddlewarePosition, Middleware] = {}
426 for position, middleware_class in self._middleware_registry.items():
427 options = self._middleware_options.get(position, {})
428 middleware_stack[position] = Middleware(middleware_class, **options)
429 middleware_stack.update(self._custom_middleware)
431 return [
432 middleware_stack[position] for position in sorted(middleware_stack.keys())
433 ]
435 def get_middleware_info(self) -> dict[str, t.Any]:
436 if not self._initialized:
437 self.initialize()
439 return {
440 "registered": {
441 pos.name: cls.__name__ for pos, cls in self._middleware_registry.items()
442 },
443 "custom": {
444 pos.name: str(middleware)
445 for pos, middleware in self._custom_middleware.items()
446 },
447 "positions": get_middleware_positions(),
448 }
451def middlewares() -> list[Middleware]:
452 return MiddlewareStackManager().build_stack()