Coverage for src/dataknobs_fsm/resources/llm.py: 0%
301 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
1"""LLM resource provider for language model interactions."""
3import json
4import os
5import time
6from dataclasses import dataclass, field as dataclass_field
7from typing import Any, Dict, List, Union
8from enum import Enum
10from dataknobs_fsm.functions.base import ResourceError
11from dataknobs_fsm.resources.base import (
12 BaseResourceProvider,
13 ResourceHealth,
14 ResourceStatus,
15)
18class LLMProvider(Enum):
19 """Supported LLM providers."""
21 OPENAI = "openai"
22 ANTHROPIC = "anthropic"
23 OLLAMA = "ollama"
24 HUGGINGFACE = "huggingface"
25 HUGGINGFACE_INFERENCE = "huggingface_inference" # HF Inference API
26 CUSTOM = "custom"
29@dataclass
30class LLMSession:
31 """LLM session with configuration and state."""
33 provider: LLMProvider
34 model_name: str
35 api_key: str | None = None
36 endpoint: str | None = None
37 temperature: float = 0.7
38 max_tokens: int = 1000
39 top_p: float = 1.0
40 frequency_penalty: float = 0.0
41 presence_penalty: float = 0.0
43 # Rate limiting (mainly for commercial APIs)
44 requests_per_minute: int = 60
45 tokens_per_minute: int = 90000
46 request_count: int = 0
47 token_count: int = 0
48 window_start: float = dataclass_field(default_factory=time.time)
50 # Token tracking
51 total_prompt_tokens: int = 0
52 total_completion_tokens: int = 0
53 total_requests: int = 0
55 # Provider-specific settings
56 provider_config: Dict[str, Any] = dataclass_field(default_factory=dict)
58 def check_rate_limits(self, estimated_tokens: int = 0) -> bool:
59 """Check if request would exceed rate limits.
61 Args:
62 estimated_tokens: Estimated tokens for the request.
64 Returns:
65 True if request can proceed, False if rate limited.
66 """
67 # Local providers don't have rate limits
68 if self.provider in [LLMProvider.OLLAMA, LLMProvider.HUGGINGFACE]:
69 return True
71 current_time = time.time()
72 window_elapsed = current_time - self.window_start
74 # Reset window if a minute has passed
75 if window_elapsed >= 60:
76 self.request_count = 0
77 self.token_count = 0
78 self.window_start = current_time
79 return True
81 # Check limits
82 if self.request_count >= self.requests_per_minute:
83 return False
85 if self.token_count + estimated_tokens > self.tokens_per_minute:
86 return False
88 return True
90 def record_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
91 """Record token usage.
93 Args:
94 prompt_tokens: Number of prompt tokens used.
95 completion_tokens: Number of completion tokens generated.
96 """
97 total_tokens = prompt_tokens + completion_tokens
99 self.request_count += 1
100 self.token_count += total_tokens
101 self.total_requests += 1
102 self.total_prompt_tokens += prompt_tokens
103 self.total_completion_tokens += completion_tokens
106class LLMResource(BaseResourceProvider):
107 """LLM resource provider for language model operations.
109 Supports multiple providers:
110 - OpenAI: GPT models via OpenAI API
111 - Anthropic: Claude models via Anthropic API
112 - Ollama: Local models via Ollama
113 - HuggingFace: Local transformers or Inference API
114 """
116 def __init__(
117 self,
118 name: str,
119 provider: Union[str, LLMProvider] = "ollama",
120 model: str = "llama2",
121 api_key: str | None = None,
122 endpoint: str | None = None,
123 **config
124 ):
125 """Initialize LLM resource.
127 Args:
128 name: Resource name.
129 provider: LLM provider (ollama, openai, anthropic, huggingface, etc).
130 model: Model name/identifier.
131 api_key: API key for commercial providers.
132 endpoint: Custom endpoint URL.
133 **config: Additional configuration.
134 """
135 super().__init__(name, config)
137 # Convert string to enum
138 if isinstance(provider, str):
139 try:
140 self.provider = LLMProvider(provider.lower())
141 except ValueError:
142 self.provider = LLMProvider.CUSTOM
143 else:
144 self.provider = provider
146 self.model = model
147 self.api_key = api_key
148 self.endpoint = endpoint or self._get_default_endpoint()
150 # Initialize provider-specific clients
151 self._client = None
152 self._initialize_client()
154 self._sessions = {}
155 self.status = ResourceStatus.IDLE
157 def _get_default_endpoint(self) -> str | None:
158 """Get default endpoint for provider.
160 Returns:
161 Default endpoint URL or None.
162 """
163 defaults = {
164 LLMProvider.OPENAI: "https://api.openai.com/v1",
165 LLMProvider.ANTHROPIC: "https://api.anthropic.com/v1",
166 LLMProvider.OLLAMA: "http://localhost:11434",
167 LLMProvider.HUGGINGFACE_INFERENCE: "https://api-inference.huggingface.co/models",
168 }
169 return defaults.get(self.provider)
171 def _initialize_client(self) -> None:
172 """Initialize provider-specific client."""
173 try:
174 if self.provider == LLMProvider.OLLAMA:
175 # Ollama uses HTTP API, no special client needed
176 # Just verify endpoint is accessible
177 import urllib.request
178 try:
179 req = urllib.request.Request(f"{self.endpoint}/api/tags")
180 with urllib.request.urlopen(req, timeout=5) as response:
181 if response.status == 200:
182 self.status = ResourceStatus.IDLE
183 except Exception:
184 # Ollama might not be running yet, that's ok
185 self.status = ResourceStatus.IDLE
187 elif self.provider == LLMProvider.HUGGINGFACE:
188 # For local HuggingFace transformers
189 # We'll lazy-load the model when needed
190 self.status = ResourceStatus.IDLE
192 elif self.provider in [LLMProvider.OPENAI, LLMProvider.ANTHROPIC]:
193 # Commercial APIs - just verify we have API key
194 if not self.api_key:
195 raise ResourceError(
196 f"{self.provider.value} requires an API key",
197 resource_name=self.name,
198 operation="initialize"
199 )
200 self.status = ResourceStatus.IDLE
202 else:
203 self.status = ResourceStatus.IDLE
205 except Exception as e:
206 self.status = ResourceStatus.ERROR
207 raise ResourceError(
208 f"Failed to initialize {self.provider.value} client: {e}",
209 resource_name=self.name,
210 operation="initialize"
211 ) from e
213 def acquire(self, **kwargs) -> LLMSession:
214 """Acquire an LLM session.
216 Args:
217 **kwargs: Session configuration overrides.
219 Returns:
220 LLMSession instance.
222 Raises:
223 ResourceError: If acquisition fails.
224 """
225 try:
226 # Set provider-specific defaults
227 if self.provider == LLMProvider.OLLAMA:
228 # Ollama defaults
229 kwargs.setdefault("temperature", 0.8)
230 kwargs.setdefault("requests_per_minute", 0) # No limit
231 kwargs.setdefault("tokens_per_minute", 0) # No limit
233 elif self.provider == LLMProvider.HUGGINGFACE:
234 # HuggingFace local defaults
235 kwargs.setdefault("device", "cpu") # or "cuda" if available
236 kwargs.setdefault("requests_per_minute", 0) # No limit
238 session = LLMSession(
239 provider=self.provider,
240 model_name=kwargs.get("model", self.model),
241 api_key=kwargs.get("api_key", self.api_key),
242 endpoint=kwargs.get("endpoint", self.endpoint),
243 temperature=kwargs.get("temperature", 0.7),
244 max_tokens=kwargs.get("max_tokens", 1000),
245 top_p=kwargs.get("top_p", 1.0),
246 frequency_penalty=kwargs.get("frequency_penalty", 0.0),
247 presence_penalty=kwargs.get("presence_penalty", 0.0),
248 requests_per_minute=kwargs.get("requests_per_minute", 60),
249 tokens_per_minute=kwargs.get("tokens_per_minute", 90000),
250 provider_config=kwargs.get("provider_config", {})
251 )
253 session_id = id(session)
254 self._sessions[session_id] = session
255 self._resources.append(session)
257 self.status = ResourceStatus.ACTIVE
258 return session
260 except Exception as e:
261 self.status = ResourceStatus.ERROR
262 raise ResourceError(
263 f"Failed to acquire LLM session: {e}",
264 resource_name=self.name,
265 operation="acquire"
266 ) from e
268 def release(self, resource: Any) -> None:
269 """Release an LLM session.
271 Args:
272 resource: The LLMSession to release.
273 """
274 if isinstance(resource, LLMSession):
275 session_id = id(resource)
276 if session_id in self._sessions:
277 del self._sessions[session_id]
279 if resource in self._resources:
280 self._resources.remove(resource)
282 if not self._resources:
283 self.status = ResourceStatus.IDLE
285 def validate(self, resource: Any) -> bool:
286 """Validate an LLM session.
288 Args:
289 resource: The LLMSession to validate.
291 Returns:
292 True if the session is valid.
293 """
294 if not isinstance(resource, LLMSession):
295 return False
297 # Check if API key is set for commercial providers
298 if resource.provider in [LLMProvider.OPENAI, LLMProvider.ANTHROPIC,
299 LLMProvider.HUGGINGFACE_INFERENCE]:
300 if not resource.api_key:
301 return False
303 return True
305 def health_check(self) -> ResourceHealth:
306 """Check LLM service health.
308 Returns:
309 Health status.
310 """
311 session = None
312 try:
313 session = self.acquire()
315 if session.provider == LLMProvider.OLLAMA:
316 # Check Ollama API
317 import urllib.request
318 req = urllib.request.Request(f"{session.endpoint}/api/tags")
319 with urllib.request.urlopen(req, timeout=5) as response:
320 if response.status == 200:
321 self.metrics.record_health_check(True)
322 return ResourceHealth.HEALTHY
324 elif session.provider == LLMProvider.HUGGINGFACE:
325 # For local HF, just check if transformers is available
326 try:
327 import importlib.util
328 if importlib.util.find_spec('transformers'):
329 self.metrics.record_health_check(True)
330 return ResourceHealth.HEALTHY
331 else:
332 self.metrics.record_health_check(False)
333 return ResourceHealth.UNHEALTHY
334 except ImportError:
335 self.metrics.record_health_check(False)
336 return ResourceHealth.UNHEALTHY
338 else:
339 # For commercial APIs, assume healthy if session is valid
340 if self.validate(session):
341 self.metrics.record_health_check(True)
342 return ResourceHealth.HEALTHY
344 except Exception:
345 self.metrics.record_health_check(False)
346 return ResourceHealth.UNHEALTHY
347 finally:
348 if session:
349 self.release(session)
351 return ResourceHealth.UNKNOWN
353 def complete(
354 self,
355 prompt: str,
356 session: LLMSession | None = None,
357 **kwargs
358 ) -> Dict[str, Any]:
359 """Generate a completion for the given prompt.
361 Args:
362 prompt: Input prompt.
363 session: Optional session to use.
364 **kwargs: Additional parameters.
366 Returns:
367 Completion response with text and metadata.
368 """
369 if session is None:
370 session = self.acquire()
371 should_release = True
372 else:
373 should_release = False
375 try:
376 # Route to appropriate provider
377 if session.provider == LLMProvider.OLLAMA:
378 response = self._ollama_complete(session, prompt, **kwargs)
379 elif session.provider == LLMProvider.HUGGINGFACE:
380 response = self._huggingface_complete(session, prompt, **kwargs)
381 elif session.provider == LLMProvider.OPENAI:
382 response = self._openai_complete(session, prompt, **kwargs)
383 elif session.provider == LLMProvider.ANTHROPIC:
384 response = self._anthropic_complete(session, prompt, **kwargs)
385 else:
386 response = self._custom_complete(session, prompt, **kwargs)
388 # Record usage if available
389 if "usage" in response:
390 prompt_tokens = response["usage"].get("prompt_tokens", 0)
391 completion_tokens = response["usage"].get("completion_tokens", 0)
392 session.record_usage(prompt_tokens, completion_tokens)
394 return response
396 finally:
397 if should_release:
398 self.release(session)
400 def _ollama_complete(
401 self,
402 session: LLMSession,
403 prompt: str,
404 **kwargs
405 ) -> Dict[str, Any]:
406 """Ollama completion.
408 Args:
409 session: LLM session.
410 prompt: Input prompt.
411 **kwargs: Additional parameters.
413 Returns:
414 Completion response.
415 """
416 import urllib.request
417 import urllib.parse
419 data = {
420 "model": session.model_name,
421 "prompt": prompt,
422 "temperature": kwargs.get("temperature", session.temperature),
423 "max_tokens": kwargs.get("max_tokens", session.max_tokens),
424 "stream": False
425 }
427 req = urllib.request.Request(
428 f"{session.endpoint}/api/generate",
429 data=json.dumps(data).encode("utf-8"),
430 headers={"Content-Type": "application/json"}
431 )
433 with urllib.request.urlopen(req) as response:
434 result = json.loads(response.read())
436 return {
437 "choices": [{
438 "text": result.get("response", ""),
439 "index": 0,
440 "finish_reason": "stop" if result.get("done") else "length"
441 }],
442 "model": session.model_name,
443 "usage": {
444 "prompt_tokens": result.get("prompt_eval_count", 0),
445 "completion_tokens": result.get("eval_count", 0),
446 "total_tokens": result.get("prompt_eval_count", 0) + result.get("eval_count", 0)
447 }
448 }
450 def _huggingface_complete(
451 self,
452 session: LLMSession,
453 prompt: str,
454 **kwargs
455 ) -> Dict[str, Any]:
456 """HuggingFace local completion.
458 Args:
459 session: LLM session.
460 prompt: Input prompt.
461 **kwargs: Additional parameters.
463 Returns:
464 Completion response.
465 """
466 # This would use transformers library for local inference
467 # Placeholder for now
468 try:
469 from transformers import pipeline
471 # Lazy load the model
472 pipe = pipeline(
473 "text-generation",
474 model=session.model_name,
475 device=session.provider_config.get("device", "cpu")
476 )
478 result = pipe(
479 prompt,
480 max_length=kwargs.get("max_tokens", session.max_tokens),
481 temperature=kwargs.get("temperature", session.temperature),
482 top_p=kwargs.get("top_p", session.top_p),
483 )
485 generated_text = result[0]["generated_text"]
486 # Remove the prompt from the output
487 if generated_text.startswith(prompt):
488 generated_text = generated_text[len(prompt):]
490 return {
491 "choices": [{
492 "text": generated_text,
493 "index": 0,
494 "finish_reason": "stop"
495 }],
496 "model": session.model_name
497 }
499 except ImportError as e:
500 raise ResourceError(
501 "HuggingFace transformers library not installed. "
502 "Install with: pip install transformers torch",
503 resource_name=self.name,
504 operation="complete"
505 ) from e
507 def _openai_complete(
508 self,
509 session: LLMSession,
510 prompt: str,
511 **kwargs
512 ) -> Dict[str, Any]:
513 """OpenAI completion using provider system."""
514 from dataknobs_fsm.llm.base import LLMConfig, LLMMessage
515 from dataknobs_fsm.llm.providers import create_provider
517 # Create config from session
518 config = LLMConfig(
519 provider="openai",
520 model=session.model_name,
521 api_key=kwargs.get('api_key', os.getenv('OPENAI_API_KEY')),
522 temperature=kwargs.get('temperature', 0.7),
523 max_tokens=kwargs.get('max_tokens', 1000)
524 )
526 try:
527 # Create provider and execute
528 provider = create_provider(config, is_async=False)
529 provider.initialize()
531 # Convert prompt to message format
532 if isinstance(prompt, str):
533 messages = [LLMMessage(role="user", content=prompt)]
534 else:
535 messages = prompt # type: ignore[unreachable]
537 response = provider.complete(messages, **kwargs)
538 provider.close()
540 # Convert to expected format
541 return {
542 "choices": [{
543 "text": response.content,
544 "index": 0,
545 "finish_reason": response.finish_reason or "stop"
546 }],
547 "model": response.model,
548 "usage": response.usage
549 }
550 except Exception as e:
551 # Fallback to placeholder on error
552 return {
553 "choices": [{
554 "text": f"Error: {e!s}",
555 "index": 0,
556 "finish_reason": "error"
557 }],
558 "model": session.model_name
559 }
561 def _anthropic_complete(
562 self,
563 session: LLMSession,
564 prompt: str,
565 **kwargs
566 ) -> Dict[str, Any]:
567 """Anthropic completion using provider system."""
568 from dataknobs_fsm.llm.base import LLMConfig, LLMMessage
569 from dataknobs_fsm.llm.providers import create_provider
571 # Create config from session
572 config = LLMConfig(
573 provider="anthropic",
574 model=session.model_name,
575 api_key=kwargs.get('api_key', os.getenv('ANTHROPIC_API_KEY')),
576 temperature=kwargs.get('temperature', 0.7),
577 max_tokens=kwargs.get('max_tokens', 1000)
578 )
580 try:
581 # Create provider and execute
582 provider = create_provider(config, is_async=False)
583 provider.initialize()
585 # Convert prompt to message format
586 if isinstance(prompt, str):
587 messages = [LLMMessage(role="user", content=prompt)]
588 else:
589 messages = prompt # type: ignore[unreachable]
591 response = provider.complete(messages, **kwargs)
592 provider.close()
594 # Convert to expected format
595 return {
596 "choices": [{
597 "text": response.content,
598 "index": 0,
599 "finish_reason": response.finish_reason or "stop"
600 }],
601 "model": response.model,
602 "usage": response.usage
603 }
604 except Exception as e:
605 # Fallback to placeholder on error
606 return {
607 "choices": [{
608 "text": f"Error: {e!s}",
609 "index": 0,
610 "finish_reason": "error"
611 }],
612 "model": session.model_name
613 }
615 def _custom_complete(
616 self,
617 session: LLMSession,
618 prompt: str,
619 **kwargs
620 ) -> Dict[str, Any]:
621 """Custom provider completion.
623 For custom/unknown providers.
624 """
625 raise NotImplementedError(
626 f"Custom provider {session.provider.value} not implemented"
627 )
629 def embed(
630 self,
631 text: Union[str, List[str]],
632 session: LLMSession | None = None,
633 **kwargs
634 ) -> List[List[float]]:
635 """Generate embeddings for text.
637 Args:
638 text: Text or list of texts to embed.
639 session: Optional session to use.
640 **kwargs: Additional parameters.
642 Returns:
643 List of embedding vectors.
644 """
645 if session is None:
646 session = self.acquire()
647 should_release = True
648 else:
649 should_release = False
651 try:
652 if isinstance(text, str):
653 texts = [text]
654 else:
655 texts = text
657 # Route to appropriate provider
658 if session.provider == LLMProvider.OLLAMA:
659 embeddings = self._ollama_embed(session, texts, **kwargs)
660 elif session.provider == LLMProvider.HUGGINGFACE:
661 embeddings = self._huggingface_embed(session, texts, **kwargs)
662 elif session.provider == LLMProvider.OPENAI:
663 embeddings = self._openai_embed(session, texts, **kwargs)
664 else:
665 # Fallback to fake embeddings
666 embeddings = [[0.1] * 768 for _ in texts]
668 return embeddings
670 finally:
671 if should_release:
672 self.release(session)
674 def _ollama_embed(
675 self,
676 session: LLMSession,
677 texts: List[str],
678 **kwargs
679 ) -> List[List[float]]:
680 """Generate embeddings using Ollama.
682 Args:
683 session: LLM session.
684 texts: Texts to embed.
685 **kwargs: Additional parameters.
687 Returns:
688 List of embeddings.
689 """
690 import urllib.request
692 embeddings = []
693 for text in texts:
694 data = {
695 "model": kwargs.get("embed_model", "nomic-embed-text"),
696 "prompt": text
697 }
699 req = urllib.request.Request(
700 f"{session.endpoint}/api/embeddings",
701 data=json.dumps(data).encode("utf-8"),
702 headers={"Content-Type": "application/json"}
703 )
705 with urllib.request.urlopen(req) as response:
706 result = json.loads(response.read())
707 embeddings.append(result.get("embedding", []))
709 return embeddings
711 def _huggingface_embed(
712 self,
713 session: LLMSession,
714 texts: List[str],
715 **kwargs
716 ) -> List[List[float]]:
717 """Generate embeddings using HuggingFace.
719 Args:
720 session: LLM session.
721 texts: Texts to embed.
722 **kwargs: Additional parameters.
724 Returns:
725 List of embeddings.
726 """
727 try:
728 from transformers import AutoTokenizer, AutoModel
729 import torch
731 model_name = kwargs.get("embed_model", "sentence-transformers/all-MiniLM-L6-v2")
732 tokenizer = AutoTokenizer.from_pretrained(model_name)
733 model = AutoModel.from_pretrained(model_name)
735 embeddings = []
736 for text in texts:
737 inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
738 with torch.no_grad():
739 outputs = model(**inputs)
740 # Use mean pooling
741 embedding = outputs.last_hidden_state.mean(dim=1).squeeze().tolist()
742 embeddings.append(embedding)
744 return embeddings
746 except ImportError as e:
747 raise ResourceError(
748 "HuggingFace transformers library not installed",
749 resource_name=self.name,
750 operation="embed"
751 ) from e
753 def _openai_embed(
754 self,
755 session: LLMSession,
756 texts: List[str],
757 **kwargs
758 ) -> List[List[float]]:
759 """Generate embeddings using OpenAI provider system."""
760 from dataknobs_fsm.llm.base import LLMConfig
761 from dataknobs_fsm.llm.providers import create_provider
763 # Create config for embeddings
764 config = LLMConfig(
765 provider="openai",
766 model=kwargs.get('embed_model', 'text-embedding-ada-002'),
767 api_key=kwargs.get('api_key', os.getenv('OPENAI_API_KEY'))
768 )
770 try:
771 # Create provider and generate embeddings
772 provider = create_provider(config, is_async=False)
773 provider.initialize()
775 embeddings = provider.embed(texts, **kwargs)
776 provider.close()
778 # Ensure we return List[List[float]]
779 if isinstance(embeddings[0], list):
780 return embeddings
781 else:
782 return [embeddings] # Single text case
784 except Exception:
785 # Fallback to placeholder dimensions on error
786 return [[0.1] * 1536 for _ in texts] # OpenAI ada-002 dimension
788 def get_usage_stats(self, session: LLMSession) -> Dict[str, Any]:
789 """Get usage statistics for a session.
791 Args:
792 session: LLM session.
794 Returns:
795 Usage statistics.
796 """
797 stats = {
798 "provider": session.provider.value,
799 "model": session.model_name,
800 "total_requests": session.total_requests,
801 }
803 # Add token stats for providers that track them
804 if session.provider in [LLMProvider.OPENAI, LLMProvider.ANTHROPIC,
805 LLMProvider.OLLAMA]:
806 stats.update({
807 "total_prompt_tokens": session.total_prompt_tokens,
808 "total_completion_tokens": session.total_completion_tokens,
809 "total_tokens": session.total_prompt_tokens + session.total_completion_tokens,
810 })
812 # Add rate limit info for commercial providers
813 if session.provider in [LLMProvider.OPENAI, LLMProvider.ANTHROPIC]:
814 stats["rate_limits"] = {
815 "requests_per_minute": session.requests_per_minute,
816 "tokens_per_minute": session.tokens_per_minute,
817 "current_window": {
818 "requests": session.request_count,
819 "tokens": session.token_count,
820 "window_start": session.window_start
821 }
822 }
824 return stats