Coverage for src/chat_limiter/models.py: 91%
184 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-18 21:15 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-18 21:15 +0100
1"""
2Dynamic model discovery from provider APIs.
4This module provides functionality to query provider APIs for available models
5instead of relying on hardcoded lists.
6"""
8import asyncio
9import logging
10from dataclasses import dataclass
11from datetime import datetime, timedelta
12from typing import Any
14import httpx
16logger = logging.getLogger(__name__)
18# Cache for model lists to avoid hitting APIs too frequently
19_model_cache: dict[str, dict[str, Any]] = {}
20_cache_duration = timedelta(hours=1) # Cache models for 1 hour
23from .utils import run_coro_blocking
26@dataclass
27class ModelDiscoveryResult:
28 """Result of model discovery process."""
30 # Discovery result
31 found_provider: str | None = None
32 model_found: bool = False
34 # All models found for each provider
35 openai_models: set[str] | None = None
36 anthropic_models: set[str] | None = None
37 openrouter_models: set[str] | None = None
39 # Errors encountered during discovery
40 errors: dict[str, str] | None = None
42 def get_all_models(self) -> dict[str, set[str]]:
43 """Get all models organized by provider."""
44 result = {}
45 if self.openai_models is not None:
46 result["openai"] = self.openai_models
47 if self.anthropic_models is not None:
48 result["anthropic"] = self.anthropic_models
49 if self.openrouter_models is not None:
50 result["openrouter"] = self.openrouter_models
51 return result
53 def get_total_models_found(self) -> int:
54 """Get total number of models found across all providers."""
55 total = 0
56 if self.openai_models:
57 total += len(self.openai_models)
58 if self.anthropic_models:
59 total += len(self.anthropic_models)
60 if self.openrouter_models:
61 total += len(self.openrouter_models)
62 return total
65class ModelDiscovery:
66 """Dynamic model discovery from provider APIs."""
68 @staticmethod
69 async def get_openai_models(api_key: str) -> set[str]:
70 """Get available OpenAI models from the API."""
71 cache_key = f"openai_models_{hash(api_key)}"
73 # Check cache first
74 if _model_cache.get(cache_key):
75 cache_entry = _model_cache[cache_key]
76 if datetime.now() - cache_entry["timestamp"] < _cache_duration:
77 return cache_entry["models"] # type: ignore[no-any-return]
79 try:
80 async with httpx.AsyncClient() as client:
81 response = await client.get(
82 "https://api.openai.com/v1/models",
83 headers={"Authorization": f"Bearer {api_key}"},
84 timeout=10.0
85 )
86 response.raise_for_status()
88 data = response.json()
89 models = set()
91 for model in data.get("data", []):
92 model_id = model.get("id", "")
93 models.add(model_id)
95 # Cache the result
96 _model_cache[cache_key] = {
97 "models": models,
98 "timestamp": datetime.now()
99 }
101 logger.info(f"Retrieved {len(models)} OpenAI models from API")
102 return models
104 except Exception as e:
105 logger.warning(f"Failed to fetch OpenAI models: {e}")
106 raise
108 @staticmethod
109 async def get_anthropic_models(api_key: str) -> set[str]:
110 """Get available Anthropic models from the API."""
111 cache_key = f"anthropic_models_{hash(api_key)}"
113 # Check cache first
114 if _model_cache.get(cache_key):
115 cache_entry = _model_cache[cache_key]
116 if datetime.now() - cache_entry["timestamp"] < _cache_duration:
117 return cache_entry["models"] # type: ignore[no-any-return]
119 try:
120 async with httpx.AsyncClient() as client:
121 response = await client.get(
122 "https://api.anthropic.com/v1/models",
123 headers={
124 "x-api-key": api_key,
125 "anthropic-version": "2023-06-01"
126 },
127 timeout=10.0
128 )
129 response.raise_for_status()
131 data = response.json()
132 models = set()
134 for model in data.get("data", []):
135 model_id = model.get("id", "")
136 models.add(model_id)
138 # Cache the result
139 _model_cache[cache_key] = {
140 "models": models,
141 "timestamp": datetime.now()
142 }
144 logger.info(f"Retrieved {len(models)} Anthropic models from API")
145 return models
147 except Exception as e:
148 logger.warning(f"Failed to fetch Anthropic models: {e}")
149 raise
151 @staticmethod
152 async def get_openrouter_models(api_key: str | None = None) -> set[str]:
153 """Get available OpenRouter models from the API."""
154 cache_key = "openrouter_models"
156 # Check cache first
157 if _model_cache.get(cache_key):
158 cache_entry = _model_cache[cache_key]
159 if datetime.now() - cache_entry["timestamp"] < _cache_duration:
160 return cache_entry["models"] # type: ignore[no-any-return]
162 try:
163 headers = {}
164 if api_key:
165 headers["Authorization"] = f"Bearer {api_key}"
167 async with httpx.AsyncClient() as client:
168 response = await client.get(
169 "https://openrouter.ai/api/v1/models",
170 headers=headers,
171 timeout=10.0
172 )
173 response.raise_for_status()
175 data = response.json()
176 models = set()
178 for model in data.get("data", []):
179 model_id = model.get("id", "")
180 if model_id:
181 models.add(model_id)
183 # Cache the result
184 _model_cache[cache_key] = {
185 "models": models,
186 "timestamp": datetime.now()
187 }
189 logger.info(f"Retrieved {len(models)} OpenRouter models from API")
190 return models
192 except Exception as e:
193 logger.warning(f"Failed to fetch OpenRouter models: {e}")
194 raise
196 @staticmethod
197 def get_openai_models_sync(api_key: str) -> set[str]:
198 """Synchronous version of get_openai_models."""
199 return run_coro_blocking(ModelDiscovery.get_openai_models(api_key))
201 @staticmethod
202 def get_anthropic_models_sync(api_key: str) -> set[str]:
203 """Synchronous version of get_anthropic_models."""
204 return run_coro_blocking(ModelDiscovery.get_anthropic_models(api_key))
206 @staticmethod
207 def get_openrouter_models_sync(api_key: str | None = None) -> set[str]:
208 """Synchronous version of get_openrouter_models."""
209 return run_coro_blocking(ModelDiscovery.get_openrouter_models(api_key))
212async def detect_provider_from_model_async(
213 model: str,
214 api_keys: dict[str, str] | None = None
215) -> ModelDiscoveryResult:
216 """
217 Detect provider from model name using live API queries.
219 Args:
220 model: The model name to check
221 api_keys: Dictionary of API keys {"openai": "sk-...", "anthropic": "sk-ant-..."}
223 Returns:
224 ModelDiscoveryResult with discovery information
225 """
226 if not api_keys:
227 api_keys = {}
229 result = ModelDiscoveryResult(errors={})
231 # Handle provider-prefixed models (e.g., "openai/o3", "anthropic/claude-3-sonnet")
232 preferred_provider = None
233 base_model = model
235 if "/" in model:
236 parts = model.split("/", 1)
237 if len(parts) == 2:
238 provider_prefix, base_model = parts
239 if provider_prefix == "openai":
240 preferred_provider = "openai"
241 elif provider_prefix == "anthropic":
242 preferred_provider = "anthropic"
244 # Create all tasks
245 tasks = []
247 if api_keys.get("openai"):
248 tasks.append(("openai", ModelDiscovery.get_openai_models(api_keys["openai"])))
250 if api_keys.get("anthropic"):
251 tasks.append(("anthropic", ModelDiscovery.get_anthropic_models(api_keys["anthropic"])))
253 if api_keys.get("openrouter"):
254 tasks.append(("openrouter", ModelDiscovery.get_openrouter_models(api_keys["openrouter"])))
255 else:
256 # OpenRouter doesn't require API key for model listing
257 tasks.append(("openrouter", ModelDiscovery.get_openrouter_models()))
259 # Use asyncio.gather to run all tasks concurrently and properly handle them
260 try:
261 # Extract just the coroutines for gather
262 coroutines = [task[1] for task in tasks]
263 provider_names = [task[0] for task in tasks]
265 # Wait for all results
266 results = await asyncio.gather(*coroutines, return_exceptions=True)
268 # Process results and store all model information
269 for provider_name, models_result in zip(provider_names, results, strict=False):
270 if isinstance(models_result, Exception):
271 logger.debug(f"Failed to check {provider_name} for model {model}: {models_result}")
272 if result.errors is not None:
273 result.errors[provider_name] = str(models_result)
274 continue
276 # Store models in result
277 if provider_name == "openai" and isinstance(models_result, set):
278 result.openai_models = models_result
279 elif provider_name == "anthropic" and isinstance(models_result, set):
280 result.anthropic_models = models_result
281 elif provider_name == "openrouter" and isinstance(models_result, set):
282 result.openrouter_models = models_result
284 # Determine the best provider to use
285 if preferred_provider and not result.model_found:
286 # Check if base model exists in preferred provider
287 provider_models = None
288 if preferred_provider == "openai" and result.openai_models:
289 provider_models = result.openai_models
290 elif preferred_provider == "anthropic" and result.anthropic_models:
291 provider_models = result.anthropic_models
293 if provider_models and base_model in provider_models:
294 result.found_provider = preferred_provider
295 result.model_found = True
296 elif result.openrouter_models and model in result.openrouter_models:
297 # Fallback to OpenRouter if base model not found in preferred provider
298 result.found_provider = "openrouter"
299 result.model_found = True
301 # For models without provider prefix, use original logic
302 if not result.model_found:
303 for provider_name, models_result in zip(provider_names, results, strict=False):
304 if isinstance(models_result, Exception):
305 continue
307 # Check if our target model was found
308 if isinstance(models_result, set) and model in models_result:
309 result.found_provider = provider_name
310 result.model_found = True
311 break
313 except Exception as e:
314 logger.debug(f"Failed to run dynamic discovery for model {model}: {e}")
315 if result.errors is not None:
316 result.errors["general"] = str(e)
318 return result
321def detect_provider_from_model_sync(
322 model: str,
323 api_keys: dict[str, str] | None = None
324) -> ModelDiscoveryResult:
325 """Synchronous version of detect_provider_from_model_async."""
326 return run_coro_blocking(detect_provider_from_model_async(model, api_keys))
329def clear_model_cache() -> None:
330 """Clear the model cache to force fresh API queries."""
331 global _model_cache
332 _model_cache.clear()
333 logger.info("Model cache cleared")