Coverage for src/chat_limiter/limiter.py: 77%
383 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-18 21:15 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-18 21:15 +0100
1"""
2Core rate limiter implementation using PyrateLimiter.
3"""
5import asyncio
6from collections.abc import AsyncIterator, Iterator
7from contextlib import asynccontextmanager, contextmanager
8from dataclasses import dataclass, field
9import logging
10import time
11from typing import Any
13import httpx
14from pyrate_limiter import Duration, Limiter, Rate
15from tenacity import (
16 retry,
17 retry_if_exception_type,
18 stop_after_attempt,
19 wait_exponential,
20)
22from .adapters import get_adapter
23from .providers import (
24 Provider,
25 ProviderConfig,
26 RateLimitInfo,
27 detect_provider_from_url,
28 get_provider_config,
29)
30from .types import (
31 ChatCompletionRequest,
32 ChatCompletionResponse,
33 Message,
34 MessageRole,
35 detect_provider_from_model,
36)
37from .models import detect_provider_from_model_sync
39logger = logging.getLogger(__name__)
42@dataclass
43class LimiterState:
44 """Current state of the rate limiter."""
46 # Current limits (None if not yet discovered)
47 request_limit: int | None = None
48 token_limit: int | None = None
50 # Usage tracking
51 requests_used: int = 0
52 tokens_used: int = 0
54 # Timing
55 last_request_time: float = field(default_factory=time.time)
56 last_limit_update: float = field(default_factory=time.time)
58 # Rate limit info from last response
59 last_rate_limit_info: RateLimitInfo | None = None
61 # Adaptive behavior
62 consecutive_rate_limit_errors: int = 0
63 adaptive_backoff_factor: float = 1.0
66class ChatLimiter:
67 """
68 A Pythonic rate limiter for API calls supporting OpenAI, Anthropic, and OpenRouter.
70 Features:
71 - Automatic rate limit discovery and adaptation
72 - Sync and async support with context managers
73 - Intelligent retry logic with exponential backoff
74 - Token and request rate limiting
75 - Provider-specific optimizations
77 Example:
78 # High-level interface (recommended)
79 async with ChatLimiter.for_model("gpt-4o", api_key="sk-...") as limiter:
80 response = await limiter.chat_completion(
81 model="gpt-4o",
82 messages=[Message(role=MessageRole.USER, content="Hello!")]
83 )
85 # Low-level interface (for advanced users)
86 async with ChatLimiter(provider=Provider.OPENAI, api_key="sk-...") as limiter:
87 response = await limiter.request("POST", "/chat/completions", json=data)
88 """
90 def __init__(
91 self,
92 provider: Provider | None = None,
93 api_key: str | None = None,
94 base_url: str | None = None,
95 config: ProviderConfig | None = None,
96 http_client: httpx.AsyncClient | None = None,
97 sync_http_client: httpx.Client | None = None,
98 enable_adaptive_limits: bool = True,
99 enable_token_estimation: bool = True,
100 request_limit: int | None = None,
101 token_limit: int | None = None,
102 max_retries: int | None = None,
103 base_backoff: float | None = None,
104 timeout: float | None = None,
105 **kwargs: Any,
106 ):
107 """
108 Initialize the ChatLimiter.
110 Args:
111 provider: The API provider (OpenAI, Anthropic, OpenRouter)
112 api_key: API key for authentication
113 base_url: Base URL for API requests
114 config: Custom provider configuration
115 http_client: Custom async HTTP client
116 sync_http_client: Custom sync HTTP client
117 enable_adaptive_limits: Enable adaptive rate limit adjustment
118 enable_token_estimation: Enable token usage estimation
119 request_limit: Override request limit (if not provided, must be discovered from API)
120 token_limit: Override token limit (if not provided, must be discovered from API)
121 max_retries: Override max retries (defaults to 3 if not provided)
122 base_backoff: Override base backoff (defaults to 1.0 if not provided)
123 timeout: HTTP request timeout in seconds (defaults to 120.0 for better reliability)
124 **kwargs: Additional arguments passed to HTTP clients
125 """
126 # Determine provider and config
127 if config:
128 self.config = config
129 self.provider = config.provider
130 elif provider:
131 self.provider = provider
132 self.config = get_provider_config(provider)
133 elif base_url:
134 detected_provider = detect_provider_from_url(base_url)
135 if detected_provider:
136 self.provider = detected_provider
137 self.config = get_provider_config(detected_provider)
138 else:
139 raise ValueError(f"Could not detect provider from URL: {base_url}")
140 else:
141 raise ValueError("Must provide either provider, config, or base_url")
143 # Override base_url if provided
144 if base_url:
145 self.config.base_url = base_url
147 # Store configuration
148 self.api_key = api_key
149 self.enable_adaptive_limits = enable_adaptive_limits
150 self.enable_token_estimation = enable_token_estimation
152 # Store user-provided overrides
153 self._user_request_limit = request_limit
154 self._user_token_limit = token_limit
155 self._user_max_retries = max_retries or 3 # Default to 3 if not provided
156 self._user_base_backoff = base_backoff or 1.0 # Default to 1.0 if not provided
157 self._user_timeout = (
158 timeout or 120.0
159 ) # Default to 120 seconds for better reliability
161 # Determine initial limits (user override, config default, or None for discovery)
162 initial_request_limit = (
163 request_limit or self.config.default_request_limit or None
164 )
165 initial_token_limit = token_limit or self.config.default_token_limit or None
167 # Initialize state - will be None if no defaults and no discovery yet
168 self.state = LimiterState(
169 request_limit=initial_request_limit,
170 token_limit=initial_token_limit,
171 )
173 # Flag to track if we need to discover limits
174 self._limits_discovered = (
175 initial_request_limit is not None and initial_token_limit is not None
176 )
178 # Initialize HTTP clients
179 self._init_http_clients(http_client, sync_http_client, **kwargs)
181 # Initialize rate limiters
182 self._init_rate_limiters()
184 # Context manager state
185 self._async_context_active = False
186 self._sync_context_active = False
188 # Logging configuration
189 self._print_rate_limit_info = False
190 self._print_request_initiation = False
192 @classmethod
193 def for_model(
194 cls,
195 model: str,
196 api_key: str | None = None,
197 provider: str | Provider | None = None,
198 use_dynamic_discovery: bool = True,
199 request_limit: int | None = None,
200 token_limit: int | None = None,
201 max_retries: int | None = None,
202 base_backoff: float | None = None,
203 timeout: float | None = None,
204 **kwargs: Any,
205 ) -> "ChatLimiter":
206 """
207 Create a ChatLimiter instance automatically detecting the provider from the model name.
209 Args:
210 model: The model name (e.g., "gpt-4o", "claude-3-sonnet-20240229")
211 api_key: API key for the provider. If None, will be read from environment variables
212 (OPENAI_API_KEY, ANTHROPIC_API_KEY, OPENROUTER_API_KEY)
213 provider: Override provider detection. Can be "openai", "anthropic", "openrouter",
214 or Provider enum. If None, will be auto-detected from model name
215 use_dynamic_discovery: Whether to query live APIs for model availability (default: True).
216 Requires appropriate API keys to be available. Falls back to
217 hardcoded model lists when disabled or when API calls fail.
218 **kwargs: Additional arguments passed to ChatLimiter
220 Returns:
221 Configured ChatLimiter instance
223 Raises:
224 ValueError: If provider cannot be determined from model name or API key not found
226 Example:
227 # Auto-detect provider with dynamic discovery (default behavior)
228 async with ChatLimiter.for_model("gpt-4o") as limiter:
229 response = await limiter.simple_chat("gpt-4o", "Hello!")
231 # Override provider detection
232 async with ChatLimiter.for_model("custom-model", provider="openai") as limiter:
233 response = await limiter.simple_chat("custom-model", "Hello!")
235 # Disable dynamic discovery to use only hardcoded model lists
236 async with ChatLimiter.for_model("gpt-4o", use_dynamic_discovery=False) as limiter:
237 response = await limiter.simple_chat("gpt-4o", "Hello!")
238 """
239 import os
241 # Determine provider
242 if provider is not None:
243 # Use provided provider
244 if isinstance(provider, str):
245 provider_enum = Provider(provider)
246 else:
247 provider_enum = provider
248 provider_name = provider_enum.value
249 else:
250 # Auto-detect from model name
251 # If dynamic discovery is requested, we need to collect API keys first
252 api_keys_for_discovery = {}
253 if use_dynamic_discovery:
254 # Collect available API keys from environment
255 env_var_map = {
256 "openai": "OPENAI_API_KEY",
257 "anthropic": "ANTHROPIC_API_KEY",
258 "openrouter": "OPENROUTER_API_KEY",
259 }
261 for provider_key, env_var in env_var_map.items():
262 key_value = os.getenv(env_var)
263 if key_value:
264 api_keys_for_discovery[provider_key] = key_value
266 # Try dynamic discovery first to get more detailed information
267 discovery_result = None
268 detected_provider = detect_provider_from_model(
269 model, use_dynamic_discovery, api_keys_for_discovery
270 )
272 if not detected_provider:
273 discovery_msg = (
274 " with dynamic API discovery" if use_dynamic_discovery else ""
275 )
276 error_msg = f"Could not determine provider from model '{model}'{discovery_msg}. "
278 # Add detailed information about available models if we have discovery results
279 if discovery_result and discovery_result.get_total_models_found() > 0:
280 error_msg += f"\n\nFound {discovery_result.get_total_models_found()} models across providers:\n"
281 for (
282 provider_name,
283 models,
284 ) in discovery_result.get_all_models().items():
285 error_msg += f" {provider_name}: {len(models)} models\n"
286 for example in sorted(list(models)):
287 error_msg += f" - {example}\n"
288 error_msg += "\nPlease check the model name or specify the provider explicitly using the 'provider' parameter."
289 else:
290 error_msg += "Please specify the provider explicitly using the 'provider' parameter."
292 # Add information about discovery errors if any
293 if discovery_result and discovery_result.errors:
294 error_msg += "\n\nDiscovery errors encountered:\n"
295 for provider_name, error in discovery_result.errors.items():
296 error_msg += f" {provider_name}: {error}\n"
298 raise ValueError(error_msg)
299 assert detected_provider is not None # Help MyPy understand type narrowing
300 provider_name = detected_provider
301 provider_enum = Provider(provider_name)
303 # Determine API key
304 if api_key is None:
305 # Try to get from environment variables
306 env_var_map = {
307 "openai": "OPENAI_API_KEY",
308 "anthropic": "ANTHROPIC_API_KEY",
309 "openrouter": "OPENROUTER_API_KEY",
310 }
312 env_var_name: str | None = env_var_map.get(provider_name)
313 if env_var_name:
314 api_key = os.getenv(env_var_name)
315 if not api_key:
316 raise ValueError(
317 f"API key not provided and {env_var_name} environment variable not set. "
318 f"Please provide api_key parameter or set {env_var_name} environment variable."
319 )
320 else:
321 raise ValueError(
322 f"Unknown provider '{provider_name}'. Cannot determine environment variable for API key."
323 )
325 return cls(
326 provider=provider_enum,
327 api_key=api_key,
328 request_limit=request_limit,
329 token_limit=token_limit,
330 max_retries=max_retries,
331 base_backoff=base_backoff,
332 timeout=timeout,
333 **kwargs,
334 )
336 def _init_http_clients(
337 self,
338 http_client: httpx.AsyncClient | None,
339 sync_http_client: httpx.Client | None,
340 **kwargs: Any,
341 ) -> None:
342 """Initialize HTTP clients with proper headers."""
343 # Prepare headers
344 headers = {
345 "User-Agent": f"chat-limiter/0.1.0 ({self.provider.value})",
346 }
348 # Add provider-specific headers
349 if self.api_key:
350 if self.provider == Provider.OPENAI:
351 headers["Authorization"] = f"Bearer {self.api_key}"
352 elif self.provider == Provider.ANTHROPIC:
353 headers["x-api-key"] = self.api_key
354 headers["anthropic-version"] = "2023-06-01"
355 elif self.provider == Provider.OPENROUTER:
356 headers["Authorization"] = f"Bearer {self.api_key}"
357 headers["HTTP-Referer"] = "https://github.com/your-repo/chat-limiter"
359 # Merge with user-provided headers
360 if "headers" in kwargs:
361 headers.update(kwargs["headers"])
362 kwargs["headers"] = headers
364 # Initialize clients
365 if http_client:
366 self.async_client = http_client
367 else:
368 self.async_client = httpx.AsyncClient(
369 base_url=self.config.base_url,
370 timeout=httpx.Timeout(self._user_timeout), # Configurable timeout
371 **kwargs,
372 )
374 if sync_http_client:
375 self.sync_client = sync_http_client
376 else:
377 self.sync_client = httpx.Client(
378 base_url=self.config.base_url,
379 timeout=httpx.Timeout(self._user_timeout), # Configurable timeout
380 **kwargs,
381 )
383 def _init_rate_limiters(self) -> None:
384 """Initialize PyrateLimiter instances."""
385 # Only initialize if we have limits
386 if self.state.request_limit is None or self.state.token_limit is None:
387 # Cannot initialize rate limiters without limits
388 # This will be called again after limits are discovered
389 self.request_limiter = None
390 self.token_limiter = None
391 self._effective_request_limit = None
392 self._effective_token_limit = None
393 return
395 # Calculate effective limits with buffer
396 effective_request_limit = int(
397 self.state.request_limit * self.config.request_buffer_ratio
398 )
399 effective_token_limit = int(
400 self.state.token_limit * self.config.token_buffer_ratio
401 )
403 # Request rate limiter
404 self.request_limiter = Limiter(
405 Rate(
406 effective_request_limit,
407 Duration.MINUTE,
408 )
409 )
411 # Token rate limiter
412 self.token_limiter = Limiter(
413 Rate(
414 effective_token_limit,
415 Duration.MINUTE,
416 )
417 )
419 # Store effective limits for logging
420 self._effective_request_limit = effective_request_limit
421 self._effective_token_limit = effective_token_limit
423 async def __aenter__(self) -> "ChatLimiter":
424 """Async context manager entry."""
425 if self._async_context_active:
426 raise RuntimeError(
427 "ChatLimiter is already active as an async context manager"
428 )
430 self._async_context_active = True
432 # Discover rate limits if supported
433 if self.config.supports_dynamic_limits:
434 await self._discover_rate_limits()
436 # Print rate limit information if enabled
437 if self._print_rate_limit_info:
438 self._print_rate_limit_info_details()
440 return self
442 async def __aexit__(
443 self,
444 exc_type: type[BaseException] | None,
445 exc_val: BaseException | None,
446 exc_tb: object,
447 ) -> None:
448 """Async context manager exit."""
449 self._async_context_active = False
450 await self.async_client.aclose()
452 def __enter__(self) -> "ChatLimiter":
453 """Sync context manager entry."""
454 if self._sync_context_active:
455 raise RuntimeError(
456 "ChatLimiter is already active as a sync context manager"
457 )
459 self._sync_context_active = True
461 # Discover rate limits if supported
462 if self.config.supports_dynamic_limits:
463 self._discover_rate_limits_sync()
465 # Print rate limit information if enabled
466 if self._print_rate_limit_info:
467 self._print_rate_limit_info_details()
469 return self
471 def __exit__(
472 self,
473 exc_type: type[BaseException] | None,
474 exc_val: BaseException | None,
475 exc_tb: object,
476 ) -> None:
477 """Sync context manager exit."""
478 self._sync_context_active = False
479 self.sync_client.close()
481 async def _discover_rate_limits(self) -> None:
482 """Discover current rate limits from the API."""
483 try:
484 if self.provider == Provider.OPENROUTER and self.config.auth_endpoint:
485 # OpenRouter uses a special auth endpoint
486 response = await self.async_client.get(self.config.auth_endpoint)
487 response.raise_for_status()
489 data = response.json()
490 # Update limits based on response
491 # This is a simplified version - actual implementation would parse the response
492 logger.info(f"Discovered OpenRouter limits: {data}")
494 else:
495 # For other providers, we'll discover limits on first request
496 if self._print_rate_limit_info:
497 print(
498 f"Rate limit discovery will happen on first request for {self.provider.value}"
499 )
500 logger.info(
501 f"Rate limit discovery will happen on first request for {self.provider.value}"
502 )
504 except Exception as e:
505 logger.warning(f"Failed to discover rate limits: {e}")
507 def _discover_rate_limits_sync(self) -> None:
508 """Sync version of rate limit discovery."""
509 try:
510 if self.provider == Provider.OPENROUTER and self.config.auth_endpoint:
511 response = self.sync_client.get(self.config.auth_endpoint)
512 response.raise_for_status()
514 data = response.json()
515 logger.info(f"Discovered OpenRouter limits: {data}")
516 else:
517 logger.info(
518 f"Rate limit discovery will happen on first request for {self.provider.value}"
519 )
521 except Exception as e:
522 logger.warning(f"Failed to discover rate limits: {e}")
524 def _update_rate_limits(self, rate_limit_info: RateLimitInfo) -> None:
525 """Update rate limits based on response headers."""
526 updated = False
527 was_uninitialized = (
528 self.state.request_limit is None or self.state.token_limit is None
529 )
531 # Update request limits
532 if (
533 rate_limit_info.requests_limit
534 and rate_limit_info.requests_limit != self.state.request_limit
535 ):
536 old_limit = self.state.request_limit
537 self.state.request_limit = rate_limit_info.requests_limit
538 updated = True
539 if was_uninitialized:
540 message = (
541 f"Discovered request limit: {self.state.request_limit} req/min"
542 )
543 if self._print_rate_limit_info:
544 print(message)
545 logger.info(message)
546 else:
547 message = f"Updated request limit: {old_limit} -> {self.state.request_limit} req/min"
548 if self._print_rate_limit_info:
549 print(message)
550 logger.info(message)
552 # Update token limits
553 if (
554 rate_limit_info.tokens_limit
555 and rate_limit_info.tokens_limit != self.state.token_limit
556 ):
557 old_limit = self.state.token_limit
558 self.state.token_limit = rate_limit_info.tokens_limit
559 updated = True
560 if was_uninitialized:
561 message = f"Discovered token limit: {self.state.token_limit} tokens/min"
562 if self._print_rate_limit_info:
563 print(message)
564 logger.info(message)
565 else:
566 message = f"Updated token limit: {old_limit} -> {self.state.token_limit} tokens/min"
567 if self._print_rate_limit_info:
568 print(message)
569 logger.info(message)
571 if updated:
572 # Reinitialize rate limiters with new limits
573 self._init_rate_limiters()
575 # Update limits_discovered flag if both limits are now available
576 if (
577 self.state.request_limit is not None
578 and self.state.token_limit is not None
579 ):
580 self._limits_discovered = True
582 if was_uninitialized:
583 message = "Rate limiters initialized after discovery"
584 if self._print_rate_limit_info:
585 print(message)
586 # Print updated rate limit info after discovery
587 self._print_rate_limit_info_details()
588 logger.info(message)
590 # Store the rate limit info
591 self.state.last_rate_limit_info = rate_limit_info
592 self.state.last_limit_update = time.time()
594 def _estimate_tokens(self, request_data: dict[str, Any]) -> int:
595 """Estimate token usage from request data."""
596 if not self.enable_token_estimation:
597 return 0
599 # Simple token estimation
600 # This is a placeholder - real implementation would use tiktoken or similar
601 if "messages" in request_data:
602 text = ""
603 for message in request_data["messages"]:
604 if isinstance(message, dict) and "content" in message:
605 text += str(message["content"])
607 # Rough estimation: 1 token ≈ 4 characters
608 return len(text) // 4
610 return 0
612 @asynccontextmanager
613 async def _acquire_rate_limits(
614 self, estimated_tokens: int = 0
615 ) -> AsyncIterator[None]:
616 """Acquire rate limits before making a request."""
617 # Check if rate limiters are initialized
618 if self.request_limiter is None or self.token_limiter is None:
619 # Limits not yet discovered - this request will help discover them
620 logger.info(
621 "Rate limits not yet discovered, proceeding without rate limiting for discovery"
622 )
623 else:
624 # Wait for request rate limit
625 await asyncio.to_thread(self.request_limiter.try_acquire, "request")
627 # Wait for token rate limit if we have token estimation and limiters are initialized
628 if (
629 estimated_tokens > 0
630 and self.token_limiter is not None
631 and self._effective_token_limit is not None
632 ):
633 # Check if request is too large for bucket capacity
634 if estimated_tokens > self._effective_token_limit:
635 # Log warning for large requests
636 logger.warning(
637 f"Request estimated at {estimated_tokens} tokens exceeds bucket capacity "
638 f"of {self._effective_token_limit} tokens. This may cause delays."
639 )
640 # For very large requests, we'll split the acquisition
641 # Acquire tokens in chunks to avoid bucket overflow
642 remaining_tokens = estimated_tokens
643 while remaining_tokens > 0:
644 chunk_size = min(
645 remaining_tokens, self._effective_token_limit // 2
646 )
647 await asyncio.to_thread(
648 self.token_limiter.try_acquire, "token", chunk_size
649 )
650 remaining_tokens -= chunk_size
651 if remaining_tokens > 0:
652 # Brief pause to let bucket refill
653 await asyncio.sleep(0.1)
654 else:
655 # Normal acquisition for smaller requests
656 await asyncio.to_thread(
657 self.token_limiter.try_acquire, "token", estimated_tokens
658 )
660 try:
661 yield
662 finally:
663 # Update usage tracking
664 self.state.requests_used += 1
665 self.state.tokens_used += estimated_tokens
666 self.state.last_request_time = time.time()
668 @contextmanager
669 def _acquire_rate_limits_sync(self, estimated_tokens: int = 0) -> Iterator[None]:
670 """Sync version of rate limit acquisition."""
671 # Check if rate limiters are initialized
672 if self.request_limiter is None or self.token_limiter is None:
673 # Limits not yet discovered - this request will help discover them
674 logger.info(
675 "Rate limits not yet discovered, proceeding without rate limiting for discovery"
676 )
677 else:
678 # Wait for request rate limit
679 self.request_limiter.try_acquire("request")
681 # Wait for token rate limit if we have token estimation and limiters are initialized
682 if (
683 estimated_tokens > 0
684 and self.token_limiter is not None
685 and self._effective_token_limit is not None
686 ):
687 # Check if request is too large for bucket capacity
688 if estimated_tokens > self._effective_token_limit:
689 # Log warning for large requests
690 logger.warning(
691 f"Request estimated at {estimated_tokens} tokens exceeds bucket capacity "
692 f"of {self._effective_token_limit} tokens. This may cause delays."
693 )
694 # For very large requests, we'll split the acquisition
695 # Acquire tokens in chunks to avoid bucket overflow
696 remaining_tokens = estimated_tokens
697 while remaining_tokens > 0:
698 chunk_size = min(
699 remaining_tokens, self._effective_token_limit // 2
700 )
701 self.token_limiter.try_acquire("token", chunk_size)
702 remaining_tokens -= chunk_size
703 if remaining_tokens > 0:
704 # Brief pause to let bucket refill
705 time.sleep(0.1)
706 else:
707 # Normal acquisition for smaller requests
708 self.token_limiter.try_acquire("token", estimated_tokens)
710 try:
711 yield
712 finally:
713 # Update usage tracking
714 self.state.requests_used += 1
715 self.state.tokens_used += estimated_tokens
716 self.state.last_request_time = time.time()
718 def _get_retry_decorator(self) -> Any:
719 """Get retry decorator with user-configured parameters."""
720 return retry(
721 stop=stop_after_attempt(self._user_max_retries),
722 wait=wait_exponential(multiplier=self._user_base_backoff, min=1, max=60),
723 retry=retry_if_exception_type(
724 (
725 httpx.HTTPStatusError,
726 httpx.RequestError,
727 httpx.ReadTimeout,
728 httpx.ConnectTimeout,
729 )
730 ),
731 )
733 def get_current_limits(self) -> dict[str, Any]:
734 """Get current rate limit information."""
735 return {
736 "provider": self.provider.value,
737 "request_limit": self.state.request_limit,
738 "token_limit": self.state.token_limit,
739 "requests_used": self.state.requests_used,
740 "tokens_used": self.state.tokens_used,
741 "last_request_time": self.state.last_request_time,
742 "last_limit_update": self.state.last_limit_update,
743 "consecutive_rate_limit_errors": self.state.consecutive_rate_limit_errors,
744 }
746 def reset_usage_tracking(self) -> None:
747 """Reset usage tracking counters."""
748 self.state.requests_used = 0
749 self.state.tokens_used = 0
750 self.state.consecutive_rate_limit_errors = 0
752 # High-level chat completion methods
754 async def chat_completion(
755 self,
756 model: str,
757 messages: list[Message],
758 max_tokens: int | None = None,
759 temperature: float | None = None,
760 top_p: float | None = None,
761 stop: str | list[str] | None = None,
762 stream: bool = False,
763 **kwargs: Any,
764 ) -> ChatCompletionResponse:
765 """
766 Make a high-level chat completion request.
768 Args:
769 model: The model to use for completion
770 messages: List of messages in the conversation
771 max_tokens: Maximum tokens to generate
772 temperature: Sampling temperature
773 top_p: Top-p sampling parameter
774 stop: Stop sequences
775 stream: Whether to stream the response
776 **kwargs: Additional provider-specific parameters
778 Returns:
779 ChatCompletionResponse with the completion result
781 Raises:
782 ValueError: If provider cannot be determined from model
783 httpx.HTTPStatusError: For HTTP error responses
784 httpx.RequestError: For request errors
785 """
786 if not self._async_context_active:
787 raise RuntimeError("ChatLimiter must be used as an async context manager")
789 # Create request object
790 request = ChatCompletionRequest(
791 model=model,
792 messages=messages,
793 max_tokens=max_tokens,
794 temperature=temperature,
795 top_p=top_p,
796 stop=stop,
797 stream=stream,
798 **kwargs,
799 )
801 # Get the appropriate adapter
802 adapter = get_adapter(self.provider)
804 # Format the request for the provider
805 formatted_request = adapter.format_request(request)
807 # Make the HTTP request with rate limiting
808 try:
809 # Print request initiation if enabled
810 if self._print_request_initiation:
811 print(f"Sending request for model {model} (attempt 1)")
813 # Estimate tokens
814 estimated_tokens = self._estimate_tokens(formatted_request)
816 # Acquire rate limits
817 async with self._acquire_rate_limits(estimated_tokens):
818 # Make the request
819 response = await self.async_client.request(
820 "POST", adapter.get_endpoint(), json=formatted_request
821 )
823 # Extract rate limit info
824 from .providers import extract_rate_limit_info
825 rate_limit_info = extract_rate_limit_info(
826 dict(response.headers), self.config
827 )
829 # Update our rate limits
830 if self.enable_adaptive_limits:
831 self._update_rate_limits(rate_limit_info)
833 # Handle rate limit errors
834 if response.status_code == 429:
835 self.state.consecutive_rate_limit_errors += 1
836 if rate_limit_info.retry_after:
837 import asyncio
838 await asyncio.sleep(rate_limit_info.retry_after)
839 else:
840 # Exponential backoff
841 import asyncio
842 backoff = self.config.base_backoff * (
843 2**self.state.consecutive_rate_limit_errors
844 )
845 await asyncio.sleep(min(backoff, self.config.max_backoff))
847 response.raise_for_status()
848 else:
849 # Reset consecutive errors on success
850 self.state.consecutive_rate_limit_errors = 0
852 # Parse the response
853 response_data = response.json()
854 return adapter.parse_response(response_data, request)
856 except Exception as e:
857 # Handle errors and return error response
858 error_response = ChatCompletionResponse(
859 id="error",
860 model=request.model,
861 success=False,
862 error_message=str(e),
863 choices=[],
864 usage=None,
865 created=None,
866 )
867 return error_response
869 def chat_completion_sync(
870 self,
871 model: str,
872 messages: list[Message],
873 max_tokens: int | None = None,
874 temperature: float | None = None,
875 top_p: float | None = None,
876 stop: str | list[str] | None = None,
877 stream: bool = False,
878 **kwargs: Any,
879 ) -> ChatCompletionResponse:
880 """
881 Make a synchronous high-level chat completion request.
883 Args:
884 model: The model to use for completion
885 messages: List of messages in the conversation
886 max_tokens: Maximum tokens to generate
887 temperature: Sampling temperature
888 top_p: Top-p sampling parameter
889 stop: Stop sequences
890 stream: Whether to stream the response
891 **kwargs: Additional provider-specific parameters
893 Returns:
894 ChatCompletionResponse with the completion result
896 Raises:
897 ValueError: If provider cannot be determined from model
898 httpx.HTTPStatusError: For HTTP error responses
899 httpx.RequestError: For request errors
900 """
901 if not self._sync_context_active:
902 raise RuntimeError("ChatLimiter must be used as a sync context manager")
904 # Create request object
905 request = ChatCompletionRequest(
906 model=model,
907 messages=messages,
908 max_tokens=max_tokens,
909 temperature=temperature,
910 top_p=top_p,
911 stop=stop,
912 stream=stream,
913 **kwargs,
914 )
916 # Get the appropriate adapter
917 adapter = get_adapter(self.provider)
919 # Format the request for the provider
920 formatted_request = adapter.format_request(request)
922 # Make the HTTP request with rate limiting
923 try:
924 # Print request initiation if enabled
925 if self._print_request_initiation:
926 print(f"Sending request for model {model} (attempt 1)")
928 # Estimate tokens
929 estimated_tokens = self._estimate_tokens(formatted_request)
931 # Acquire rate limits
932 with self._acquire_rate_limits_sync(estimated_tokens):
933 # Make the request
934 response = self.sync_client.request(
935 "POST", adapter.get_endpoint(), json=formatted_request
936 )
938 # Extract rate limit info
939 from .providers import extract_rate_limit_info
940 rate_limit_info = extract_rate_limit_info(
941 dict(response.headers), self.config
942 )
944 # Update our rate limits
945 if self.enable_adaptive_limits:
946 self._update_rate_limits(rate_limit_info)
948 # Handle rate limit errors
949 if response.status_code == 429:
950 self.state.consecutive_rate_limit_errors += 1
951 if rate_limit_info.retry_after:
952 import time
953 time.sleep(rate_limit_info.retry_after)
954 else:
955 # Exponential backoff
956 import time
957 backoff = self.config.base_backoff * (
958 2**self.state.consecutive_rate_limit_errors
959 )
960 time.sleep(min(backoff, self.config.max_backoff))
962 response.raise_for_status()
963 else:
964 # Reset consecutive errors on success
965 self.state.consecutive_rate_limit_errors = 0
967 # Parse the response
968 response_data = response.json()
969 return adapter.parse_response(response_data, request)
971 except Exception as e:
972 # Handle errors and return error response
973 error_response = ChatCompletionResponse(
974 id="error",
975 model=request.model,
976 success=False,
977 error_message=str(e),
978 choices=[],
979 usage=None,
980 created=None,
981 )
982 return error_response
984 # Convenience methods for different message types
986 async def simple_chat(
987 self,
988 model: str,
989 prompt: str,
990 max_tokens: int | None = None,
991 temperature: float | None = None,
992 **kwargs: Any,
993 ) -> str:
994 """
995 Simple chat completion that returns just the text response.
997 Args:
998 model: The model to use
999 prompt: The user prompt
1000 max_tokens: Maximum tokens to generate
1001 temperature: Sampling temperature
1002 **kwargs: Additional parameters
1004 Returns:
1005 The text response from the model
1006 """
1007 messages = [Message(role=MessageRole.USER, content=prompt)]
1008 response = await self.chat_completion(
1009 model=model,
1010 messages=messages,
1011 max_tokens=max_tokens,
1012 temperature=temperature,
1013 **kwargs,
1014 )
1016 if response.choices:
1017 return response.choices[0].message.content
1018 return ""
1020 def simple_chat_sync(
1021 self,
1022 model: str,
1023 prompt: str,
1024 max_tokens: int | None = None,
1025 temperature: float | None = None,
1026 **kwargs: Any,
1027 ) -> str:
1028 """
1029 Simple synchronous chat completion that returns just the text response.
1031 Args:
1032 model: The model to use
1033 prompt: The user prompt
1034 max_tokens: Maximum tokens to generate
1035 temperature: Sampling temperature
1036 **kwargs: Additional parameters
1038 Returns:
1039 The text response from the model
1040 """
1041 messages = [Message(role=MessageRole.USER, content=prompt)]
1042 response = self.chat_completion_sync(
1043 model=model,
1044 messages=messages,
1045 max_tokens=max_tokens,
1046 temperature=temperature,
1047 **kwargs,
1048 )
1050 if response.choices:
1051 return response.choices[0].message.content
1052 return ""
1054 def set_print_rate_limit_info(self, enabled: bool) -> None:
1055 """Set whether to print rate limit information."""
1056 self._print_rate_limit_info = enabled
1058 def set_print_request_initiation(self, enabled: bool) -> None:
1059 """Set whether to print request initiation messages."""
1060 self._print_request_initiation = enabled
1062 def _print_rate_limit_info_details(self) -> None:
1063 """Print current rate limit configuration."""
1064 print(f"\n=== Rate Limit Configuration for {self.provider.value.title()} ===")
1065 print(f"Provider: {self.provider.value}")
1066 print(f"Base URL: {self.config.base_url}")
1068 # Handle None values for limits
1069 if self.state.request_limit is not None:
1070 effective_req = self._effective_request_limit or "not calculated"
1071 print(
1072 f"Request Limit: {self.state.request_limit}/minute (effective: {effective_req}/minute)"
1073 )
1074 else:
1075 print("Request Limit: Not yet discovered (will be fetched from API)")
1077 if self.state.token_limit is not None:
1078 effective_tok = self._effective_token_limit or "not calculated"
1079 print(
1080 f"Token Limit: {self.state.token_limit}/minute (effective: {effective_tok}/minute)"
1081 )
1082 else:
1083 print("Token Limit: Not yet discovered (will be fetched from API)")
1085 print(f"Request Buffer Ratio: {self.config.request_buffer_ratio}")
1086 print(f"Token Buffer Ratio: {self.config.token_buffer_ratio}")
1087 print(f"Adaptive Limits: {self.enable_adaptive_limits}")
1088 print(f"Token Estimation: {self.enable_token_estimation}")
1089 print(f"Dynamic Discovery: {self.config.supports_dynamic_limits}")
1090 print(f"Limits Discovered: {self._limits_discovered}")
1091 print("=" * 50)