Coverage for fastblocks/caching.py: 63%

432 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-09 00:47 -0700

1import asyncio 

2import base64 

3import email.utils 

4import hashlib 

5import re 

6import sys 

7import time 

8import typing as t 

9from collections.abc import Iterable, Sequence 

10from contextlib import suppress 

11from dataclasses import dataclass 

12from functools import partial 

13from threading import local 

14from urllib.request import parse_http_list 

15 

16from starlette.datastructures import URL, Headers, MutableHeaders 

17from starlette.requests import Request 

18from starlette.responses import Response 

19 

20HashFunc = t.Callable[[t.Any], str] 

21GetAdapterFunc = t.Callable[[str], t.Any] 

22ImportAdapterFunc = t.Callable[[str | list[str] | None], t.Any] 

23from acb.adapters import get_adapter 

24from acb.depends import depends 

25from starlette.types import ASGIApp, Message, Receive, Scope, Send 

26 

27from .exceptions import RequestNotCachable, ResponseNotCachable 

28 

29 

30def _safe_log(logger: t.Any, level: str, message: str) -> None: 

31 return CacheUtils.safe_log(logger, level, message) 

32 

33 

34_CacheClass = None 

35 

36_hasher_pool: local = local() 

37 

38_str_encode = str.encode 

39_base64_encodebytes = base64.encodebytes 

40_base64_decodebytes = base64.decodebytes 

41 

42 

43def _get_hasher() -> t.Any: 

44 if not hasattr(_hasher_pool, "hasher"): 

45 _hasher_pool.hasher = hashlib.md5(usedforsecurity=False) 

46 else: 

47 _hasher_pool.hasher.__init__(usedforsecurity=False) 

48 return _hasher_pool.hasher 

49 

50 

51def get_cache() -> t.Any: 

52 global _CacheClass 

53 if _CacheClass is None: 

54 _CacheClass = get_adapter("cache") 

55 return _CacheClass 

56 

57 

58class CacheUtils: 

59 GET = sys.intern("GET") 

60 HEAD = sys.intern("HEAD") 

61 POST = sys.intern("POST") 

62 PUT = sys.intern("PUT") 

63 PATCH = sys.intern("PATCH") 

64 DELETE = sys.intern("DELETE") 

65 CACHE_CONTROL = sys.intern("Cache-Control") 

66 ETAG = sys.intern("ETag") 

67 LAST_MODIFIED = sys.intern("Last-Modified") 

68 VARY = sys.intern("Vary") 

69 

70 CACHEABLE_METHODS = frozenset((GET, HEAD)) 

71 CACHEABLE_STATUS_CODES = frozenset( 

72 (200, 203, 204, 206, 300, 301, 404, 405, 410, 414, 501), 

73 ) 

74 ONE_YEAR = 60 * 60 * 24 * 365 

75 INVALIDATING_METHODS = frozenset((POST, PUT, PATCH, DELETE)) 

76 

77 @staticmethod 

78 def safe_log(logger: t.Any, level: str, message: str) -> None: 

79 if logger and hasattr(logger, level): 

80 getattr(logger, level)(message) 

81 

82 

83cacheable_methods = CacheUtils.CACHEABLE_METHODS 

84cacheable_status_codes = CacheUtils.CACHEABLE_STATUS_CODES 

85one_year = CacheUtils.ONE_YEAR 

86invalidating_methods = CacheUtils.INVALIDATING_METHODS 

87 

88 

89@dataclass 

90class Rule: 

91 match: str | re.Pattern[str] | Iterable[str | re.Pattern[str]] = "*" 

92 status: int | Iterable[int] | None = None 

93 ttl: float | None = None 

94 

95 

96class CacheRules: 

97 @staticmethod 

98 def request_matches_rule(rule: Rule, *, request: Request) -> bool: 

99 match = ( 

100 [rule.match] 

101 if isinstance(rule.match, str | re.Pattern) 

102 else list(rule.match) 

103 ) 

104 return _check_rule_match(match, request.url.path) 

105 

106 

107def _check_rule_match(match: list[str | re.Pattern[str]], path: str) -> bool: 

108 """Check if any rule matches the request path.""" 

109 for item in match: 

110 if isinstance(item, re.Pattern): 

111 if item.match(path): 

112 return True 

113 elif item in ("*", path): 

114 return True 

115 return False 

116 

117 @staticmethod 

118 def response_matches_rule( 

119 rule: Rule, 

120 *, 

121 request: Request, 

122 response: Response, 

123 ) -> bool: 

124 # First check if request matches the rule 

125 if not CacheRules.request_matches_rule(rule, request=request): 

126 return False 

127 # Then check if response status matches 

128 return _check_response_status_match(rule, response) 

129 

130 

131def _check_response_status_match(rule: Rule, response: Response) -> bool: 

132 """Check if response status code matches the rule.""" 

133 if rule.status is not None: 

134 statuses = [rule.status] if isinstance(rule.status, int) else rule.status 

135 if response.status_code not in statuses: 

136 return False 

137 return True 

138 

139 @staticmethod 

140 def get_rule_matching_request( 

141 rules: Sequence[Rule], 

142 *, 

143 request: Request, 

144 ) -> Rule | None: 

145 return next( 

146 ( 

147 rule 

148 for rule in rules 

149 if CacheRules.request_matches_rule(rule, request=request) 

150 ), 

151 None, 

152 ) 

153 

154 @staticmethod 

155 def get_rule_matching_response( 

156 rules: Sequence[Rule], 

157 *, 

158 request: Request, 

159 response: Response, 

160 ) -> Rule | None: 

161 return next( 

162 ( 

163 rule 

164 for rule in rules 

165 if CacheRules.response_matches_rule( 

166 rule, 

167 request=request, 

168 response=response, 

169 ) 

170 ), 

171 None, 

172 ) 

173 

174 

175def get_rule_matching_request( 

176 rules: Sequence[Rule], 

177 *, 

178 request: Request, 

179) -> Rule | None: 

180 method = getattr(CacheRules, "get_rule_matching_request") 

181 result = method(rules, request=request) 

182 return t.cast(Rule | None, result) 

183 

184 

185def get_rule_matching_response( 

186 rules: Sequence[Rule], 

187 *, 

188 request: Request, 

189 response: Response, 

190) -> Rule | None: 

191 method = getattr(CacheRules, "get_rule_matching_response") 

192 result = method(rules, request=request, response=response) 

193 return t.cast(Rule | None, result) 

194 

195 

196def request_matches_rule(rule: Rule, *, request: Request) -> bool: 

197 method = getattr(CacheRules, "request_matches_rule") 

198 result = method(rule, request=request) 

199 return t.cast(bool, result) 

200 

201 

202def response_matches_rule(rule: Rule, *, request: Request, response: Response) -> bool: 

203 method = getattr(CacheRules, "response_matches_rule") 

204 result = method(rule, request=request, response=response) 

205 return t.cast(bool, result) 

206 

207 

208class CacheDirectives(t.TypedDict, total=False): 

209 max_age: int 

210 s_maxage: int 

211 no_cache: bool 

212 no_store: bool 

213 no_transform: bool 

214 must_revalidate: bool 

215 proxy_revalidate: bool 

216 must_understand: bool 

217 private: bool 

218 public: bool 

219 immutable: bool 

220 stale_while_revalidate: int 

221 stale_if_error: int 

222 

223 

224async def set_in_cache( 

225 response: Response, 

226 *, 

227 request: Request, 

228 rules: Sequence[Rule], 

229 cache: t.Any = None, 

230 logger: t.Any = None, 

231) -> None: 

232 # Initialize dependencies if not provided 

233 cache, logger = _init_cache_dependencies(cache, logger) 

234 

235 # Validate response can be cached 

236 _validate_response_cacheable(response, request, logger) 

237 

238 # Find matching rule for caching 

239 rule = get_rule_matching_response(rules, request=request, response=response) 

240 if not rule: 

241 _safe_log(logger, "debug", "response_not_cacheable reason=rule") 

242 raise ResponseNotCachable(response) 

243 

244 # Calculate TTL and max age 

245 ttl, max_age = _calculate_cache_ttl(rule, cache, logger) 

246 

247 # Set cache headers 

248 _set_cache_headers(response, max_age, logger) 

249 

250 # Generate cache key and serialize response 

251 cache_key = await learn_cache_key(request, response, cache=cache) 

252 serialized_response = serialize_response(response) 

253 

254 # Store in cache 

255 await _store_in_cache(cache, cache_key, serialized_response, ttl, logger) 

256 

257 # Update response header 

258 response.headers["X-Cache"] = "miss" 

259 

260 

261def _init_cache_dependencies(cache: t.Any, logger: t.Any) -> tuple[t.Any, t.Any]: 

262 """Initialize cache and logger dependencies.""" 

263 if cache is None: 

264 cache = depends.get("cache") 

265 if logger is None: 

266 logger = depends.get("logger") 

267 return cache, logger 

268 

269 

270def _validate_response_cacheable( 

271 response: Response, request: Request, logger: t.Any 

272) -> None: 

273 """Validate that a response can be cached.""" 

274 if response.status_code not in cacheable_status_codes: 

275 _safe_log(logger, "debug", "response_not_cacheable reason=status_code") 

276 raise ResponseNotCachable(response) 

277 if not request.cookies and "Set-Cookie" in response.headers: 

278 _safe_log( 

279 logger, 

280 "debug", 

281 "response_not_cacheable reason=cookies_for_cookieless_request", 

282 ) 

283 raise ResponseNotCachable(response) 

284 

285 

286def _calculate_cache_ttl(rule: Rule, cache: t.Any, logger: t.Any) -> tuple[t.Any, int]: 

287 """Calculate TTL and max age for caching.""" 

288 ttl = rule.ttl if rule.ttl is not None else cache.ttl 

289 if ttl == 0: 

290 _safe_log(logger, "debug", "response_not_cacheable reason=zero_ttl") 

291 # Create a minimal response for the exception 

292 raise ResponseNotCachable(Response(content=b"", status_code=200)) 

293 

294 if ttl is None: 

295 max_age = one_year 

296 _safe_log(logger, "debug", f"max_out_ttl value={max_age!r}") 

297 else: 

298 max_age = int(ttl) 

299 _safe_log(logger, "debug", f"set_in_cache max_age={max_age!r}") 

300 return ttl, max_age 

301 

302 

303def _set_cache_headers(response: Response, max_age: int, logger: t.Any) -> None: 

304 """Set cache headers on the response.""" 

305 response.headers["X-Cache"] = "hit" 

306 cache_headers = get_cache_response_headers(response, max_age=max_age) 

307 _safe_log(logger, "debug", f"patch_response_headers headers={cache_headers!r}") 

308 response.headers.update(cache_headers) 

309 

310 

311async def _store_in_cache( 

312 cache: t.Any, 

313 cache_key: str, 

314 serialized_response: dict[str, t.Any], 

315 ttl: t.Any, 

316 logger: t.Any, 

317) -> None: 

318 """Store serialized response in cache.""" 

319 _safe_log( 

320 logger, 

321 "debug", 

322 f"set_response_in_cache key={cache_key!r} value={serialized_response!r}", 

323 ) 

324 kwargs = {} 

325 if ttl is not None: 

326 kwargs["ttl"] = ttl 

327 await cache.set(key=cache_key, value=serialized_response, **kwargs) 

328 

329 

330async def get_from_cache( 

331 request: Request, 

332 *, 

333 rules: Sequence[Rule], 

334 cache: t.Any = None, 

335 logger: t.Any = None, 

336) -> Response | None: 

337 # Initialize dependencies if not provided 

338 cache, logger = _init_cache_dependencies(cache, logger) 

339 

340 # Log request details 

341 _safe_log( 

342 logger, 

343 "debug", 

344 f"get_from_cache request.url={str(request.url)!r} request.method={request.method!r}", 

345 ) 

346 

347 # Validate request can use cache 

348 _validate_request_cacheable(request, logger) 

349 

350 # Find matching rule 

351 rule = getattr(CacheRules, "get_rule_matching_request")(rules, request=request) 

352 if rule is None: 

353 _safe_log(logger, "debug", "request_not_cacheable reason=rule") 

354 raise RequestNotCachable(request) 

355 

356 # Try to get cached response 

357 return await _try_get_cached_response(request, cache, logger) 

358 

359 

360def _validate_request_cacheable(request: Request, logger: t.Any) -> None: 

361 """Validate that a request can use the cache.""" 

362 if request.method not in cacheable_methods: 

363 _safe_log(logger, "debug", "request_not_cacheable reason=method") 

364 raise RequestNotCachable(request) 

365 

366 

367async def _try_get_cached_response( 

368 request: Request, cache: t.Any, logger: t.Any 

369) -> Response | None: 

370 """Try to get a cached response for the request.""" 

371 # Try GET method first 

372 _safe_log(logger, "debug", "lookup_cached_response method='GET'") 

373 cache_key = await get_cache_key(request, method="GET", cache=cache) 

374 if cache_key is not None: 

375 serialized_response = await cache.get(cache_key) 

376 if serialized_response is not None: 

377 return _return_cached_response(cache_key, serialized_response, logger) 

378 

379 # Try HEAD method 

380 _safe_log(logger, "debug", "lookup_cached_response method='HEAD'") 

381 cache_key = await get_cache_key(request, method="HEAD", cache=cache) 

382 if cache_key is not None: 

383 serialized_response = await cache.get(cache_key) 

384 if serialized_response is not None: 

385 return _return_cached_response(cache_key, serialized_response, logger) 

386 

387 # No cached response found 

388 _safe_log(logger, "debug", "cached_response found=False") 

389 return None 

390 

391 

392def _return_cached_response( 

393 cache_key: str, serialized_response: t.Any, logger: t.Any 

394) -> Response: 

395 """Return a cached response after logging.""" 

396 _safe_log( 

397 logger, 

398 "debug", 

399 f"cached_response found=True key={cache_key!r} value={serialized_response!r}", 

400 ) 

401 return deserialize_response(serialized_response) 

402 

403 

404async def delete_from_cache( 

405 url: URL, 

406 *, 

407 vary: Headers, 

408 cache: t.Any = None, 

409 logger: t.Any = None, 

410) -> None: 

411 if cache is None or logger is None: 

412 if cache is None: 

413 cache = depends.get("cache") 

414 if logger is None: 

415 logger = depends.get("logger") 

416 

417 varying_headers_cache_key = generate_varying_headers_cache_key(url) 

418 varying_headers = await cache.get(varying_headers_cache_key) 

419 if varying_headers is None: 

420 return 

421 

422 await _delete_cache_entries(url, vary, cache, logger, varying_headers) 

423 await cache.delete(varying_headers_cache_key) 

424 

425 

426async def _delete_cache_entries( 

427 url: URL, 

428 vary: Headers, 

429 cache: t.Any, 

430 logger: t.Any, 

431 varying_headers: t.Any, 

432) -> None: 

433 """Delete cache entries for GET and HEAD methods.""" 

434 for method in ("GET", "HEAD"): 

435 cache_key = generate_cache_key( 

436 url, 

437 method=method, 

438 headers=vary, 

439 varying_headers=varying_headers, 

440 ) 

441 if cache_key is None: 

442 continue 

443 

444 logger.debug(f"clear_cache key={cache_key!r}") 

445 await cache.delete(cache_key) 

446 

447 # Publish cache invalidation event (async, don't block) 

448 with suppress(Exception): 

449 

450 async def _publish_event() -> None: 

451 from .adapters.templates._events_wrapper import ( 

452 publish_cache_invalidation, 

453 ) 

454 

455 await publish_cache_invalidation( 

456 cache_key=cache_key, 

457 reason="url_invalidation", 

458 invalidated_by="cache_middleware", 

459 affected_templates=None, 

460 ) 

461 

462 asyncio.create_task(_publish_event()) 

463 

464 

465def serialize_response(response: Response) -> dict[str, t.Any]: 

466 """Serialize a response for caching.""" 

467 return { 

468 "content": _base64_encodebytes(response.body).decode("ascii"), 

469 "status_code": response.status_code, 

470 "headers": dict(response.headers), 

471 } 

472 

473 

474def deserialize_response(serialized_response: t.Any) -> Response: 

475 """Deserialize a cached response.""" 

476 _validate_serialized_response(serialized_response) 

477 

478 content = serialized_response["content"] 

479 status_code = serialized_response["status_code"] 

480 headers = serialized_response["headers"] 

481 

482 return Response( 

483 content=_base64_decodebytes(_str_encode(content, "ascii")), 

484 status_code=status_code, 

485 headers=headers, 

486 ) 

487 

488 

489def _validate_serialized_response(serialized_response: t.Any) -> None: 

490 """Validate the structure of a serialized response.""" 

491 if not isinstance(serialized_response, dict): 

492 msg = f"Expected dict, got {type(serialized_response)}" 

493 raise TypeError(msg) 

494 content = serialized_response.get("content") 

495 if not isinstance(content, str): 

496 msg = f"Expected content to be str, got {type(content)}" 

497 raise TypeError(msg) 

498 status_code = serialized_response.get("status_code") 

499 if not isinstance(status_code, int): 

500 msg = f"Expected status_code to be int, got {type(status_code)}" 

501 raise TypeError(msg) 

502 headers = serialized_response.get("headers") 

503 if not isinstance(headers, dict): 

504 msg = f"Expected headers to be dict, got {type(headers)}" 

505 raise TypeError(msg) 

506 

507 

508async def learn_cache_key( 

509 request: Request, 

510 response: Response, 

511 *, 

512 cache: t.Any = None, 

513 logger: t.Any = None, 

514) -> str: 

515 if cache is None or logger is None: 

516 if cache is None: 

517 cache = depends.get("cache") 

518 if logger is None: 

519 logger = depends.get("logger") 

520 logger.debug( 

521 f"learn_cache_key request.method={request.method!r} response.headers.Vary={response.headers.get('Vary')!r}", 

522 ) 

523 url = request.url 

524 varying_headers_cache_key = generate_varying_headers_cache_key(url) 

525 cached_vary_headers = set(await cache.get(key=varying_headers_cache_key) or ()) 

526 response_vary_headers = { 

527 header.lower() for header in parse_http_list(response.headers.get("Vary", "")) 

528 } 

529 varying_headers = sorted(response_vary_headers | cached_vary_headers) 

530 if varying_headers: 

531 response.headers["Vary"] = ", ".join(varying_headers) 

532 logger.debug( 

533 f"store_varying_headers cache_key={varying_headers_cache_key!r} headers={varying_headers!r}", 

534 ) 

535 await cache.set(key=varying_headers_cache_key, value=varying_headers) 

536 cache_key = generate_cache_key( 

537 url, 

538 method=request.method, 

539 headers=request.headers, 

540 varying_headers=varying_headers, 

541 ) 

542 if cache_key is None: 

543 msg = f"Unable to generate cache key for method {request.method}" 

544 raise ValueError(msg) 

545 return cache_key 

546 

547 

548async def get_cache_key( 

549 request: Request, 

550 method: str, 

551 cache: t.Any = None, 

552 logger: t.Any = None, 

553) -> str | None: 

554 if cache is None or logger is None: 

555 if cache is None: 

556 cache = depends.get("cache") 

557 if logger is None: 

558 logger = depends.get("logger") 

559 url = request.url 

560 _safe_log( 

561 logger, 

562 "debug", 

563 f"get_cache_key request.url={str(url)!r} method={method!r}", 

564 ) 

565 varying_headers_cache_key = generate_varying_headers_cache_key(url) 

566 varying_headers = await cache.get(varying_headers_cache_key) 

567 if varying_headers is None: 

568 _safe_log(logger, "debug", "varying_headers found=False") 

569 return None 

570 _safe_log( 

571 logger, 

572 "debug", 

573 f"varying_headers found=True headers={varying_headers!r}", 

574 ) 

575 return generate_cache_key( 

576 request.url, 

577 method=method, 

578 headers=request.headers, 

579 varying_headers=varying_headers, 

580 ) 

581 

582 

583def generate_cache_key( 

584 url: URL, 

585 method: str, 

586 headers: Headers, 

587 varying_headers: list[str], 

588 config: t.Any = None, 

589) -> str | None: 

590 if config is None: 

591 config = depends.get("config") 

592 

593 if method not in cacheable_methods: 

594 return None 

595 

596 vary_hash = _generate_vary_hash(headers, varying_headers) 

597 url_hash = _generate_url_hash(url) 

598 

599 return f"{config.app.name}:cached:{method}.{url_hash}.{vary_hash}" 

600 

601 

602def _generate_vary_hash(headers: Headers, varying_headers: list[str]) -> str: 

603 """Generate hash for varying headers.""" 

604 vary_values = [ 

605 f"{header}:{value}" 

606 for header in varying_headers 

607 if (value := headers.get(header)) is not None 

608 ] 

609 

610 if not vary_values: 

611 return "" 

612 

613 hasher = _get_hasher() 

614 hasher.update(_str_encode("|".join(vary_values))) 

615 return t.cast(str, hasher.hexdigest()) 

616 

617 

618def _generate_url_hash(url: URL) -> str: 

619 """Generate hash for URL.""" 

620 hasher = _get_hasher() 

621 hasher.update(_str_encode(str(url))) 

622 return t.cast(str, hasher.hexdigest()) 

623 

624 

625def generate_varying_headers_cache_key(url: URL) -> str: 

626 hasher = _get_hasher() 

627 hasher.update(_str_encode(str(url.path))) 

628 url_hash = str(hasher.hexdigest()) 

629 return f"varying_headers.{url_hash}" 

630 

631 

632def get_cache_response_headers(response: Response, *, max_age: int) -> dict[str, str]: 

633 max_age = max(max_age, 0) 

634 headers = {} 

635 if "Expires" not in response.headers: 

636 headers["Expires"] = email.utils.formatdate(time.time() + max_age, usegmt=True) 

637 patch_cache_control(response.headers, max_age=max_age) 

638 

639 return headers 

640 

641 

642def patch_cache_control( 

643 headers: MutableHeaders, 

644 **kwargs: t.Unpack[CacheDirectives], 

645) -> None: 

646 cache_control: dict[str, t.Any] = {} 

647 value: t.Any 

648 for field in parse_http_list(headers.get("Cache-Control", "")): 

649 try: 

650 key, value = field.split("=") 

651 except ValueError: 

652 cache_control[field] = True 

653 else: 

654 cache_control[key] = value 

655 

656 if "max-age" in cache_control and "max_age" in kwargs: 

657 kwargs["max_age"] = min(int(cache_control["max-age"]), kwargs["max_age"]) 

658 

659 # Check for unsupported directives 

660 _check_unsupported_directives(kwargs) 

661 

662 for key, value in kwargs.items(): 

663 key = key.replace("_", "-") 

664 cache_control[key] = value 

665 

666 directives: list[str] = [] 

667 for key, value in cache_control.items(): 

668 if value is False: 

669 continue 

670 if value is True: 

671 directives.append(key) 

672 else: 

673 directives.append(f"{key}={value}") 

674 

675 patched_cache_control = ", ".join(directives) 

676 if patched_cache_control: 

677 headers["Cache-Control"] = patched_cache_control 

678 else: 

679 del headers["Cache-Control"] 

680 

681 

682def _check_unsupported_directives(kwargs: t.Any) -> None: 

683 """Check for unsupported cache control directives.""" 

684 if "public" in kwargs: 

685 msg = "The 'public' cache control directive isn't supported yet." 

686 raise NotImplementedError(msg) 

687 if "private" in kwargs: 

688 msg = "The 'private' cache control directive isn't supported yet." 

689 raise NotImplementedError(msg) 

690 

691 

692class CacheResponder: 

693 def __init__(self, app: ASGIApp, *, rules: Sequence[Rule]) -> None: 

694 self.app = app 

695 self.rules = rules 

696 try: 

697 self.logger = depends.get("logger") 

698 except Exception: 

699 import logging 

700 

701 self.logger = logging.getLogger("fastblocks.cache") 

702 try: 

703 self.cache = depends.get("cache") 

704 except Exception: 

705 self.cache = None 

706 self.initial_message: Message = {} 

707 self.is_response_cacheable = True 

708 self.request: Request | None = None 

709 

710 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

711 if scope["type"] != "http": 

712 await self.app(scope, receive, send) 

713 return 

714 self.request = request = Request(scope) 

715 try: 

716 response = await get_from_cache(request, cache=self.cache, rules=self.rules) 

717 except RequestNotCachable: 

718 if request.method in invalidating_methods: 

719 send = partial(self.send_then_invalidate, send=send) 

720 else: 

721 if response is not None: 

722 _safe_log(self.logger, "debug", "cache_lookup HIT") 

723 await response(scope, receive, send) 

724 return 

725 send = partial(self.send_with_caching, send=send) 

726 _safe_log(self.logger, "debug", "cache_lookup MISS") 

727 await self.app(scope, receive, send) 

728 

729 async def send_with_caching(self, message: Message, *, send: Send) -> None: 

730 if not self.is_response_cacheable or message["type"] not in ( 

731 "http.response.start", 

732 "http.response.body", 

733 ): 

734 await send(message) 

735 return 

736 if message["type"] == "http.response.start": 

737 self.initial_message = message 

738 return 

739 if message["type"] != "http.response.body": 

740 return 

741 if message.get("more_body", False): 

742 _safe_log( 

743 self.logger, 

744 "debug", 

745 "response_not_cacheable reason=is_streaming", 

746 ) 

747 self.is_response_cacheable = False 

748 await send(self.initial_message) 

749 await send(message) 

750 return 

751 if self.request is None: 

752 return 

753 body = message["body"] 

754 response = Response(content=body, status_code=self.initial_message["status"]) 

755 response.raw_headers = list(self.initial_message["headers"]) 

756 try: 

757 await set_in_cache( 

758 response, 

759 request=self.request, 

760 cache=self.cache, 

761 rules=self.rules, 

762 ) 

763 except ResponseNotCachable: 

764 self.is_response_cacheable = False 

765 else: 

766 self.initial_message["headers"] = response.raw_headers.copy() 

767 await send(self.initial_message) 

768 await send(message) 

769 

770 async def send_then_invalidate(self, message: Message, *, send: Send) -> None: 

771 if self.request is None: 

772 return 

773 if message["type"] == "http.response.start" and 200 <= message["status"] < 400: 

774 await delete_from_cache( 

775 self.request.url, 

776 vary=self.request.headers, 

777 cache=self.cache, 

778 ) 

779 await send(message) 

780 

781 

782class CacheControlResponder: 

783 def __init__(self, app: ASGIApp, **kwargs: t.Unpack[CacheDirectives]) -> None: 

784 self.app = app 

785 self.kwargs = kwargs 

786 try: 

787 self.logger = depends.get("logger") 

788 except Exception: 

789 import logging 

790 

791 self.logger = logging.getLogger("fastblocks.cache") 

792 

793 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

794 if scope["type"] != "http": 

795 await self.app(scope, receive, send) 

796 return 

797 send = partial(self.send_with_caching, send=send) 

798 await self.app(scope, receive, send) 

799 

800 @staticmethod 

801 def kvformat(**kwargs: t.Any) -> str: 

802 return " ".join((f"{key}={value}" for key, value in kwargs.items())) 

803 

804 async def send_with_caching(self, message: Message, *, send: Send) -> None: 

805 if message["type"] == "http.response.start": 

806 _safe_log( 

807 self.logger, 

808 "debug", 

809 f"patch_cache_control {self.kvformat(**self.kwargs)}", 

810 ) 

811 headers = MutableHeaders(raw=list(message["headers"])) 

812 patch_cache_control(headers, **self.kwargs) 

813 message["headers"] = headers.raw 

814 await send(message)