Coverage for fastblocks/middleware.py: 73%

302 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-21 04:50 -0700

1import sys 

2import typing as t 

3from collections.abc import Mapping, Sequence 

4from contextvars import ContextVar 

5from enum import IntEnum 

6 

7from acb.debug import debug 

8from acb.depends import depends 

9from brotli_asgi import BrotliMiddleware 

10from secure import Secure 

11from starlette.datastructures import URL, Headers, MutableHeaders 

12from starlette.middleware import Middleware 

13from starlette.middleware.sessions import SessionMiddleware 

14from starlette.requests import Request 

15from starlette.types import ASGIApp, Message, Receive, Scope, Send 

16from starlette_csrf.middleware import CSRFMiddleware 

17 

18from .caching import ( 

19 CacheControlResponder, 

20 CacheDirectives, 

21 CacheResponder, 

22 Rule, 

23 delete_from_cache, 

24) 

25from .htmx import HtmxDetails 

26 

27MiddlewareCallable = t.Callable[[ASGIApp], ASGIApp] 

28MiddlewareClass = type[t.Any] 

29MiddlewareOptions = dict[str, t.Any] 

30from .exceptions import MissingCaching 

31 

32 

33class MiddlewarePosition(IntEnum): 

34 CSRF = 0 

35 SESSION = 1 

36 HTMX = 2 

37 CURRENT_REQUEST = 3 

38 COMPRESSION = 4 

39 SECURITY_HEADERS = 5 

40 

41 

42class HtmxMiddleware: 

43 def __init__(self, app: ASGIApp) -> None: 

44 self._app = app 

45 debug("HtmxMiddleware: Initialized FastBlocks native HTMX middleware") 

46 

47 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

48 if scope["type"] in ("http", "websocket"): 

49 htmx_details = HtmxDetails(scope) 

50 scope["htmx"] = htmx_details 

51 if debug.enabled: 

52 method = scope.get("method", "UNKNOWN") 

53 path = scope.get("path", "unknown") 

54 is_htmx = bool(htmx_details) 

55 debug(f"HtmxMiddleware: {method} {path} - HTMX: {is_htmx}") 

56 if is_htmx: 

57 headers = htmx_details.get_all_headers() 

58 for header_name, header_value in headers.items(): 

59 debug(f"HtmxMiddleware: {header_name}: {header_value}") 

60 await self._app(scope, receive, send) 

61 

62 

63class HtmxResponseMiddleware: 

64 def __init__(self, app: ASGIApp) -> None: 

65 self._app = app 

66 debug("HtmxResponseMiddleware: Initialized FastBlocks HTMX response middleware") 

67 

68 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

69 if scope["type"] != "http": 

70 await self._app(scope, receive, send) 

71 return 

72 

73 async def send_wrapper(message: Message) -> None: 

74 if message["type"] == "http.response.start": 

75 htmx_details = scope.get("htmx") 

76 if htmx_details and bool(htmx_details): 

77 debug("HtmxResponseMiddleware: Processing HTMX response") 

78 headers = list(message.get("headers", [])) 

79 message["headers"] = headers 

80 await send(message) 

81 

82 await self._app(scope, receive, send_wrapper) 

83 

84 

85class MiddlewareUtils: 

86 Cache = t.Any 

87 

88 secure_headers = Secure() 

89 

90 scope_name = "__starlette_caches__" 

91 

92 _request_ctx_var: ContextVar[Scope | None] = ContextVar("request", default=None) 

93 

94 HTTP = sys.intern("http") 

95 WEBSOCKET = sys.intern("websocket") 

96 TYPE = sys.intern("type") 

97 METHOD = sys.intern("method") 

98 PATH = sys.intern("path") 

99 GET = sys.intern("GET") 

100 HEAD = sys.intern("HEAD") 

101 POST = sys.intern("POST") 

102 PUT = sys.intern("PUT") 

103 PATCH = sys.intern("PATCH") 

104 DELETE = sys.intern("DELETE") 

105 

106 @classmethod 

107 def get_request(cls) -> Scope | None: 

108 return cls._request_ctx_var.get() 

109 

110 @classmethod 

111 def set_request(cls, scope: Scope | None) -> None: 

112 cls._request_ctx_var.set(scope) 

113 

114 

115Cache = MiddlewareUtils.Cache 

116secure_headers = MiddlewareUtils.secure_headers 

117scope_name = MiddlewareUtils.scope_name 

118_request_ctx_var = MiddlewareUtils._request_ctx_var 

119 

120 

121def get_request() -> Scope | None: 

122 return MiddlewareUtils.get_request() 

123 

124 

125class CurrentRequestMiddleware: 

126 def __init__(self, app: ASGIApp) -> None: 

127 self.app = app 

128 

129 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

130 if scope[MiddlewareUtils.TYPE] not in ( 

131 MiddlewareUtils.HTTP, 

132 MiddlewareUtils.WEBSOCKET, 

133 ): 

134 await self.app(scope, receive, send) 

135 return None 

136 local_scope = _request_ctx_var.set(scope) 

137 response = await self.app(scope, receive, send) 

138 _request_ctx_var.reset(local_scope) 

139 return response 

140 

141 

142class SecureHeadersMiddleware: 

143 def __init__(self, app: ASGIApp) -> None: 

144 self.app = app 

145 try: 

146 self.logger = depends.get("logger") 

147 except Exception: 

148 self.logger = None 

149 

150 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

151 if scope["type"] != "http": 

152 return await self.app(scope, receive, send) 

153 

154 async def send_with_secure_headers(message: Message) -> None: 

155 if message["type"] == "http.response.start": 

156 headers = MutableHeaders(scope=message) 

157 for header_name, header_value in secure_headers.headers.items(): 

158 headers.append(header_name, header_value) 

159 await send(message) 

160 

161 await self.app(scope, receive, send_with_secure_headers) 

162 return None 

163 

164 

165class CacheValidator: 

166 def __init__(self, rules: Sequence[Rule] | None = None) -> None: 

167 self.rules = rules or [Rule()] 

168 

169 def check_for_duplicate_middleware(self, app: ASGIApp) -> None: 

170 if hasattr(app, "middleware"): 

171 middleware_attr = app.middleware # type: ignore[attr-defined] 

172 if callable(middleware_attr): 

173 return 

174 middleware = middleware_attr 

175 for middleware_item in middleware: 

176 if isinstance(middleware_item, CacheMiddleware): 

177 from .exceptions import DuplicateCaching 

178 

179 msg = "CacheMiddleware detected in middleware stack" 

180 raise DuplicateCaching( 

181 msg, 

182 ) 

183 

184 def is_duplicate_in_scope(self, scope: Scope) -> bool: 

185 return scope_name in scope 

186 

187 

188class CacheKeyManager: 

189 def __init__(self, cache: t.Any | None = None) -> None: 

190 self.cache = cache 

191 self._cache_dict = {} 

192 

193 def get_cache_instance(self): 

194 if self.cache is None: 

195 from .exceptions import safe_depends_get 

196 

197 self.cache = safe_depends_get("cache", self._cache_dict) 

198 return self.cache 

199 

200 

201class CacheMiddleware: 

202 def __init__( 

203 self, 

204 app: ASGIApp, 

205 *, 

206 cache: t.Any | None = None, 

207 rules: Sequence[Rule] | None = None, 

208 ) -> None: 

209 self.app = app 

210 

211 self.validator = CacheValidator(rules) 

212 self.key_manager = CacheKeyManager(cache) 

213 

214 self.cache = cache 

215 

216 self.rules = self.validator.rules 

217 

218 self.validator.check_for_duplicate_middleware(app) 

219 

220 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

221 cache = self.key_manager.get_cache_instance() 

222 self.cache = cache 

223 if scope["type"] != "http": 

224 await self.app(scope, receive, send) 

225 return 

226 if self.validator.is_duplicate_in_scope(scope): 

227 from .exceptions import DuplicateCaching 

228 

229 msg = ( 

230 "Another `CacheMiddleware` was detected in the middleware stack.\n" 

231 "HINT: this exception probably occurred because:\n" 

232 "- You wrapped an application around `CacheMiddleware` multiple times.\n" 

233 "- You tried to apply `@cached()` onto an endpoint, but the application " 

234 "is already wrapped around a `CacheMiddleware`." 

235 ) 

236 raise DuplicateCaching( 

237 msg, 

238 ) 

239 scope[scope_name] = self 

240 responder = CacheResponder(self.app, rules=self.rules) 

241 await responder(scope, receive, send) 

242 

243 

244class _BaseCacheMiddlewareHelper: 

245 def __init__(self, request: Request) -> None: 

246 self.request = request 

247 if scope_name not in request.scope: 

248 msg = "No CacheMiddleware instance found in the ASGI scope. Did you forget to wrap the ASGI application with `CacheMiddleware`?" 

249 raise MissingCaching( 

250 msg, 

251 ) 

252 middleware = request.scope[scope_name] 

253 if not isinstance(middleware, CacheMiddleware): 

254 msg = f"A scope variable named {scope_name!r} was found, but it does not contain a `CacheMiddleware` instance. It is likely that an incompatible middleware was added to the middleware stack." 

255 raise MissingCaching( 

256 msg, 

257 ) 

258 self.middleware = middleware 

259 

260 

261class CacheHelper(_BaseCacheMiddlewareHelper): 

262 async def invalidate_cache_for( 

263 self, 

264 url: str | URL, 

265 *, 

266 headers: Mapping[str, str] | None = None, 

267 ) -> None: 

268 if not isinstance(url, URL): 

269 url = self.request.url_for(url) 

270 if not isinstance(headers, Headers): 

271 headers = Headers(headers) 

272 await delete_from_cache(url, vary=headers, cache=self.middleware.cache) 

273 

274 

275class CacheControlMiddleware: 

276 app: ASGIApp 

277 kwargs: CacheDirectives 

278 max_age: int | None 

279 s_maxage: int | None 

280 no_cache: bool 

281 no_store: bool 

282 no_transform: bool 

283 must_revalidate: bool 

284 proxy_revalidate: bool 

285 must_understand: bool 

286 private: bool 

287 public: bool 

288 immutable: bool 

289 stale_while_revalidate: int | None 

290 stale_if_error: int | None 

291 

292 def __init__(self, app: ASGIApp, **kwargs: t.Unpack[CacheDirectives]) -> None: 

293 self.app = app 

294 self.kwargs = kwargs 

295 self.max_age = None 

296 self.s_maxage = None 

297 self.no_cache = False 

298 self.no_store = False 

299 self.no_transform = False 

300 self.must_revalidate = False 

301 self.proxy_revalidate = False 

302 self.must_understand = False 

303 self.private = False 

304 self.public = False 

305 self.immutable = False 

306 self.stale_while_revalidate = None 

307 self.stale_if_error = None 

308 for key, value in kwargs.items(): 

309 setattr(self, key, value) 

310 

311 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

312 if scope["type"] != "http": 

313 await self.app(scope, receive, send) 

314 return 

315 responder = CacheControlResponder(self.app, **self.kwargs) 

316 await responder(scope, receive, send) 

317 

318 def process_response(self, response: t.Any) -> None: 

319 cache_control_parts: list[str] = [] 

320 if getattr(self, "public", False): 

321 cache_control_parts.append("public") 

322 elif getattr(self, "private", False): 

323 cache_control_parts.append("private") 

324 if getattr(self, "no_cache", False): 

325 cache_control_parts.append("no-cache") 

326 if getattr(self, "no_store", False): 

327 cache_control_parts.append("no-store") 

328 if getattr(self, "must_revalidate", False): 

329 cache_control_parts.append("must-revalidate") 

330 max_age = getattr(self, "max_age", None) 

331 if max_age is not None: 

332 cache_control_parts.append(f"max-age={max_age}") 

333 if cache_control_parts: 

334 response.headers["Cache-Control"] = ", ".join(cache_control_parts) 

335 

336 

337def get_middleware_positions() -> dict[str, int]: 

338 return {position.name: position.value for position in MiddlewarePosition} 

339 

340 

341class MiddlewareStackManager: 

342 def __init__( 

343 self, 

344 config: t.Any | None = None, 

345 logger: t.Any | None = None, 

346 ) -> None: 

347 self.config = config 

348 self.logger = logger 

349 self._middleware_registry: dict[MiddlewarePosition, MiddlewareClass] = {} 

350 self._middleware_options: dict[MiddlewarePosition, MiddlewareOptions] = {} 

351 self._custom_middleware: dict[MiddlewarePosition, Middleware] = {} 

352 self._initialized = False 

353 

354 def _ensure_dependencies(self) -> None: 

355 if self.config is None or self.logger is None: 

356 if self.config is None: 

357 self.config = depends.get("config") 

358 if self.logger is None: 

359 try: 

360 self.logger = depends.get("logger") 

361 except Exception: 

362 self.logger = None 

363 

364 def _register_default_middleware(self) -> None: 

365 self._middleware_registry.update( 

366 { 

367 MiddlewarePosition.HTMX: HtmxMiddleware, 

368 MiddlewarePosition.CURRENT_REQUEST: CurrentRequestMiddleware, 

369 MiddlewarePosition.COMPRESSION: BrotliMiddleware, 

370 }, 

371 ) 

372 self._middleware_options[MiddlewarePosition.COMPRESSION] = {"quality": 3} 

373 

374 def _register_conditional_middleware(self) -> None: 

375 self._ensure_dependencies() 

376 if not self.config: 

377 return 

378 from acb.adapters import get_adapter 

379 

380 self._middleware_registry[MiddlewarePosition.CSRF] = CSRFMiddleware 

381 self._middleware_options[MiddlewarePosition.CSRF] = { 

382 "secret": self.config.app.secret_key.get_secret_value(), 

383 "cookie_name": f"{getattr(self.config.app, 'token_id', '_fb_')}_csrf", 

384 "cookie_secure": self.config.deployed, 

385 } 

386 if get_adapter("auth"): 

387 self._middleware_registry[MiddlewarePosition.SESSION] = SessionMiddleware 

388 self._middleware_options[MiddlewarePosition.SESSION] = { 

389 "secret_key": self.config.app.secret_key.get_secret_value(), 

390 "session_cookie": f"{getattr(self.config.app, 'token_id', '_fb_')}_app", 

391 "https_only": self.config.deployed, 

392 } 

393 if self.config.deployed or getattr(self.config.debug, "production", False): 

394 self._middleware_registry[MiddlewarePosition.SECURITY_HEADERS] = ( 

395 SecureHeadersMiddleware 

396 ) 

397 

398 def initialize(self) -> None: 

399 if self._initialized: 

400 return 

401 self._register_default_middleware() 

402 self._register_conditional_middleware() 

403 self._initialized = True 

404 

405 def register_middleware( 

406 self, 

407 middleware_class: MiddlewareClass, 

408 position: MiddlewarePosition, 

409 **options: t.Any, 

410 ) -> None: 

411 self._middleware_registry[position] = middleware_class 

412 if options: 

413 self._middleware_options[position] = options 

414 

415 def add_custom_middleware( 

416 self, 

417 middleware: Middleware, 

418 position: MiddlewarePosition, 

419 ) -> None: 

420 self._custom_middleware[position] = middleware 

421 

422 def build_stack(self) -> list[Middleware]: 

423 if not self._initialized: 

424 self.initialize() 

425 middleware_stack: dict[MiddlewarePosition, Middleware] = {} 

426 for position, middleware_class in self._middleware_registry.items(): 

427 options = self._middleware_options.get(position, {}) 

428 middleware_stack[position] = Middleware(middleware_class, **options) 

429 middleware_stack.update(self._custom_middleware) 

430 

431 return [ 

432 middleware_stack[position] for position in sorted(middleware_stack.keys()) 

433 ] 

434 

435 def get_middleware_info(self) -> dict[str, t.Any]: 

436 if not self._initialized: 

437 self.initialize() 

438 

439 return { 

440 "registered": { 

441 pos.name: cls.__name__ for pos, cls in self._middleware_registry.items() 

442 }, 

443 "custom": { 

444 pos.name: str(middleware) 

445 for pos, middleware in self._custom_middleware.items() 

446 }, 

447 "positions": get_middleware_positions(), 

448 } 

449 

450 

451def middlewares() -> list[Middleware]: 

452 return MiddlewareStackManager().build_stack()