Coverage for fastblocks/middleware.py: 72%

313 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-09 00:47 -0700

1import sys 

2import typing as t 

3from collections.abc import Mapping, Sequence 

4from contextvars import ContextVar 

5from enum import IntEnum 

6 

7from acb.debug import debug 

8from acb.depends import depends 

9from brotli_asgi import BrotliMiddleware 

10from secure import Secure 

11from starlette.datastructures import URL, Headers, MutableHeaders 

12from starlette.middleware import Middleware 

13from starlette.middleware.sessions import SessionMiddleware 

14from starlette.requests import Request 

15from starlette.types import ASGIApp, Message, Receive, Scope, Send 

16from starlette_csrf.middleware import CSRFMiddleware 

17 

18from .caching import ( 

19 CacheControlResponder, 

20 CacheDirectives, 

21 CacheResponder, 

22 Rule, 

23 delete_from_cache, 

24) 

25from .htmx import HtmxDetails 

26 

27MiddlewareCallable = t.Callable[[ASGIApp], ASGIApp] 

28MiddlewareClass = type[t.Any] 

29MiddlewareOptions = dict[str, t.Any] 

30from .exceptions import MissingCaching 

31 

32 

33class MiddlewarePosition(IntEnum): 

34 CSRF = 0 

35 SESSION = 1 

36 HTMX = 2 

37 CURRENT_REQUEST = 3 

38 COMPRESSION = 4 

39 SECURITY_HEADERS = 5 

40 

41 

42class HtmxMiddleware: 

43 def __init__(self, app: ASGIApp) -> None: 

44 self._app = app 

45 debug("HtmxMiddleware: Initialized FastBlocks native HTMX middleware") 

46 

47 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

48 if scope["type"] in ("http", "websocket"): 

49 await self._process_htmx_request(scope) 

50 await self._app(scope, receive, send) 

51 

52 async def _process_htmx_request(self, scope: Scope) -> None: 

53 """Process HTMX request and add HTMX details to scope.""" 

54 htmx_details = HtmxDetails(scope) 

55 scope["htmx"] = htmx_details 

56 if debug.enabled: 

57 self._log_htmx_details(scope, htmx_details) 

58 

59 def _log_htmx_details(self, scope: Scope, htmx_details: HtmxDetails) -> None: 

60 """Log HTMX details if debugging is enabled.""" 

61 method = scope.get("method", "UNKNOWN") 

62 path = scope.get("path", "unknown") 

63 is_htmx = bool(htmx_details) 

64 debug(f"HtmxMiddleware: {method} {path} - HTMX: {is_htmx}") 

65 if is_htmx: 

66 headers = htmx_details.get_all_headers() 

67 for header_name, header_value in headers.items(): 

68 debug(f"HtmxMiddleware: {header_name}: {header_value}") 

69 

70 

71class HtmxResponseMiddleware: 

72 def __init__(self, app: ASGIApp) -> None: 

73 self._app = app 

74 debug("HtmxResponseMiddleware: Initialized FastBlocks HTMX response middleware") 

75 

76 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

77 if scope["type"] != "http": 

78 await self._app(scope, receive, send) 

79 return 

80 

81 async def send_wrapper(message: Message) -> None: 

82 await self._process_response_message(message, scope, send) 

83 

84 await self._app(scope, receive, send_wrapper) 

85 

86 async def _process_response_message( 

87 self, message: Message, scope: Scope, send: Send 

88 ) -> None: 

89 """Process response message and handle HTMX responses.""" 

90 if message["type"] == "http.response.start": 

91 htmx_details = scope.get("htmx") 

92 if htmx_details and bool(htmx_details): 

93 debug("HtmxResponseMiddleware: Processing HTMX response") 

94 headers = list(message.get("headers", [])) 

95 message["headers"] = headers 

96 await send(message) 

97 

98 

99class MiddlewareUtils: 

100 Cache = t.Any 

101 

102 secure_headers = Secure() 

103 

104 scope_name = "__starlette_caches__" 

105 

106 _request_ctx_var: ContextVar[Scope | None] = ContextVar("request", default=None) 

107 

108 HTTP = sys.intern("http") 

109 WEBSOCKET = sys.intern("websocket") 

110 TYPE = sys.intern("type") 

111 METHOD = sys.intern("method") 

112 PATH = sys.intern("path") 

113 GET = sys.intern("GET") 

114 HEAD = sys.intern("HEAD") 

115 POST = sys.intern("POST") 

116 PUT = sys.intern("PUT") 

117 PATCH = sys.intern("PATCH") 

118 DELETE = sys.intern("DELETE") 

119 

120 @classmethod 

121 def get_request(cls) -> Scope | None: 

122 return cls._request_ctx_var.get() 

123 

124 @classmethod 

125 def set_request(cls, scope: Scope | None) -> None: 

126 cls._request_ctx_var.set(scope) 

127 

128 

129Cache = MiddlewareUtils.Cache 

130secure_headers = MiddlewareUtils.secure_headers 

131scope_name = MiddlewareUtils.scope_name 

132_request_ctx_var = MiddlewareUtils._request_ctx_var 

133 

134 

135def get_request() -> Scope | None: 

136 return MiddlewareUtils.get_request() 

137 

138 

139class CurrentRequestMiddleware: 

140 def __init__(self, app: ASGIApp) -> None: 

141 self.app = app 

142 

143 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> t.Any: # type: ignore[func-returns-value,no-any-return] 

144 if scope[MiddlewareUtils.TYPE] not in ( 

145 MiddlewareUtils.HTTP, 

146 MiddlewareUtils.WEBSOCKET, 

147 ): 

148 await self.app(scope, receive, send) 

149 return None # type: ignore[func-returns-value] 

150 local_scope = _request_ctx_var.set(scope) 

151 response = await self.app(scope, receive, send) # type: ignore[func-returns-value] 

152 _request_ctx_var.reset(local_scope) 

153 return response # type: ignore[no-any-return] 

154 

155 

156class SecureHeadersMiddleware: 

157 def __init__(self, app: ASGIApp) -> None: 

158 self.app = app 

159 try: 

160 self.logger = depends.get("logger") 

161 except Exception: 

162 self.logger = None 

163 

164 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

165 if scope["type"] != "http": 

166 return await self.app(scope, receive, send) 

167 

168 async def send_with_secure_headers(message: Message) -> None: 

169 if message["type"] == "http.response.start": 

170 headers = MutableHeaders(scope=message) 

171 for header_name, header_value in secure_headers.headers.items(): 

172 headers.append(header_name, header_value) 

173 await send(message) 

174 

175 await self.app(scope, receive, send_with_secure_headers) 

176 return None 

177 

178 

179class CacheValidator: 

180 def __init__(self, rules: Sequence[Rule] | None = None) -> None: 

181 self.rules = rules or [Rule()] 

182 

183 def check_for_duplicate_middleware(self, app: ASGIApp) -> None: 

184 if not hasattr(app, "middleware"): 

185 return 

186 

187 middleware_attr = app.middleware # type: ignore[attr-defined] 

188 if callable(middleware_attr): 

189 return 

190 

191 middleware = middleware_attr 

192 self._check_for_cache_middleware_duplicates(middleware) 

193 

194 def _check_for_cache_middleware_duplicates(self, middleware: t.Any) -> None: 

195 """Check if CacheMiddleware is already in the middleware stack.""" 

196 for middleware_item in middleware: 

197 if isinstance(middleware_item, CacheMiddleware): 

198 from .exceptions import DuplicateCaching 

199 

200 msg = "CacheMiddleware detected in middleware stack" 

201 raise DuplicateCaching(msg) 

202 

203 def is_duplicate_in_scope(self, scope: Scope) -> bool: 

204 return scope_name in scope 

205 

206 

207class CacheKeyManager: 

208 def __init__(self, cache: t.Any | None = None) -> None: 

209 self.cache = cache 

210 self._cache_dict: dict[t.Any, t.Any] = {} 

211 

212 def get_cache_instance(self) -> t.Any: 

213 if self.cache is None: 

214 from .exceptions import safe_depends_get 

215 

216 self.cache = safe_depends_get("cache", self._cache_dict) 

217 return self.cache 

218 

219 

220class CacheMiddleware: 

221 def __init__( 

222 self, 

223 app: ASGIApp, 

224 *, 

225 cache: t.Any | None = None, 

226 rules: Sequence[Rule] | None = None, 

227 ) -> None: 

228 self.app = app 

229 

230 self.validator = CacheValidator(rules) 

231 self.key_manager = CacheKeyManager(cache) 

232 

233 self.cache = cache 

234 

235 self.rules = self.validator.rules 

236 

237 self.validator.check_for_duplicate_middleware(app) 

238 

239 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

240 cache = self.key_manager.get_cache_instance() # type: ignore[no-untyped-call] 

241 self.cache = cache 

242 if scope["type"] != "http": 

243 await self.app(scope, receive, send) 

244 return 

245 if self.validator.is_duplicate_in_scope(scope): 

246 from .exceptions import DuplicateCaching 

247 

248 msg = ( 

249 "Another `CacheMiddleware` was detected in the middleware stack.\n" 

250 "HINT: this exception probably occurred because:\n" 

251 "- You wrapped an application around `CacheMiddleware` multiple times.\n" 

252 "- You tried to apply `@cached()` onto an endpoint, but the application " 

253 "is already wrapped around a `CacheMiddleware`." 

254 ) 

255 raise DuplicateCaching( 

256 msg, 

257 ) 

258 scope[scope_name] = self 

259 responder = CacheResponder(self.app, rules=self.rules) 

260 await responder(scope, receive, send) 

261 

262 

263class _BaseCacheMiddlewareHelper: 

264 def __init__(self, request: Request) -> None: 

265 self.request = request 

266 if scope_name not in request.scope: 

267 msg = "No CacheMiddleware instance found in the ASGI scope. Did you forget to wrap the ASGI application with `CacheMiddleware`?" 

268 raise MissingCaching( 

269 msg, 

270 ) 

271 middleware = request.scope[scope_name] 

272 if not isinstance(middleware, CacheMiddleware): 

273 msg = f"A scope variable named {scope_name!r} was found, but it does not contain a `CacheMiddleware` instance. It is likely that an incompatible middleware was added to the middleware stack." 

274 raise MissingCaching( 

275 msg, 

276 ) 

277 self.middleware = middleware 

278 

279 

280class CacheHelper(_BaseCacheMiddlewareHelper): 

281 async def invalidate_cache_for( 

282 self, 

283 url: str | URL, 

284 *, 

285 headers: Mapping[str, str] | None = None, 

286 ) -> None: 

287 if not isinstance(url, URL): 

288 url = self.request.url_for(url) 

289 if not isinstance(headers, Headers): 

290 headers = Headers(headers) 

291 await delete_from_cache(url, vary=headers, cache=self.middleware.cache) 

292 

293 

294class CacheControlMiddleware: 

295 app: ASGIApp 

296 kwargs: CacheDirectives 

297 max_age: int | None 

298 s_maxage: int | None 

299 no_cache: bool 

300 no_store: bool 

301 no_transform: bool 

302 must_revalidate: bool 

303 proxy_revalidate: bool 

304 must_understand: bool 

305 private: bool 

306 public: bool 

307 immutable: bool 

308 stale_while_revalidate: int | None 

309 stale_if_error: int | None 

310 

311 def __init__(self, app: ASGIApp, **kwargs: t.Unpack[CacheDirectives]) -> None: 

312 self.app = app 

313 self.kwargs = kwargs 

314 self.max_age = None 

315 self.s_maxage = None 

316 self.no_cache = False 

317 self.no_store = False 

318 self.no_transform = False 

319 self.must_revalidate = False 

320 self.proxy_revalidate = False 

321 self.must_understand = False 

322 self.private = False 

323 self.public = False 

324 self.immutable = False 

325 self.stale_while_revalidate = None 

326 self.stale_if_error = None 

327 for key, value in kwargs.items(): 

328 setattr(self, key, value) 

329 

330 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

331 if scope["type"] != "http": 

332 await self.app(scope, receive, send) 

333 return 

334 responder = CacheControlResponder(self.app, **self.kwargs) 

335 await responder(scope, receive, send) 

336 

337 def process_response(self, response: t.Any) -> None: 

338 cache_control_parts: list[str] = [] 

339 if getattr(self, "public", False): 

340 cache_control_parts.append("public") 

341 elif getattr(self, "private", False): 

342 cache_control_parts.append("private") 

343 if getattr(self, "no_cache", False): 

344 cache_control_parts.append("no-cache") 

345 if getattr(self, "no_store", False): 

346 cache_control_parts.append("no-store") 

347 if getattr(self, "must_revalidate", False): 

348 cache_control_parts.append("must-revalidate") 

349 max_age = getattr(self, "max_age", None) 

350 if max_age is not None: 

351 cache_control_parts.append(f"max-age={max_age}") 

352 if cache_control_parts: 

353 response.headers["Cache-Control"] = ", ".join(cache_control_parts) 

354 

355 

356def get_middleware_positions() -> dict[str, int]: 

357 return {position.name: position.value for position in MiddlewarePosition} 

358 

359 

360class MiddlewareStackManager: 

361 def __init__( 

362 self, 

363 config: t.Any | None = None, 

364 logger: t.Any | None = None, 

365 ) -> None: 

366 self.config = config 

367 self.logger = logger 

368 self._middleware_registry: dict[MiddlewarePosition, MiddlewareClass] = {} 

369 self._middleware_options: dict[MiddlewarePosition, MiddlewareOptions] = {} 

370 self._custom_middleware: dict[MiddlewarePosition, Middleware] = {} 

371 self._initialized = False 

372 

373 def _ensure_dependencies(self) -> None: 

374 if self.config is None or self.logger is None: 

375 if self.config is None: 

376 self.config = depends.get("config") 

377 if self.logger is None: 

378 try: 

379 self.logger = depends.get("logger") 

380 except Exception: 

381 self.logger = None 

382 

383 def _register_default_middleware(self) -> None: 

384 self._middleware_registry.update( 

385 { 

386 MiddlewarePosition.HTMX: HtmxMiddleware, 

387 MiddlewarePosition.CURRENT_REQUEST: CurrentRequestMiddleware, 

388 MiddlewarePosition.COMPRESSION: BrotliMiddleware, 

389 }, 

390 ) 

391 self._middleware_options[MiddlewarePosition.COMPRESSION] = {"quality": 3} 

392 

393 def _register_conditional_middleware(self) -> None: 

394 self._ensure_dependencies() 

395 if not self.config: 

396 return 

397 from acb.adapters import get_adapter 

398 

399 self._middleware_registry[MiddlewarePosition.CSRF] = CSRFMiddleware 

400 self._middleware_options[MiddlewarePosition.CSRF] = { 

401 "secret": self.config.app.secret_key.get_secret_value(), 

402 "cookie_name": f"{getattr(self.config.app, 'token_id', '_fb_')}_csrf", 

403 "cookie_secure": self.config.deployed, 

404 } 

405 if get_adapter("auth"): 

406 self._middleware_registry[MiddlewarePosition.SESSION] = SessionMiddleware 

407 self._middleware_options[MiddlewarePosition.SESSION] = { 

408 "secret_key": self.config.app.secret_key.get_secret_value(), 

409 "session_cookie": f"{getattr(self.config.app, 'token_id', '_fb_')}_app", 

410 "https_only": self.config.deployed, 

411 } 

412 if self.config.deployed or getattr(self.config.debug, "production", False): 

413 self._middleware_registry[MiddlewarePosition.SECURITY_HEADERS] = ( 

414 SecureHeadersMiddleware 

415 ) 

416 

417 def initialize(self) -> None: 

418 if self._initialized: 

419 return 

420 self._register_default_middleware() 

421 self._register_conditional_middleware() 

422 self._initialized = True 

423 

424 def register_middleware( 

425 self, 

426 middleware_class: MiddlewareClass, 

427 position: MiddlewarePosition, 

428 **options: t.Any, 

429 ) -> None: 

430 self._middleware_registry[position] = middleware_class 

431 if options: 

432 self._middleware_options[position] = options 

433 

434 def add_custom_middleware( 

435 self, 

436 middleware: Middleware, 

437 position: MiddlewarePosition, 

438 ) -> None: 

439 self._custom_middleware[position] = middleware 

440 

441 def build_stack(self) -> list[Middleware]: 

442 if not self._initialized: 

443 self.initialize() 

444 

445 middleware_stack: dict[MiddlewarePosition, Middleware] = {} 

446 self._build_middleware_stack(middleware_stack) 

447 middleware_stack.update(self._custom_middleware) 

448 

449 return [ 

450 middleware_stack[position] for position in sorted(middleware_stack.keys()) 

451 ] 

452 

453 def _build_middleware_stack( 

454 self, middleware_stack: dict[MiddlewarePosition, Middleware] 

455 ) -> None: 

456 """Build the middleware stack from registered middleware.""" 

457 for position, middleware_class in self._middleware_registry.items(): 

458 options = self._middleware_options.get(position, {}) 

459 middleware_stack[position] = Middleware(middleware_class, **options) 

460 

461 def get_middleware_info(self) -> dict[str, t.Any]: 

462 if not self._initialized: 

463 self.initialize() 

464 

465 return { 

466 "registered": { 

467 pos.name: cls.__name__ for pos, cls in self._middleware_registry.items() 

468 }, 

469 "custom": { 

470 pos.name: str(middleware) 

471 for pos, middleware in self._custom_middleware.items() 

472 }, 

473 "positions": get_middleware_positions(), 

474 } 

475 

476 

477def middlewares() -> list[Middleware]: 

478 return MiddlewareStackManager().build_stack()