Coverage for src/dataknobs_fsm/llm/providers.py: 0%

480 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""LLM provider implementations. 

2 

3This module provides implementations for various LLM providers. 

4""" 

5 

6import os 

7import json 

8from typing import Any, Dict, List, Union, AsyncIterator 

9 

10from .base import ( 

11 LLMConfig, LLMMessage, LLMResponse, LLMStreamResponse, 

12 AsyncLLMProvider, SyncLLMProvider, ModelCapability, 

13 LLMAdapter 

14) 

15 

16 

17class SyncProviderAdapter: 

18 """Sync adapter for async LLM providers.""" 

19 

20 def __init__(self, async_provider: AsyncLLMProvider): 

21 """Initialize with async provider. 

22  

23 Args: 

24 async_provider: The async provider to wrap. 

25 """ 

26 self.async_provider = async_provider 

27 

28 def initialize(self) -> None: 

29 """Initialize the provider synchronously.""" 

30 import asyncio 

31 try: 

32 loop = asyncio.get_event_loop() 

33 except RuntimeError: 

34 loop = asyncio.new_event_loop() 

35 asyncio.set_event_loop(loop) 

36 

37 return loop.run_until_complete(self.async_provider.initialize()) 

38 

39 def close(self) -> None: 

40 """Close the provider synchronously.""" 

41 import asyncio 

42 try: 

43 loop = asyncio.get_event_loop() 

44 except RuntimeError: 

45 loop = asyncio.new_event_loop() 

46 asyncio.set_event_loop(loop) 

47 

48 return loop.run_until_complete(self.async_provider.close()) 

49 

50 def complete( 

51 self, 

52 messages: Union[str, List[LLMMessage]], 

53 **kwargs 

54 ) -> LLMResponse: 

55 """Generate completion synchronously.""" 

56 import asyncio 

57 try: 

58 loop = asyncio.get_event_loop() 

59 except RuntimeError: 

60 loop = asyncio.new_event_loop() 

61 asyncio.set_event_loop(loop) 

62 

63 return loop.run_until_complete(self.async_provider.complete(messages, **kwargs)) 

64 

65 def stream( 

66 self, 

67 messages: Union[str, List[LLMMessage]], 

68 **kwargs 

69 ): 

70 """Stream completion synchronously.""" 

71 import asyncio 

72 try: 

73 loop = asyncio.get_event_loop() 

74 except RuntimeError: 

75 loop = asyncio.new_event_loop() 

76 asyncio.set_event_loop(loop) 

77 

78 async def _stream(): 

79 async for chunk in self.async_provider.stream(messages, **kwargs): 

80 yield chunk 

81 

82 # Convert async generator to sync generator 

83 async_gen = _stream() 

84 try: 

85 while True: 

86 try: 

87 yield loop.run_until_complete(async_gen.__anext__()) 

88 except StopAsyncIteration: 

89 break 

90 finally: 

91 loop.run_until_complete(async_gen.aclose()) 

92 

93 def embed( 

94 self, 

95 texts: Union[str, List[str]], 

96 **kwargs 

97 ) -> Union[List[float], List[List[float]]]: 

98 """Generate embeddings synchronously.""" 

99 import asyncio 

100 try: 

101 loop = asyncio.get_event_loop() 

102 except RuntimeError: 

103 loop = asyncio.new_event_loop() 

104 asyncio.set_event_loop(loop) 

105 

106 return loop.run_until_complete(self.async_provider.embed(texts, **kwargs)) 

107 

108 def function_call( 

109 self, 

110 messages: List[LLMMessage], 

111 functions: List[Dict[str, Any]], 

112 **kwargs 

113 ) -> LLMResponse: 

114 """Make function call synchronously.""" 

115 import asyncio 

116 try: 

117 loop = asyncio.get_event_loop() 

118 except RuntimeError: 

119 loop = asyncio.new_event_loop() 

120 asyncio.set_event_loop(loop) 

121 

122 return loop.run_until_complete(self.async_provider.function_call(messages, functions, **kwargs)) 

123 

124 def validate_model(self) -> bool: 

125 """Validate model synchronously.""" 

126 import asyncio 

127 try: 

128 loop = asyncio.get_event_loop() 

129 except RuntimeError: 

130 loop = asyncio.new_event_loop() 

131 asyncio.set_event_loop(loop) 

132 

133 return loop.run_until_complete(self.async_provider.validate_model()) # type: ignore 

134 

135 def get_capabilities(self) -> List[ModelCapability]: 

136 """Get capabilities synchronously.""" 

137 return self.async_provider.get_capabilities() 

138 

139 @property 

140 def is_initialized(self) -> bool: 

141 """Check if provider is initialized.""" 

142 return self.async_provider.is_initialized 

143 

144 

145class OpenAIAdapter(LLMAdapter): 

146 """Adapter for OpenAI API format.""" 

147 

148 def adapt_messages(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]: 

149 """Convert messages to OpenAI format.""" 

150 adapted = [] 

151 for msg in messages: 

152 message = { 

153 'role': msg.role, 

154 'content': msg.content 

155 } 

156 if msg.name: 

157 message['name'] = msg.name 

158 if msg.function_call: 

159 message['function_call'] = msg.function_call 

160 adapted.append(message) 

161 return adapted 

162 

163 def adapt_response(self, response: Any) -> LLMResponse: 

164 """Convert OpenAI response to standard format.""" 

165 choice = response.choices[0] 

166 message = choice.message 

167 

168 return LLMResponse( 

169 content=message.content or '', 

170 model=response.model, 

171 finish_reason=choice.finish_reason, 

172 usage={ 

173 'prompt_tokens': response.usage.prompt_tokens, 

174 'completion_tokens': response.usage.completion_tokens, 

175 'total_tokens': response.usage.total_tokens 

176 } if response.usage else None, 

177 function_call=message.function_call if hasattr(message, 'function_call') else None 

178 ) 

179 

180 def adapt_config(self, config: LLMConfig) -> Dict[str, Any]: 

181 """Convert config to OpenAI parameters.""" 

182 params = { 

183 'model': config.model, 

184 'temperature': config.temperature, 

185 'top_p': config.top_p, 

186 'frequency_penalty': config.frequency_penalty, 

187 'presence_penalty': config.presence_penalty, 

188 } 

189 

190 if config.max_tokens: 

191 params['max_tokens'] = config.max_tokens 

192 if config.stop_sequences: 

193 params['stop'] = config.stop_sequences 

194 if config.seed: 

195 params['seed'] = config.seed 

196 if config.logit_bias: 

197 params['logit_bias'] = config.logit_bias 

198 if config.user_id: 

199 params['user'] = config.user_id 

200 if config.response_format == 'json': 

201 params['response_format'] = {'type': 'json_object'} 

202 if config.functions: 

203 params['functions'] = config.functions 

204 if config.function_call: 

205 params['function_call'] = config.function_call 

206 

207 return params 

208 

209 

210class OpenAIProvider(AsyncLLMProvider): 

211 """OpenAI LLM provider.""" 

212 

213 def __init__(self, config: LLMConfig): 

214 super().__init__(config) 

215 self.adapter = OpenAIAdapter() 

216 

217 async def initialize(self) -> None: 

218 """Initialize OpenAI client.""" 

219 try: 

220 import openai 

221 

222 api_key = self.config.api_key or os.environ.get('OPENAI_API_KEY') 

223 if not api_key: 

224 raise ValueError("OpenAI API key not provided") 

225 

226 self._client = openai.AsyncOpenAI( 

227 api_key=api_key, 

228 base_url=self.config.api_base, 

229 timeout=self.config.timeout 

230 ) 

231 self._is_initialized = True 

232 except ImportError as e: 

233 raise ImportError("openai package not installed. Install with: pip install openai") from e 

234 

235 async def close(self) -> None: 

236 """Close OpenAI client.""" 

237 if self._client: 

238 await self._client.close() # type: ignore[unreachable] 

239 self._is_initialized = False 

240 

241 async def validate_model(self) -> bool: 

242 """Validate model availability.""" 

243 try: 

244 # List available models 

245 models = await self._client.models.list() 

246 model_ids = [m.id for m in models.data] 

247 return self.config.model in model_ids 

248 except Exception: 

249 return False 

250 

251 def get_capabilities(self) -> List[ModelCapability]: 

252 """Get OpenAI model capabilities.""" 

253 capabilities = [ 

254 ModelCapability.TEXT_GENERATION, 

255 ModelCapability.CHAT, 

256 ModelCapability.STREAMING 

257 ] 

258 

259 if 'gpt-4' in self.config.model or 'gpt-3.5' in self.config.model: 

260 capabilities.extend([ 

261 ModelCapability.FUNCTION_CALLING, 

262 ModelCapability.JSON_MODE 

263 ]) 

264 

265 if 'vision' in self.config.model: 

266 capabilities.append(ModelCapability.VISION) 

267 

268 if 'embedding' in self.config.model: 

269 capabilities.append(ModelCapability.EMBEDDINGS) 

270 

271 return capabilities 

272 

273 async def complete( 

274 self, 

275 messages: Union[str, List[LLMMessage]], 

276 **kwargs 

277 ) -> LLMResponse: 

278 """Generate completion.""" 

279 if not self._is_initialized: 

280 await self.initialize() 

281 

282 # Convert string to message list 

283 if isinstance(messages, str): 

284 messages = [LLMMessage(role='user', content=messages)] 

285 

286 # Add system prompt if configured 

287 if self.config.system_prompt and messages[0].role != 'system': 

288 messages.insert(0, LLMMessage(role='system', content=self.config.system_prompt)) 

289 

290 # Adapt messages and config 

291 adapted_messages = self.adapter.adapt_messages(messages) 

292 params = self.adapter.adapt_config(self.config) 

293 params.update(kwargs) 

294 

295 # Make API call 

296 response = await self._client.chat.completions.create( 

297 messages=adapted_messages, 

298 **params 

299 ) 

300 

301 return self.adapter.adapt_response(response) 

302 

303 async def stream_complete( 

304 self, 

305 messages: Union[str, List[LLMMessage]], 

306 **kwargs 

307 ) -> AsyncIterator[LLMStreamResponse]: 

308 """Generate streaming completion.""" 

309 if not self._is_initialized: 

310 await self.initialize() 

311 

312 # Convert string to message list 

313 if isinstance(messages, str): 

314 messages = [LLMMessage(role='user', content=messages)] 

315 

316 # Add system prompt if configured 

317 if self.config.system_prompt and messages[0].role != 'system': 

318 messages.insert(0, LLMMessage(role='system', content=self.config.system_prompt)) 

319 

320 # Adapt messages and config 

321 adapted_messages = self.adapter.adapt_messages(messages) 

322 params = self.adapter.adapt_config(self.config) 

323 params['stream'] = True 

324 params.update(kwargs) 

325 

326 # Stream API call 

327 stream = await self._client.chat.completions.create( 

328 messages=adapted_messages, 

329 **params 

330 ) 

331 

332 async for chunk in stream: 

333 if chunk.choices[0].delta.content: 

334 yield LLMStreamResponse( 

335 delta=chunk.choices[0].delta.content, 

336 is_final=chunk.choices[0].finish_reason is not None, 

337 finish_reason=chunk.choices[0].finish_reason 

338 ) 

339 

340 async def embed( 

341 self, 

342 texts: Union[str, List[str]], 

343 **kwargs 

344 ) -> Union[List[float], List[List[float]]]: 

345 """Generate embeddings.""" 

346 if not self._is_initialized: 

347 await self.initialize() 

348 

349 if isinstance(texts, str): 

350 texts = [texts] 

351 single = True 

352 else: 

353 single = False 

354 

355 response = await self._client.embeddings.create( 

356 input=texts, 

357 model=self.config.model or 'text-embedding-ada-002' 

358 ) 

359 

360 embeddings = [e.embedding for e in response.data] 

361 return embeddings[0] if single else embeddings 

362 

363 async def function_call( 

364 self, 

365 messages: List[LLMMessage], 

366 functions: List[Dict[str, Any]], 

367 **kwargs 

368 ) -> LLMResponse: 

369 """Execute function calling.""" 

370 if not self._is_initialized: 

371 await self.initialize() 

372 

373 # Add system prompt if configured 

374 if self.config.system_prompt and messages[0].role != 'system': 

375 messages.insert(0, LLMMessage(role='system', content=self.config.system_prompt)) 

376 

377 # Adapt messages and config 

378 adapted_messages = self.adapter.adapt_messages(messages) 

379 params = self.adapter.adapt_config(self.config) 

380 params['functions'] = functions 

381 params['function_call'] = kwargs.get('function_call', 'auto') 

382 params.update(kwargs) 

383 

384 # Make API call 

385 response = await self._client.chat.completions.create( 

386 messages=adapted_messages, 

387 **params 

388 ) 

389 

390 return self.adapter.adapt_response(response) 

391 

392 

393class AnthropicProvider(AsyncLLMProvider): 

394 """Anthropic Claude LLM provider.""" 

395 

396 def __init__(self, config: LLMConfig): 

397 super().__init__(config) 

398 

399 async def initialize(self) -> None: 

400 """Initialize Anthropic client.""" 

401 try: 

402 import anthropic 

403 

404 api_key = self.config.api_key or os.environ.get('ANTHROPIC_API_KEY') 

405 if not api_key: 

406 raise ValueError("Anthropic API key not provided") 

407 

408 self._client = anthropic.AsyncAnthropic( 

409 api_key=api_key, 

410 base_url=self.config.api_base, 

411 timeout=self.config.timeout 

412 ) 

413 self._is_initialized = True 

414 except ImportError as e: 

415 raise ImportError("anthropic package not installed. Install with: pip install anthropic") from e 

416 

417 async def close(self) -> None: 

418 """Close Anthropic client.""" 

419 if self._client: 

420 await self._client.close() # type: ignore[unreachable] 

421 self._is_initialized = False 

422 

423 async def validate_model(self) -> bool: 

424 """Validate model availability.""" 

425 valid_models = [ 

426 'claude-3-opus', 'claude-3-sonnet', 'claude-3-haiku', 

427 'claude-2.1', 'claude-2.0', 'claude-instant-1.2' 

428 ] 

429 return any(m in self.config.model for m in valid_models) 

430 

431 def get_capabilities(self) -> List[ModelCapability]: 

432 """Get Anthropic model capabilities.""" 

433 return [ 

434 ModelCapability.TEXT_GENERATION, 

435 ModelCapability.CHAT, 

436 ModelCapability.STREAMING, 

437 ModelCapability.CODE, 

438 ModelCapability.VISION if 'claude-3' in self.config.model else None # type: ignore 

439 ] 

440 

441 async def complete( 

442 self, 

443 messages: Union[str, List[LLMMessage]], 

444 **kwargs 

445 ) -> LLMResponse: 

446 """Generate completion.""" 

447 if not self._is_initialized: 

448 await self.initialize() 

449 

450 # Convert to Anthropic format 

451 if isinstance(messages, str): 

452 prompt = messages 

453 else: 

454 # Build prompt from messages 

455 prompt = "" 

456 for msg in messages: 

457 if msg.role == 'system': 

458 prompt = msg.content + "\n\n" + prompt 

459 elif msg.role == 'user': 

460 prompt += f"\n\nHuman: {msg.content}" 

461 elif msg.role == 'assistant': 

462 prompt += f"\n\nAssistant: {msg.content}" 

463 prompt += "\n\nAssistant:" 

464 

465 # Make API call 

466 response = await self._client.messages.create( 

467 model=self.config.model, 

468 messages=[{"role": "user", "content": prompt}], 

469 max_tokens=self.config.max_tokens or 1024, 

470 temperature=self.config.temperature, 

471 top_p=self.config.top_p, 

472 stop_sequences=self.config.stop_sequences 

473 ) 

474 

475 return LLMResponse( 

476 content=response.content[0].text, 

477 model=response.model, 

478 finish_reason=response.stop_reason, 

479 usage={ 

480 'prompt_tokens': response.usage.input_tokens, 

481 'completion_tokens': response.usage.output_tokens, 

482 'total_tokens': response.usage.input_tokens + response.usage.output_tokens 

483 } if hasattr(response, 'usage') else None 

484 ) 

485 

486 async def stream_complete( 

487 self, 

488 messages: Union[str, List[LLMMessage]], 

489 **kwargs 

490 ) -> AsyncIterator[LLMStreamResponse]: 

491 """Generate streaming completion.""" 

492 if not self._is_initialized: 

493 await self.initialize() 

494 

495 # Convert to Anthropic format 

496 if isinstance(messages, str): 

497 prompt = messages 

498 else: 

499 prompt = self._build_prompt(messages) 

500 

501 # Stream API call 

502 async with self._client.messages.stream( 

503 model=self.config.model, 

504 messages=[{"role": "user", "content": prompt}], 

505 max_tokens=self.config.max_tokens or 1024, 

506 temperature=self.config.temperature 

507 ) as stream: 

508 async for chunk in stream: 

509 if chunk.type == 'content_block_delta': 

510 yield LLMStreamResponse( 

511 delta=chunk.delta.text, 

512 is_final=False 

513 ) 

514 

515 # Final message 

516 message = await stream.get_final_message() 

517 yield LLMStreamResponse( 

518 delta='', 

519 is_final=True, 

520 finish_reason=message.stop_reason 

521 ) 

522 

523 async def embed( 

524 self, 

525 texts: Union[str, List[str]], 

526 **kwargs 

527 ) -> Union[List[float], List[List[float]]]: 

528 """Anthropic doesn't provide embeddings.""" 

529 raise NotImplementedError("Anthropic doesn't provide embedding models") 

530 

531 async def function_call( 

532 self, 

533 messages: List[LLMMessage], 

534 functions: List[Dict[str, Any]], 

535 **kwargs 

536 ) -> LLMResponse: 

537 """Anthropic doesn't have native function calling.""" 

538 # Implement function calling through prompting 

539 function_descriptions = "\n".join([ 

540 f"- {f['name']}: {f['description']}" 

541 for f in functions 

542 ]) 

543 

544 system_prompt = f"""You have access to the following functions: 

545{function_descriptions} 

546 

547When you need to call a function, respond with: 

548FUNCTION_CALL: {{ 

549 "name": "function_name", 

550 "arguments": {{...}} 

551}}""" 

552 

553 messages_with_system = [ 

554 LLMMessage(role='system', content=system_prompt) 

555 ] + messages 

556 

557 response = await self.complete(messages_with_system, **kwargs) 

558 

559 # Parse function call from response 

560 if 'FUNCTION_CALL:' in response.content: 

561 try: 

562 func_json = response.content.split('FUNCTION_CALL:')[1].strip() 

563 function_call = json.loads(func_json) 

564 response.function_call = function_call 

565 except (json.JSONDecodeError, IndexError): 

566 pass 

567 

568 return response 

569 

570 def _build_prompt(self, messages: List[LLMMessage]) -> str: 

571 """Build Anthropic-style prompt from messages.""" 

572 prompt = "" 

573 for msg in messages: 

574 if msg.role == 'system': 

575 prompt = msg.content + "\n\n" + prompt 

576 elif msg.role == 'user': 

577 prompt += f"\n\nHuman: {msg.content}" 

578 elif msg.role == 'assistant': 

579 prompt += f"\n\nAssistant: {msg.content}" 

580 prompt += "\n\nAssistant:" 

581 return prompt 

582 

583 

584class OllamaProvider(AsyncLLMProvider): 

585 """Ollama local LLM provider.""" 

586 

587 def __init__(self, config: LLMConfig): 

588 super().__init__(config) 

589 # Check for Docker environment and adjust URL accordingly 

590 default_url = 'http://localhost:11434' 

591 if os.path.exists('/.dockerenv'): 

592 # Running in Docker, use host.docker.internal 

593 default_url = 'http://host.docker.internal:11434' 

594 

595 # Allow environment variable override 

596 self.base_url = config.api_base or os.environ.get('OLLAMA_BASE_URL', default_url) 

597 

598 def _build_options(self) -> Dict[str, Any]: 

599 """Build options dict for Ollama API calls. 

600 

601 Returns: 

602 Dictionary of options for the API request. 

603 """ 

604 options: Dict[str, Any] = { 

605 'temperature': self.config.temperature, 

606 'top_p': self.config.top_p 

607 } 

608 

609 if self.config.seed is not None: 

610 options['seed'] = self.config.seed 

611 

612 if self.config.max_tokens: 

613 options['num_predict'] = self.config.max_tokens # type: ignore 

614 

615 if self.config.stop_sequences: 

616 options['stop'] = self.config.stop_sequences # type: ignore 

617 

618 return options 

619 

620 async def initialize(self) -> None: 

621 """Initialize Ollama client.""" 

622 try: 

623 import aiohttp 

624 self._session = aiohttp.ClientSession( 

625 timeout=aiohttp.ClientTimeout(total=self.config.timeout or 30.0) 

626 ) 

627 

628 # Test connection and verify model availability 

629 try: 

630 async with self._session.get(f"{self.base_url}/api/tags") as response: 

631 if response.status == 200: 

632 data = await response.json() 

633 models = [m['name'] for m in data.get('models', [])] 

634 if models: 

635 # Check if configured model is available 

636 if self.config.model not in models: 

637 # Try without tag (e.g., 'llama2' instead of 'llama2:latest') 

638 base_model = self.config.model.split(':')[0] 

639 matching_models = [m for m in models if m.startswith(base_model)] 

640 if matching_models: 

641 # Use first matching model 

642 self.config.model = matching_models[0] 

643 import logging 

644 logging.info(f"Ollama: Using model {self.config.model}") 

645 else: 

646 import logging 

647 logging.warning(f"Ollama: Model {self.config.model} not found. Available: {models}") 

648 else: 

649 import logging 

650 logging.warning("Ollama: No models found. Please pull a model first.") 

651 else: 

652 import logging 

653 logging.warning(f"Ollama: API returned status {response.status}") 

654 except aiohttp.ClientError as e: 

655 import logging 

656 logging.warning(f"Ollama: Could not connect to {self.base_url}: {e}") 

657 

658 self._is_initialized = True 

659 except ImportError as e: 

660 raise ImportError("aiohttp package not installed. Install with: pip install aiohttp") from e 

661 

662 async def close(self) -> None: 

663 """Close Ollama client.""" 

664 if hasattr(self, '_session') and self._session: 

665 await self._session.close() 

666 self._is_initialized = False 

667 

668 async def validate_model(self) -> bool: 

669 """Validate model availability.""" 

670 if not self._is_initialized or not hasattr(self, '_session'): 

671 return False 

672 

673 try: 

674 async with self._session.get(f"{self.base_url}/api/tags") as response: 

675 if response.status == 200: 

676 data = await response.json() 

677 models = [m['name'] for m in data.get('models', [])] 

678 # Check exact match or base model match 

679 if self.config.model in models: 

680 return True 

681 base_model = self.config.model.split(':')[0] 

682 return any(m.startswith(base_model) for m in models) 

683 except Exception: 

684 return False 

685 return False 

686 

687 def get_capabilities(self) -> List[ModelCapability]: 

688 """Get Ollama model capabilities.""" 

689 # Capabilities depend on the specific model 

690 capabilities = [ 

691 ModelCapability.TEXT_GENERATION, 

692 ModelCapability.STREAMING 

693 ] 

694 

695 if 'llava' in self.config.model.lower(): 

696 capabilities.append(ModelCapability.VISION) 

697 

698 if 'codellama' in self.config.model.lower(): 

699 capabilities.append(ModelCapability.CODE) 

700 

701 return capabilities 

702 

703 async def complete( 

704 self, 

705 messages: Union[str, List[LLMMessage]], 

706 **kwargs 

707 ) -> LLMResponse: 

708 """Generate completion.""" 

709 if not self._is_initialized: 

710 await self.initialize() 

711 

712 # Convert to Ollama format 

713 if isinstance(messages, str): 

714 prompt = messages 

715 else: 

716 prompt = self._build_prompt(messages) 

717 

718 # Make API call 

719 payload = { 

720 'model': self.config.model, 

721 'prompt': prompt, 

722 'stream': False, 

723 'options': self._build_options() 

724 } 

725 

726 async with self._session.post(f"{self.base_url}/api/generate", json=payload) as response: 

727 response.raise_for_status() 

728 data = await response.json() 

729 

730 return LLMResponse( 

731 content=data['response'], 

732 model=self.config.model, 

733 finish_reason='stop' if data.get('done') else 'length', 

734 usage={ 

735 'prompt_tokens': data.get('prompt_eval_count', 0), 

736 'completion_tokens': data.get('eval_count', 0), 

737 'total_tokens': data.get('prompt_eval_count', 0) + data.get('eval_count', 0) 

738 } if 'eval_count' in data else None, 

739 metadata={ 

740 'eval_duration': data.get('eval_duration'), 

741 'total_duration': data.get('total_duration') 

742 } 

743 ) 

744 

745 async def stream_complete( 

746 self, 

747 messages: Union[str, List[LLMMessage]], 

748 **kwargs 

749 ) -> AsyncIterator[LLMStreamResponse]: 

750 """Generate streaming completion.""" 

751 if not self._is_initialized: 

752 await self.initialize() 

753 

754 # Convert to Ollama format 

755 if isinstance(messages, str): 

756 prompt = messages 

757 else: 

758 prompt = self._build_prompt(messages) 

759 

760 # Stream API call 

761 payload = { 

762 'model': self.config.model, 

763 'prompt': prompt, 

764 'stream': True, 

765 'options': self._build_options() 

766 } 

767 

768 async with self._session.post(f"{self.base_url}/api/generate", json=payload) as response: 

769 response.raise_for_status() 

770 

771 async for line in response.content: 

772 if line: 

773 data = json.loads(line.decode('utf-8')) 

774 yield LLMStreamResponse( 

775 delta=data.get('response', ''), 

776 is_final=data.get('done', False), 

777 finish_reason='stop' if data.get('done') else None 

778 ) 

779 

780 async def embed( 

781 self, 

782 texts: Union[str, List[str]], 

783 **kwargs 

784 ) -> Union[List[float], List[List[float]]]: 

785 """Generate embeddings.""" 

786 if not self._is_initialized: 

787 await self.initialize() 

788 

789 if isinstance(texts, str): 

790 texts = [texts] 

791 single = True 

792 else: 

793 single = False 

794 

795 embeddings = [] 

796 for text in texts: 

797 payload = { 

798 'model': self.config.model, 

799 'prompt': text 

800 } 

801 

802 async with self._session.post(f"{self.base_url}/api/embeddings", json=payload) as response: 

803 response.raise_for_status() 

804 data = await response.json() 

805 embeddings.append(data['embedding']) 

806 

807 return embeddings[0] if single else embeddings 

808 

809 async def function_call( 

810 self, 

811 messages: List[LLMMessage], 

812 functions: List[Dict[str, Any]], 

813 **kwargs 

814 ) -> LLMResponse: 

815 """Ollama doesn't have native function calling.""" 

816 # Similar to Anthropic, implement through prompting 

817 function_descriptions = json.dumps(functions, indent=2) 

818 

819 system_prompt = f"""You have access to these functions: 

820{function_descriptions} 

821 

822To call a function, respond with JSON: 

823{{"function": "name", "arguments": {{...}}}}""" 

824 

825 messages_with_system = [ 

826 LLMMessage(role='system', content=system_prompt) 

827 ] + messages 

828 

829 response = await self.complete(messages_with_system, **kwargs) 

830 

831 # Try to parse function call 

832 try: 

833 func_data = json.loads(response.content) 

834 if 'function' in func_data: 

835 response.function_call = { 

836 'name': func_data['function'], 

837 'arguments': func_data.get('arguments', {}) 

838 } 

839 except json.JSONDecodeError: 

840 pass 

841 

842 return response 

843 

844 def _build_prompt(self, messages: List[LLMMessage]) -> str: 

845 """Build prompt from messages.""" 

846 prompt = "" 

847 for msg in messages: 

848 if msg.role == 'system': 

849 prompt += f"System: {msg.content}\n\n" 

850 elif msg.role == 'user': 

851 prompt += f"User: {msg.content}\n\n" 

852 elif msg.role == 'assistant': 

853 prompt += f"Assistant: {msg.content}\n\n" 

854 return prompt 

855 

856 

857class HuggingFaceProvider(AsyncLLMProvider): 

858 """HuggingFace Inference API provider.""" 

859 

860 def __init__(self, config: LLMConfig): 

861 super().__init__(config) 

862 self.base_url = config.api_base or 'https://api-inference.huggingface.co/models' 

863 

864 async def initialize(self) -> None: 

865 """Initialize HuggingFace client.""" 

866 try: 

867 import aiohttp 

868 

869 api_key = self.config.api_key or os.environ.get('HUGGINGFACE_API_KEY') 

870 if not api_key: 

871 raise ValueError("HuggingFace API key not provided") 

872 

873 self._session = aiohttp.ClientSession( 

874 headers={'Authorization': f'Bearer {api_key}'}, 

875 timeout=aiohttp.ClientTimeout(total=self.config.timeout) 

876 ) 

877 self._is_initialized = True 

878 except ImportError as e: 

879 raise ImportError("aiohttp package not installed. Install with: pip install aiohttp") from e 

880 

881 async def close(self) -> None: 

882 """Close HuggingFace client.""" 

883 if hasattr(self, '_session') and self._session: 

884 await self._session.close() 

885 self._is_initialized = False 

886 

887 async def validate_model(self) -> bool: 

888 """Validate model availability.""" 

889 try: 

890 url = f"{self.base_url}/{self.config.model}" 

891 async with self._session.get(url) as response: 

892 return response.status == 200 

893 except Exception: 

894 return False 

895 

896 def get_capabilities(self) -> List[ModelCapability]: 

897 """Get HuggingFace model capabilities.""" 

898 # Basic capabilities for text generation models 

899 return [ 

900 ModelCapability.TEXT_GENERATION, 

901 ModelCapability.EMBEDDINGS if 'embedding' in self.config.model else None # type: ignore 

902 ] 

903 

904 async def complete( 

905 self, 

906 messages: Union[str, List[LLMMessage]], 

907 **kwargs 

908 ) -> LLMResponse: 

909 """Generate completion.""" 

910 if not self._is_initialized: 

911 await self.initialize() 

912 

913 # Convert to prompt 

914 if isinstance(messages, str): 

915 prompt = messages 

916 else: 

917 prompt = self._build_prompt(messages) 

918 

919 # Make API call 

920 url = f"{self.base_url}/{self.config.model}" 

921 payload = { 

922 'inputs': prompt, 

923 'parameters': { 

924 'temperature': self.config.temperature, 

925 'top_p': self.config.top_p, 

926 'max_new_tokens': self.config.max_tokens or 100, 

927 'return_full_text': False 

928 } 

929 } 

930 

931 async with self._session.post(url, json=payload) as response: 

932 response.raise_for_status() 

933 data = await response.json() 

934 

935 # Parse response 

936 if isinstance(data, list) and len(data) > 0: 

937 text = data[0].get('generated_text', '') 

938 else: 

939 text = str(data) 

940 

941 return LLMResponse( 

942 content=text, 

943 model=self.config.model, 

944 finish_reason='stop' 

945 ) 

946 

947 async def stream_complete( 

948 self, 

949 messages: Union[str, List[LLMMessage]], 

950 **kwargs 

951 ) -> AsyncIterator[LLMStreamResponse]: 

952 """HuggingFace Inference API doesn't support streaming.""" 

953 # Simulate streaming by yielding complete response 

954 response = await self.complete(messages, **kwargs) 

955 yield LLMStreamResponse( 

956 delta=response.content, 

957 is_final=True, 

958 finish_reason=response.finish_reason 

959 ) 

960 

961 async def embed( 

962 self, 

963 texts: Union[str, List[str]], 

964 **kwargs 

965 ) -> Union[List[float], List[List[float]]]: 

966 """Generate embeddings.""" 

967 if not self._is_initialized: 

968 await self.initialize() 

969 

970 if isinstance(texts, str): 

971 texts = [texts] 

972 single = True 

973 else: 

974 single = False 

975 

976 url = f"{self.base_url}/{self.config.model}" 

977 payload = {'inputs': texts} 

978 

979 async with self._session.post(url, json=payload) as response: 

980 response.raise_for_status() 

981 embeddings = await response.json() 

982 

983 return embeddings[0] if single else embeddings 

984 

985 async def function_call( 

986 self, 

987 messages: List[LLMMessage], 

988 functions: List[Dict[str, Any]], 

989 **kwargs 

990 ) -> LLMResponse: 

991 """HuggingFace doesn't have native function calling.""" 

992 raise NotImplementedError("Function calling not supported for HuggingFace models") 

993 

994 def _build_prompt(self, messages: List[LLMMessage]) -> str: 

995 """Build prompt from messages.""" 

996 prompt = "" 

997 for msg in messages: 

998 if msg.role == 'system': 

999 prompt += f"{msg.content}\n\n" 

1000 elif msg.role == 'user': 

1001 prompt += f"User: {msg.content}\n" 

1002 elif msg.role == 'assistant': 

1003 prompt += f"Assistant: {msg.content}\n" 

1004 return prompt 

1005 

1006 

1007def create_llm_provider( 

1008 config: LLMConfig, 

1009 is_async: bool = True 

1010) -> Union[AsyncLLMProvider, SyncLLMProvider]: 

1011 """Create appropriate LLM provider based on configuration. 

1012  

1013 Args: 

1014 config: LLM configuration 

1015 is_async: Whether to create async provider 

1016  

1017 Returns: 

1018 LLM provider instance 

1019 """ 

1020 provider_map = { 

1021 'openai': OpenAIProvider, 

1022 'anthropic': AnthropicProvider, 

1023 'ollama': OllamaProvider, 

1024 'huggingface': HuggingFaceProvider, 

1025 } 

1026 

1027 provider_class = provider_map.get(config.provider.lower()) 

1028 if not provider_class: 

1029 raise ValueError(f"Unknown provider: {config.provider}") 

1030 

1031 if not is_async: 

1032 # Wrap async provider in sync adapter 

1033 async_provider = provider_class(config) # type: ignore 

1034 return SyncProviderAdapter(async_provider) # type: ignore 

1035 

1036 return provider_class(config) # type: ignore