Coverage for fastblocks/caching.py: 67%

401 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-21 04:50 -0700

1import base64 

2import email.utils 

3import hashlib 

4import re 

5import sys 

6import time 

7import typing as t 

8from collections.abc import Iterable, Sequence 

9from dataclasses import dataclass 

10from functools import partial 

11from threading import local 

12from urllib.request import parse_http_list 

13 

14from starlette.datastructures import URL, Headers, MutableHeaders 

15from starlette.requests import Request 

16from starlette.responses import Response 

17 

18HashFunc = t.Callable[[t.Any], str] 

19GetAdapterFunc = t.Callable[[str], t.Any] 

20ImportAdapterFunc = t.Callable[[str | list[str] | None], t.Any] 

21from acb.adapters import get_adapter 

22from acb.depends import depends 

23from starlette.types import ASGIApp, Message, Receive, Scope, Send 

24 

25from .exceptions import RequestNotCachable, ResponseNotCachable 

26 

27 

28def _safe_log(logger: t.Any, level: str, message: str) -> None: 

29 return CacheUtils.safe_log(logger, level, message) 

30 

31 

32_CacheClass = None 

33 

34_hasher_pool = local() 

35 

36_str_encode = str.encode 

37_base64_encodebytes = base64.encodebytes 

38_base64_decodebytes = base64.decodebytes 

39 

40 

41def _get_hasher(): 

42 if not hasattr(_hasher_pool, "hasher"): 

43 _hasher_pool.hasher = hashlib.md5(usedforsecurity=False) 

44 else: 

45 _hasher_pool.hasher.__init__(usedforsecurity=False) 

46 return _hasher_pool.hasher 

47 

48 

49def get_cache() -> t.Any: 

50 global _CacheClass 

51 if _CacheClass is None: 

52 _CacheClass = get_adapter("cache") 

53 return _CacheClass 

54 

55 

56class CacheUtils: 

57 GET = sys.intern("GET") 

58 HEAD = sys.intern("HEAD") 

59 POST = sys.intern("POST") 

60 PUT = sys.intern("PUT") 

61 PATCH = sys.intern("PATCH") 

62 DELETE = sys.intern("DELETE") 

63 CACHE_CONTROL = sys.intern("Cache-Control") 

64 ETAG = sys.intern("ETag") 

65 LAST_MODIFIED = sys.intern("Last-Modified") 

66 VARY = sys.intern("Vary") 

67 

68 CACHEABLE_METHODS = frozenset((GET, HEAD)) 

69 CACHEABLE_STATUS_CODES = frozenset( 

70 (200, 203, 204, 206, 300, 301, 404, 405, 410, 414, 501), 

71 ) 

72 ONE_YEAR = 60 * 60 * 24 * 365 

73 INVALIDATING_METHODS = frozenset((POST, PUT, PATCH, DELETE)) 

74 

75 @staticmethod 

76 def safe_log(logger: t.Any, level: str, message: str) -> None: 

77 if logger and hasattr(logger, level): 

78 getattr(logger, level)(message) 

79 

80 

81cacheable_methods = CacheUtils.CACHEABLE_METHODS 

82cacheable_status_codes = CacheUtils.CACHEABLE_STATUS_CODES 

83one_year = CacheUtils.ONE_YEAR 

84invalidating_methods = CacheUtils.INVALIDATING_METHODS 

85 

86 

87@dataclass 

88class Rule: 

89 match: str | re.Pattern[str] | Iterable[str | re.Pattern[str]] = "*" 

90 status: int | Iterable[int] | None = None 

91 ttl: float | None = None 

92 

93 

94class CacheRules: 

95 @staticmethod 

96 def request_matches_rule(rule: Rule, *, request: Request) -> bool: 

97 match = ( 

98 [rule.match] 

99 if isinstance(rule.match, str | re.Pattern) 

100 else list(rule.match) 

101 ) 

102 for item in match: 

103 if isinstance(item, re.Pattern): 

104 if item.match(request.url.path): 

105 return True 

106 elif item in ("*", request.url.path): 

107 return True 

108 return False 

109 

110 @staticmethod 

111 def response_matches_rule( 

112 rule: Rule, 

113 *, 

114 request: Request, 

115 response: Response, 

116 ) -> bool: 

117 if not CacheRules.request_matches_rule(rule, request=request): 

118 return False 

119 if rule.status is not None: 

120 statuses = [rule.status] if isinstance(rule.status, int) else rule.status 

121 if response.status_code not in statuses: 

122 return False 

123 return True 

124 

125 @staticmethod 

126 def get_rule_matching_request( 

127 rules: Sequence[Rule], 

128 *, 

129 request: Request, 

130 ) -> Rule | None: 

131 return next( 

132 ( 

133 rule 

134 for rule in rules 

135 if CacheRules.request_matches_rule(rule, request=request) 

136 ), 

137 None, 

138 ) 

139 

140 @staticmethod 

141 def get_rule_matching_response( 

142 rules: Sequence[Rule], 

143 *, 

144 request: Request, 

145 response: Response, 

146 ) -> Rule | None: 

147 return next( 

148 ( 

149 rule 

150 for rule in rules 

151 if CacheRules.response_matches_rule( 

152 rule, 

153 request=request, 

154 response=response, 

155 ) 

156 ), 

157 None, 

158 ) 

159 

160 

161def request_matches_rule(rule: Rule, *, request: Request) -> bool: 

162 return CacheRules.request_matches_rule(rule, request=request) 

163 

164 

165def response_matches_rule(rule: Rule, *, request: Request, response: Response) -> bool: 

166 return CacheRules.response_matches_rule(rule, request=request, response=response) 

167 

168 

169def get_rule_matching_request( 

170 rules: Sequence[Rule], 

171 *, 

172 request: Request, 

173) -> Rule | None: 

174 return CacheRules.get_rule_matching_request(rules, request=request) 

175 

176 

177def get_rule_matching_response( 

178 rules: Sequence[Rule], 

179 *, 

180 request: Request, 

181 response: Response, 

182) -> Rule | None: 

183 return CacheRules.get_rule_matching_response( 

184 rules, 

185 request=request, 

186 response=response, 

187 ) 

188 

189 

190class CacheDirectives(t.TypedDict, total=False): 

191 max_age: int 

192 s_maxage: int 

193 no_cache: bool 

194 no_store: bool 

195 no_transform: bool 

196 must_revalidate: bool 

197 proxy_revalidate: bool 

198 must_understand: bool 

199 private: bool 

200 public: bool 

201 immutable: bool 

202 stale_while_revalidate: int 

203 stale_if_error: int 

204 

205 

206async def set_in_cache( 

207 response: Response, 

208 *, 

209 request: Request, 

210 rules: Sequence[Rule], 

211 cache: t.Any = None, 

212 logger: t.Any = None, 

213) -> None: 

214 if cache is None or logger is None: 

215 if cache is None: 

216 cache = depends.get("cache") 

217 if logger is None: 

218 logger = depends.get("logger") 

219 if response.status_code not in cacheable_status_codes: 

220 _safe_log(logger, "debug", "response_not_cacheable reason=status_code") 

221 raise ResponseNotCachable(response) 

222 if not request.cookies and "Set-Cookie" in response.headers: 

223 _safe_log( 

224 logger, 

225 "debug", 

226 "response_not_cacheable reason=cookies_for_cookieless_request", 

227 ) 

228 raise ResponseNotCachable(response) 

229 rule = get_rule_matching_response(rules, request=request, response=response) 

230 if not rule: 

231 _safe_log(logger, "debug", "response_not_cacheable reason=rule") 

232 raise ResponseNotCachable(response) 

233 ttl = rule.ttl if rule.ttl is not None else cache.ttl 

234 if ttl == 0: 

235 _safe_log(logger, "debug", "response_not_cacheable reason=zero_ttl") 

236 raise ResponseNotCachable(response) 

237 if ttl is None: 

238 max_age = one_year 

239 _safe_log(logger, "debug", f"max_out_ttl value={max_age!r}") 

240 else: 

241 max_age = int(ttl) 

242 _safe_log(logger, "debug", f"set_in_cache max_age={max_age!r}") 

243 response.headers["X-Cache"] = "hit" 

244 cache_headers = get_cache_response_headers(response, max_age=max_age) 

245 _safe_log(logger, "debug", f"patch_response_headers headers={cache_headers!r}") 

246 response.headers.update(cache_headers) 

247 cache_key = await learn_cache_key(request, response, cache=cache) 

248 _safe_log(logger, "debug", f"learnt_cache_key cache_key={cache_key!r}") 

249 serialized_response = serialize_response(response) 

250 _safe_log( 

251 logger, 

252 "debug", 

253 f"set_response_in_cache key={cache_key!r} value={serialized_response!r}", 

254 ) 

255 kwargs = {} 

256 if ttl is not None: 

257 kwargs["ttl"] = ttl 

258 await cache.set(key=cache_key, value=serialized_response, **kwargs) 

259 response.headers["X-Cache"] = "miss" 

260 

261 

262async def get_from_cache( 

263 request: Request, 

264 *, 

265 rules: Sequence[Rule], 

266 cache: t.Any = None, 

267 logger: t.Any = None, 

268) -> Response | None: 

269 if cache is None or logger is None: 

270 if cache is None: 

271 cache = depends.get("cache") 

272 if logger is None: 

273 logger = depends.get("logger") 

274 _safe_log( 

275 logger, 

276 "debug", 

277 f"get_from_cache request.url={str(request.url)!r} request.method={request.method!r}", 

278 ) 

279 if request.method not in cacheable_methods: 

280 _safe_log(logger, "debug", "request_not_cacheable reason=method") 

281 raise RequestNotCachable(request) 

282 rule = get_rule_matching_request(rules, request=request) 

283 if rule is None: 

284 _safe_log(logger, "debug", "request_not_cacheable reason=rule") 

285 raise RequestNotCachable(request) 

286 _safe_log(logger, "debug", "lookup_cached_response method='GET'") 

287 cache_key = await get_cache_key(request, method="GET", cache=cache) 

288 if cache_key is None: 

289 _safe_log(logger, "debug", "cache_key found=False") 

290 return None 

291 _safe_log(logger, "debug", f"cache_key found=True cache_key={cache_key!r}") 

292 serialized_response = await cache.get(cache_key) 

293 if serialized_response is None: 

294 _safe_log(logger, "debug", "lookup_cached_response method='HEAD'") 

295 cache_key = await get_cache_key(request, method="HEAD", cache=cache) 

296 if cache_key is None: 

297 return None 

298 _safe_log(logger, "debug", f"cache_key found=True cache_key={cache_key!r}") 

299 serialized_response = await cache.get(cache_key) 

300 if serialized_response is None: 

301 _safe_log(logger, "debug", "cached_response found=False") 

302 return None 

303 _safe_log( 

304 logger, 

305 "debug", 

306 f"cached_response found=True key={cache_key!r} value={serialized_response!r}", 

307 ) 

308 return deserialize_response(serialized_response) 

309 

310 

311async def delete_from_cache( 

312 url: URL, 

313 *, 

314 vary: Headers, 

315 cache: t.Any = None, 

316 logger: t.Any = None, 

317) -> None: 

318 if cache is None or logger is None: 

319 if cache is None: 

320 cache = depends.get("cache") 

321 if logger is None: 

322 logger = depends.get("logger") 

323 varying_headers_cache_key = generate_varying_headers_cache_key(url) 

324 varying_headers = await cache.get(varying_headers_cache_key) 

325 if varying_headers is None: 

326 return 

327 for method in ("GET", "HEAD"): 

328 cache_key = generate_cache_key( 

329 url, 

330 method=method, 

331 headers=vary, 

332 varying_headers=varying_headers, 

333 ) 

334 logger.debug(f"clear_cache key={cache_key!r}") 

335 await cache.delete(cache_key) 

336 await cache.delete(varying_headers_cache_key) 

337 

338 

339def serialize_response(response: Response) -> dict[str, t.Any]: 

340 return { 

341 "content": _base64_encodebytes(response.body).decode("ascii"), 

342 "status_code": response.status_code, 

343 "headers": dict(response.headers), 

344 } 

345 

346 

347def deserialize_response(serialized_response: t.Any) -> Response: 

348 if not isinstance(serialized_response, dict): 

349 msg = f"Expected dict, got {type(serialized_response)}" 

350 raise TypeError(msg) 

351 content = serialized_response.get("content") 

352 if not isinstance(content, str): 

353 msg = f"Expected content to be str, got {type(content)}" 

354 raise TypeError(msg) 

355 status_code = serialized_response.get("status_code") 

356 if not isinstance(status_code, int): 

357 msg = f"Expected status_code to be int, got {type(status_code)}" 

358 raise TypeError(msg) 

359 headers = serialized_response.get("headers") 

360 if not isinstance(headers, dict): 

361 msg = f"Expected headers to be dict, got {type(headers)}" 

362 raise TypeError(msg) 

363 return Response( 

364 content=_base64_decodebytes(_str_encode(content, "ascii")), 

365 status_code=status_code, 

366 headers=headers, 

367 ) 

368 

369 

370async def learn_cache_key( 

371 request: Request, 

372 response: Response, 

373 *, 

374 cache: t.Any = None, 

375 logger: t.Any = None, 

376) -> str: 

377 if cache is None or logger is None: 

378 if cache is None: 

379 cache = depends.get("cache") 

380 if logger is None: 

381 logger = depends.get("logger") 

382 logger.debug( 

383 f"learn_cache_key request.method={request.method!r} response.headers.Vary={response.headers.get('Vary')!r}", 

384 ) 

385 url = request.url 

386 varying_headers_cache_key = generate_varying_headers_cache_key(url) 

387 cached_vary_headers = set(await cache.get(key=varying_headers_cache_key) or ()) 

388 response_vary_headers = { 

389 header.lower() for header in parse_http_list(response.headers.get("Vary", "")) 

390 } 

391 varying_headers = sorted(response_vary_headers | cached_vary_headers) 

392 if varying_headers: 

393 response.headers["Vary"] = ", ".join(varying_headers) 

394 logger.debug( 

395 f"store_varying_headers cache_key={varying_headers_cache_key!r} headers={varying_headers!r}", 

396 ) 

397 await cache.set(key=varying_headers_cache_key, value=varying_headers) 

398 cache_key = generate_cache_key( 

399 url, 

400 method=request.method, 

401 headers=request.headers, 

402 varying_headers=varying_headers, 

403 ) 

404 if cache_key is None: 

405 msg = f"Unable to generate cache key for method {request.method}" 

406 raise ValueError(msg) 

407 return cache_key 

408 

409 

410async def get_cache_key( 

411 request: Request, 

412 method: str, 

413 cache: t.Any = None, 

414 logger: t.Any = None, 

415) -> str | None: 

416 if cache is None or logger is None: 

417 if cache is None: 

418 cache = depends.get("cache") 

419 if logger is None: 

420 logger = depends.get("logger") 

421 url = request.url 

422 _safe_log( 

423 logger, 

424 "debug", 

425 f"get_cache_key request.url={str(url)!r} method={method!r}", 

426 ) 

427 varying_headers_cache_key = generate_varying_headers_cache_key(url) 

428 varying_headers = await cache.get(varying_headers_cache_key) 

429 if varying_headers is None: 

430 _safe_log(logger, "debug", "varying_headers found=False") 

431 return None 

432 _safe_log( 

433 logger, 

434 "debug", 

435 f"varying_headers found=True headers={varying_headers!r}", 

436 ) 

437 return generate_cache_key( 

438 request.url, 

439 method=method, 

440 headers=request.headers, 

441 varying_headers=varying_headers, 

442 ) 

443 

444 

445def generate_cache_key( 

446 url: URL, 

447 method: str, 

448 headers: Headers, 

449 varying_headers: list[str], 

450 config: t.Any = None, 

451) -> str | None: 

452 if config is None: 

453 config = depends.get("config") 

454 

455 if method not in cacheable_methods: 

456 return None 

457 

458 vary_values = [ 

459 f"{header}:{value}" 

460 for header in varying_headers 

461 if (value := headers.get(header)) is not None 

462 ] 

463 

464 vary_hash = "" 

465 if vary_values: 

466 hasher = _get_hasher() 

467 hasher.update(_str_encode("|".join(vary_values))) 

468 vary_hash = hasher.hexdigest() 

469 

470 hasher = _get_hasher() 

471 hasher.update(_str_encode(str(url))) 

472 url_hash = hasher.hexdigest() 

473 

474 return f"{config.app.name}:cached:{method}.{url_hash}.{vary_hash}" 

475 

476 

477def generate_varying_headers_cache_key(url: URL) -> str: 

478 hasher = _get_hasher() 

479 hasher.update(_str_encode(str(url.path))) 

480 url_hash = hasher.hexdigest() 

481 return f"varying_headers.{url_hash}" 

482 

483 

484def get_cache_response_headers(response: Response, *, max_age: int) -> dict[str, str]: 

485 max_age = max(max_age, 0) 

486 headers = {} 

487 if "Expires" not in response.headers: 

488 headers["Expires"] = email.utils.formatdate(time.time() + max_age, usegmt=True) 

489 patch_cache_control(response.headers, max_age=max_age) 

490 

491 return headers 

492 

493 

494def patch_cache_control( 

495 headers: MutableHeaders, 

496 **kwargs: t.Unpack[CacheDirectives], 

497) -> None: 

498 cache_control: dict[str, t.Any] = {} 

499 value: t.Any 

500 for field in parse_http_list(headers.get("Cache-Control", "")): 

501 try: 

502 key, value = field.split("=") 

503 except ValueError: 

504 cache_control[field] = True 

505 else: 

506 cache_control[key] = value 

507 

508 if "max-age" in cache_control and "max_age" in kwargs: 

509 kwargs["max_age"] = min(int(cache_control["max-age"]), kwargs["max_age"]) 

510 

511 if "public" in kwargs: 

512 msg = "The 'public' cache control directive isn't supported yet." 

513 raise NotImplementedError( 

514 msg, 

515 ) 

516 if "private" in kwargs: 

517 msg = "The 'private' cache control directive isn't supported yet." 

518 raise NotImplementedError( 

519 msg, 

520 ) 

521 

522 for key, value in kwargs.items(): 

523 key = key.replace("_", "-") 

524 cache_control[key] = value 

525 

526 directives: list[str] = [] 

527 for key, value in cache_control.items(): 

528 if value is False: 

529 continue 

530 if value is True: 

531 directives.append(key) 

532 else: 

533 directives.append(f"{key}={value}") 

534 

535 patched_cache_control = ", ".join(directives) 

536 if patched_cache_control: 

537 headers["Cache-Control"] = patched_cache_control 

538 else: 

539 del headers["Cache-Control"] 

540 

541 

542class CacheResponder: 

543 def __init__(self, app: ASGIApp, *, rules: Sequence[Rule]) -> None: 

544 self.app = app 

545 self.rules = rules 

546 try: 

547 self.logger = depends.get("logger") 

548 except Exception: 

549 import logging 

550 

551 self.logger = logging.getLogger("fastblocks.cache") 

552 try: 

553 self.cache = depends.get("cache") 

554 except Exception: 

555 self.cache = None 

556 self.initial_message: Message = {} 

557 self.is_response_cacheable = True 

558 self.request: Request | None = None 

559 

560 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

561 if scope["type"] != "http": 

562 await self.app(scope, receive, send) 

563 return 

564 self.request = request = Request(scope) 

565 try: 

566 response = await get_from_cache(request, cache=self.cache, rules=self.rules) 

567 except RequestNotCachable: 

568 if request.method in invalidating_methods: 

569 send = partial(self.send_then_invalidate, send=send) 

570 else: 

571 if response is not None: 

572 _safe_log(self.logger, "debug", "cache_lookup HIT") 

573 await response(scope, receive, send) 

574 return 

575 send = partial(self.send_with_caching, send=send) 

576 _safe_log(self.logger, "debug", "cache_lookup MISS") 

577 await self.app(scope, receive, send) 

578 

579 async def send_with_caching(self, message: Message, *, send: Send) -> None: 

580 if not self.is_response_cacheable or message["type"] not in ( 

581 "http.response.start", 

582 "http.response.body", 

583 ): 

584 await send(message) 

585 return 

586 if message["type"] == "http.response.start": 

587 self.initial_message = message 

588 return 

589 if message["type"] != "http.response.body": 

590 return 

591 if message.get("more_body", False): 

592 _safe_log( 

593 self.logger, 

594 "debug", 

595 "response_not_cacheable reason=is_streaming", 

596 ) 

597 self.is_response_cacheable = False 

598 await send(self.initial_message) 

599 await send(message) 

600 return 

601 if self.request is None: 

602 return 

603 body = message["body"] 

604 response = Response(content=body, status_code=self.initial_message["status"]) 

605 response.raw_headers = list(self.initial_message["headers"]) 

606 try: 

607 await set_in_cache( 

608 response, 

609 request=self.request, 

610 cache=self.cache, 

611 rules=self.rules, 

612 ) 

613 except ResponseNotCachable: 

614 self.is_response_cacheable = False 

615 else: 

616 self.initial_message["headers"] = list(response.raw_headers) 

617 await send(self.initial_message) 

618 await send(message) 

619 

620 async def send_then_invalidate(self, message: Message, *, send: Send) -> None: 

621 if self.request is None: 

622 return 

623 if message["type"] == "http.response.start" and 200 <= message["status"] < 400: 

624 await delete_from_cache( 

625 self.request.url, 

626 vary=self.request.headers, 

627 cache=self.cache, 

628 ) 

629 await send(message) 

630 

631 

632class CacheControlResponder: 

633 def __init__(self, app: ASGIApp, **kwargs: t.Unpack[CacheDirectives]) -> None: 

634 self.app = app 

635 self.kwargs = kwargs 

636 try: 

637 self.logger = depends.get("logger") 

638 except Exception: 

639 import logging 

640 

641 self.logger = logging.getLogger("fastblocks.cache") 

642 

643 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

644 if scope["type"] != "http": 

645 await self.app(scope, receive, send) 

646 return 

647 send = partial(self.send_with_caching, send=send) 

648 await self.app(scope, receive, send) 

649 

650 @staticmethod 

651 def kvformat(**kwargs: t.Any) -> str: 

652 return " ".join((f"{key}={value}" for key, value in kwargs.items())) 

653 

654 async def send_with_caching(self, message: Message, *, send: Send) -> None: 

655 if message["type"] == "http.response.start": 

656 _safe_log( 

657 self.logger, 

658 "debug", 

659 f"patch_cache_control {self.kvformat(**self.kwargs)}", 

660 ) 

661 headers = MutableHeaders(raw=list(message["headers"])) 

662 patch_cache_control(headers, **self.kwargs) 

663 message["headers"] = headers.raw 

664 await send(message)