Coverage for src/dataknobs_fsm/llm/providers.py: 0%
480 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
1"""LLM provider implementations.
3This module provides implementations for various LLM providers.
4"""
6import os
7import json
8from typing import Any, Dict, List, Union, AsyncIterator
10from .base import (
11 LLMConfig, LLMMessage, LLMResponse, LLMStreamResponse,
12 AsyncLLMProvider, SyncLLMProvider, ModelCapability,
13 LLMAdapter
14)
17class SyncProviderAdapter:
18 """Sync adapter for async LLM providers."""
20 def __init__(self, async_provider: AsyncLLMProvider):
21 """Initialize with async provider.
23 Args:
24 async_provider: The async provider to wrap.
25 """
26 self.async_provider = async_provider
28 def initialize(self) -> None:
29 """Initialize the provider synchronously."""
30 import asyncio
31 try:
32 loop = asyncio.get_event_loop()
33 except RuntimeError:
34 loop = asyncio.new_event_loop()
35 asyncio.set_event_loop(loop)
37 return loop.run_until_complete(self.async_provider.initialize())
39 def close(self) -> None:
40 """Close the provider synchronously."""
41 import asyncio
42 try:
43 loop = asyncio.get_event_loop()
44 except RuntimeError:
45 loop = asyncio.new_event_loop()
46 asyncio.set_event_loop(loop)
48 return loop.run_until_complete(self.async_provider.close())
50 def complete(
51 self,
52 messages: Union[str, List[LLMMessage]],
53 **kwargs
54 ) -> LLMResponse:
55 """Generate completion synchronously."""
56 import asyncio
57 try:
58 loop = asyncio.get_event_loop()
59 except RuntimeError:
60 loop = asyncio.new_event_loop()
61 asyncio.set_event_loop(loop)
63 return loop.run_until_complete(self.async_provider.complete(messages, **kwargs))
65 def stream(
66 self,
67 messages: Union[str, List[LLMMessage]],
68 **kwargs
69 ):
70 """Stream completion synchronously."""
71 import asyncio
72 try:
73 loop = asyncio.get_event_loop()
74 except RuntimeError:
75 loop = asyncio.new_event_loop()
76 asyncio.set_event_loop(loop)
78 async def _stream():
79 async for chunk in self.async_provider.stream(messages, **kwargs):
80 yield chunk
82 # Convert async generator to sync generator
83 async_gen = _stream()
84 try:
85 while True:
86 try:
87 yield loop.run_until_complete(async_gen.__anext__())
88 except StopAsyncIteration:
89 break
90 finally:
91 loop.run_until_complete(async_gen.aclose())
93 def embed(
94 self,
95 texts: Union[str, List[str]],
96 **kwargs
97 ) -> Union[List[float], List[List[float]]]:
98 """Generate embeddings synchronously."""
99 import asyncio
100 try:
101 loop = asyncio.get_event_loop()
102 except RuntimeError:
103 loop = asyncio.new_event_loop()
104 asyncio.set_event_loop(loop)
106 return loop.run_until_complete(self.async_provider.embed(texts, **kwargs))
108 def function_call(
109 self,
110 messages: List[LLMMessage],
111 functions: List[Dict[str, Any]],
112 **kwargs
113 ) -> LLMResponse:
114 """Make function call synchronously."""
115 import asyncio
116 try:
117 loop = asyncio.get_event_loop()
118 except RuntimeError:
119 loop = asyncio.new_event_loop()
120 asyncio.set_event_loop(loop)
122 return loop.run_until_complete(self.async_provider.function_call(messages, functions, **kwargs))
124 def validate_model(self) -> bool:
125 """Validate model synchronously."""
126 import asyncio
127 try:
128 loop = asyncio.get_event_loop()
129 except RuntimeError:
130 loop = asyncio.new_event_loop()
131 asyncio.set_event_loop(loop)
133 return loop.run_until_complete(self.async_provider.validate_model()) # type: ignore
135 def get_capabilities(self) -> List[ModelCapability]:
136 """Get capabilities synchronously."""
137 return self.async_provider.get_capabilities()
139 @property
140 def is_initialized(self) -> bool:
141 """Check if provider is initialized."""
142 return self.async_provider.is_initialized
145class OpenAIAdapter(LLMAdapter):
146 """Adapter for OpenAI API format."""
148 def adapt_messages(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]:
149 """Convert messages to OpenAI format."""
150 adapted = []
151 for msg in messages:
152 message = {
153 'role': msg.role,
154 'content': msg.content
155 }
156 if msg.name:
157 message['name'] = msg.name
158 if msg.function_call:
159 message['function_call'] = msg.function_call
160 adapted.append(message)
161 return adapted
163 def adapt_response(self, response: Any) -> LLMResponse:
164 """Convert OpenAI response to standard format."""
165 choice = response.choices[0]
166 message = choice.message
168 return LLMResponse(
169 content=message.content or '',
170 model=response.model,
171 finish_reason=choice.finish_reason,
172 usage={
173 'prompt_tokens': response.usage.prompt_tokens,
174 'completion_tokens': response.usage.completion_tokens,
175 'total_tokens': response.usage.total_tokens
176 } if response.usage else None,
177 function_call=message.function_call if hasattr(message, 'function_call') else None
178 )
180 def adapt_config(self, config: LLMConfig) -> Dict[str, Any]:
181 """Convert config to OpenAI parameters."""
182 params = {
183 'model': config.model,
184 'temperature': config.temperature,
185 'top_p': config.top_p,
186 'frequency_penalty': config.frequency_penalty,
187 'presence_penalty': config.presence_penalty,
188 }
190 if config.max_tokens:
191 params['max_tokens'] = config.max_tokens
192 if config.stop_sequences:
193 params['stop'] = config.stop_sequences
194 if config.seed:
195 params['seed'] = config.seed
196 if config.logit_bias:
197 params['logit_bias'] = config.logit_bias
198 if config.user_id:
199 params['user'] = config.user_id
200 if config.response_format == 'json':
201 params['response_format'] = {'type': 'json_object'}
202 if config.functions:
203 params['functions'] = config.functions
204 if config.function_call:
205 params['function_call'] = config.function_call
207 return params
210class OpenAIProvider(AsyncLLMProvider):
211 """OpenAI LLM provider."""
213 def __init__(self, config: LLMConfig):
214 super().__init__(config)
215 self.adapter = OpenAIAdapter()
217 async def initialize(self) -> None:
218 """Initialize OpenAI client."""
219 try:
220 import openai
222 api_key = self.config.api_key or os.environ.get('OPENAI_API_KEY')
223 if not api_key:
224 raise ValueError("OpenAI API key not provided")
226 self._client = openai.AsyncOpenAI(
227 api_key=api_key,
228 base_url=self.config.api_base,
229 timeout=self.config.timeout
230 )
231 self._is_initialized = True
232 except ImportError as e:
233 raise ImportError("openai package not installed. Install with: pip install openai") from e
235 async def close(self) -> None:
236 """Close OpenAI client."""
237 if self._client:
238 await self._client.close() # type: ignore[unreachable]
239 self._is_initialized = False
241 async def validate_model(self) -> bool:
242 """Validate model availability."""
243 try:
244 # List available models
245 models = await self._client.models.list()
246 model_ids = [m.id for m in models.data]
247 return self.config.model in model_ids
248 except Exception:
249 return False
251 def get_capabilities(self) -> List[ModelCapability]:
252 """Get OpenAI model capabilities."""
253 capabilities = [
254 ModelCapability.TEXT_GENERATION,
255 ModelCapability.CHAT,
256 ModelCapability.STREAMING
257 ]
259 if 'gpt-4' in self.config.model or 'gpt-3.5' in self.config.model:
260 capabilities.extend([
261 ModelCapability.FUNCTION_CALLING,
262 ModelCapability.JSON_MODE
263 ])
265 if 'vision' in self.config.model:
266 capabilities.append(ModelCapability.VISION)
268 if 'embedding' in self.config.model:
269 capabilities.append(ModelCapability.EMBEDDINGS)
271 return capabilities
273 async def complete(
274 self,
275 messages: Union[str, List[LLMMessage]],
276 **kwargs
277 ) -> LLMResponse:
278 """Generate completion."""
279 if not self._is_initialized:
280 await self.initialize()
282 # Convert string to message list
283 if isinstance(messages, str):
284 messages = [LLMMessage(role='user', content=messages)]
286 # Add system prompt if configured
287 if self.config.system_prompt and messages[0].role != 'system':
288 messages.insert(0, LLMMessage(role='system', content=self.config.system_prompt))
290 # Adapt messages and config
291 adapted_messages = self.adapter.adapt_messages(messages)
292 params = self.adapter.adapt_config(self.config)
293 params.update(kwargs)
295 # Make API call
296 response = await self._client.chat.completions.create(
297 messages=adapted_messages,
298 **params
299 )
301 return self.adapter.adapt_response(response)
303 async def stream_complete(
304 self,
305 messages: Union[str, List[LLMMessage]],
306 **kwargs
307 ) -> AsyncIterator[LLMStreamResponse]:
308 """Generate streaming completion."""
309 if not self._is_initialized:
310 await self.initialize()
312 # Convert string to message list
313 if isinstance(messages, str):
314 messages = [LLMMessage(role='user', content=messages)]
316 # Add system prompt if configured
317 if self.config.system_prompt and messages[0].role != 'system':
318 messages.insert(0, LLMMessage(role='system', content=self.config.system_prompt))
320 # Adapt messages and config
321 adapted_messages = self.adapter.adapt_messages(messages)
322 params = self.adapter.adapt_config(self.config)
323 params['stream'] = True
324 params.update(kwargs)
326 # Stream API call
327 stream = await self._client.chat.completions.create(
328 messages=adapted_messages,
329 **params
330 )
332 async for chunk in stream:
333 if chunk.choices[0].delta.content:
334 yield LLMStreamResponse(
335 delta=chunk.choices[0].delta.content,
336 is_final=chunk.choices[0].finish_reason is not None,
337 finish_reason=chunk.choices[0].finish_reason
338 )
340 async def embed(
341 self,
342 texts: Union[str, List[str]],
343 **kwargs
344 ) -> Union[List[float], List[List[float]]]:
345 """Generate embeddings."""
346 if not self._is_initialized:
347 await self.initialize()
349 if isinstance(texts, str):
350 texts = [texts]
351 single = True
352 else:
353 single = False
355 response = await self._client.embeddings.create(
356 input=texts,
357 model=self.config.model or 'text-embedding-ada-002'
358 )
360 embeddings = [e.embedding for e in response.data]
361 return embeddings[0] if single else embeddings
363 async def function_call(
364 self,
365 messages: List[LLMMessage],
366 functions: List[Dict[str, Any]],
367 **kwargs
368 ) -> LLMResponse:
369 """Execute function calling."""
370 if not self._is_initialized:
371 await self.initialize()
373 # Add system prompt if configured
374 if self.config.system_prompt and messages[0].role != 'system':
375 messages.insert(0, LLMMessage(role='system', content=self.config.system_prompt))
377 # Adapt messages and config
378 adapted_messages = self.adapter.adapt_messages(messages)
379 params = self.adapter.adapt_config(self.config)
380 params['functions'] = functions
381 params['function_call'] = kwargs.get('function_call', 'auto')
382 params.update(kwargs)
384 # Make API call
385 response = await self._client.chat.completions.create(
386 messages=adapted_messages,
387 **params
388 )
390 return self.adapter.adapt_response(response)
393class AnthropicProvider(AsyncLLMProvider):
394 """Anthropic Claude LLM provider."""
396 def __init__(self, config: LLMConfig):
397 super().__init__(config)
399 async def initialize(self) -> None:
400 """Initialize Anthropic client."""
401 try:
402 import anthropic
404 api_key = self.config.api_key or os.environ.get('ANTHROPIC_API_KEY')
405 if not api_key:
406 raise ValueError("Anthropic API key not provided")
408 self._client = anthropic.AsyncAnthropic(
409 api_key=api_key,
410 base_url=self.config.api_base,
411 timeout=self.config.timeout
412 )
413 self._is_initialized = True
414 except ImportError as e:
415 raise ImportError("anthropic package not installed. Install with: pip install anthropic") from e
417 async def close(self) -> None:
418 """Close Anthropic client."""
419 if self._client:
420 await self._client.close() # type: ignore[unreachable]
421 self._is_initialized = False
423 async def validate_model(self) -> bool:
424 """Validate model availability."""
425 valid_models = [
426 'claude-3-opus', 'claude-3-sonnet', 'claude-3-haiku',
427 'claude-2.1', 'claude-2.0', 'claude-instant-1.2'
428 ]
429 return any(m in self.config.model for m in valid_models)
431 def get_capabilities(self) -> List[ModelCapability]:
432 """Get Anthropic model capabilities."""
433 return [
434 ModelCapability.TEXT_GENERATION,
435 ModelCapability.CHAT,
436 ModelCapability.STREAMING,
437 ModelCapability.CODE,
438 ModelCapability.VISION if 'claude-3' in self.config.model else None # type: ignore
439 ]
441 async def complete(
442 self,
443 messages: Union[str, List[LLMMessage]],
444 **kwargs
445 ) -> LLMResponse:
446 """Generate completion."""
447 if not self._is_initialized:
448 await self.initialize()
450 # Convert to Anthropic format
451 if isinstance(messages, str):
452 prompt = messages
453 else:
454 # Build prompt from messages
455 prompt = ""
456 for msg in messages:
457 if msg.role == 'system':
458 prompt = msg.content + "\n\n" + prompt
459 elif msg.role == 'user':
460 prompt += f"\n\nHuman: {msg.content}"
461 elif msg.role == 'assistant':
462 prompt += f"\n\nAssistant: {msg.content}"
463 prompt += "\n\nAssistant:"
465 # Make API call
466 response = await self._client.messages.create(
467 model=self.config.model,
468 messages=[{"role": "user", "content": prompt}],
469 max_tokens=self.config.max_tokens or 1024,
470 temperature=self.config.temperature,
471 top_p=self.config.top_p,
472 stop_sequences=self.config.stop_sequences
473 )
475 return LLMResponse(
476 content=response.content[0].text,
477 model=response.model,
478 finish_reason=response.stop_reason,
479 usage={
480 'prompt_tokens': response.usage.input_tokens,
481 'completion_tokens': response.usage.output_tokens,
482 'total_tokens': response.usage.input_tokens + response.usage.output_tokens
483 } if hasattr(response, 'usage') else None
484 )
486 async def stream_complete(
487 self,
488 messages: Union[str, List[LLMMessage]],
489 **kwargs
490 ) -> AsyncIterator[LLMStreamResponse]:
491 """Generate streaming completion."""
492 if not self._is_initialized:
493 await self.initialize()
495 # Convert to Anthropic format
496 if isinstance(messages, str):
497 prompt = messages
498 else:
499 prompt = self._build_prompt(messages)
501 # Stream API call
502 async with self._client.messages.stream(
503 model=self.config.model,
504 messages=[{"role": "user", "content": prompt}],
505 max_tokens=self.config.max_tokens or 1024,
506 temperature=self.config.temperature
507 ) as stream:
508 async for chunk in stream:
509 if chunk.type == 'content_block_delta':
510 yield LLMStreamResponse(
511 delta=chunk.delta.text,
512 is_final=False
513 )
515 # Final message
516 message = await stream.get_final_message()
517 yield LLMStreamResponse(
518 delta='',
519 is_final=True,
520 finish_reason=message.stop_reason
521 )
523 async def embed(
524 self,
525 texts: Union[str, List[str]],
526 **kwargs
527 ) -> Union[List[float], List[List[float]]]:
528 """Anthropic doesn't provide embeddings."""
529 raise NotImplementedError("Anthropic doesn't provide embedding models")
531 async def function_call(
532 self,
533 messages: List[LLMMessage],
534 functions: List[Dict[str, Any]],
535 **kwargs
536 ) -> LLMResponse:
537 """Anthropic doesn't have native function calling."""
538 # Implement function calling through prompting
539 function_descriptions = "\n".join([
540 f"- {f['name']}: {f['description']}"
541 for f in functions
542 ])
544 system_prompt = f"""You have access to the following functions:
545{function_descriptions}
547When you need to call a function, respond with:
548FUNCTION_CALL: {{
549 "name": "function_name",
550 "arguments": {{...}}
551}}"""
553 messages_with_system = [
554 LLMMessage(role='system', content=system_prompt)
555 ] + messages
557 response = await self.complete(messages_with_system, **kwargs)
559 # Parse function call from response
560 if 'FUNCTION_CALL:' in response.content:
561 try:
562 func_json = response.content.split('FUNCTION_CALL:')[1].strip()
563 function_call = json.loads(func_json)
564 response.function_call = function_call
565 except (json.JSONDecodeError, IndexError):
566 pass
568 return response
570 def _build_prompt(self, messages: List[LLMMessage]) -> str:
571 """Build Anthropic-style prompt from messages."""
572 prompt = ""
573 for msg in messages:
574 if msg.role == 'system':
575 prompt = msg.content + "\n\n" + prompt
576 elif msg.role == 'user':
577 prompt += f"\n\nHuman: {msg.content}"
578 elif msg.role == 'assistant':
579 prompt += f"\n\nAssistant: {msg.content}"
580 prompt += "\n\nAssistant:"
581 return prompt
584class OllamaProvider(AsyncLLMProvider):
585 """Ollama local LLM provider."""
587 def __init__(self, config: LLMConfig):
588 super().__init__(config)
589 # Check for Docker environment and adjust URL accordingly
590 default_url = 'http://localhost:11434'
591 if os.path.exists('/.dockerenv'):
592 # Running in Docker, use host.docker.internal
593 default_url = 'http://host.docker.internal:11434'
595 # Allow environment variable override
596 self.base_url = config.api_base or os.environ.get('OLLAMA_BASE_URL', default_url)
598 def _build_options(self) -> Dict[str, Any]:
599 """Build options dict for Ollama API calls.
601 Returns:
602 Dictionary of options for the API request.
603 """
604 options: Dict[str, Any] = {
605 'temperature': self.config.temperature,
606 'top_p': self.config.top_p
607 }
609 if self.config.seed is not None:
610 options['seed'] = self.config.seed
612 if self.config.max_tokens:
613 options['num_predict'] = self.config.max_tokens # type: ignore
615 if self.config.stop_sequences:
616 options['stop'] = self.config.stop_sequences # type: ignore
618 return options
620 async def initialize(self) -> None:
621 """Initialize Ollama client."""
622 try:
623 import aiohttp
624 self._session = aiohttp.ClientSession(
625 timeout=aiohttp.ClientTimeout(total=self.config.timeout or 30.0)
626 )
628 # Test connection and verify model availability
629 try:
630 async with self._session.get(f"{self.base_url}/api/tags") as response:
631 if response.status == 200:
632 data = await response.json()
633 models = [m['name'] for m in data.get('models', [])]
634 if models:
635 # Check if configured model is available
636 if self.config.model not in models:
637 # Try without tag (e.g., 'llama2' instead of 'llama2:latest')
638 base_model = self.config.model.split(':')[0]
639 matching_models = [m for m in models if m.startswith(base_model)]
640 if matching_models:
641 # Use first matching model
642 self.config.model = matching_models[0]
643 import logging
644 logging.info(f"Ollama: Using model {self.config.model}")
645 else:
646 import logging
647 logging.warning(f"Ollama: Model {self.config.model} not found. Available: {models}")
648 else:
649 import logging
650 logging.warning("Ollama: No models found. Please pull a model first.")
651 else:
652 import logging
653 logging.warning(f"Ollama: API returned status {response.status}")
654 except aiohttp.ClientError as e:
655 import logging
656 logging.warning(f"Ollama: Could not connect to {self.base_url}: {e}")
658 self._is_initialized = True
659 except ImportError as e:
660 raise ImportError("aiohttp package not installed. Install with: pip install aiohttp") from e
662 async def close(self) -> None:
663 """Close Ollama client."""
664 if hasattr(self, '_session') and self._session:
665 await self._session.close()
666 self._is_initialized = False
668 async def validate_model(self) -> bool:
669 """Validate model availability."""
670 if not self._is_initialized or not hasattr(self, '_session'):
671 return False
673 try:
674 async with self._session.get(f"{self.base_url}/api/tags") as response:
675 if response.status == 200:
676 data = await response.json()
677 models = [m['name'] for m in data.get('models', [])]
678 # Check exact match or base model match
679 if self.config.model in models:
680 return True
681 base_model = self.config.model.split(':')[0]
682 return any(m.startswith(base_model) for m in models)
683 except Exception:
684 return False
685 return False
687 def get_capabilities(self) -> List[ModelCapability]:
688 """Get Ollama model capabilities."""
689 # Capabilities depend on the specific model
690 capabilities = [
691 ModelCapability.TEXT_GENERATION,
692 ModelCapability.STREAMING
693 ]
695 if 'llava' in self.config.model.lower():
696 capabilities.append(ModelCapability.VISION)
698 if 'codellama' in self.config.model.lower():
699 capabilities.append(ModelCapability.CODE)
701 return capabilities
703 async def complete(
704 self,
705 messages: Union[str, List[LLMMessage]],
706 **kwargs
707 ) -> LLMResponse:
708 """Generate completion."""
709 if not self._is_initialized:
710 await self.initialize()
712 # Convert to Ollama format
713 if isinstance(messages, str):
714 prompt = messages
715 else:
716 prompt = self._build_prompt(messages)
718 # Make API call
719 payload = {
720 'model': self.config.model,
721 'prompt': prompt,
722 'stream': False,
723 'options': self._build_options()
724 }
726 async with self._session.post(f"{self.base_url}/api/generate", json=payload) as response:
727 response.raise_for_status()
728 data = await response.json()
730 return LLMResponse(
731 content=data['response'],
732 model=self.config.model,
733 finish_reason='stop' if data.get('done') else 'length',
734 usage={
735 'prompt_tokens': data.get('prompt_eval_count', 0),
736 'completion_tokens': data.get('eval_count', 0),
737 'total_tokens': data.get('prompt_eval_count', 0) + data.get('eval_count', 0)
738 } if 'eval_count' in data else None,
739 metadata={
740 'eval_duration': data.get('eval_duration'),
741 'total_duration': data.get('total_duration')
742 }
743 )
745 async def stream_complete(
746 self,
747 messages: Union[str, List[LLMMessage]],
748 **kwargs
749 ) -> AsyncIterator[LLMStreamResponse]:
750 """Generate streaming completion."""
751 if not self._is_initialized:
752 await self.initialize()
754 # Convert to Ollama format
755 if isinstance(messages, str):
756 prompt = messages
757 else:
758 prompt = self._build_prompt(messages)
760 # Stream API call
761 payload = {
762 'model': self.config.model,
763 'prompt': prompt,
764 'stream': True,
765 'options': self._build_options()
766 }
768 async with self._session.post(f"{self.base_url}/api/generate", json=payload) as response:
769 response.raise_for_status()
771 async for line in response.content:
772 if line:
773 data = json.loads(line.decode('utf-8'))
774 yield LLMStreamResponse(
775 delta=data.get('response', ''),
776 is_final=data.get('done', False),
777 finish_reason='stop' if data.get('done') else None
778 )
780 async def embed(
781 self,
782 texts: Union[str, List[str]],
783 **kwargs
784 ) -> Union[List[float], List[List[float]]]:
785 """Generate embeddings."""
786 if not self._is_initialized:
787 await self.initialize()
789 if isinstance(texts, str):
790 texts = [texts]
791 single = True
792 else:
793 single = False
795 embeddings = []
796 for text in texts:
797 payload = {
798 'model': self.config.model,
799 'prompt': text
800 }
802 async with self._session.post(f"{self.base_url}/api/embeddings", json=payload) as response:
803 response.raise_for_status()
804 data = await response.json()
805 embeddings.append(data['embedding'])
807 return embeddings[0] if single else embeddings
809 async def function_call(
810 self,
811 messages: List[LLMMessage],
812 functions: List[Dict[str, Any]],
813 **kwargs
814 ) -> LLMResponse:
815 """Ollama doesn't have native function calling."""
816 # Similar to Anthropic, implement through prompting
817 function_descriptions = json.dumps(functions, indent=2)
819 system_prompt = f"""You have access to these functions:
820{function_descriptions}
822To call a function, respond with JSON:
823{{"function": "name", "arguments": {{...}}}}"""
825 messages_with_system = [
826 LLMMessage(role='system', content=system_prompt)
827 ] + messages
829 response = await self.complete(messages_with_system, **kwargs)
831 # Try to parse function call
832 try:
833 func_data = json.loads(response.content)
834 if 'function' in func_data:
835 response.function_call = {
836 'name': func_data['function'],
837 'arguments': func_data.get('arguments', {})
838 }
839 except json.JSONDecodeError:
840 pass
842 return response
844 def _build_prompt(self, messages: List[LLMMessage]) -> str:
845 """Build prompt from messages."""
846 prompt = ""
847 for msg in messages:
848 if msg.role == 'system':
849 prompt += f"System: {msg.content}\n\n"
850 elif msg.role == 'user':
851 prompt += f"User: {msg.content}\n\n"
852 elif msg.role == 'assistant':
853 prompt += f"Assistant: {msg.content}\n\n"
854 return prompt
857class HuggingFaceProvider(AsyncLLMProvider):
858 """HuggingFace Inference API provider."""
860 def __init__(self, config: LLMConfig):
861 super().__init__(config)
862 self.base_url = config.api_base or 'https://api-inference.huggingface.co/models'
864 async def initialize(self) -> None:
865 """Initialize HuggingFace client."""
866 try:
867 import aiohttp
869 api_key = self.config.api_key or os.environ.get('HUGGINGFACE_API_KEY')
870 if not api_key:
871 raise ValueError("HuggingFace API key not provided")
873 self._session = aiohttp.ClientSession(
874 headers={'Authorization': f'Bearer {api_key}'},
875 timeout=aiohttp.ClientTimeout(total=self.config.timeout)
876 )
877 self._is_initialized = True
878 except ImportError as e:
879 raise ImportError("aiohttp package not installed. Install with: pip install aiohttp") from e
881 async def close(self) -> None:
882 """Close HuggingFace client."""
883 if hasattr(self, '_session') and self._session:
884 await self._session.close()
885 self._is_initialized = False
887 async def validate_model(self) -> bool:
888 """Validate model availability."""
889 try:
890 url = f"{self.base_url}/{self.config.model}"
891 async with self._session.get(url) as response:
892 return response.status == 200
893 except Exception:
894 return False
896 def get_capabilities(self) -> List[ModelCapability]:
897 """Get HuggingFace model capabilities."""
898 # Basic capabilities for text generation models
899 return [
900 ModelCapability.TEXT_GENERATION,
901 ModelCapability.EMBEDDINGS if 'embedding' in self.config.model else None # type: ignore
902 ]
904 async def complete(
905 self,
906 messages: Union[str, List[LLMMessage]],
907 **kwargs
908 ) -> LLMResponse:
909 """Generate completion."""
910 if not self._is_initialized:
911 await self.initialize()
913 # Convert to prompt
914 if isinstance(messages, str):
915 prompt = messages
916 else:
917 prompt = self._build_prompt(messages)
919 # Make API call
920 url = f"{self.base_url}/{self.config.model}"
921 payload = {
922 'inputs': prompt,
923 'parameters': {
924 'temperature': self.config.temperature,
925 'top_p': self.config.top_p,
926 'max_new_tokens': self.config.max_tokens or 100,
927 'return_full_text': False
928 }
929 }
931 async with self._session.post(url, json=payload) as response:
932 response.raise_for_status()
933 data = await response.json()
935 # Parse response
936 if isinstance(data, list) and len(data) > 0:
937 text = data[0].get('generated_text', '')
938 else:
939 text = str(data)
941 return LLMResponse(
942 content=text,
943 model=self.config.model,
944 finish_reason='stop'
945 )
947 async def stream_complete(
948 self,
949 messages: Union[str, List[LLMMessage]],
950 **kwargs
951 ) -> AsyncIterator[LLMStreamResponse]:
952 """HuggingFace Inference API doesn't support streaming."""
953 # Simulate streaming by yielding complete response
954 response = await self.complete(messages, **kwargs)
955 yield LLMStreamResponse(
956 delta=response.content,
957 is_final=True,
958 finish_reason=response.finish_reason
959 )
961 async def embed(
962 self,
963 texts: Union[str, List[str]],
964 **kwargs
965 ) -> Union[List[float], List[List[float]]]:
966 """Generate embeddings."""
967 if not self._is_initialized:
968 await self.initialize()
970 if isinstance(texts, str):
971 texts = [texts]
972 single = True
973 else:
974 single = False
976 url = f"{self.base_url}/{self.config.model}"
977 payload = {'inputs': texts}
979 async with self._session.post(url, json=payload) as response:
980 response.raise_for_status()
981 embeddings = await response.json()
983 return embeddings[0] if single else embeddings
985 async def function_call(
986 self,
987 messages: List[LLMMessage],
988 functions: List[Dict[str, Any]],
989 **kwargs
990 ) -> LLMResponse:
991 """HuggingFace doesn't have native function calling."""
992 raise NotImplementedError("Function calling not supported for HuggingFace models")
994 def _build_prompt(self, messages: List[LLMMessage]) -> str:
995 """Build prompt from messages."""
996 prompt = ""
997 for msg in messages:
998 if msg.role == 'system':
999 prompt += f"{msg.content}\n\n"
1000 elif msg.role == 'user':
1001 prompt += f"User: {msg.content}\n"
1002 elif msg.role == 'assistant':
1003 prompt += f"Assistant: {msg.content}\n"
1004 return prompt
1007def create_llm_provider(
1008 config: LLMConfig,
1009 is_async: bool = True
1010) -> Union[AsyncLLMProvider, SyncLLMProvider]:
1011 """Create appropriate LLM provider based on configuration.
1013 Args:
1014 config: LLM configuration
1015 is_async: Whether to create async provider
1017 Returns:
1018 LLM provider instance
1019 """
1020 provider_map = {
1021 'openai': OpenAIProvider,
1022 'anthropic': AnthropicProvider,
1023 'ollama': OllamaProvider,
1024 'huggingface': HuggingFaceProvider,
1025 }
1027 provider_class = provider_map.get(config.provider.lower())
1028 if not provider_class:
1029 raise ValueError(f"Unknown provider: {config.provider}")
1031 if not is_async:
1032 # Wrap async provider in sync adapter
1033 async_provider = provider_class(config) # type: ignore
1034 return SyncProviderAdapter(async_provider) # type: ignore
1036 return provider_class(config) # type: ignore