Coverage for src/chat_limiter/models.py: 91%

184 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-09-18 21:15 +0100

1""" 

2Dynamic model discovery from provider APIs. 

3 

4This module provides functionality to query provider APIs for available models 

5instead of relying on hardcoded lists. 

6""" 

7 

8import asyncio 

9import logging 

10from dataclasses import dataclass 

11from datetime import datetime, timedelta 

12from typing import Any 

13 

14import httpx 

15 

16logger = logging.getLogger(__name__) 

17 

18# Cache for model lists to avoid hitting APIs too frequently 

19_model_cache: dict[str, dict[str, Any]] = {} 

20_cache_duration = timedelta(hours=1) # Cache models for 1 hour 

21 

22 

23from .utils import run_coro_blocking 

24 

25 

26@dataclass 

27class ModelDiscoveryResult: 

28 """Result of model discovery process.""" 

29 

30 # Discovery result 

31 found_provider: str | None = None 

32 model_found: bool = False 

33 

34 # All models found for each provider 

35 openai_models: set[str] | None = None 

36 anthropic_models: set[str] | None = None 

37 openrouter_models: set[str] | None = None 

38 

39 # Errors encountered during discovery 

40 errors: dict[str, str] | None = None 

41 

42 def get_all_models(self) -> dict[str, set[str]]: 

43 """Get all models organized by provider.""" 

44 result = {} 

45 if self.openai_models is not None: 

46 result["openai"] = self.openai_models 

47 if self.anthropic_models is not None: 

48 result["anthropic"] = self.anthropic_models 

49 if self.openrouter_models is not None: 

50 result["openrouter"] = self.openrouter_models 

51 return result 

52 

53 def get_total_models_found(self) -> int: 

54 """Get total number of models found across all providers.""" 

55 total = 0 

56 if self.openai_models: 

57 total += len(self.openai_models) 

58 if self.anthropic_models: 

59 total += len(self.anthropic_models) 

60 if self.openrouter_models: 

61 total += len(self.openrouter_models) 

62 return total 

63 

64 

65class ModelDiscovery: 

66 """Dynamic model discovery from provider APIs.""" 

67 

68 @staticmethod 

69 async def get_openai_models(api_key: str) -> set[str]: 

70 """Get available OpenAI models from the API.""" 

71 cache_key = f"openai_models_{hash(api_key)}" 

72 

73 # Check cache first 

74 if _model_cache.get(cache_key): 

75 cache_entry = _model_cache[cache_key] 

76 if datetime.now() - cache_entry["timestamp"] < _cache_duration: 

77 return cache_entry["models"] # type: ignore[no-any-return] 

78 

79 try: 

80 async with httpx.AsyncClient() as client: 

81 response = await client.get( 

82 "https://api.openai.com/v1/models", 

83 headers={"Authorization": f"Bearer {api_key}"}, 

84 timeout=10.0 

85 ) 

86 response.raise_for_status() 

87 

88 data = response.json() 

89 models = set() 

90 

91 for model in data.get("data", []): 

92 model_id = model.get("id", "") 

93 models.add(model_id) 

94 

95 # Cache the result 

96 _model_cache[cache_key] = { 

97 "models": models, 

98 "timestamp": datetime.now() 

99 } 

100 

101 logger.info(f"Retrieved {len(models)} OpenAI models from API") 

102 return models 

103 

104 except Exception as e: 

105 logger.warning(f"Failed to fetch OpenAI models: {e}") 

106 raise 

107 

108 @staticmethod 

109 async def get_anthropic_models(api_key: str) -> set[str]: 

110 """Get available Anthropic models from the API.""" 

111 cache_key = f"anthropic_models_{hash(api_key)}" 

112 

113 # Check cache first 

114 if _model_cache.get(cache_key): 

115 cache_entry = _model_cache[cache_key] 

116 if datetime.now() - cache_entry["timestamp"] < _cache_duration: 

117 return cache_entry["models"] # type: ignore[no-any-return] 

118 

119 try: 

120 async with httpx.AsyncClient() as client: 

121 response = await client.get( 

122 "https://api.anthropic.com/v1/models", 

123 headers={ 

124 "x-api-key": api_key, 

125 "anthropic-version": "2023-06-01" 

126 }, 

127 timeout=10.0 

128 ) 

129 response.raise_for_status() 

130 

131 data = response.json() 

132 models = set() 

133 

134 for model in data.get("data", []): 

135 model_id = model.get("id", "") 

136 models.add(model_id) 

137 

138 # Cache the result 

139 _model_cache[cache_key] = { 

140 "models": models, 

141 "timestamp": datetime.now() 

142 } 

143 

144 logger.info(f"Retrieved {len(models)} Anthropic models from API") 

145 return models 

146 

147 except Exception as e: 

148 logger.warning(f"Failed to fetch Anthropic models: {e}") 

149 raise 

150 

151 @staticmethod 

152 async def get_openrouter_models(api_key: str | None = None) -> set[str]: 

153 """Get available OpenRouter models from the API.""" 

154 cache_key = "openrouter_models" 

155 

156 # Check cache first 

157 if _model_cache.get(cache_key): 

158 cache_entry = _model_cache[cache_key] 

159 if datetime.now() - cache_entry["timestamp"] < _cache_duration: 

160 return cache_entry["models"] # type: ignore[no-any-return] 

161 

162 try: 

163 headers = {} 

164 if api_key: 

165 headers["Authorization"] = f"Bearer {api_key}" 

166 

167 async with httpx.AsyncClient() as client: 

168 response = await client.get( 

169 "https://openrouter.ai/api/v1/models", 

170 headers=headers, 

171 timeout=10.0 

172 ) 

173 response.raise_for_status() 

174 

175 data = response.json() 

176 models = set() 

177 

178 for model in data.get("data", []): 

179 model_id = model.get("id", "") 

180 if model_id: 

181 models.add(model_id) 

182 

183 # Cache the result 

184 _model_cache[cache_key] = { 

185 "models": models, 

186 "timestamp": datetime.now() 

187 } 

188 

189 logger.info(f"Retrieved {len(models)} OpenRouter models from API") 

190 return models 

191 

192 except Exception as e: 

193 logger.warning(f"Failed to fetch OpenRouter models: {e}") 

194 raise 

195 

196 @staticmethod 

197 def get_openai_models_sync(api_key: str) -> set[str]: 

198 """Synchronous version of get_openai_models.""" 

199 return run_coro_blocking(ModelDiscovery.get_openai_models(api_key)) 

200 

201 @staticmethod 

202 def get_anthropic_models_sync(api_key: str) -> set[str]: 

203 """Synchronous version of get_anthropic_models.""" 

204 return run_coro_blocking(ModelDiscovery.get_anthropic_models(api_key)) 

205 

206 @staticmethod 

207 def get_openrouter_models_sync(api_key: str | None = None) -> set[str]: 

208 """Synchronous version of get_openrouter_models.""" 

209 return run_coro_blocking(ModelDiscovery.get_openrouter_models(api_key)) 

210 

211 

212async def detect_provider_from_model_async( 

213 model: str, 

214 api_keys: dict[str, str] | None = None 

215) -> ModelDiscoveryResult: 

216 """ 

217 Detect provider from model name using live API queries. 

218 

219 Args: 

220 model: The model name to check 

221 api_keys: Dictionary of API keys {"openai": "sk-...", "anthropic": "sk-ant-..."} 

222 

223 Returns: 

224 ModelDiscoveryResult with discovery information 

225 """ 

226 if not api_keys: 

227 api_keys = {} 

228 

229 result = ModelDiscoveryResult(errors={}) 

230 

231 # Handle provider-prefixed models (e.g., "openai/o3", "anthropic/claude-3-sonnet") 

232 preferred_provider = None 

233 base_model = model 

234 

235 if "/" in model: 

236 parts = model.split("/", 1) 

237 if len(parts) == 2: 

238 provider_prefix, base_model = parts 

239 if provider_prefix == "openai": 

240 preferred_provider = "openai" 

241 elif provider_prefix == "anthropic": 

242 preferred_provider = "anthropic" 

243 

244 # Create all tasks 

245 tasks = [] 

246 

247 if api_keys.get("openai"): 

248 tasks.append(("openai", ModelDiscovery.get_openai_models(api_keys["openai"]))) 

249 

250 if api_keys.get("anthropic"): 

251 tasks.append(("anthropic", ModelDiscovery.get_anthropic_models(api_keys["anthropic"]))) 

252 

253 if api_keys.get("openrouter"): 

254 tasks.append(("openrouter", ModelDiscovery.get_openrouter_models(api_keys["openrouter"]))) 

255 else: 

256 # OpenRouter doesn't require API key for model listing 

257 tasks.append(("openrouter", ModelDiscovery.get_openrouter_models())) 

258 

259 # Use asyncio.gather to run all tasks concurrently and properly handle them 

260 try: 

261 # Extract just the coroutines for gather 

262 coroutines = [task[1] for task in tasks] 

263 provider_names = [task[0] for task in tasks] 

264 

265 # Wait for all results 

266 results = await asyncio.gather(*coroutines, return_exceptions=True) 

267 

268 # Process results and store all model information 

269 for provider_name, models_result in zip(provider_names, results, strict=False): 

270 if isinstance(models_result, Exception): 

271 logger.debug(f"Failed to check {provider_name} for model {model}: {models_result}") 

272 if result.errors is not None: 

273 result.errors[provider_name] = str(models_result) 

274 continue 

275 

276 # Store models in result 

277 if provider_name == "openai" and isinstance(models_result, set): 

278 result.openai_models = models_result 

279 elif provider_name == "anthropic" and isinstance(models_result, set): 

280 result.anthropic_models = models_result 

281 elif provider_name == "openrouter" and isinstance(models_result, set): 

282 result.openrouter_models = models_result 

283 

284 # Determine the best provider to use 

285 if preferred_provider and not result.model_found: 

286 # Check if base model exists in preferred provider 

287 provider_models = None 

288 if preferred_provider == "openai" and result.openai_models: 

289 provider_models = result.openai_models 

290 elif preferred_provider == "anthropic" and result.anthropic_models: 

291 provider_models = result.anthropic_models 

292 

293 if provider_models and base_model in provider_models: 

294 result.found_provider = preferred_provider 

295 result.model_found = True 

296 elif result.openrouter_models and model in result.openrouter_models: 

297 # Fallback to OpenRouter if base model not found in preferred provider 

298 result.found_provider = "openrouter" 

299 result.model_found = True 

300 

301 # For models without provider prefix, use original logic 

302 if not result.model_found: 

303 for provider_name, models_result in zip(provider_names, results, strict=False): 

304 if isinstance(models_result, Exception): 

305 continue 

306 

307 # Check if our target model was found 

308 if isinstance(models_result, set) and model in models_result: 

309 result.found_provider = provider_name 

310 result.model_found = True 

311 break 

312 

313 except Exception as e: 

314 logger.debug(f"Failed to run dynamic discovery for model {model}: {e}") 

315 if result.errors is not None: 

316 result.errors["general"] = str(e) 

317 

318 return result 

319 

320 

321def detect_provider_from_model_sync( 

322 model: str, 

323 api_keys: dict[str, str] | None = None 

324) -> ModelDiscoveryResult: 

325 """Synchronous version of detect_provider_from_model_async.""" 

326 return run_coro_blocking(detect_provider_from_model_async(model, api_keys)) 

327 

328 

329def clear_model_cache() -> None: 

330 """Clear the model cache to force fresh API queries.""" 

331 global _model_cache 

332 _model_cache.clear() 

333 logger.info("Model cache cleared") 

334 

335 

336