Coverage for fastblocks/caching.py: 67%
401 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-21 04:50 -0700
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-21 04:50 -0700
1import base64
2import email.utils
3import hashlib
4import re
5import sys
6import time
7import typing as t
8from collections.abc import Iterable, Sequence
9from dataclasses import dataclass
10from functools import partial
11from threading import local
12from urllib.request import parse_http_list
14from starlette.datastructures import URL, Headers, MutableHeaders
15from starlette.requests import Request
16from starlette.responses import Response
18HashFunc = t.Callable[[t.Any], str]
19GetAdapterFunc = t.Callable[[str], t.Any]
20ImportAdapterFunc = t.Callable[[str | list[str] | None], t.Any]
21from acb.adapters import get_adapter
22from acb.depends import depends
23from starlette.types import ASGIApp, Message, Receive, Scope, Send
25from .exceptions import RequestNotCachable, ResponseNotCachable
28def _safe_log(logger: t.Any, level: str, message: str) -> None:
29 return CacheUtils.safe_log(logger, level, message)
32_CacheClass = None
34_hasher_pool = local()
36_str_encode = str.encode
37_base64_encodebytes = base64.encodebytes
38_base64_decodebytes = base64.decodebytes
41def _get_hasher():
42 if not hasattr(_hasher_pool, "hasher"):
43 _hasher_pool.hasher = hashlib.md5(usedforsecurity=False)
44 else:
45 _hasher_pool.hasher.__init__(usedforsecurity=False)
46 return _hasher_pool.hasher
49def get_cache() -> t.Any:
50 global _CacheClass
51 if _CacheClass is None:
52 _CacheClass = get_adapter("cache")
53 return _CacheClass
56class CacheUtils:
57 GET = sys.intern("GET")
58 HEAD = sys.intern("HEAD")
59 POST = sys.intern("POST")
60 PUT = sys.intern("PUT")
61 PATCH = sys.intern("PATCH")
62 DELETE = sys.intern("DELETE")
63 CACHE_CONTROL = sys.intern("Cache-Control")
64 ETAG = sys.intern("ETag")
65 LAST_MODIFIED = sys.intern("Last-Modified")
66 VARY = sys.intern("Vary")
68 CACHEABLE_METHODS = frozenset((GET, HEAD))
69 CACHEABLE_STATUS_CODES = frozenset(
70 (200, 203, 204, 206, 300, 301, 404, 405, 410, 414, 501),
71 )
72 ONE_YEAR = 60 * 60 * 24 * 365
73 INVALIDATING_METHODS = frozenset((POST, PUT, PATCH, DELETE))
75 @staticmethod
76 def safe_log(logger: t.Any, level: str, message: str) -> None:
77 if logger and hasattr(logger, level):
78 getattr(logger, level)(message)
81cacheable_methods = CacheUtils.CACHEABLE_METHODS
82cacheable_status_codes = CacheUtils.CACHEABLE_STATUS_CODES
83one_year = CacheUtils.ONE_YEAR
84invalidating_methods = CacheUtils.INVALIDATING_METHODS
87@dataclass
88class Rule:
89 match: str | re.Pattern[str] | Iterable[str | re.Pattern[str]] = "*"
90 status: int | Iterable[int] | None = None
91 ttl: float | None = None
94class CacheRules:
95 @staticmethod
96 def request_matches_rule(rule: Rule, *, request: Request) -> bool:
97 match = (
98 [rule.match]
99 if isinstance(rule.match, str | re.Pattern)
100 else list(rule.match)
101 )
102 for item in match:
103 if isinstance(item, re.Pattern):
104 if item.match(request.url.path):
105 return True
106 elif item in ("*", request.url.path):
107 return True
108 return False
110 @staticmethod
111 def response_matches_rule(
112 rule: Rule,
113 *,
114 request: Request,
115 response: Response,
116 ) -> bool:
117 if not CacheRules.request_matches_rule(rule, request=request):
118 return False
119 if rule.status is not None:
120 statuses = [rule.status] if isinstance(rule.status, int) else rule.status
121 if response.status_code not in statuses:
122 return False
123 return True
125 @staticmethod
126 def get_rule_matching_request(
127 rules: Sequence[Rule],
128 *,
129 request: Request,
130 ) -> Rule | None:
131 return next(
132 (
133 rule
134 for rule in rules
135 if CacheRules.request_matches_rule(rule, request=request)
136 ),
137 None,
138 )
140 @staticmethod
141 def get_rule_matching_response(
142 rules: Sequence[Rule],
143 *,
144 request: Request,
145 response: Response,
146 ) -> Rule | None:
147 return next(
148 (
149 rule
150 for rule in rules
151 if CacheRules.response_matches_rule(
152 rule,
153 request=request,
154 response=response,
155 )
156 ),
157 None,
158 )
161def request_matches_rule(rule: Rule, *, request: Request) -> bool:
162 return CacheRules.request_matches_rule(rule, request=request)
165def response_matches_rule(rule: Rule, *, request: Request, response: Response) -> bool:
166 return CacheRules.response_matches_rule(rule, request=request, response=response)
169def get_rule_matching_request(
170 rules: Sequence[Rule],
171 *,
172 request: Request,
173) -> Rule | None:
174 return CacheRules.get_rule_matching_request(rules, request=request)
177def get_rule_matching_response(
178 rules: Sequence[Rule],
179 *,
180 request: Request,
181 response: Response,
182) -> Rule | None:
183 return CacheRules.get_rule_matching_response(
184 rules,
185 request=request,
186 response=response,
187 )
190class CacheDirectives(t.TypedDict, total=False):
191 max_age: int
192 s_maxage: int
193 no_cache: bool
194 no_store: bool
195 no_transform: bool
196 must_revalidate: bool
197 proxy_revalidate: bool
198 must_understand: bool
199 private: bool
200 public: bool
201 immutable: bool
202 stale_while_revalidate: int
203 stale_if_error: int
206async def set_in_cache(
207 response: Response,
208 *,
209 request: Request,
210 rules: Sequence[Rule],
211 cache: t.Any = None,
212 logger: t.Any = None,
213) -> None:
214 if cache is None or logger is None:
215 if cache is None:
216 cache = depends.get("cache")
217 if logger is None:
218 logger = depends.get("logger")
219 if response.status_code not in cacheable_status_codes:
220 _safe_log(logger, "debug", "response_not_cacheable reason=status_code")
221 raise ResponseNotCachable(response)
222 if not request.cookies and "Set-Cookie" in response.headers:
223 _safe_log(
224 logger,
225 "debug",
226 "response_not_cacheable reason=cookies_for_cookieless_request",
227 )
228 raise ResponseNotCachable(response)
229 rule = get_rule_matching_response(rules, request=request, response=response)
230 if not rule:
231 _safe_log(logger, "debug", "response_not_cacheable reason=rule")
232 raise ResponseNotCachable(response)
233 ttl = rule.ttl if rule.ttl is not None else cache.ttl
234 if ttl == 0:
235 _safe_log(logger, "debug", "response_not_cacheable reason=zero_ttl")
236 raise ResponseNotCachable(response)
237 if ttl is None:
238 max_age = one_year
239 _safe_log(logger, "debug", f"max_out_ttl value={max_age!r}")
240 else:
241 max_age = int(ttl)
242 _safe_log(logger, "debug", f"set_in_cache max_age={max_age!r}")
243 response.headers["X-Cache"] = "hit"
244 cache_headers = get_cache_response_headers(response, max_age=max_age)
245 _safe_log(logger, "debug", f"patch_response_headers headers={cache_headers!r}")
246 response.headers.update(cache_headers)
247 cache_key = await learn_cache_key(request, response, cache=cache)
248 _safe_log(logger, "debug", f"learnt_cache_key cache_key={cache_key!r}")
249 serialized_response = serialize_response(response)
250 _safe_log(
251 logger,
252 "debug",
253 f"set_response_in_cache key={cache_key!r} value={serialized_response!r}",
254 )
255 kwargs = {}
256 if ttl is not None:
257 kwargs["ttl"] = ttl
258 await cache.set(key=cache_key, value=serialized_response, **kwargs)
259 response.headers["X-Cache"] = "miss"
262async def get_from_cache(
263 request: Request,
264 *,
265 rules: Sequence[Rule],
266 cache: t.Any = None,
267 logger: t.Any = None,
268) -> Response | None:
269 if cache is None or logger is None:
270 if cache is None:
271 cache = depends.get("cache")
272 if logger is None:
273 logger = depends.get("logger")
274 _safe_log(
275 logger,
276 "debug",
277 f"get_from_cache request.url={str(request.url)!r} request.method={request.method!r}",
278 )
279 if request.method not in cacheable_methods:
280 _safe_log(logger, "debug", "request_not_cacheable reason=method")
281 raise RequestNotCachable(request)
282 rule = get_rule_matching_request(rules, request=request)
283 if rule is None:
284 _safe_log(logger, "debug", "request_not_cacheable reason=rule")
285 raise RequestNotCachable(request)
286 _safe_log(logger, "debug", "lookup_cached_response method='GET'")
287 cache_key = await get_cache_key(request, method="GET", cache=cache)
288 if cache_key is None:
289 _safe_log(logger, "debug", "cache_key found=False")
290 return None
291 _safe_log(logger, "debug", f"cache_key found=True cache_key={cache_key!r}")
292 serialized_response = await cache.get(cache_key)
293 if serialized_response is None:
294 _safe_log(logger, "debug", "lookup_cached_response method='HEAD'")
295 cache_key = await get_cache_key(request, method="HEAD", cache=cache)
296 if cache_key is None:
297 return None
298 _safe_log(logger, "debug", f"cache_key found=True cache_key={cache_key!r}")
299 serialized_response = await cache.get(cache_key)
300 if serialized_response is None:
301 _safe_log(logger, "debug", "cached_response found=False")
302 return None
303 _safe_log(
304 logger,
305 "debug",
306 f"cached_response found=True key={cache_key!r} value={serialized_response!r}",
307 )
308 return deserialize_response(serialized_response)
311async def delete_from_cache(
312 url: URL,
313 *,
314 vary: Headers,
315 cache: t.Any = None,
316 logger: t.Any = None,
317) -> None:
318 if cache is None or logger is None:
319 if cache is None:
320 cache = depends.get("cache")
321 if logger is None:
322 logger = depends.get("logger")
323 varying_headers_cache_key = generate_varying_headers_cache_key(url)
324 varying_headers = await cache.get(varying_headers_cache_key)
325 if varying_headers is None:
326 return
327 for method in ("GET", "HEAD"):
328 cache_key = generate_cache_key(
329 url,
330 method=method,
331 headers=vary,
332 varying_headers=varying_headers,
333 )
334 logger.debug(f"clear_cache key={cache_key!r}")
335 await cache.delete(cache_key)
336 await cache.delete(varying_headers_cache_key)
339def serialize_response(response: Response) -> dict[str, t.Any]:
340 return {
341 "content": _base64_encodebytes(response.body).decode("ascii"),
342 "status_code": response.status_code,
343 "headers": dict(response.headers),
344 }
347def deserialize_response(serialized_response: t.Any) -> Response:
348 if not isinstance(serialized_response, dict):
349 msg = f"Expected dict, got {type(serialized_response)}"
350 raise TypeError(msg)
351 content = serialized_response.get("content")
352 if not isinstance(content, str):
353 msg = f"Expected content to be str, got {type(content)}"
354 raise TypeError(msg)
355 status_code = serialized_response.get("status_code")
356 if not isinstance(status_code, int):
357 msg = f"Expected status_code to be int, got {type(status_code)}"
358 raise TypeError(msg)
359 headers = serialized_response.get("headers")
360 if not isinstance(headers, dict):
361 msg = f"Expected headers to be dict, got {type(headers)}"
362 raise TypeError(msg)
363 return Response(
364 content=_base64_decodebytes(_str_encode(content, "ascii")),
365 status_code=status_code,
366 headers=headers,
367 )
370async def learn_cache_key(
371 request: Request,
372 response: Response,
373 *,
374 cache: t.Any = None,
375 logger: t.Any = None,
376) -> str:
377 if cache is None or logger is None:
378 if cache is None:
379 cache = depends.get("cache")
380 if logger is None:
381 logger = depends.get("logger")
382 logger.debug(
383 f"learn_cache_key request.method={request.method!r} response.headers.Vary={response.headers.get('Vary')!r}",
384 )
385 url = request.url
386 varying_headers_cache_key = generate_varying_headers_cache_key(url)
387 cached_vary_headers = set(await cache.get(key=varying_headers_cache_key) or ())
388 response_vary_headers = {
389 header.lower() for header in parse_http_list(response.headers.get("Vary", ""))
390 }
391 varying_headers = sorted(response_vary_headers | cached_vary_headers)
392 if varying_headers:
393 response.headers["Vary"] = ", ".join(varying_headers)
394 logger.debug(
395 f"store_varying_headers cache_key={varying_headers_cache_key!r} headers={varying_headers!r}",
396 )
397 await cache.set(key=varying_headers_cache_key, value=varying_headers)
398 cache_key = generate_cache_key(
399 url,
400 method=request.method,
401 headers=request.headers,
402 varying_headers=varying_headers,
403 )
404 if cache_key is None:
405 msg = f"Unable to generate cache key for method {request.method}"
406 raise ValueError(msg)
407 return cache_key
410async def get_cache_key(
411 request: Request,
412 method: str,
413 cache: t.Any = None,
414 logger: t.Any = None,
415) -> str | None:
416 if cache is None or logger is None:
417 if cache is None:
418 cache = depends.get("cache")
419 if logger is None:
420 logger = depends.get("logger")
421 url = request.url
422 _safe_log(
423 logger,
424 "debug",
425 f"get_cache_key request.url={str(url)!r} method={method!r}",
426 )
427 varying_headers_cache_key = generate_varying_headers_cache_key(url)
428 varying_headers = await cache.get(varying_headers_cache_key)
429 if varying_headers is None:
430 _safe_log(logger, "debug", "varying_headers found=False")
431 return None
432 _safe_log(
433 logger,
434 "debug",
435 f"varying_headers found=True headers={varying_headers!r}",
436 )
437 return generate_cache_key(
438 request.url,
439 method=method,
440 headers=request.headers,
441 varying_headers=varying_headers,
442 )
445def generate_cache_key(
446 url: URL,
447 method: str,
448 headers: Headers,
449 varying_headers: list[str],
450 config: t.Any = None,
451) -> str | None:
452 if config is None:
453 config = depends.get("config")
455 if method not in cacheable_methods:
456 return None
458 vary_values = [
459 f"{header}:{value}"
460 for header in varying_headers
461 if (value := headers.get(header)) is not None
462 ]
464 vary_hash = ""
465 if vary_values:
466 hasher = _get_hasher()
467 hasher.update(_str_encode("|".join(vary_values)))
468 vary_hash = hasher.hexdigest()
470 hasher = _get_hasher()
471 hasher.update(_str_encode(str(url)))
472 url_hash = hasher.hexdigest()
474 return f"{config.app.name}:cached:{method}.{url_hash}.{vary_hash}"
477def generate_varying_headers_cache_key(url: URL) -> str:
478 hasher = _get_hasher()
479 hasher.update(_str_encode(str(url.path)))
480 url_hash = hasher.hexdigest()
481 return f"varying_headers.{url_hash}"
484def get_cache_response_headers(response: Response, *, max_age: int) -> dict[str, str]:
485 max_age = max(max_age, 0)
486 headers = {}
487 if "Expires" not in response.headers:
488 headers["Expires"] = email.utils.formatdate(time.time() + max_age, usegmt=True)
489 patch_cache_control(response.headers, max_age=max_age)
491 return headers
494def patch_cache_control(
495 headers: MutableHeaders,
496 **kwargs: t.Unpack[CacheDirectives],
497) -> None:
498 cache_control: dict[str, t.Any] = {}
499 value: t.Any
500 for field in parse_http_list(headers.get("Cache-Control", "")):
501 try:
502 key, value = field.split("=")
503 except ValueError:
504 cache_control[field] = True
505 else:
506 cache_control[key] = value
508 if "max-age" in cache_control and "max_age" in kwargs:
509 kwargs["max_age"] = min(int(cache_control["max-age"]), kwargs["max_age"])
511 if "public" in kwargs:
512 msg = "The 'public' cache control directive isn't supported yet."
513 raise NotImplementedError(
514 msg,
515 )
516 if "private" in kwargs:
517 msg = "The 'private' cache control directive isn't supported yet."
518 raise NotImplementedError(
519 msg,
520 )
522 for key, value in kwargs.items():
523 key = key.replace("_", "-")
524 cache_control[key] = value
526 directives: list[str] = []
527 for key, value in cache_control.items():
528 if value is False:
529 continue
530 if value is True:
531 directives.append(key)
532 else:
533 directives.append(f"{key}={value}")
535 patched_cache_control = ", ".join(directives)
536 if patched_cache_control:
537 headers["Cache-Control"] = patched_cache_control
538 else:
539 del headers["Cache-Control"]
542class CacheResponder:
543 def __init__(self, app: ASGIApp, *, rules: Sequence[Rule]) -> None:
544 self.app = app
545 self.rules = rules
546 try:
547 self.logger = depends.get("logger")
548 except Exception:
549 import logging
551 self.logger = logging.getLogger("fastblocks.cache")
552 try:
553 self.cache = depends.get("cache")
554 except Exception:
555 self.cache = None
556 self.initial_message: Message = {}
557 self.is_response_cacheable = True
558 self.request: Request | None = None
560 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
561 if scope["type"] != "http":
562 await self.app(scope, receive, send)
563 return
564 self.request = request = Request(scope)
565 try:
566 response = await get_from_cache(request, cache=self.cache, rules=self.rules)
567 except RequestNotCachable:
568 if request.method in invalidating_methods:
569 send = partial(self.send_then_invalidate, send=send)
570 else:
571 if response is not None:
572 _safe_log(self.logger, "debug", "cache_lookup HIT")
573 await response(scope, receive, send)
574 return
575 send = partial(self.send_with_caching, send=send)
576 _safe_log(self.logger, "debug", "cache_lookup MISS")
577 await self.app(scope, receive, send)
579 async def send_with_caching(self, message: Message, *, send: Send) -> None:
580 if not self.is_response_cacheable or message["type"] not in (
581 "http.response.start",
582 "http.response.body",
583 ):
584 await send(message)
585 return
586 if message["type"] == "http.response.start":
587 self.initial_message = message
588 return
589 if message["type"] != "http.response.body":
590 return
591 if message.get("more_body", False):
592 _safe_log(
593 self.logger,
594 "debug",
595 "response_not_cacheable reason=is_streaming",
596 )
597 self.is_response_cacheable = False
598 await send(self.initial_message)
599 await send(message)
600 return
601 if self.request is None:
602 return
603 body = message["body"]
604 response = Response(content=body, status_code=self.initial_message["status"])
605 response.raw_headers = list(self.initial_message["headers"])
606 try:
607 await set_in_cache(
608 response,
609 request=self.request,
610 cache=self.cache,
611 rules=self.rules,
612 )
613 except ResponseNotCachable:
614 self.is_response_cacheable = False
615 else:
616 self.initial_message["headers"] = list(response.raw_headers)
617 await send(self.initial_message)
618 await send(message)
620 async def send_then_invalidate(self, message: Message, *, send: Send) -> None:
621 if self.request is None:
622 return
623 if message["type"] == "http.response.start" and 200 <= message["status"] < 400:
624 await delete_from_cache(
625 self.request.url,
626 vary=self.request.headers,
627 cache=self.cache,
628 )
629 await send(message)
632class CacheControlResponder:
633 def __init__(self, app: ASGIApp, **kwargs: t.Unpack[CacheDirectives]) -> None:
634 self.app = app
635 self.kwargs = kwargs
636 try:
637 self.logger = depends.get("logger")
638 except Exception:
639 import logging
641 self.logger = logging.getLogger("fastblocks.cache")
643 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
644 if scope["type"] != "http":
645 await self.app(scope, receive, send)
646 return
647 send = partial(self.send_with_caching, send=send)
648 await self.app(scope, receive, send)
650 @staticmethod
651 def kvformat(**kwargs: t.Any) -> str:
652 return " ".join((f"{key}={value}" for key, value in kwargs.items()))
654 async def send_with_caching(self, message: Message, *, send: Send) -> None:
655 if message["type"] == "http.response.start":
656 _safe_log(
657 self.logger,
658 "debug",
659 f"patch_cache_control {self.kvformat(**self.kwargs)}",
660 )
661 headers = MutableHeaders(raw=list(message["headers"]))
662 patch_cache_control(headers, **self.kwargs)
663 message["headers"] = headers.raw
664 await send(message)