Coverage for src/chat_limiter/limiter.py: 77%

383 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-09-18 21:15 +0100

1""" 

2Core rate limiter implementation using PyrateLimiter. 

3""" 

4 

5import asyncio 

6from collections.abc import AsyncIterator, Iterator 

7from contextlib import asynccontextmanager, contextmanager 

8from dataclasses import dataclass, field 

9import logging 

10import time 

11from typing import Any 

12 

13import httpx 

14from pyrate_limiter import Duration, Limiter, Rate 

15from tenacity import ( 

16 retry, 

17 retry_if_exception_type, 

18 stop_after_attempt, 

19 wait_exponential, 

20) 

21 

22from .adapters import get_adapter 

23from .providers import ( 

24 Provider, 

25 ProviderConfig, 

26 RateLimitInfo, 

27 detect_provider_from_url, 

28 get_provider_config, 

29) 

30from .types import ( 

31 ChatCompletionRequest, 

32 ChatCompletionResponse, 

33 Message, 

34 MessageRole, 

35 detect_provider_from_model, 

36) 

37from .models import detect_provider_from_model_sync 

38 

39logger = logging.getLogger(__name__) 

40 

41 

42@dataclass 

43class LimiterState: 

44 """Current state of the rate limiter.""" 

45 

46 # Current limits (None if not yet discovered) 

47 request_limit: int | None = None 

48 token_limit: int | None = None 

49 

50 # Usage tracking 

51 requests_used: int = 0 

52 tokens_used: int = 0 

53 

54 # Timing 

55 last_request_time: float = field(default_factory=time.time) 

56 last_limit_update: float = field(default_factory=time.time) 

57 

58 # Rate limit info from last response 

59 last_rate_limit_info: RateLimitInfo | None = None 

60 

61 # Adaptive behavior 

62 consecutive_rate_limit_errors: int = 0 

63 adaptive_backoff_factor: float = 1.0 

64 

65 

66class ChatLimiter: 

67 """ 

68 A Pythonic rate limiter for API calls supporting OpenAI, Anthropic, and OpenRouter. 

69 

70 Features: 

71 - Automatic rate limit discovery and adaptation 

72 - Sync and async support with context managers 

73 - Intelligent retry logic with exponential backoff 

74 - Token and request rate limiting 

75 - Provider-specific optimizations 

76 

77 Example: 

78 # High-level interface (recommended) 

79 async with ChatLimiter.for_model("gpt-4o", api_key="sk-...") as limiter: 

80 response = await limiter.chat_completion( 

81 model="gpt-4o", 

82 messages=[Message(role=MessageRole.USER, content="Hello!")] 

83 ) 

84 

85 # Low-level interface (for advanced users) 

86 async with ChatLimiter(provider=Provider.OPENAI, api_key="sk-...") as limiter: 

87 response = await limiter.request("POST", "/chat/completions", json=data) 

88 """ 

89 

90 def __init__( 

91 self, 

92 provider: Provider | None = None, 

93 api_key: str | None = None, 

94 base_url: str | None = None, 

95 config: ProviderConfig | None = None, 

96 http_client: httpx.AsyncClient | None = None, 

97 sync_http_client: httpx.Client | None = None, 

98 enable_adaptive_limits: bool = True, 

99 enable_token_estimation: bool = True, 

100 request_limit: int | None = None, 

101 token_limit: int | None = None, 

102 max_retries: int | None = None, 

103 base_backoff: float | None = None, 

104 timeout: float | None = None, 

105 **kwargs: Any, 

106 ): 

107 """ 

108 Initialize the ChatLimiter. 

109 

110 Args: 

111 provider: The API provider (OpenAI, Anthropic, OpenRouter) 

112 api_key: API key for authentication 

113 base_url: Base URL for API requests 

114 config: Custom provider configuration 

115 http_client: Custom async HTTP client 

116 sync_http_client: Custom sync HTTP client 

117 enable_adaptive_limits: Enable adaptive rate limit adjustment 

118 enable_token_estimation: Enable token usage estimation 

119 request_limit: Override request limit (if not provided, must be discovered from API) 

120 token_limit: Override token limit (if not provided, must be discovered from API) 

121 max_retries: Override max retries (defaults to 3 if not provided) 

122 base_backoff: Override base backoff (defaults to 1.0 if not provided) 

123 timeout: HTTP request timeout in seconds (defaults to 120.0 for better reliability) 

124 **kwargs: Additional arguments passed to HTTP clients 

125 """ 

126 # Determine provider and config 

127 if config: 

128 self.config = config 

129 self.provider = config.provider 

130 elif provider: 

131 self.provider = provider 

132 self.config = get_provider_config(provider) 

133 elif base_url: 

134 detected_provider = detect_provider_from_url(base_url) 

135 if detected_provider: 

136 self.provider = detected_provider 

137 self.config = get_provider_config(detected_provider) 

138 else: 

139 raise ValueError(f"Could not detect provider from URL: {base_url}") 

140 else: 

141 raise ValueError("Must provide either provider, config, or base_url") 

142 

143 # Override base_url if provided 

144 if base_url: 

145 self.config.base_url = base_url 

146 

147 # Store configuration 

148 self.api_key = api_key 

149 self.enable_adaptive_limits = enable_adaptive_limits 

150 self.enable_token_estimation = enable_token_estimation 

151 

152 # Store user-provided overrides 

153 self._user_request_limit = request_limit 

154 self._user_token_limit = token_limit 

155 self._user_max_retries = max_retries or 3 # Default to 3 if not provided 

156 self._user_base_backoff = base_backoff or 1.0 # Default to 1.0 if not provided 

157 self._user_timeout = ( 

158 timeout or 120.0 

159 ) # Default to 120 seconds for better reliability 

160 

161 # Determine initial limits (user override, config default, or None for discovery) 

162 initial_request_limit = ( 

163 request_limit or self.config.default_request_limit or None 

164 ) 

165 initial_token_limit = token_limit or self.config.default_token_limit or None 

166 

167 # Initialize state - will be None if no defaults and no discovery yet 

168 self.state = LimiterState( 

169 request_limit=initial_request_limit, 

170 token_limit=initial_token_limit, 

171 ) 

172 

173 # Flag to track if we need to discover limits 

174 self._limits_discovered = ( 

175 initial_request_limit is not None and initial_token_limit is not None 

176 ) 

177 

178 # Initialize HTTP clients 

179 self._init_http_clients(http_client, sync_http_client, **kwargs) 

180 

181 # Initialize rate limiters 

182 self._init_rate_limiters() 

183 

184 # Context manager state 

185 self._async_context_active = False 

186 self._sync_context_active = False 

187 

188 # Logging configuration 

189 self._print_rate_limit_info = False 

190 self._print_request_initiation = False 

191 

192 @classmethod 

193 def for_model( 

194 cls, 

195 model: str, 

196 api_key: str | None = None, 

197 provider: str | Provider | None = None, 

198 use_dynamic_discovery: bool = True, 

199 request_limit: int | None = None, 

200 token_limit: int | None = None, 

201 max_retries: int | None = None, 

202 base_backoff: float | None = None, 

203 timeout: float | None = None, 

204 **kwargs: Any, 

205 ) -> "ChatLimiter": 

206 """ 

207 Create a ChatLimiter instance automatically detecting the provider from the model name. 

208 

209 Args: 

210 model: The model name (e.g., "gpt-4o", "claude-3-sonnet-20240229") 

211 api_key: API key for the provider. If None, will be read from environment variables 

212 (OPENAI_API_KEY, ANTHROPIC_API_KEY, OPENROUTER_API_KEY) 

213 provider: Override provider detection. Can be "openai", "anthropic", "openrouter", 

214 or Provider enum. If None, will be auto-detected from model name 

215 use_dynamic_discovery: Whether to query live APIs for model availability (default: True). 

216 Requires appropriate API keys to be available. Falls back to 

217 hardcoded model lists when disabled or when API calls fail. 

218 **kwargs: Additional arguments passed to ChatLimiter 

219 

220 Returns: 

221 Configured ChatLimiter instance 

222 

223 Raises: 

224 ValueError: If provider cannot be determined from model name or API key not found 

225 

226 Example: 

227 # Auto-detect provider with dynamic discovery (default behavior) 

228 async with ChatLimiter.for_model("gpt-4o") as limiter: 

229 response = await limiter.simple_chat("gpt-4o", "Hello!") 

230 

231 # Override provider detection 

232 async with ChatLimiter.for_model("custom-model", provider="openai") as limiter: 

233 response = await limiter.simple_chat("custom-model", "Hello!") 

234 

235 # Disable dynamic discovery to use only hardcoded model lists 

236 async with ChatLimiter.for_model("gpt-4o", use_dynamic_discovery=False) as limiter: 

237 response = await limiter.simple_chat("gpt-4o", "Hello!") 

238 """ 

239 import os 

240 

241 # Determine provider 

242 if provider is not None: 

243 # Use provided provider 

244 if isinstance(provider, str): 

245 provider_enum = Provider(provider) 

246 else: 

247 provider_enum = provider 

248 provider_name = provider_enum.value 

249 else: 

250 # Auto-detect from model name 

251 # If dynamic discovery is requested, we need to collect API keys first 

252 api_keys_for_discovery = {} 

253 if use_dynamic_discovery: 

254 # Collect available API keys from environment 

255 env_var_map = { 

256 "openai": "OPENAI_API_KEY", 

257 "anthropic": "ANTHROPIC_API_KEY", 

258 "openrouter": "OPENROUTER_API_KEY", 

259 } 

260 

261 for provider_key, env_var in env_var_map.items(): 

262 key_value = os.getenv(env_var) 

263 if key_value: 

264 api_keys_for_discovery[provider_key] = key_value 

265 

266 # Try dynamic discovery first to get more detailed information 

267 discovery_result = None 

268 detected_provider = detect_provider_from_model( 

269 model, use_dynamic_discovery, api_keys_for_discovery 

270 ) 

271 

272 if not detected_provider: 

273 discovery_msg = ( 

274 " with dynamic API discovery" if use_dynamic_discovery else "" 

275 ) 

276 error_msg = f"Could not determine provider from model '{model}'{discovery_msg}. " 

277 

278 # Add detailed information about available models if we have discovery results 

279 if discovery_result and discovery_result.get_total_models_found() > 0: 

280 error_msg += f"\n\nFound {discovery_result.get_total_models_found()} models across providers:\n" 

281 for ( 

282 provider_name, 

283 models, 

284 ) in discovery_result.get_all_models().items(): 

285 error_msg += f" {provider_name}: {len(models)} models\n" 

286 for example in sorted(list(models)): 

287 error_msg += f" - {example}\n" 

288 error_msg += "\nPlease check the model name or specify the provider explicitly using the 'provider' parameter." 

289 else: 

290 error_msg += "Please specify the provider explicitly using the 'provider' parameter." 

291 

292 # Add information about discovery errors if any 

293 if discovery_result and discovery_result.errors: 

294 error_msg += "\n\nDiscovery errors encountered:\n" 

295 for provider_name, error in discovery_result.errors.items(): 

296 error_msg += f" {provider_name}: {error}\n" 

297 

298 raise ValueError(error_msg) 

299 assert detected_provider is not None # Help MyPy understand type narrowing 

300 provider_name = detected_provider 

301 provider_enum = Provider(provider_name) 

302 

303 # Determine API key 

304 if api_key is None: 

305 # Try to get from environment variables 

306 env_var_map = { 

307 "openai": "OPENAI_API_KEY", 

308 "anthropic": "ANTHROPIC_API_KEY", 

309 "openrouter": "OPENROUTER_API_KEY", 

310 } 

311 

312 env_var_name: str | None = env_var_map.get(provider_name) 

313 if env_var_name: 

314 api_key = os.getenv(env_var_name) 

315 if not api_key: 

316 raise ValueError( 

317 f"API key not provided and {env_var_name} environment variable not set. " 

318 f"Please provide api_key parameter or set {env_var_name} environment variable." 

319 ) 

320 else: 

321 raise ValueError( 

322 f"Unknown provider '{provider_name}'. Cannot determine environment variable for API key." 

323 ) 

324 

325 return cls( 

326 provider=provider_enum, 

327 api_key=api_key, 

328 request_limit=request_limit, 

329 token_limit=token_limit, 

330 max_retries=max_retries, 

331 base_backoff=base_backoff, 

332 timeout=timeout, 

333 **kwargs, 

334 ) 

335 

336 def _init_http_clients( 

337 self, 

338 http_client: httpx.AsyncClient | None, 

339 sync_http_client: httpx.Client | None, 

340 **kwargs: Any, 

341 ) -> None: 

342 """Initialize HTTP clients with proper headers.""" 

343 # Prepare headers 

344 headers = { 

345 "User-Agent": f"chat-limiter/0.1.0 ({self.provider.value})", 

346 } 

347 

348 # Add provider-specific headers 

349 if self.api_key: 

350 if self.provider == Provider.OPENAI: 

351 headers["Authorization"] = f"Bearer {self.api_key}" 

352 elif self.provider == Provider.ANTHROPIC: 

353 headers["x-api-key"] = self.api_key 

354 headers["anthropic-version"] = "2023-06-01" 

355 elif self.provider == Provider.OPENROUTER: 

356 headers["Authorization"] = f"Bearer {self.api_key}" 

357 headers["HTTP-Referer"] = "https://github.com/your-repo/chat-limiter" 

358 

359 # Merge with user-provided headers 

360 if "headers" in kwargs: 

361 headers.update(kwargs["headers"]) 

362 kwargs["headers"] = headers 

363 

364 # Initialize clients 

365 if http_client: 

366 self.async_client = http_client 

367 else: 

368 self.async_client = httpx.AsyncClient( 

369 base_url=self.config.base_url, 

370 timeout=httpx.Timeout(self._user_timeout), # Configurable timeout 

371 **kwargs, 

372 ) 

373 

374 if sync_http_client: 

375 self.sync_client = sync_http_client 

376 else: 

377 self.sync_client = httpx.Client( 

378 base_url=self.config.base_url, 

379 timeout=httpx.Timeout(self._user_timeout), # Configurable timeout 

380 **kwargs, 

381 ) 

382 

383 def _init_rate_limiters(self) -> None: 

384 """Initialize PyrateLimiter instances.""" 

385 # Only initialize if we have limits 

386 if self.state.request_limit is None or self.state.token_limit is None: 

387 # Cannot initialize rate limiters without limits 

388 # This will be called again after limits are discovered 

389 self.request_limiter = None 

390 self.token_limiter = None 

391 self._effective_request_limit = None 

392 self._effective_token_limit = None 

393 return 

394 

395 # Calculate effective limits with buffer 

396 effective_request_limit = int( 

397 self.state.request_limit * self.config.request_buffer_ratio 

398 ) 

399 effective_token_limit = int( 

400 self.state.token_limit * self.config.token_buffer_ratio 

401 ) 

402 

403 # Request rate limiter 

404 self.request_limiter = Limiter( 

405 Rate( 

406 effective_request_limit, 

407 Duration.MINUTE, 

408 ) 

409 ) 

410 

411 # Token rate limiter 

412 self.token_limiter = Limiter( 

413 Rate( 

414 effective_token_limit, 

415 Duration.MINUTE, 

416 ) 

417 ) 

418 

419 # Store effective limits for logging 

420 self._effective_request_limit = effective_request_limit 

421 self._effective_token_limit = effective_token_limit 

422 

423 async def __aenter__(self) -> "ChatLimiter": 

424 """Async context manager entry.""" 

425 if self._async_context_active: 

426 raise RuntimeError( 

427 "ChatLimiter is already active as an async context manager" 

428 ) 

429 

430 self._async_context_active = True 

431 

432 # Discover rate limits if supported 

433 if self.config.supports_dynamic_limits: 

434 await self._discover_rate_limits() 

435 

436 # Print rate limit information if enabled 

437 if self._print_rate_limit_info: 

438 self._print_rate_limit_info_details() 

439 

440 return self 

441 

442 async def __aexit__( 

443 self, 

444 exc_type: type[BaseException] | None, 

445 exc_val: BaseException | None, 

446 exc_tb: object, 

447 ) -> None: 

448 """Async context manager exit.""" 

449 self._async_context_active = False 

450 await self.async_client.aclose() 

451 

452 def __enter__(self) -> "ChatLimiter": 

453 """Sync context manager entry.""" 

454 if self._sync_context_active: 

455 raise RuntimeError( 

456 "ChatLimiter is already active as a sync context manager" 

457 ) 

458 

459 self._sync_context_active = True 

460 

461 # Discover rate limits if supported 

462 if self.config.supports_dynamic_limits: 

463 self._discover_rate_limits_sync() 

464 

465 # Print rate limit information if enabled 

466 if self._print_rate_limit_info: 

467 self._print_rate_limit_info_details() 

468 

469 return self 

470 

471 def __exit__( 

472 self, 

473 exc_type: type[BaseException] | None, 

474 exc_val: BaseException | None, 

475 exc_tb: object, 

476 ) -> None: 

477 """Sync context manager exit.""" 

478 self._sync_context_active = False 

479 self.sync_client.close() 

480 

481 async def _discover_rate_limits(self) -> None: 

482 """Discover current rate limits from the API.""" 

483 try: 

484 if self.provider == Provider.OPENROUTER and self.config.auth_endpoint: 

485 # OpenRouter uses a special auth endpoint 

486 response = await self.async_client.get(self.config.auth_endpoint) 

487 response.raise_for_status() 

488 

489 data = response.json() 

490 # Update limits based on response 

491 # This is a simplified version - actual implementation would parse the response 

492 logger.info(f"Discovered OpenRouter limits: {data}") 

493 

494 else: 

495 # For other providers, we'll discover limits on first request 

496 if self._print_rate_limit_info: 

497 print( 

498 f"Rate limit discovery will happen on first request for {self.provider.value}" 

499 ) 

500 logger.info( 

501 f"Rate limit discovery will happen on first request for {self.provider.value}" 

502 ) 

503 

504 except Exception as e: 

505 logger.warning(f"Failed to discover rate limits: {e}") 

506 

507 def _discover_rate_limits_sync(self) -> None: 

508 """Sync version of rate limit discovery.""" 

509 try: 

510 if self.provider == Provider.OPENROUTER and self.config.auth_endpoint: 

511 response = self.sync_client.get(self.config.auth_endpoint) 

512 response.raise_for_status() 

513 

514 data = response.json() 

515 logger.info(f"Discovered OpenRouter limits: {data}") 

516 else: 

517 logger.info( 

518 f"Rate limit discovery will happen on first request for {self.provider.value}" 

519 ) 

520 

521 except Exception as e: 

522 logger.warning(f"Failed to discover rate limits: {e}") 

523 

524 def _update_rate_limits(self, rate_limit_info: RateLimitInfo) -> None: 

525 """Update rate limits based on response headers.""" 

526 updated = False 

527 was_uninitialized = ( 

528 self.state.request_limit is None or self.state.token_limit is None 

529 ) 

530 

531 # Update request limits 

532 if ( 

533 rate_limit_info.requests_limit 

534 and rate_limit_info.requests_limit != self.state.request_limit 

535 ): 

536 old_limit = self.state.request_limit 

537 self.state.request_limit = rate_limit_info.requests_limit 

538 updated = True 

539 if was_uninitialized: 

540 message = ( 

541 f"Discovered request limit: {self.state.request_limit} req/min" 

542 ) 

543 if self._print_rate_limit_info: 

544 print(message) 

545 logger.info(message) 

546 else: 

547 message = f"Updated request limit: {old_limit} -> {self.state.request_limit} req/min" 

548 if self._print_rate_limit_info: 

549 print(message) 

550 logger.info(message) 

551 

552 # Update token limits 

553 if ( 

554 rate_limit_info.tokens_limit 

555 and rate_limit_info.tokens_limit != self.state.token_limit 

556 ): 

557 old_limit = self.state.token_limit 

558 self.state.token_limit = rate_limit_info.tokens_limit 

559 updated = True 

560 if was_uninitialized: 

561 message = f"Discovered token limit: {self.state.token_limit} tokens/min" 

562 if self._print_rate_limit_info: 

563 print(message) 

564 logger.info(message) 

565 else: 

566 message = f"Updated token limit: {old_limit} -> {self.state.token_limit} tokens/min" 

567 if self._print_rate_limit_info: 

568 print(message) 

569 logger.info(message) 

570 

571 if updated: 

572 # Reinitialize rate limiters with new limits 

573 self._init_rate_limiters() 

574 

575 # Update limits_discovered flag if both limits are now available 

576 if ( 

577 self.state.request_limit is not None 

578 and self.state.token_limit is not None 

579 ): 

580 self._limits_discovered = True 

581 

582 if was_uninitialized: 

583 message = "Rate limiters initialized after discovery" 

584 if self._print_rate_limit_info: 

585 print(message) 

586 # Print updated rate limit info after discovery 

587 self._print_rate_limit_info_details() 

588 logger.info(message) 

589 

590 # Store the rate limit info 

591 self.state.last_rate_limit_info = rate_limit_info 

592 self.state.last_limit_update = time.time() 

593 

594 def _estimate_tokens(self, request_data: dict[str, Any]) -> int: 

595 """Estimate token usage from request data.""" 

596 if not self.enable_token_estimation: 

597 return 0 

598 

599 # Simple token estimation 

600 # This is a placeholder - real implementation would use tiktoken or similar 

601 if "messages" in request_data: 

602 text = "" 

603 for message in request_data["messages"]: 

604 if isinstance(message, dict) and "content" in message: 

605 text += str(message["content"]) 

606 

607 # Rough estimation: 1 token ≈ 4 characters 

608 return len(text) // 4 

609 

610 return 0 

611 

612 @asynccontextmanager 

613 async def _acquire_rate_limits( 

614 self, estimated_tokens: int = 0 

615 ) -> AsyncIterator[None]: 

616 """Acquire rate limits before making a request.""" 

617 # Check if rate limiters are initialized 

618 if self.request_limiter is None or self.token_limiter is None: 

619 # Limits not yet discovered - this request will help discover them 

620 logger.info( 

621 "Rate limits not yet discovered, proceeding without rate limiting for discovery" 

622 ) 

623 else: 

624 # Wait for request rate limit 

625 await asyncio.to_thread(self.request_limiter.try_acquire, "request") 

626 

627 # Wait for token rate limit if we have token estimation and limiters are initialized 

628 if ( 

629 estimated_tokens > 0 

630 and self.token_limiter is not None 

631 and self._effective_token_limit is not None 

632 ): 

633 # Check if request is too large for bucket capacity 

634 if estimated_tokens > self._effective_token_limit: 

635 # Log warning for large requests 

636 logger.warning( 

637 f"Request estimated at {estimated_tokens} tokens exceeds bucket capacity " 

638 f"of {self._effective_token_limit} tokens. This may cause delays." 

639 ) 

640 # For very large requests, we'll split the acquisition 

641 # Acquire tokens in chunks to avoid bucket overflow 

642 remaining_tokens = estimated_tokens 

643 while remaining_tokens > 0: 

644 chunk_size = min( 

645 remaining_tokens, self._effective_token_limit // 2 

646 ) 

647 await asyncio.to_thread( 

648 self.token_limiter.try_acquire, "token", chunk_size 

649 ) 

650 remaining_tokens -= chunk_size 

651 if remaining_tokens > 0: 

652 # Brief pause to let bucket refill 

653 await asyncio.sleep(0.1) 

654 else: 

655 # Normal acquisition for smaller requests 

656 await asyncio.to_thread( 

657 self.token_limiter.try_acquire, "token", estimated_tokens 

658 ) 

659 

660 try: 

661 yield 

662 finally: 

663 # Update usage tracking 

664 self.state.requests_used += 1 

665 self.state.tokens_used += estimated_tokens 

666 self.state.last_request_time = time.time() 

667 

668 @contextmanager 

669 def _acquire_rate_limits_sync(self, estimated_tokens: int = 0) -> Iterator[None]: 

670 """Sync version of rate limit acquisition.""" 

671 # Check if rate limiters are initialized 

672 if self.request_limiter is None or self.token_limiter is None: 

673 # Limits not yet discovered - this request will help discover them 

674 logger.info( 

675 "Rate limits not yet discovered, proceeding without rate limiting for discovery" 

676 ) 

677 else: 

678 # Wait for request rate limit 

679 self.request_limiter.try_acquire("request") 

680 

681 # Wait for token rate limit if we have token estimation and limiters are initialized 

682 if ( 

683 estimated_tokens > 0 

684 and self.token_limiter is not None 

685 and self._effective_token_limit is not None 

686 ): 

687 # Check if request is too large for bucket capacity 

688 if estimated_tokens > self._effective_token_limit: 

689 # Log warning for large requests 

690 logger.warning( 

691 f"Request estimated at {estimated_tokens} tokens exceeds bucket capacity " 

692 f"of {self._effective_token_limit} tokens. This may cause delays." 

693 ) 

694 # For very large requests, we'll split the acquisition 

695 # Acquire tokens in chunks to avoid bucket overflow 

696 remaining_tokens = estimated_tokens 

697 while remaining_tokens > 0: 

698 chunk_size = min( 

699 remaining_tokens, self._effective_token_limit // 2 

700 ) 

701 self.token_limiter.try_acquire("token", chunk_size) 

702 remaining_tokens -= chunk_size 

703 if remaining_tokens > 0: 

704 # Brief pause to let bucket refill 

705 time.sleep(0.1) 

706 else: 

707 # Normal acquisition for smaller requests 

708 self.token_limiter.try_acquire("token", estimated_tokens) 

709 

710 try: 

711 yield 

712 finally: 

713 # Update usage tracking 

714 self.state.requests_used += 1 

715 self.state.tokens_used += estimated_tokens 

716 self.state.last_request_time = time.time() 

717 

718 def _get_retry_decorator(self) -> Any: 

719 """Get retry decorator with user-configured parameters.""" 

720 return retry( 

721 stop=stop_after_attempt(self._user_max_retries), 

722 wait=wait_exponential(multiplier=self._user_base_backoff, min=1, max=60), 

723 retry=retry_if_exception_type( 

724 ( 

725 httpx.HTTPStatusError, 

726 httpx.RequestError, 

727 httpx.ReadTimeout, 

728 httpx.ConnectTimeout, 

729 ) 

730 ), 

731 ) 

732 

733 def get_current_limits(self) -> dict[str, Any]: 

734 """Get current rate limit information.""" 

735 return { 

736 "provider": self.provider.value, 

737 "request_limit": self.state.request_limit, 

738 "token_limit": self.state.token_limit, 

739 "requests_used": self.state.requests_used, 

740 "tokens_used": self.state.tokens_used, 

741 "last_request_time": self.state.last_request_time, 

742 "last_limit_update": self.state.last_limit_update, 

743 "consecutive_rate_limit_errors": self.state.consecutive_rate_limit_errors, 

744 } 

745 

746 def reset_usage_tracking(self) -> None: 

747 """Reset usage tracking counters.""" 

748 self.state.requests_used = 0 

749 self.state.tokens_used = 0 

750 self.state.consecutive_rate_limit_errors = 0 

751 

752 # High-level chat completion methods 

753 

754 async def chat_completion( 

755 self, 

756 model: str, 

757 messages: list[Message], 

758 max_tokens: int | None = None, 

759 temperature: float | None = None, 

760 top_p: float | None = None, 

761 stop: str | list[str] | None = None, 

762 stream: bool = False, 

763 **kwargs: Any, 

764 ) -> ChatCompletionResponse: 

765 """ 

766 Make a high-level chat completion request. 

767 

768 Args: 

769 model: The model to use for completion 

770 messages: List of messages in the conversation 

771 max_tokens: Maximum tokens to generate 

772 temperature: Sampling temperature 

773 top_p: Top-p sampling parameter 

774 stop: Stop sequences 

775 stream: Whether to stream the response 

776 **kwargs: Additional provider-specific parameters 

777 

778 Returns: 

779 ChatCompletionResponse with the completion result 

780 

781 Raises: 

782 ValueError: If provider cannot be determined from model 

783 httpx.HTTPStatusError: For HTTP error responses 

784 httpx.RequestError: For request errors 

785 """ 

786 if not self._async_context_active: 

787 raise RuntimeError("ChatLimiter must be used as an async context manager") 

788 

789 # Create request object 

790 request = ChatCompletionRequest( 

791 model=model, 

792 messages=messages, 

793 max_tokens=max_tokens, 

794 temperature=temperature, 

795 top_p=top_p, 

796 stop=stop, 

797 stream=stream, 

798 **kwargs, 

799 ) 

800 

801 # Get the appropriate adapter 

802 adapter = get_adapter(self.provider) 

803 

804 # Format the request for the provider 

805 formatted_request = adapter.format_request(request) 

806 

807 # Make the HTTP request with rate limiting 

808 try: 

809 # Print request initiation if enabled 

810 if self._print_request_initiation: 

811 print(f"Sending request for model {model} (attempt 1)") 

812 

813 # Estimate tokens 

814 estimated_tokens = self._estimate_tokens(formatted_request) 

815 

816 # Acquire rate limits 

817 async with self._acquire_rate_limits(estimated_tokens): 

818 # Make the request 

819 response = await self.async_client.request( 

820 "POST", adapter.get_endpoint(), json=formatted_request 

821 ) 

822 

823 # Extract rate limit info 

824 from .providers import extract_rate_limit_info 

825 rate_limit_info = extract_rate_limit_info( 

826 dict(response.headers), self.config 

827 ) 

828 

829 # Update our rate limits 

830 if self.enable_adaptive_limits: 

831 self._update_rate_limits(rate_limit_info) 

832 

833 # Handle rate limit errors 

834 if response.status_code == 429: 

835 self.state.consecutive_rate_limit_errors += 1 

836 if rate_limit_info.retry_after: 

837 import asyncio 

838 await asyncio.sleep(rate_limit_info.retry_after) 

839 else: 

840 # Exponential backoff 

841 import asyncio 

842 backoff = self.config.base_backoff * ( 

843 2**self.state.consecutive_rate_limit_errors 

844 ) 

845 await asyncio.sleep(min(backoff, self.config.max_backoff)) 

846 

847 response.raise_for_status() 

848 else: 

849 # Reset consecutive errors on success 

850 self.state.consecutive_rate_limit_errors = 0 

851 

852 # Parse the response 

853 response_data = response.json() 

854 return adapter.parse_response(response_data, request) 

855 

856 except Exception as e: 

857 # Handle errors and return error response 

858 error_response = ChatCompletionResponse( 

859 id="error", 

860 model=request.model, 

861 success=False, 

862 error_message=str(e), 

863 choices=[], 

864 usage=None, 

865 created=None, 

866 ) 

867 return error_response 

868 

869 def chat_completion_sync( 

870 self, 

871 model: str, 

872 messages: list[Message], 

873 max_tokens: int | None = None, 

874 temperature: float | None = None, 

875 top_p: float | None = None, 

876 stop: str | list[str] | None = None, 

877 stream: bool = False, 

878 **kwargs: Any, 

879 ) -> ChatCompletionResponse: 

880 """ 

881 Make a synchronous high-level chat completion request. 

882 

883 Args: 

884 model: The model to use for completion 

885 messages: List of messages in the conversation 

886 max_tokens: Maximum tokens to generate 

887 temperature: Sampling temperature 

888 top_p: Top-p sampling parameter 

889 stop: Stop sequences 

890 stream: Whether to stream the response 

891 **kwargs: Additional provider-specific parameters 

892 

893 Returns: 

894 ChatCompletionResponse with the completion result 

895 

896 Raises: 

897 ValueError: If provider cannot be determined from model 

898 httpx.HTTPStatusError: For HTTP error responses 

899 httpx.RequestError: For request errors 

900 """ 

901 if not self._sync_context_active: 

902 raise RuntimeError("ChatLimiter must be used as a sync context manager") 

903 

904 # Create request object 

905 request = ChatCompletionRequest( 

906 model=model, 

907 messages=messages, 

908 max_tokens=max_tokens, 

909 temperature=temperature, 

910 top_p=top_p, 

911 stop=stop, 

912 stream=stream, 

913 **kwargs, 

914 ) 

915 

916 # Get the appropriate adapter 

917 adapter = get_adapter(self.provider) 

918 

919 # Format the request for the provider 

920 formatted_request = adapter.format_request(request) 

921 

922 # Make the HTTP request with rate limiting 

923 try: 

924 # Print request initiation if enabled 

925 if self._print_request_initiation: 

926 print(f"Sending request for model {model} (attempt 1)") 

927 

928 # Estimate tokens 

929 estimated_tokens = self._estimate_tokens(formatted_request) 

930 

931 # Acquire rate limits 

932 with self._acquire_rate_limits_sync(estimated_tokens): 

933 # Make the request 

934 response = self.sync_client.request( 

935 "POST", adapter.get_endpoint(), json=formatted_request 

936 ) 

937 

938 # Extract rate limit info 

939 from .providers import extract_rate_limit_info 

940 rate_limit_info = extract_rate_limit_info( 

941 dict(response.headers), self.config 

942 ) 

943 

944 # Update our rate limits 

945 if self.enable_adaptive_limits: 

946 self._update_rate_limits(rate_limit_info) 

947 

948 # Handle rate limit errors 

949 if response.status_code == 429: 

950 self.state.consecutive_rate_limit_errors += 1 

951 if rate_limit_info.retry_after: 

952 import time 

953 time.sleep(rate_limit_info.retry_after) 

954 else: 

955 # Exponential backoff 

956 import time 

957 backoff = self.config.base_backoff * ( 

958 2**self.state.consecutive_rate_limit_errors 

959 ) 

960 time.sleep(min(backoff, self.config.max_backoff)) 

961 

962 response.raise_for_status() 

963 else: 

964 # Reset consecutive errors on success 

965 self.state.consecutive_rate_limit_errors = 0 

966 

967 # Parse the response 

968 response_data = response.json() 

969 return adapter.parse_response(response_data, request) 

970 

971 except Exception as e: 

972 # Handle errors and return error response 

973 error_response = ChatCompletionResponse( 

974 id="error", 

975 model=request.model, 

976 success=False, 

977 error_message=str(e), 

978 choices=[], 

979 usage=None, 

980 created=None, 

981 ) 

982 return error_response 

983 

984 # Convenience methods for different message types 

985 

986 async def simple_chat( 

987 self, 

988 model: str, 

989 prompt: str, 

990 max_tokens: int | None = None, 

991 temperature: float | None = None, 

992 **kwargs: Any, 

993 ) -> str: 

994 """ 

995 Simple chat completion that returns just the text response. 

996 

997 Args: 

998 model: The model to use 

999 prompt: The user prompt 

1000 max_tokens: Maximum tokens to generate 

1001 temperature: Sampling temperature 

1002 **kwargs: Additional parameters 

1003 

1004 Returns: 

1005 The text response from the model 

1006 """ 

1007 messages = [Message(role=MessageRole.USER, content=prompt)] 

1008 response = await self.chat_completion( 

1009 model=model, 

1010 messages=messages, 

1011 max_tokens=max_tokens, 

1012 temperature=temperature, 

1013 **kwargs, 

1014 ) 

1015 

1016 if response.choices: 

1017 return response.choices[0].message.content 

1018 return "" 

1019 

1020 def simple_chat_sync( 

1021 self, 

1022 model: str, 

1023 prompt: str, 

1024 max_tokens: int | None = None, 

1025 temperature: float | None = None, 

1026 **kwargs: Any, 

1027 ) -> str: 

1028 """ 

1029 Simple synchronous chat completion that returns just the text response. 

1030 

1031 Args: 

1032 model: The model to use 

1033 prompt: The user prompt 

1034 max_tokens: Maximum tokens to generate 

1035 temperature: Sampling temperature 

1036 **kwargs: Additional parameters 

1037 

1038 Returns: 

1039 The text response from the model 

1040 """ 

1041 messages = [Message(role=MessageRole.USER, content=prompt)] 

1042 response = self.chat_completion_sync( 

1043 model=model, 

1044 messages=messages, 

1045 max_tokens=max_tokens, 

1046 temperature=temperature, 

1047 **kwargs, 

1048 ) 

1049 

1050 if response.choices: 

1051 return response.choices[0].message.content 

1052 return "" 

1053 

1054 def set_print_rate_limit_info(self, enabled: bool) -> None: 

1055 """Set whether to print rate limit information.""" 

1056 self._print_rate_limit_info = enabled 

1057 

1058 def set_print_request_initiation(self, enabled: bool) -> None: 

1059 """Set whether to print request initiation messages.""" 

1060 self._print_request_initiation = enabled 

1061 

1062 def _print_rate_limit_info_details(self) -> None: 

1063 """Print current rate limit configuration.""" 

1064 print(f"\n=== Rate Limit Configuration for {self.provider.value.title()} ===") 

1065 print(f"Provider: {self.provider.value}") 

1066 print(f"Base URL: {self.config.base_url}") 

1067 

1068 # Handle None values for limits 

1069 if self.state.request_limit is not None: 

1070 effective_req = self._effective_request_limit or "not calculated" 

1071 print( 

1072 f"Request Limit: {self.state.request_limit}/minute (effective: {effective_req}/minute)" 

1073 ) 

1074 else: 

1075 print("Request Limit: Not yet discovered (will be fetched from API)") 

1076 

1077 if self.state.token_limit is not None: 

1078 effective_tok = self._effective_token_limit or "not calculated" 

1079 print( 

1080 f"Token Limit: {self.state.token_limit}/minute (effective: {effective_tok}/minute)" 

1081 ) 

1082 else: 

1083 print("Token Limit: Not yet discovered (will be fetched from API)") 

1084 

1085 print(f"Request Buffer Ratio: {self.config.request_buffer_ratio}") 

1086 print(f"Token Buffer Ratio: {self.config.token_buffer_ratio}") 

1087 print(f"Adaptive Limits: {self.enable_adaptive_limits}") 

1088 print(f"Token Estimation: {self.enable_token_estimation}") 

1089 print(f"Dynamic Discovery: {self.config.supports_dynamic_limits}") 

1090 print(f"Limits Discovered: {self._limits_discovered}") 

1091 print("=" * 50)