Coverage for src/dataknobs_fsm/functions/library/llm.py: 0%
180 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
1"""Built-in LLM functions for FSM.
3This module provides LLM-related functions that can be referenced
4in FSM configurations for AI-powered workflows.
5"""
7import asyncio
8import json
9from typing import Any, Callable, Dict, List
11from dataknobs_fsm.functions.base import (
12 ITransformFunction,
13 IValidationFunction,
14 TransformFunctionError,
15 ValidationFunctionError,
16)
17from dataknobs_fsm.resources.llm import LLMResource
20class PromptBuilder(ITransformFunction):
21 """Build prompts for LLM calls."""
23 def __init__(
24 self,
25 template: str,
26 system_prompt: str | None = None,
27 variables: List[str] | None = None,
28 format_spec: str | None = None, # "json", "markdown", "plain"
29 ):
30 """Initialize the prompt builder.
32 Args:
33 template: Prompt template with {variable} placeholders.
34 system_prompt: Optional system prompt.
35 variables: List of variable names to extract from data.
36 format_spec: Output format specification.
37 """
38 self.template = template
39 self.system_prompt = system_prompt
40 self.variables = variables or []
41 self.format_spec = format_spec
43 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
44 """Transform data by building prompt.
46 Args:
47 data: Input data containing variables for prompt.
49 Returns:
50 Data with built prompt.
51 """
52 # Extract variables
53 variables = {}
54 for var in self.variables:
55 if var in data:
56 variables[var] = data[var]
57 else:
58 # Try nested access
59 parts = var.split(".")
60 value = data
61 for part in parts:
62 if isinstance(value, dict) and part in value:
63 value = value[part]
64 else:
65 value = None
66 break
67 if value is not None:
68 variables[var] = value
70 # Build prompt
71 try:
72 prompt = self.template.format(**variables)
73 except KeyError as e:
74 raise TransformFunctionError(f"Missing variable for prompt: {e}") from e
76 # Add format specification if provided
77 if self.format_spec:
78 if self.format_spec == "json":
79 prompt += "\n\nPlease respond with valid JSON only."
80 elif self.format_spec == "markdown":
81 prompt += "\n\nPlease format your response using Markdown."
83 result = {
84 **data,
85 "prompt": prompt,
86 }
88 if self.system_prompt:
89 result["system_prompt"] = self.system_prompt
91 return result
94class LLMCaller(ITransformFunction):
95 """Call an LLM with a prompt."""
97 def __init__(
98 self,
99 resource_name: str,
100 model: str | None = None,
101 temperature: float = 0.7,
102 max_tokens: int | None = None,
103 stream: bool = False,
104 response_field: str = "llm_response",
105 ):
106 """Initialize the LLM caller.
108 Args:
109 resource_name: Name of the LLM resource to use.
110 model: Model to use (if None, use resource default).
111 temperature: Temperature for generation.
112 max_tokens: Maximum tokens to generate.
113 stream: Whether to stream the response.
114 response_field: Field to store response in.
115 """
116 self.resource_name = resource_name
117 self.model = model
118 self.temperature = temperature
119 self.max_tokens = max_tokens
120 self.stream = stream
121 self.response_field = response_field
123 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
124 """Transform data by calling LLM.
126 Args:
127 data: Input data containing prompt.
129 Returns:
130 Data with LLM response.
131 """
132 # Get resource from context
133 resource = data.get("_resources", {}).get(self.resource_name)
134 if not resource or not isinstance(resource, LLMResource):
135 raise TransformFunctionError(f"LLM resource '{self.resource_name}' not found")
137 # Get prompt
138 prompt = data.get("prompt")
139 if not prompt:
140 raise TransformFunctionError("No prompt found in data")
142 system_prompt = data.get("system_prompt")
144 try:
145 # Call LLM
146 response = await resource.generate(
147 prompt=prompt,
148 system_prompt=system_prompt,
149 model=self.model,
150 temperature=self.temperature,
151 max_tokens=self.max_tokens,
152 stream=self.stream,
153 )
155 if self.stream:
156 # For streaming, return an async generator
157 return {
158 **data,
159 self.response_field: response, # Async generator
160 "is_streaming": True,
161 }
162 else:
163 # For non-streaming, return the full response
164 return {
165 **data,
166 self.response_field: response,
167 "tokens_used": response.get("usage", {}).get("total_tokens"),
168 }
170 except Exception as e:
171 raise TransformFunctionError(f"LLM call failed: {e}") from e
174class ResponseValidator(IValidationFunction):
175 """Validate LLM responses."""
177 def __init__(
178 self,
179 response_field: str = "llm_response",
180 format: str | None = None, # "json", "markdown", etc.
181 schema: Dict[str, Any] | None = None,
182 min_length: int | None = None,
183 max_length: int | None = None,
184 required_fields: List[str] | None = None,
185 ):
186 """Initialize the response validator.
188 Args:
189 response_field: Field containing LLM response.
190 format: Expected response format.
191 schema: JSON schema for validation (if format is JSON).
192 min_length: Minimum response length.
193 max_length: Maximum response length.
194 required_fields: Required fields in parsed response.
195 """
196 self.response_field = response_field
197 self.format = format
198 self.schema = schema
199 self.min_length = min_length
200 self.max_length = max_length
201 self.required_fields = required_fields or []
203 def validate(self, data: Dict[str, Any]) -> bool:
204 """Validate LLM response.
206 Args:
207 data: Data containing LLM response.
209 Returns:
210 True if valid.
212 Raises:
213 ValidationFunctionError: If validation fails.
214 """
215 response = data.get(self.response_field)
216 if response is None:
217 raise ValidationFunctionError(f"Response field '{self.response_field}' not found")
219 # Extract text from response object if needed
220 if isinstance(response, dict):
221 text = response.get("text", response.get("content", str(response)))
222 else:
223 text = str(response)
225 # Check length constraints
226 if self.min_length and len(text) < self.min_length: # type: ignore
227 raise ValidationFunctionError(
228 f"Response too short: {len(text)} < {self.min_length}" # type: ignore
229 )
231 if self.max_length and len(text) > self.max_length: # type: ignore
232 raise ValidationFunctionError(
233 f"Response too long: {len(text)} > {self.max_length}" # type: ignore
234 )
236 # Validate format
237 if self.format == "json":
238 try:
239 parsed = json.loads(text) # type: ignore
241 # Validate against schema if provided
242 if self.schema:
243 from pydantic import create_model, ValidationError
244 model = create_model("ResponseSchema", **self.schema)
245 try:
246 model(**parsed)
247 except ValidationError as e:
248 raise ValidationFunctionError(f"Schema validation failed: {e}") from e
250 # Check required fields
251 for field in self.required_fields:
252 if field not in parsed:
253 raise ValidationFunctionError(f"Required field missing: {field}")
255 except json.JSONDecodeError as e:
256 raise ValidationFunctionError(f"Invalid JSON response: {e}") from e
258 return True
261class FunctionCaller(ITransformFunction):
262 """Call functions/tools based on LLM output."""
264 def __init__(
265 self,
266 response_field: str = "llm_response",
267 function_registry: Dict[str, Callable] | None = None,
268 result_field: str = "function_result",
269 ):
270 """Initialize the function caller.
272 Args:
273 response_field: Field containing LLM response with function call.
274 function_registry: Registry of available functions.
275 result_field: Field to store function result.
276 """
277 self.response_field = response_field
278 self.function_registry = function_registry or {}
279 self.result_field = result_field
281 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
282 """Transform data by calling function from LLM response.
284 Args:
285 data: Input data containing LLM response with function call.
287 Returns:
288 Data with function result.
289 """
290 response = data.get(self.response_field)
291 if not response:
292 return data
294 # Parse function call from response
295 if isinstance(response, str):
296 try:
297 response = json.loads(response)
298 except json.JSONDecodeError:
299 # Not a JSON response, no function to call
300 return data
302 # Extract function call
303 function_name = response.get("function")
304 function_args = response.get("arguments", {})
306 if not function_name:
307 return data
309 # Look up function
310 if function_name not in self.function_registry:
311 raise TransformFunctionError(f"Function not found: {function_name}")
313 func = self.function_registry[function_name]
315 try:
316 # Call function
317 if asyncio.iscoroutinefunction(func):
318 result = await func(**function_args)
319 else:
320 result = func(**function_args)
322 return {
323 **data,
324 self.result_field: result,
325 "function_called": function_name,
326 }
328 except Exception as e:
329 raise TransformFunctionError(f"Function call failed: {e}") from e
332class ConversationManager(ITransformFunction):
333 """Manage conversation history for multi-turn interactions."""
335 def __init__(
336 self,
337 max_history: int = 10,
338 history_field: str = "conversation_history",
339 role_field: str = "role",
340 content_field: str = "content",
341 ):
342 """Initialize the conversation manager.
344 Args:
345 max_history: Maximum number of messages to keep.
346 history_field: Field to store conversation history.
347 role_field: Field for message role.
348 content_field: Field for message content.
349 """
350 self.max_history = max_history
351 self.history_field = history_field
352 self.role_field = role_field
353 self.content_field = content_field
355 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
356 """Transform data by managing conversation history.
358 Args:
359 data: Input data with new message.
361 Returns:
362 Data with updated conversation history.
363 """
364 # Get existing history
365 history = data.get(self.history_field, [])
367 # Add user message if present
368 if "prompt" in data:
369 history.append({
370 self.role_field: "user",
371 self.content_field: data["prompt"],
372 })
374 # Add assistant message if present
375 if "llm_response" in data:
376 response = data["llm_response"]
377 if isinstance(response, dict):
378 content = response.get("text", response.get("content", str(response)))
379 else:
380 content = str(response)
382 history.append({
383 self.role_field: "assistant",
384 self.content_field: content,
385 })
387 # Trim history if needed
388 if len(history) > self.max_history:
389 history = history[-self.max_history:]
391 return {
392 **data,
393 self.history_field: history,
394 }
397class EmbeddingGenerator(ITransformFunction):
398 """Generate embeddings for text using LLM."""
400 def __init__(
401 self,
402 resource_name: str,
403 text_field: str = "text",
404 embedding_field: str = "embedding",
405 model: str | None = None,
406 batch_size: int = 100,
407 ):
408 """Initialize the embedding generator.
410 Args:
411 resource_name: Name of the LLM resource to use.
412 text_field: Field containing text to embed.
413 embedding_field: Field to store embeddings.
414 model: Embedding model to use.
415 batch_size: Batch size for embedding generation.
416 """
417 self.resource_name = resource_name
418 self.text_field = text_field
419 self.embedding_field = embedding_field
420 self.model = model
421 self.batch_size = batch_size
423 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
424 """Transform data by generating embeddings.
426 Args:
427 data: Input data containing text.
429 Returns:
430 Data with embeddings.
431 """
432 # Get resource from context
433 resource = data.get("_resources", {}).get(self.resource_name)
434 if not resource or not isinstance(resource, LLMResource):
435 raise TransformFunctionError(f"LLM resource '{self.resource_name}' not found")
437 # Get text to embed
438 text = data.get(self.text_field)
439 if not text:
440 return data
442 try:
443 # Generate embedding(s)
444 if isinstance(text, list):
445 # Batch processing
446 embeddings = []
447 for i in range(0, len(text), self.batch_size):
448 batch = text[i:i + self.batch_size]
449 batch_embeddings = await resource.embed(batch, model=self.model)
450 embeddings.extend(batch_embeddings)
451 else:
452 # Single text
453 embeddings = await resource.embed(text, model=self.model)
455 return {
456 **data,
457 self.embedding_field: embeddings,
458 }
460 except Exception as e:
461 raise TransformFunctionError(f"Embedding generation failed: {e}") from e
464# Convenience functions for creating LLM functions
465def build_prompt(template: str, **kwargs) -> PromptBuilder:
466 """Create a PromptBuilder."""
467 return PromptBuilder(template, **kwargs)
470def call_llm(resource: str, **kwargs) -> LLMCaller:
471 """Create an LLMCaller."""
472 return LLMCaller(resource, **kwargs)
475def validate_response(**kwargs) -> ResponseValidator:
476 """Create a ResponseValidator."""
477 return ResponseValidator(**kwargs)
480def call_function(**kwargs) -> FunctionCaller:
481 """Create a FunctionCaller."""
482 return FunctionCaller(**kwargs)
485def manage_conversation(**kwargs) -> ConversationManager:
486 """Create a ConversationManager."""
487 return ConversationManager(**kwargs)
490def generate_embeddings(resource: str, **kwargs) -> EmbeddingGenerator:
491 """Create an EmbeddingGenerator."""
492 return EmbeddingGenerator(resource, **kwargs)