Coverage for src/dataknobs_fsm/llm/utils.py: 0%
260 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
1"""Utility functions for LLM operations.
3This module provides utility functions for working with LLMs.
4"""
6import re
7import json
8from typing import Any, Dict, List, Union
9from dataclasses import dataclass, field
11from .base import LLMMessage, LLMResponse
14def render_conditional_template(template: str, params: Dict[str, Any]) -> str:
15 """Render a template with variable substitution and conditional sections.
17 Variable substitution:
18 - {{variable}} syntax for placeholders
19 - Variables in params dict are replaced with their values
20 - Variables not in params are left unchanged ({{variable}} remains as-is)
21 - Whitespace handling: {{ var }} -> " value " when substituted, " {{var}} " when not
23 Conditional sections:
24 - ((optional content)) syntax for conditional blocks
25 - Section is removed if all {{variables}} inside are empty/None/missing
26 - Section is rendered (without parentheses) if any variable has a value
27 - Variables inside conditionals are replaced with empty strings if missing
28 - Nested conditionals are processed recursively
30 Example:
31 template = "Hello {{name}}((, you have {{count}} messages))"
32 params = {"name": "Alice", "count": 5}
33 result = "Hello Alice, you have 5 messages"
35 params = {"name": "Bob"} # no count
36 result = "Hello Bob" # conditional section removed
38 Args:
39 template: The template string
40 params: Dictionary of parameters to substitute
42 Returns:
43 The rendered template
44 """
45 def replace_variable(text: str, params: Dict[str, Any], in_conditional: bool = False) -> str:
46 """Replace variables in text with proper whitespace handling."""
47 # Pattern to match variables with optional whitespace
48 var_pattern = r'\{\{(\s*)(\w+)(\s*)\}\}'
50 def replace_var(match):
51 """Replace a single variable with whitespace handling."""
52 prefix_ws = match.group(1)
53 var_name = match.group(2)
54 suffix_ws = match.group(3)
56 if var_name not in params:
57 if in_conditional:
58 # In conditional sections, missing variables become empty
59 return ""
60 else:
61 # Outside conditionals, preserve the pattern but move whitespace outside
62 if prefix_ws or suffix_ws:
63 return f"{prefix_ws}{{{{{var_name}}}}}{suffix_ws}"
64 else:
65 return match.group(0)
67 value = params[var_name]
68 if value is None:
69 if in_conditional:
70 return ""
71 else:
72 # Move whitespace outside for None values
73 if prefix_ws or suffix_ws:
74 return f"{prefix_ws}{{{{{var_name}}}}}{suffix_ws}"
75 else:
76 return ""
77 else:
78 # Preserve whitespace when substituting
79 return f"{prefix_ws}{value!s}{suffix_ws}"
81 return re.sub(var_pattern, replace_var, text)
83 def find_all_variables(text: str) -> set:
84 """Find all variables in text, including nested conditionals."""
85 var_pattern = r'\{\{(\s*)(\w+)(\s*)\}\}'
86 variables = set()
87 for match in re.finditer(var_pattern, text):
88 variables.add(match.group(2))
89 return variables
91 def process_conditionals(text: str, params: Dict[str, Any]) -> str:
92 """Process conditional sections recursively."""
93 result = text
94 changed = True
96 while changed:
97 changed = False
98 # Find the first (( ... )) section
99 start_pos = 0
100 while True:
101 start = result.find('((', start_pos)
102 if start == -1:
103 break
105 # Find matching )) - must track ALL parens for correct nesting
106 depth = 1
107 paren_depth = 0 # Track single parentheses
108 end = start + 2
109 while end < len(result) and depth > 0:
110 if result[end:end+2] == '((':
111 depth += 1
112 end += 2
113 elif result[end:end+2] == '))':
114 # Only count as )) if we're not inside single parens
115 if paren_depth == 0:
116 depth -= 1
117 end += 2
118 else:
119 # This is ) followed by another )
120 paren_depth -= 1
121 end += 1
122 elif result[end] == '(':
123 paren_depth += 1
124 end += 1
125 elif result[end] == ')':
126 paren_depth -= 1
127 end += 1
128 else:
129 end += 1
131 if depth == 0:
132 # Found a complete section
133 content = result[start+2:end-2]
135 # Find ALL variables in this section (including nested)
136 all_vars = find_all_variables(content)
138 if all_vars:
139 # Check if all variables are empty/missing
140 has_value = False
141 for var_name in all_vars:
142 if var_name in params:
143 value = params[var_name]
144 if value is not None:
145 if isinstance(value, str):
146 # For strings, check if non-empty after stripping
147 if value.strip():
148 has_value = True
149 break
150 else:
151 # For non-strings, any truthy value counts
152 if value:
153 has_value = True
154 break
156 if not has_value:
157 # Remove the entire section - all variables are empty/missing
158 result = result[:start] + result[end:]
159 else:
160 # At least one variable has a value, process nested conditionals
161 processed_content = process_conditionals(content, params)
162 # Then substitute variables in the processed content
163 rendered = replace_variable(processed_content, params, in_conditional=True)
164 result = result[:start] + rendered + result[end:]
165 else:
166 # No variables in this section, keep the content as-is
167 # But still process any nested conditionals
168 processed_content = process_conditionals(content, params)
169 result = result[:start] + processed_content + result[end:]
171 changed = True
172 break
173 else:
174 # Unmatched parentheses, leave as-is and move on
175 start_pos = start + 1
177 return result
179 # First process all conditional sections
180 result = process_conditionals(template, params)
182 # Then handle remaining variables outside of conditional sections
183 result = replace_variable(result, params, in_conditional=False)
185 return result
188@dataclass
189class PromptTemplate:
190 """Template for generating prompts."""
191 template: str
192 variables: List[str] = field(default_factory=list)
194 def __post_init__(self):
195 """Extract variables from template."""
196 if not self.variables:
197 # Extract {variable} patterns
198 self.variables = re.findall(r'\{(\w+)\}', self.template)
200 def format(self, **kwargs) -> str:
201 """Format template with variables.
203 Args:
204 **kwargs: Variable values
206 Returns:
207 Formatted prompt
208 """
209 # Check all required variables are provided
210 missing = set(self.variables) - set(kwargs.keys())
211 if missing:
212 raise ValueError(f"Missing variables: {missing}")
214 return self.template.format(**kwargs)
216 def partial(self, **kwargs) -> 'PromptTemplate':
217 """Create partial template with some variables filled.
219 Args:
220 **kwargs: Variable values to fill
222 Returns:
223 New template with partial values
224 """
225 new_template = self.template
226 new_variables = self.variables.copy()
228 for key, value in kwargs.items():
229 if key in new_variables:
230 new_template = new_template.replace(f'{{{key}}}', str(value))
231 new_variables.remove(key)
233 return PromptTemplate(new_template, new_variables)
236class MessageBuilder:
237 """Builder for constructing message sequences."""
239 def __init__(self):
240 self.messages = []
242 def system(self, content: str) -> 'MessageBuilder':
243 """Add system message.
245 Args:
246 content: Message content
248 Returns:
249 Self for chaining
250 """
251 self.messages.append(LLMMessage(role='system', content=content))
252 return self
254 def user(self, content: str) -> 'MessageBuilder':
255 """Add user message.
257 Args:
258 content: Message content
260 Returns:
261 Self for chaining
262 """
263 self.messages.append(LLMMessage(role='user', content=content))
264 return self
266 def assistant(self, content: str) -> 'MessageBuilder':
267 """Add assistant message.
269 Args:
270 content: Message content
272 Returns:
273 Self for chaining
274 """
275 self.messages.append(LLMMessage(role='assistant', content=content))
276 return self
278 def function(
279 self,
280 name: str,
281 content: str,
282 function_call: Dict[str, Any] | None = None
283 ) -> 'MessageBuilder':
284 """Add function message.
286 Args:
287 name: Function name
288 content: Function result
289 function_call: Function call details
291 Returns:
292 Self for chaining
293 """
294 self.messages.append(LLMMessage(
295 role='function',
296 name=name,
297 content=content,
298 function_call=function_call
299 ))
300 return self
302 def from_template(
303 self,
304 role: str,
305 template: PromptTemplate,
306 **kwargs
307 ) -> 'MessageBuilder':
308 """Add message from template.
310 Args:
311 role: Message role
312 template: Prompt template
313 **kwargs: Template variables
315 Returns:
316 Self for chaining
317 """
318 content = template.format(**kwargs)
319 self.messages.append(LLMMessage(role=role, content=content))
320 return self
322 def build(self) -> List[LLMMessage]:
323 """Build message list.
325 Returns:
326 List of messages
327 """
328 return self.messages.copy()
330 def clear(self) -> 'MessageBuilder':
331 """Clear all messages.
333 Returns:
334 Self for chaining
335 """
336 self.messages.clear()
337 return self
340class ResponseParser:
341 """Parser for LLM responses."""
343 @staticmethod
344 def extract_json(response: Union[str, LLMResponse]) -> Dict[str, Any] | None:
345 """Extract JSON from response.
347 Args:
348 response: LLM response
350 Returns:
351 Extracted JSON or None
352 """
353 text = response.content if isinstance(response, LLMResponse) else response
355 # Try to find JSON in the text
356 json_patterns = [
357 r'\{[^}]*\}', # Simple object
358 r'\[[^\]]*\]', # Array
359 r'```json\s*(.*?)\s*```', # Markdown code block
360 r'```\s*(.*?)\s*```', # Generic code block
361 ]
363 for pattern in json_patterns:
364 matches = re.findall(pattern, text, re.DOTALL)
365 for match in matches:
366 try:
367 return json.loads(match)
368 except json.JSONDecodeError:
369 continue
371 # Try parsing the entire text as JSON
372 try:
373 return json.loads(text)
374 except json.JSONDecodeError:
375 return None
377 @staticmethod
378 def extract_code(
379 response: Union[str, LLMResponse],
380 language: str | None = None
381 ) -> List[str]:
382 """Extract code blocks from response.
384 Args:
385 response: LLM response
386 language: Optional language filter
388 Returns:
389 List of code blocks
390 """
391 text = response.content if isinstance(response, LLMResponse) else response
393 if language:
394 # Language-specific code blocks
395 pattern = rf'```{language}\s*(.*?)\s*```'
396 else:
397 # All code blocks
398 pattern = r'```(?:\w+)?\s*(.*?)\s*```'
400 matches = re.findall(pattern, text, re.DOTALL)
401 return [m.strip() for m in matches]
403 @staticmethod
404 def extract_list(
405 response: Union[str, LLMResponse],
406 numbered: bool = False
407 ) -> List[str]:
408 """Extract list items from response.
410 Args:
411 response: LLM response
412 numbered: Whether to look for numbered lists
414 Returns:
415 List of items
416 """
417 text = response.content if isinstance(response, LLMResponse) else response
419 if numbered:
420 # Numbered list (1. item, 2. item, etc.)
421 pattern = r'^\d+\.\s+(.+)$'
422 else:
423 # Bullet points (-, *, •)
424 pattern = r'^[-*•]\s+(.+)$'
426 matches = re.findall(pattern, text, re.MULTILINE)
427 return [m.strip() for m in matches]
429 @staticmethod
430 def extract_sections(
431 response: Union[str, LLMResponse]
432 ) -> Dict[str, str]:
433 """Extract sections from response.
435 Args:
436 response: LLM response
438 Returns:
439 Dictionary of section name to content
440 """
441 text = response.content if isinstance(response, LLMResponse) else response
443 # Split by headers (# Header, ## Header, etc.)
444 sections = {}
445 current_section = 'main'
446 current_content = []
448 for line in text.split('\n'):
449 header_match = re.match(r'^#+\s+(.+)$', line)
450 if header_match:
451 # Save previous section
452 if current_content:
453 sections[current_section] = '\n'.join(current_content).strip()
454 # Start new section
455 current_section = header_match.group(1).strip()
456 current_content = []
457 else:
458 current_content.append(line)
460 # Save last section
461 if current_content:
462 sections[current_section] = '\n'.join(current_content).strip()
464 return sections
467class TokenCounter:
468 """Estimate token counts for different models."""
470 # Approximate tokens per character for different models
471 TOKENS_PER_CHAR = {
472 'gpt-4': 0.25,
473 'gpt-3.5': 0.25,
474 'claude': 0.25,
475 'llama': 0.3,
476 'default': 0.25
477 }
479 @classmethod
480 def estimate_tokens(
481 cls,
482 text: str,
483 model: str = 'default'
484 ) -> int:
485 """Estimate token count for text.
487 Args:
488 text: Input text
489 model: Model name
491 Returns:
492 Estimated token count
493 """
494 # Find matching model pattern
495 ratio = cls.TOKENS_PER_CHAR['default']
496 for pattern, r in cls.TOKENS_PER_CHAR.items():
497 if pattern in model.lower():
498 ratio = r
499 break
501 # Estimate based on character count
502 return int(len(text) * ratio)
504 @classmethod
505 def estimate_messages_tokens(
506 cls,
507 messages: List[LLMMessage],
508 model: str = 'default'
509 ) -> int:
510 """Estimate token count for messages.
512 Args:
513 messages: List of messages
514 model: Model name
516 Returns:
517 Estimated token count
518 """
519 total = 0
520 for msg in messages:
521 # Add role tokens (approximately 4 tokens)
522 total += 4
523 # Add content tokens
524 total += cls.estimate_tokens(msg.content, model)
525 # Add name tokens if present
526 if msg.name:
527 total += cls.estimate_tokens(msg.name, model)
529 return total
531 @classmethod
532 def fits_in_context(
533 cls,
534 text: str,
535 model: str,
536 max_tokens: int
537 ) -> bool:
538 """Check if text fits in context window.
540 Args:
541 text: Input text
542 model: Model name
543 max_tokens: Maximum token limit
545 Returns:
546 True if fits
547 """
548 estimated = cls.estimate_tokens(text, model)
549 return estimated <= max_tokens
552class CostCalculator:
553 """Calculate costs for LLM usage."""
555 # Cost per 1K tokens (in USD)
556 PRICING = {
557 'gpt-4': {'input': 0.03, 'output': 0.06},
558 'gpt-4-32k': {'input': 0.06, 'output': 0.12},
559 'gpt-3.5-turbo': {'input': 0.0015, 'output': 0.002},
560 'claude-3-opus': {'input': 0.015, 'output': 0.075},
561 'claude-3-sonnet': {'input': 0.003, 'output': 0.015},
562 'claude-3-haiku': {'input': 0.00025, 'output': 0.00125},
563 }
565 @classmethod
566 def calculate_cost(
567 cls,
568 response: LLMResponse,
569 model: str | None = None
570 ) -> float | None:
571 """Calculate cost for LLM response.
573 Args:
574 response: LLM response with usage info
575 model: Model name (if not in response)
577 Returns:
578 Cost in USD or None if cannot calculate
579 """
580 if not response.usage:
581 return None
583 model = model or response.model
585 # Find matching pricing
586 pricing = None
587 for pattern, prices in cls.PRICING.items():
588 if pattern in model.lower():
589 pricing = prices
590 break
592 if not pricing:
593 return None
595 # Calculate cost
596 input_cost = (response.usage.get('prompt_tokens', 0) / 1000) * pricing['input']
597 output_cost = (response.usage.get('completion_tokens', 0) / 1000) * pricing['output']
599 return input_cost + output_cost
601 @classmethod
602 def estimate_cost(
603 cls,
604 text: str,
605 model: str,
606 expected_output_tokens: int = 100
607 ) -> float | None:
608 """Estimate cost for text completion.
610 Args:
611 text: Input text
612 model: Model name
613 expected_output_tokens: Expected output length
615 Returns:
616 Estimated cost in USD
617 """
618 # Find matching pricing
619 pricing = None
620 for pattern, prices in cls.PRICING.items():
621 if pattern in model.lower():
622 pricing = prices
623 break
625 if not pricing:
626 return None
628 # Estimate tokens
629 input_tokens = TokenCounter.estimate_tokens(text, model)
631 # Calculate cost
632 input_cost = (input_tokens / 1000) * pricing['input']
633 output_cost = (expected_output_tokens / 1000) * pricing['output']
635 return input_cost + output_cost
638def chain_prompts(
639 *templates: PromptTemplate
640) -> PromptTemplate:
641 """Chain multiple prompt templates.
643 Args:
644 *templates: Templates to chain
646 Returns:
647 Combined template
648 """
649 combined_template = '\n\n'.join(t.template for t in templates)
650 combined_variables = []
651 seen = set()
653 for t in templates:
654 for var in t.variables:
655 if var not in seen:
656 combined_variables.append(var)
657 seen.add(var)
659 return PromptTemplate(combined_template, combined_variables)
662def create_few_shot_prompt(
663 instruction: str,
664 examples: List[Dict[str, str]],
665 query_key: str = 'input',
666 response_key: str = 'output'
667) -> PromptTemplate:
668 """Create few-shot learning prompt.
670 Args:
671 instruction: Task instruction
672 examples: List of example input/output pairs
673 query_key: Key for input in examples
674 response_key: Key for output in examples
676 Returns:
677 Few-shot prompt template
678 """
679 template_parts = [instruction, '']
681 # Add examples
682 for i, example in enumerate(examples, 1):
683 template_parts.append(f"Example {i}:")
684 template_parts.append(f"Input: {example[query_key]}")
685 template_parts.append(f"Output: {example[response_key]}")
686 template_parts.append('')
688 # Add query placeholder
689 template_parts.append("Now, process this input:")
690 template_parts.append("Input: {query}")
691 template_parts.append("Output:")
693 return PromptTemplate('\n'.join(template_parts), ['query'])