Coverage for src/chat_limiter/adapters.py: 92%

180 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-09-18 21:15 +0100

1""" 

2Provider-specific adapters for converting between our unified types and provider APIs. 

3""" 

4 

5import time 

6import warnings 

7from abc import ABC, abstractmethod 

8from typing import Any 

9 

10from .providers import Provider 

11from .types import ( 

12 ChatCompletionRequest, 

13 ChatCompletionResponse, 

14 Choice, 

15 Message, 

16 MessageRole, 

17 Usage, 

18) 

19 

20 

21class ProviderAdapter(ABC): 

22 """Abstract base class for provider-specific adapters.""" 

23 

24 def is_reasoning_model(self, model_name: str) -> bool: 

25 """Check if the model is a reasoning model (o1, o3, o4 series).""" 

26 # Handle prefixed models (e.g., "openai/o3-mini") 

27 if "/" in model_name: 

28 # Extract the base model name after the "/" 

29 base_model = model_name.split("/", 1)[1] 

30 return base_model.startswith(("o1", "o3", "o4")) 

31 

32 # Handle non-prefixed models 

33 return model_name.startswith(("o1", "o3", "o4")) 

34 

35 @abstractmethod 

36 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]: 

37 """Convert our request format to provider-specific format.""" 

38 pass 

39 

40 @abstractmethod 

41 def parse_response( 

42 self, 

43 response_data: dict[str, Any], 

44 original_request: ChatCompletionRequest 

45 ) -> ChatCompletionResponse: 

46 """Convert provider response to our unified format.""" 

47 pass 

48 

49 @abstractmethod 

50 def get_endpoint(self) -> str: 

51 """Get the API endpoint for this provider.""" 

52 pass 

53 

54 

55class OpenAIAdapter(ProviderAdapter): 

56 """Adapter for OpenAI API.""" 

57 

58 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]: 

59 """Convert to OpenAI format.""" 

60 # Convert messages 

61 messages: list[dict[str, Any]] = [] 

62 for msg in request.messages: 

63 messages.append({ 

64 "role": msg.role.value, 

65 "content": msg.content 

66 }) 

67 

68 model = request.model.strip() 

69 if model.startswith("openai/"): 

70 # Remove the "openai/" prefix, since we are already using the OpenAI API 

71 model = model.split("openai/", 1)[1] 

72 

73 # Build request 

74 openai_request: dict[str, Any] = { 

75 "model": model, 

76 "messages": messages, 

77 } 

78 

79 # Add optional parameters 

80 if request.max_tokens is not None: 

81 # Use max_completion_tokens for reasoning models (o1, o3, o4) 

82 if self.is_reasoning_model(model): 

83 openai_request["max_completion_tokens"] = request.max_tokens 

84 else: 

85 openai_request["max_tokens"] = request.max_tokens 

86 

87 # Handle temperature for reasoning models 

88 if self.is_reasoning_model(model): 

89 # For reasoning models, default to temperature=1 

90 default_temperature = 1.0 

91 

92 if request.temperature is not None: 

93 # If user provided a different temperature, warn them and use temperature=1 

94 if request.temperature != default_temperature: 

95 warnings.warn( 

96 f"WARNING: Model '{model}' is a reasoning model that requires temperature=1. " 

97 f"Your specified temperature={request.temperature} will be overridden to temperature=1.", 

98 UserWarning 

99 ) 

100 print(f"WARNING: Model '{model}' is a reasoning model that requires temperature=1. " 

101 f"Your specified temperature={request.temperature} will be overridden to temperature=1.") 

102 

103 # Always use temperature=1 for reasoning models 

104 openai_request["temperature"] = default_temperature 

105 else: 

106 # For non-reasoning models, use the provided temperature 

107 if request.temperature is not None: 

108 openai_request["temperature"] = request.temperature 

109 

110 if request.top_p is not None: 

111 openai_request["top_p"] = request.top_p 

112 if request.stop is not None: 

113 openai_request["stop"] = request.stop 

114 if request.stream: 

115 openai_request["stream"] = request.stream 

116 if request.frequency_penalty is not None: 

117 openai_request["frequency_penalty"] = request.frequency_penalty 

118 if request.presence_penalty is not None: 

119 openai_request["presence_penalty"] = request.presence_penalty 

120 if request.seed is not None: 

121 openai_request["seed"] = request.seed 

122 

123 # Add reasoning parameter for thinking models 

124 if (request.reasoning_effort is not None and 

125 self.is_reasoning_model(model)): 

126 openai_request["reasoning"] = {"effort": request.reasoning_effort} 

127 

128 return openai_request 

129 

130 def parse_response( 

131 self, 

132 response_data: dict[str, Any], 

133 original_request: ChatCompletionRequest 

134 ) -> ChatCompletionResponse: 

135 """Parse OpenAI response.""" 

136 # Check for errors first 

137 success = True 

138 error_message = None 

139 

140 if "error" in response_data: 

141 success = False 

142 error_data = response_data["error"] 

143 error_message = error_data.get("message", "Unknown error") 

144 

145 choices = [] 

146 for choice_data in response_data.get("choices", []): 

147 message_data = choice_data.get("message", {}) 

148 message = Message( 

149 role=MessageRole(message_data.get("role", "assistant")), 

150 content=message_data.get("content", "") 

151 ) 

152 choice = Choice( 

153 index=choice_data.get("index", 0), 

154 message=message, 

155 finish_reason=choice_data.get("finish_reason") 

156 ) 

157 choices.append(choice) 

158 

159 # Parse usage 

160 usage = None 

161 if "usage" in response_data: 

162 usage_data = response_data["usage"] 

163 usage = Usage( 

164 prompt_tokens=usage_data.get("prompt_tokens", 0), 

165 completion_tokens=usage_data.get("completion_tokens", 0), 

166 total_tokens=usage_data.get("total_tokens", 0) 

167 ) 

168 

169 return ChatCompletionResponse( 

170 id=response_data.get("id", ""), 

171 model=response_data.get("model", original_request.model), 

172 choices=choices, 

173 usage=usage, 

174 created=response_data.get("created"), 

175 success=success, 

176 error_message=error_message, 

177 provider="openai", 

178 raw_response=response_data 

179 ) 

180 

181 def get_endpoint(self) -> str: 

182 return "/chat/completions" 

183 

184 

185class AnthropicAdapter(ProviderAdapter): 

186 """Adapter for Anthropic API.""" 

187 

188 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]: 

189 """Convert to Anthropic format.""" 

190 # Anthropic has a different message format 

191 messages: list[dict[str, Any]] = [] 

192 system_message: str | None = None 

193 

194 for msg in request.messages: 

195 if msg.role == MessageRole.SYSTEM: 

196 # Anthropic handles system messages separately 

197 system_message = msg.content 

198 else: 

199 messages.append({ 

200 "role": msg.role.value, 

201 "content": msg.content 

202 }) 

203 

204 model = request.model.strip() 

205 if model.startswith("anthropic/"): 

206 # Remove the "anthropic/" prefix, since we are already using the Anthropic API 

207 model = model.split("anthropic/", 1)[1] 

208 

209 # Build request 

210 anthropic_request: dict[str, Any] = { 

211 "model": model, 

212 "messages": messages, 

213 "max_tokens": request.max_tokens or 1024, # Required for Anthropic 

214 } 

215 

216 # Add system message if present 

217 if system_message: 

218 anthropic_request["system"] = system_message 

219 

220 # Add optional parameters 

221 if request.temperature is not None: 

222 anthropic_request["temperature"] = request.temperature 

223 if request.top_p is not None: 

224 anthropic_request["top_p"] = request.top_p 

225 if request.stop is not None: 

226 anthropic_request["stop_sequences"] = ( 

227 [request.stop] if isinstance(request.stop, str) else request.stop 

228 ) 

229 if request.stream: 

230 anthropic_request["stream"] = request.stream 

231 if request.top_k is not None: 

232 anthropic_request["top_k"] = request.top_k 

233 if request.seed is not None: 

234 anthropic_request["seed"] = request.seed 

235 

236 return anthropic_request 

237 

238 def parse_response( 

239 self, 

240 response_data: dict[str, Any], 

241 original_request: ChatCompletionRequest 

242 ) -> ChatCompletionResponse: 

243 """Parse Anthropic response.""" 

244 # Check for errors first 

245 success = True 

246 error_message = None 

247 

248 if "error" in response_data: 

249 success = False 

250 error_data = response_data["error"] 

251 error_message = error_data.get("message", "Unknown error") 

252 

253 # Anthropic returns content differently 

254 content_blocks = response_data.get("content", []) 

255 content = "" 

256 if content_blocks: 

257 # Extract text from content blocks 

258 for block in content_blocks: 

259 if block.get("type") == "text": 

260 content += block.get("text", "") 

261 

262 message = Message( 

263 role=MessageRole.ASSISTANT, 

264 content=content 

265 ) 

266 

267 choice = Choice( 

268 index=0, 

269 message=message, 

270 finish_reason=response_data.get("stop_reason") 

271 ) 

272 

273 # Parse usage 

274 usage = None 

275 if "usage" in response_data: 

276 usage_data = response_data["usage"] 

277 usage = Usage( 

278 prompt_tokens=usage_data.get("input_tokens", 0), 

279 completion_tokens=usage_data.get("output_tokens", 0), 

280 total_tokens=usage_data.get("input_tokens", 0) + usage_data.get("output_tokens", 0) 

281 ) 

282 

283 return ChatCompletionResponse( 

284 id=response_data.get("id", ""), 

285 model=response_data.get("model", original_request.model), 

286 choices=[choice], 

287 usage=usage, 

288 created=int(time.time()), # Anthropic doesn't provide created timestamp 

289 success=success, 

290 error_message=error_message, 

291 provider="anthropic", 

292 raw_response=response_data 

293 ) 

294 

295 def get_endpoint(self) -> str: 

296 return "/messages" 

297 

298 

299class OpenRouterAdapter(ProviderAdapter): 

300 """Adapter for OpenRouter API.""" 

301 

302 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]: 

303 """Convert to OpenRouter format (similar to OpenAI).""" 

304 # OpenRouter uses OpenAI-compatible format 

305 messages: list[dict[str, Any]] = [] 

306 for msg in request.messages: 

307 messages.append({ 

308 "role": msg.role.value, 

309 "content": msg.content 

310 }) 

311 

312 model = request.model.strip() 

313 

314 # Build request 

315 openrouter_request: dict[str, Any] = { 

316 "model": model, 

317 "messages": messages, 

318 } 

319 

320 # Add optional parameters 

321 if request.max_tokens is not None: 

322 openrouter_request["max_tokens"] = request.max_tokens 

323 if request.temperature is not None: 

324 openrouter_request["temperature"] = request.temperature 

325 if request.top_p is not None: 

326 openrouter_request["top_p"] = request.top_p 

327 if request.stop is not None: 

328 openrouter_request["stop"] = request.stop 

329 if request.stream: 

330 openrouter_request["stream"] = request.stream 

331 if request.frequency_penalty is not None: 

332 openrouter_request["frequency_penalty"] = request.frequency_penalty 

333 if request.presence_penalty is not None: 

334 openrouter_request["presence_penalty"] = request.presence_penalty 

335 if request.top_k is not None: 

336 openrouter_request["top_k"] = request.top_k 

337 if request.seed is not None: 

338 openrouter_request["seed"] = request.seed 

339 

340 # Add reasoning parameter for thinking models 

341 if (request.reasoning_effort is not None and 

342 self.is_reasoning_model(model)): 

343 openrouter_request["reasoning"] = {"effort": request.reasoning_effort} 

344 

345 # Add provider routing if specified 

346 if request.providers is not None: 

347 openrouter_request["provider"] = { 

348 "order": request.providers, 

349 "allow_fallbacks": False 

350 } 

351 

352 return openrouter_request 

353 

354 def parse_response( 

355 self, 

356 response_data: dict[str, Any], 

357 original_request: ChatCompletionRequest 

358 ) -> ChatCompletionResponse: 

359 """Parse OpenRouter response (similar to OpenAI).""" 

360 # Check for errors first 

361 success = True 

362 error_message = None 

363 

364 if "error" in response_data: 

365 success = False 

366 error_data = response_data["error"] 

367 error_message = error_data.get("message", "Unknown error") 

368 

369 choices = [] 

370 for choice_data in response_data.get("choices", []): 

371 message_data = choice_data.get("message", {}) 

372 message = Message( 

373 role=MessageRole(message_data.get("role", "assistant")), 

374 content=message_data.get("content", "") 

375 ) 

376 choice = Choice( 

377 index=choice_data.get("index", 0), 

378 message=message, 

379 finish_reason=choice_data.get("finish_reason") 

380 ) 

381 choices.append(choice) 

382 

383 # Parse usage 

384 usage = None 

385 if "usage" in response_data: 

386 usage_data = response_data["usage"] 

387 usage = Usage( 

388 prompt_tokens=usage_data.get("prompt_tokens", 0), 

389 completion_tokens=usage_data.get("completion_tokens", 0), 

390 total_tokens=usage_data.get("total_tokens", 0) 

391 ) 

392 

393 return ChatCompletionResponse( 

394 id=response_data.get("id", ""), 

395 model=response_data.get("model", original_request.model), 

396 choices=choices, 

397 usage=usage, 

398 created=response_data.get("created"), 

399 success=success, 

400 error_message=error_message, 

401 provider="openrouter", 

402 raw_response=response_data 

403 ) 

404 

405 def get_endpoint(self) -> str: 

406 return "/chat/completions" 

407 

408 

409# Provider adapter registry 

410PROVIDER_ADAPTERS = { 

411 Provider.OPENAI: OpenAIAdapter(), 

412 Provider.ANTHROPIC: AnthropicAdapter(), 

413 Provider.OPENROUTER: OpenRouterAdapter(), 

414} 

415 

416 

417def get_adapter(provider: Provider) -> ProviderAdapter: 

418 """Get the appropriate adapter for a provider.""" 

419 return PROVIDER_ADAPTERS[provider]