Coverage for src/dataknobs_fsm/llm/utils.py: 0%

260 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""Utility functions for LLM operations. 

2 

3This module provides utility functions for working with LLMs. 

4""" 

5 

6import re 

7import json 

8from typing import Any, Dict, List, Union 

9from dataclasses import dataclass, field 

10 

11from .base import LLMMessage, LLMResponse 

12 

13 

14def render_conditional_template(template: str, params: Dict[str, Any]) -> str: 

15 """Render a template with variable substitution and conditional sections. 

16 

17 Variable substitution: 

18 - {{variable}} syntax for placeholders 

19 - Variables in params dict are replaced with their values 

20 - Variables not in params are left unchanged ({{variable}} remains as-is) 

21 - Whitespace handling: {{ var }} -> " value " when substituted, " {{var}} " when not 

22 

23 Conditional sections: 

24 - ((optional content)) syntax for conditional blocks 

25 - Section is removed if all {{variables}} inside are empty/None/missing 

26 - Section is rendered (without parentheses) if any variable has a value 

27 - Variables inside conditionals are replaced with empty strings if missing 

28 - Nested conditionals are processed recursively 

29 

30 Example: 

31 template = "Hello {{name}}((, you have {{count}} messages))" 

32 params = {"name": "Alice", "count": 5} 

33 result = "Hello Alice, you have 5 messages" 

34 

35 params = {"name": "Bob"} # no count 

36 result = "Hello Bob" # conditional section removed 

37 

38 Args: 

39 template: The template string 

40 params: Dictionary of parameters to substitute 

41 

42 Returns: 

43 The rendered template 

44 """ 

45 def replace_variable(text: str, params: Dict[str, Any], in_conditional: bool = False) -> str: 

46 """Replace variables in text with proper whitespace handling.""" 

47 # Pattern to match variables with optional whitespace 

48 var_pattern = r'\{\{(\s*)(\w+)(\s*)\}\}' 

49 

50 def replace_var(match): 

51 """Replace a single variable with whitespace handling.""" 

52 prefix_ws = match.group(1) 

53 var_name = match.group(2) 

54 suffix_ws = match.group(3) 

55 

56 if var_name not in params: 

57 if in_conditional: 

58 # In conditional sections, missing variables become empty 

59 return "" 

60 else: 

61 # Outside conditionals, preserve the pattern but move whitespace outside 

62 if prefix_ws or suffix_ws: 

63 return f"{prefix_ws}{{{{{var_name}}}}}{suffix_ws}" 

64 else: 

65 return match.group(0) 

66 

67 value = params[var_name] 

68 if value is None: 

69 if in_conditional: 

70 return "" 

71 else: 

72 # Move whitespace outside for None values 

73 if prefix_ws or suffix_ws: 

74 return f"{prefix_ws}{{{{{var_name}}}}}{suffix_ws}" 

75 else: 

76 return "" 

77 else: 

78 # Preserve whitespace when substituting 

79 return f"{prefix_ws}{value!s}{suffix_ws}" 

80 

81 return re.sub(var_pattern, replace_var, text) 

82 

83 def find_all_variables(text: str) -> set: 

84 """Find all variables in text, including nested conditionals.""" 

85 var_pattern = r'\{\{(\s*)(\w+)(\s*)\}\}' 

86 variables = set() 

87 for match in re.finditer(var_pattern, text): 

88 variables.add(match.group(2)) 

89 return variables 

90 

91 def process_conditionals(text: str, params: Dict[str, Any]) -> str: 

92 """Process conditional sections recursively.""" 

93 result = text 

94 changed = True 

95 

96 while changed: 

97 changed = False 

98 # Find the first (( ... )) section 

99 start_pos = 0 

100 while True: 

101 start = result.find('((', start_pos) 

102 if start == -1: 

103 break 

104 

105 # Find matching )) - must track ALL parens for correct nesting 

106 depth = 1 

107 paren_depth = 0 # Track single parentheses 

108 end = start + 2 

109 while end < len(result) and depth > 0: 

110 if result[end:end+2] == '((': 

111 depth += 1 

112 end += 2 

113 elif result[end:end+2] == '))': 

114 # Only count as )) if we're not inside single parens 

115 if paren_depth == 0: 

116 depth -= 1 

117 end += 2 

118 else: 

119 # This is ) followed by another ) 

120 paren_depth -= 1 

121 end += 1 

122 elif result[end] == '(': 

123 paren_depth += 1 

124 end += 1 

125 elif result[end] == ')': 

126 paren_depth -= 1 

127 end += 1 

128 else: 

129 end += 1 

130 

131 if depth == 0: 

132 # Found a complete section 

133 content = result[start+2:end-2] 

134 

135 # Find ALL variables in this section (including nested) 

136 all_vars = find_all_variables(content) 

137 

138 if all_vars: 

139 # Check if all variables are empty/missing 

140 has_value = False 

141 for var_name in all_vars: 

142 if var_name in params: 

143 value = params[var_name] 

144 if value is not None: 

145 if isinstance(value, str): 

146 # For strings, check if non-empty after stripping 

147 if value.strip(): 

148 has_value = True 

149 break 

150 else: 

151 # For non-strings, any truthy value counts 

152 if value: 

153 has_value = True 

154 break 

155 

156 if not has_value: 

157 # Remove the entire section - all variables are empty/missing 

158 result = result[:start] + result[end:] 

159 else: 

160 # At least one variable has a value, process nested conditionals 

161 processed_content = process_conditionals(content, params) 

162 # Then substitute variables in the processed content 

163 rendered = replace_variable(processed_content, params, in_conditional=True) 

164 result = result[:start] + rendered + result[end:] 

165 else: 

166 # No variables in this section, keep the content as-is 

167 # But still process any nested conditionals 

168 processed_content = process_conditionals(content, params) 

169 result = result[:start] + processed_content + result[end:] 

170 

171 changed = True 

172 break 

173 else: 

174 # Unmatched parentheses, leave as-is and move on 

175 start_pos = start + 1 

176 

177 return result 

178 

179 # First process all conditional sections 

180 result = process_conditionals(template, params) 

181 

182 # Then handle remaining variables outside of conditional sections 

183 result = replace_variable(result, params, in_conditional=False) 

184 

185 return result 

186 

187 

188@dataclass 

189class PromptTemplate: 

190 """Template for generating prompts.""" 

191 template: str 

192 variables: List[str] = field(default_factory=list) 

193 

194 def __post_init__(self): 

195 """Extract variables from template.""" 

196 if not self.variables: 

197 # Extract {variable} patterns 

198 self.variables = re.findall(r'\{(\w+)\}', self.template) 

199 

200 def format(self, **kwargs) -> str: 

201 """Format template with variables. 

202  

203 Args: 

204 **kwargs: Variable values 

205  

206 Returns: 

207 Formatted prompt 

208 """ 

209 # Check all required variables are provided 

210 missing = set(self.variables) - set(kwargs.keys()) 

211 if missing: 

212 raise ValueError(f"Missing variables: {missing}") 

213 

214 return self.template.format(**kwargs) 

215 

216 def partial(self, **kwargs) -> 'PromptTemplate': 

217 """Create partial template with some variables filled. 

218  

219 Args: 

220 **kwargs: Variable values to fill 

221  

222 Returns: 

223 New template with partial values 

224 """ 

225 new_template = self.template 

226 new_variables = self.variables.copy() 

227 

228 for key, value in kwargs.items(): 

229 if key in new_variables: 

230 new_template = new_template.replace(f'{{{key}}}', str(value)) 

231 new_variables.remove(key) 

232 

233 return PromptTemplate(new_template, new_variables) 

234 

235 

236class MessageBuilder: 

237 """Builder for constructing message sequences.""" 

238 

239 def __init__(self): 

240 self.messages = [] 

241 

242 def system(self, content: str) -> 'MessageBuilder': 

243 """Add system message. 

244  

245 Args: 

246 content: Message content 

247  

248 Returns: 

249 Self for chaining 

250 """ 

251 self.messages.append(LLMMessage(role='system', content=content)) 

252 return self 

253 

254 def user(self, content: str) -> 'MessageBuilder': 

255 """Add user message. 

256  

257 Args: 

258 content: Message content 

259  

260 Returns: 

261 Self for chaining 

262 """ 

263 self.messages.append(LLMMessage(role='user', content=content)) 

264 return self 

265 

266 def assistant(self, content: str) -> 'MessageBuilder': 

267 """Add assistant message. 

268  

269 Args: 

270 content: Message content 

271  

272 Returns: 

273 Self for chaining 

274 """ 

275 self.messages.append(LLMMessage(role='assistant', content=content)) 

276 return self 

277 

278 def function( 

279 self, 

280 name: str, 

281 content: str, 

282 function_call: Dict[str, Any] | None = None 

283 ) -> 'MessageBuilder': 

284 """Add function message. 

285  

286 Args: 

287 name: Function name 

288 content: Function result 

289 function_call: Function call details 

290  

291 Returns: 

292 Self for chaining 

293 """ 

294 self.messages.append(LLMMessage( 

295 role='function', 

296 name=name, 

297 content=content, 

298 function_call=function_call 

299 )) 

300 return self 

301 

302 def from_template( 

303 self, 

304 role: str, 

305 template: PromptTemplate, 

306 **kwargs 

307 ) -> 'MessageBuilder': 

308 """Add message from template. 

309  

310 Args: 

311 role: Message role 

312 template: Prompt template 

313 **kwargs: Template variables 

314  

315 Returns: 

316 Self for chaining 

317 """ 

318 content = template.format(**kwargs) 

319 self.messages.append(LLMMessage(role=role, content=content)) 

320 return self 

321 

322 def build(self) -> List[LLMMessage]: 

323 """Build message list. 

324  

325 Returns: 

326 List of messages 

327 """ 

328 return self.messages.copy() 

329 

330 def clear(self) -> 'MessageBuilder': 

331 """Clear all messages. 

332  

333 Returns: 

334 Self for chaining 

335 """ 

336 self.messages.clear() 

337 return self 

338 

339 

340class ResponseParser: 

341 """Parser for LLM responses.""" 

342 

343 @staticmethod 

344 def extract_json(response: Union[str, LLMResponse]) -> Dict[str, Any] | None: 

345 """Extract JSON from response. 

346  

347 Args: 

348 response: LLM response 

349  

350 Returns: 

351 Extracted JSON or None 

352 """ 

353 text = response.content if isinstance(response, LLMResponse) else response 

354 

355 # Try to find JSON in the text 

356 json_patterns = [ 

357 r'\{[^}]*\}', # Simple object 

358 r'\[[^\]]*\]', # Array 

359 r'```json\s*(.*?)\s*```', # Markdown code block 

360 r'```\s*(.*?)\s*```', # Generic code block 

361 ] 

362 

363 for pattern in json_patterns: 

364 matches = re.findall(pattern, text, re.DOTALL) 

365 for match in matches: 

366 try: 

367 return json.loads(match) 

368 except json.JSONDecodeError: 

369 continue 

370 

371 # Try parsing the entire text as JSON 

372 try: 

373 return json.loads(text) 

374 except json.JSONDecodeError: 

375 return None 

376 

377 @staticmethod 

378 def extract_code( 

379 response: Union[str, LLMResponse], 

380 language: str | None = None 

381 ) -> List[str]: 

382 """Extract code blocks from response. 

383  

384 Args: 

385 response: LLM response 

386 language: Optional language filter 

387  

388 Returns: 

389 List of code blocks 

390 """ 

391 text = response.content if isinstance(response, LLMResponse) else response 

392 

393 if language: 

394 # Language-specific code blocks 

395 pattern = rf'```{language}\s*(.*?)\s*```' 

396 else: 

397 # All code blocks 

398 pattern = r'```(?:\w+)?\s*(.*?)\s*```' 

399 

400 matches = re.findall(pattern, text, re.DOTALL) 

401 return [m.strip() for m in matches] 

402 

403 @staticmethod 

404 def extract_list( 

405 response: Union[str, LLMResponse], 

406 numbered: bool = False 

407 ) -> List[str]: 

408 """Extract list items from response. 

409  

410 Args: 

411 response: LLM response 

412 numbered: Whether to look for numbered lists 

413  

414 Returns: 

415 List of items 

416 """ 

417 text = response.content if isinstance(response, LLMResponse) else response 

418 

419 if numbered: 

420 # Numbered list (1. item, 2. item, etc.) 

421 pattern = r'^\d+\.\s+(.+)$' 

422 else: 

423 # Bullet points (-, *, •) 

424 pattern = r'^[-*•]\s+(.+)$' 

425 

426 matches = re.findall(pattern, text, re.MULTILINE) 

427 return [m.strip() for m in matches] 

428 

429 @staticmethod 

430 def extract_sections( 

431 response: Union[str, LLMResponse] 

432 ) -> Dict[str, str]: 

433 """Extract sections from response. 

434  

435 Args: 

436 response: LLM response 

437  

438 Returns: 

439 Dictionary of section name to content 

440 """ 

441 text = response.content if isinstance(response, LLMResponse) else response 

442 

443 # Split by headers (# Header, ## Header, etc.) 

444 sections = {} 

445 current_section = 'main' 

446 current_content = [] 

447 

448 for line in text.split('\n'): 

449 header_match = re.match(r'^#+\s+(.+)$', line) 

450 if header_match: 

451 # Save previous section 

452 if current_content: 

453 sections[current_section] = '\n'.join(current_content).strip() 

454 # Start new section 

455 current_section = header_match.group(1).strip() 

456 current_content = [] 

457 else: 

458 current_content.append(line) 

459 

460 # Save last section 

461 if current_content: 

462 sections[current_section] = '\n'.join(current_content).strip() 

463 

464 return sections 

465 

466 

467class TokenCounter: 

468 """Estimate token counts for different models.""" 

469 

470 # Approximate tokens per character for different models 

471 TOKENS_PER_CHAR = { 

472 'gpt-4': 0.25, 

473 'gpt-3.5': 0.25, 

474 'claude': 0.25, 

475 'llama': 0.3, 

476 'default': 0.25 

477 } 

478 

479 @classmethod 

480 def estimate_tokens( 

481 cls, 

482 text: str, 

483 model: str = 'default' 

484 ) -> int: 

485 """Estimate token count for text. 

486  

487 Args: 

488 text: Input text 

489 model: Model name 

490  

491 Returns: 

492 Estimated token count 

493 """ 

494 # Find matching model pattern 

495 ratio = cls.TOKENS_PER_CHAR['default'] 

496 for pattern, r in cls.TOKENS_PER_CHAR.items(): 

497 if pattern in model.lower(): 

498 ratio = r 

499 break 

500 

501 # Estimate based on character count 

502 return int(len(text) * ratio) 

503 

504 @classmethod 

505 def estimate_messages_tokens( 

506 cls, 

507 messages: List[LLMMessage], 

508 model: str = 'default' 

509 ) -> int: 

510 """Estimate token count for messages. 

511  

512 Args: 

513 messages: List of messages 

514 model: Model name 

515  

516 Returns: 

517 Estimated token count 

518 """ 

519 total = 0 

520 for msg in messages: 

521 # Add role tokens (approximately 4 tokens) 

522 total += 4 

523 # Add content tokens 

524 total += cls.estimate_tokens(msg.content, model) 

525 # Add name tokens if present 

526 if msg.name: 

527 total += cls.estimate_tokens(msg.name, model) 

528 

529 return total 

530 

531 @classmethod 

532 def fits_in_context( 

533 cls, 

534 text: str, 

535 model: str, 

536 max_tokens: int 

537 ) -> bool: 

538 """Check if text fits in context window. 

539  

540 Args: 

541 text: Input text 

542 model: Model name 

543 max_tokens: Maximum token limit 

544  

545 Returns: 

546 True if fits 

547 """ 

548 estimated = cls.estimate_tokens(text, model) 

549 return estimated <= max_tokens 

550 

551 

552class CostCalculator: 

553 """Calculate costs for LLM usage.""" 

554 

555 # Cost per 1K tokens (in USD) 

556 PRICING = { 

557 'gpt-4': {'input': 0.03, 'output': 0.06}, 

558 'gpt-4-32k': {'input': 0.06, 'output': 0.12}, 

559 'gpt-3.5-turbo': {'input': 0.0015, 'output': 0.002}, 

560 'claude-3-opus': {'input': 0.015, 'output': 0.075}, 

561 'claude-3-sonnet': {'input': 0.003, 'output': 0.015}, 

562 'claude-3-haiku': {'input': 0.00025, 'output': 0.00125}, 

563 } 

564 

565 @classmethod 

566 def calculate_cost( 

567 cls, 

568 response: LLMResponse, 

569 model: str | None = None 

570 ) -> float | None: 

571 """Calculate cost for LLM response. 

572  

573 Args: 

574 response: LLM response with usage info 

575 model: Model name (if not in response) 

576  

577 Returns: 

578 Cost in USD or None if cannot calculate 

579 """ 

580 if not response.usage: 

581 return None 

582 

583 model = model or response.model 

584 

585 # Find matching pricing 

586 pricing = None 

587 for pattern, prices in cls.PRICING.items(): 

588 if pattern in model.lower(): 

589 pricing = prices 

590 break 

591 

592 if not pricing: 

593 return None 

594 

595 # Calculate cost 

596 input_cost = (response.usage.get('prompt_tokens', 0) / 1000) * pricing['input'] 

597 output_cost = (response.usage.get('completion_tokens', 0) / 1000) * pricing['output'] 

598 

599 return input_cost + output_cost 

600 

601 @classmethod 

602 def estimate_cost( 

603 cls, 

604 text: str, 

605 model: str, 

606 expected_output_tokens: int = 100 

607 ) -> float | None: 

608 """Estimate cost for text completion. 

609  

610 Args: 

611 text: Input text 

612 model: Model name 

613 expected_output_tokens: Expected output length 

614  

615 Returns: 

616 Estimated cost in USD 

617 """ 

618 # Find matching pricing 

619 pricing = None 

620 for pattern, prices in cls.PRICING.items(): 

621 if pattern in model.lower(): 

622 pricing = prices 

623 break 

624 

625 if not pricing: 

626 return None 

627 

628 # Estimate tokens 

629 input_tokens = TokenCounter.estimate_tokens(text, model) 

630 

631 # Calculate cost 

632 input_cost = (input_tokens / 1000) * pricing['input'] 

633 output_cost = (expected_output_tokens / 1000) * pricing['output'] 

634 

635 return input_cost + output_cost 

636 

637 

638def chain_prompts( 

639 *templates: PromptTemplate 

640) -> PromptTemplate: 

641 """Chain multiple prompt templates. 

642  

643 Args: 

644 *templates: Templates to chain 

645  

646 Returns: 

647 Combined template 

648 """ 

649 combined_template = '\n\n'.join(t.template for t in templates) 

650 combined_variables = [] 

651 seen = set() 

652 

653 for t in templates: 

654 for var in t.variables: 

655 if var not in seen: 

656 combined_variables.append(var) 

657 seen.add(var) 

658 

659 return PromptTemplate(combined_template, combined_variables) 

660 

661 

662def create_few_shot_prompt( 

663 instruction: str, 

664 examples: List[Dict[str, str]], 

665 query_key: str = 'input', 

666 response_key: str = 'output' 

667) -> PromptTemplate: 

668 """Create few-shot learning prompt. 

669  

670 Args: 

671 instruction: Task instruction 

672 examples: List of example input/output pairs 

673 query_key: Key for input in examples 

674 response_key: Key for output in examples 

675  

676 Returns: 

677 Few-shot prompt template 

678 """ 

679 template_parts = [instruction, ''] 

680 

681 # Add examples 

682 for i, example in enumerate(examples, 1): 

683 template_parts.append(f"Example {i}:") 

684 template_parts.append(f"Input: {example[query_key]}") 

685 template_parts.append(f"Output: {example[response_key]}") 

686 template_parts.append('') 

687 

688 # Add query placeholder 

689 template_parts.append("Now, process this input:") 

690 template_parts.append("Input: {query}") 

691 template_parts.append("Output:") 

692 

693 return PromptTemplate('\n'.join(template_parts), ['query'])