Coverage for src/chat_limiter/adapters.py: 92%
180 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-18 21:15 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-18 21:15 +0100
1"""
2Provider-specific adapters for converting between our unified types and provider APIs.
3"""
5import time
6import warnings
7from abc import ABC, abstractmethod
8from typing import Any
10from .providers import Provider
11from .types import (
12 ChatCompletionRequest,
13 ChatCompletionResponse,
14 Choice,
15 Message,
16 MessageRole,
17 Usage,
18)
21class ProviderAdapter(ABC):
22 """Abstract base class for provider-specific adapters."""
24 def is_reasoning_model(self, model_name: str) -> bool:
25 """Check if the model is a reasoning model (o1, o3, o4 series)."""
26 # Handle prefixed models (e.g., "openai/o3-mini")
27 if "/" in model_name:
28 # Extract the base model name after the "/"
29 base_model = model_name.split("/", 1)[1]
30 return base_model.startswith(("o1", "o3", "o4"))
32 # Handle non-prefixed models
33 return model_name.startswith(("o1", "o3", "o4"))
35 @abstractmethod
36 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]:
37 """Convert our request format to provider-specific format."""
38 pass
40 @abstractmethod
41 def parse_response(
42 self,
43 response_data: dict[str, Any],
44 original_request: ChatCompletionRequest
45 ) -> ChatCompletionResponse:
46 """Convert provider response to our unified format."""
47 pass
49 @abstractmethod
50 def get_endpoint(self) -> str:
51 """Get the API endpoint for this provider."""
52 pass
55class OpenAIAdapter(ProviderAdapter):
56 """Adapter for OpenAI API."""
58 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]:
59 """Convert to OpenAI format."""
60 # Convert messages
61 messages: list[dict[str, Any]] = []
62 for msg in request.messages:
63 messages.append({
64 "role": msg.role.value,
65 "content": msg.content
66 })
68 model = request.model.strip()
69 if model.startswith("openai/"):
70 # Remove the "openai/" prefix, since we are already using the OpenAI API
71 model = model.split("openai/", 1)[1]
73 # Build request
74 openai_request: dict[str, Any] = {
75 "model": model,
76 "messages": messages,
77 }
79 # Add optional parameters
80 if request.max_tokens is not None:
81 # Use max_completion_tokens for reasoning models (o1, o3, o4)
82 if self.is_reasoning_model(model):
83 openai_request["max_completion_tokens"] = request.max_tokens
84 else:
85 openai_request["max_tokens"] = request.max_tokens
87 # Handle temperature for reasoning models
88 if self.is_reasoning_model(model):
89 # For reasoning models, default to temperature=1
90 default_temperature = 1.0
92 if request.temperature is not None:
93 # If user provided a different temperature, warn them and use temperature=1
94 if request.temperature != default_temperature:
95 warnings.warn(
96 f"WARNING: Model '{model}' is a reasoning model that requires temperature=1. "
97 f"Your specified temperature={request.temperature} will be overridden to temperature=1.",
98 UserWarning
99 )
100 print(f"WARNING: Model '{model}' is a reasoning model that requires temperature=1. "
101 f"Your specified temperature={request.temperature} will be overridden to temperature=1.")
103 # Always use temperature=1 for reasoning models
104 openai_request["temperature"] = default_temperature
105 else:
106 # For non-reasoning models, use the provided temperature
107 if request.temperature is not None:
108 openai_request["temperature"] = request.temperature
110 if request.top_p is not None:
111 openai_request["top_p"] = request.top_p
112 if request.stop is not None:
113 openai_request["stop"] = request.stop
114 if request.stream:
115 openai_request["stream"] = request.stream
116 if request.frequency_penalty is not None:
117 openai_request["frequency_penalty"] = request.frequency_penalty
118 if request.presence_penalty is not None:
119 openai_request["presence_penalty"] = request.presence_penalty
120 if request.seed is not None:
121 openai_request["seed"] = request.seed
123 # Add reasoning parameter for thinking models
124 if (request.reasoning_effort is not None and
125 self.is_reasoning_model(model)):
126 openai_request["reasoning"] = {"effort": request.reasoning_effort}
128 return openai_request
130 def parse_response(
131 self,
132 response_data: dict[str, Any],
133 original_request: ChatCompletionRequest
134 ) -> ChatCompletionResponse:
135 """Parse OpenAI response."""
136 # Check for errors first
137 success = True
138 error_message = None
140 if "error" in response_data:
141 success = False
142 error_data = response_data["error"]
143 error_message = error_data.get("message", "Unknown error")
145 choices = []
146 for choice_data in response_data.get("choices", []):
147 message_data = choice_data.get("message", {})
148 message = Message(
149 role=MessageRole(message_data.get("role", "assistant")),
150 content=message_data.get("content", "")
151 )
152 choice = Choice(
153 index=choice_data.get("index", 0),
154 message=message,
155 finish_reason=choice_data.get("finish_reason")
156 )
157 choices.append(choice)
159 # Parse usage
160 usage = None
161 if "usage" in response_data:
162 usage_data = response_data["usage"]
163 usage = Usage(
164 prompt_tokens=usage_data.get("prompt_tokens", 0),
165 completion_tokens=usage_data.get("completion_tokens", 0),
166 total_tokens=usage_data.get("total_tokens", 0)
167 )
169 return ChatCompletionResponse(
170 id=response_data.get("id", ""),
171 model=response_data.get("model", original_request.model),
172 choices=choices,
173 usage=usage,
174 created=response_data.get("created"),
175 success=success,
176 error_message=error_message,
177 provider="openai",
178 raw_response=response_data
179 )
181 def get_endpoint(self) -> str:
182 return "/chat/completions"
185class AnthropicAdapter(ProviderAdapter):
186 """Adapter for Anthropic API."""
188 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]:
189 """Convert to Anthropic format."""
190 # Anthropic has a different message format
191 messages: list[dict[str, Any]] = []
192 system_message: str | None = None
194 for msg in request.messages:
195 if msg.role == MessageRole.SYSTEM:
196 # Anthropic handles system messages separately
197 system_message = msg.content
198 else:
199 messages.append({
200 "role": msg.role.value,
201 "content": msg.content
202 })
204 model = request.model.strip()
205 if model.startswith("anthropic/"):
206 # Remove the "anthropic/" prefix, since we are already using the Anthropic API
207 model = model.split("anthropic/", 1)[1]
209 # Build request
210 anthropic_request: dict[str, Any] = {
211 "model": model,
212 "messages": messages,
213 "max_tokens": request.max_tokens or 1024, # Required for Anthropic
214 }
216 # Add system message if present
217 if system_message:
218 anthropic_request["system"] = system_message
220 # Add optional parameters
221 if request.temperature is not None:
222 anthropic_request["temperature"] = request.temperature
223 if request.top_p is not None:
224 anthropic_request["top_p"] = request.top_p
225 if request.stop is not None:
226 anthropic_request["stop_sequences"] = (
227 [request.stop] if isinstance(request.stop, str) else request.stop
228 )
229 if request.stream:
230 anthropic_request["stream"] = request.stream
231 if request.top_k is not None:
232 anthropic_request["top_k"] = request.top_k
233 if request.seed is not None:
234 anthropic_request["seed"] = request.seed
236 return anthropic_request
238 def parse_response(
239 self,
240 response_data: dict[str, Any],
241 original_request: ChatCompletionRequest
242 ) -> ChatCompletionResponse:
243 """Parse Anthropic response."""
244 # Check for errors first
245 success = True
246 error_message = None
248 if "error" in response_data:
249 success = False
250 error_data = response_data["error"]
251 error_message = error_data.get("message", "Unknown error")
253 # Anthropic returns content differently
254 content_blocks = response_data.get("content", [])
255 content = ""
256 if content_blocks:
257 # Extract text from content blocks
258 for block in content_blocks:
259 if block.get("type") == "text":
260 content += block.get("text", "")
262 message = Message(
263 role=MessageRole.ASSISTANT,
264 content=content
265 )
267 choice = Choice(
268 index=0,
269 message=message,
270 finish_reason=response_data.get("stop_reason")
271 )
273 # Parse usage
274 usage = None
275 if "usage" in response_data:
276 usage_data = response_data["usage"]
277 usage = Usage(
278 prompt_tokens=usage_data.get("input_tokens", 0),
279 completion_tokens=usage_data.get("output_tokens", 0),
280 total_tokens=usage_data.get("input_tokens", 0) + usage_data.get("output_tokens", 0)
281 )
283 return ChatCompletionResponse(
284 id=response_data.get("id", ""),
285 model=response_data.get("model", original_request.model),
286 choices=[choice],
287 usage=usage,
288 created=int(time.time()), # Anthropic doesn't provide created timestamp
289 success=success,
290 error_message=error_message,
291 provider="anthropic",
292 raw_response=response_data
293 )
295 def get_endpoint(self) -> str:
296 return "/messages"
299class OpenRouterAdapter(ProviderAdapter):
300 """Adapter for OpenRouter API."""
302 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]:
303 """Convert to OpenRouter format (similar to OpenAI)."""
304 # OpenRouter uses OpenAI-compatible format
305 messages: list[dict[str, Any]] = []
306 for msg in request.messages:
307 messages.append({
308 "role": msg.role.value,
309 "content": msg.content
310 })
312 model = request.model.strip()
314 # Build request
315 openrouter_request: dict[str, Any] = {
316 "model": model,
317 "messages": messages,
318 }
320 # Add optional parameters
321 if request.max_tokens is not None:
322 openrouter_request["max_tokens"] = request.max_tokens
323 if request.temperature is not None:
324 openrouter_request["temperature"] = request.temperature
325 if request.top_p is not None:
326 openrouter_request["top_p"] = request.top_p
327 if request.stop is not None:
328 openrouter_request["stop"] = request.stop
329 if request.stream:
330 openrouter_request["stream"] = request.stream
331 if request.frequency_penalty is not None:
332 openrouter_request["frequency_penalty"] = request.frequency_penalty
333 if request.presence_penalty is not None:
334 openrouter_request["presence_penalty"] = request.presence_penalty
335 if request.top_k is not None:
336 openrouter_request["top_k"] = request.top_k
337 if request.seed is not None:
338 openrouter_request["seed"] = request.seed
340 # Add reasoning parameter for thinking models
341 if (request.reasoning_effort is not None and
342 self.is_reasoning_model(model)):
343 openrouter_request["reasoning"] = {"effort": request.reasoning_effort}
345 # Add provider routing if specified
346 if request.providers is not None:
347 openrouter_request["provider"] = {
348 "order": request.providers,
349 "allow_fallbacks": False
350 }
352 return openrouter_request
354 def parse_response(
355 self,
356 response_data: dict[str, Any],
357 original_request: ChatCompletionRequest
358 ) -> ChatCompletionResponse:
359 """Parse OpenRouter response (similar to OpenAI)."""
360 # Check for errors first
361 success = True
362 error_message = None
364 if "error" in response_data:
365 success = False
366 error_data = response_data["error"]
367 error_message = error_data.get("message", "Unknown error")
369 choices = []
370 for choice_data in response_data.get("choices", []):
371 message_data = choice_data.get("message", {})
372 message = Message(
373 role=MessageRole(message_data.get("role", "assistant")),
374 content=message_data.get("content", "")
375 )
376 choice = Choice(
377 index=choice_data.get("index", 0),
378 message=message,
379 finish_reason=choice_data.get("finish_reason")
380 )
381 choices.append(choice)
383 # Parse usage
384 usage = None
385 if "usage" in response_data:
386 usage_data = response_data["usage"]
387 usage = Usage(
388 prompt_tokens=usage_data.get("prompt_tokens", 0),
389 completion_tokens=usage_data.get("completion_tokens", 0),
390 total_tokens=usage_data.get("total_tokens", 0)
391 )
393 return ChatCompletionResponse(
394 id=response_data.get("id", ""),
395 model=response_data.get("model", original_request.model),
396 choices=choices,
397 usage=usage,
398 created=response_data.get("created"),
399 success=success,
400 error_message=error_message,
401 provider="openrouter",
402 raw_response=response_data
403 )
405 def get_endpoint(self) -> str:
406 return "/chat/completions"
409# Provider adapter registry
410PROVIDER_ADAPTERS = {
411 Provider.OPENAI: OpenAIAdapter(),
412 Provider.ANTHROPIC: AnthropicAdapter(),
413 Provider.OPENROUTER: OpenRouterAdapter(),
414}
417def get_adapter(provider: Provider) -> ProviderAdapter:
418 """Get the appropriate adapter for a provider."""
419 return PROVIDER_ADAPTERS[provider]