Coverage for src/dataknobs_fsm/functions/library/llm.py: 0%

180 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""Built-in LLM functions for FSM. 

2 

3This module provides LLM-related functions that can be referenced 

4in FSM configurations for AI-powered workflows. 

5""" 

6 

7import asyncio 

8import json 

9from typing import Any, Callable, Dict, List 

10 

11from dataknobs_fsm.functions.base import ( 

12 ITransformFunction, 

13 IValidationFunction, 

14 TransformFunctionError, 

15 ValidationFunctionError, 

16) 

17from dataknobs_fsm.resources.llm import LLMResource 

18 

19 

20class PromptBuilder(ITransformFunction): 

21 """Build prompts for LLM calls.""" 

22 

23 def __init__( 

24 self, 

25 template: str, 

26 system_prompt: str | None = None, 

27 variables: List[str] | None = None, 

28 format_spec: str | None = None, # "json", "markdown", "plain" 

29 ): 

30 """Initialize the prompt builder. 

31  

32 Args: 

33 template: Prompt template with {variable} placeholders. 

34 system_prompt: Optional system prompt. 

35 variables: List of variable names to extract from data. 

36 format_spec: Output format specification. 

37 """ 

38 self.template = template 

39 self.system_prompt = system_prompt 

40 self.variables = variables or [] 

41 self.format_spec = format_spec 

42 

43 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

44 """Transform data by building prompt. 

45  

46 Args: 

47 data: Input data containing variables for prompt. 

48  

49 Returns: 

50 Data with built prompt. 

51 """ 

52 # Extract variables 

53 variables = {} 

54 for var in self.variables: 

55 if var in data: 

56 variables[var] = data[var] 

57 else: 

58 # Try nested access 

59 parts = var.split(".") 

60 value = data 

61 for part in parts: 

62 if isinstance(value, dict) and part in value: 

63 value = value[part] 

64 else: 

65 value = None 

66 break 

67 if value is not None: 

68 variables[var] = value 

69 

70 # Build prompt 

71 try: 

72 prompt = self.template.format(**variables) 

73 except KeyError as e: 

74 raise TransformFunctionError(f"Missing variable for prompt: {e}") from e 

75 

76 # Add format specification if provided 

77 if self.format_spec: 

78 if self.format_spec == "json": 

79 prompt += "\n\nPlease respond with valid JSON only." 

80 elif self.format_spec == "markdown": 

81 prompt += "\n\nPlease format your response using Markdown." 

82 

83 result = { 

84 **data, 

85 "prompt": prompt, 

86 } 

87 

88 if self.system_prompt: 

89 result["system_prompt"] = self.system_prompt 

90 

91 return result 

92 

93 

94class LLMCaller(ITransformFunction): 

95 """Call an LLM with a prompt.""" 

96 

97 def __init__( 

98 self, 

99 resource_name: str, 

100 model: str | None = None, 

101 temperature: float = 0.7, 

102 max_tokens: int | None = None, 

103 stream: bool = False, 

104 response_field: str = "llm_response", 

105 ): 

106 """Initialize the LLM caller. 

107  

108 Args: 

109 resource_name: Name of the LLM resource to use. 

110 model: Model to use (if None, use resource default). 

111 temperature: Temperature for generation. 

112 max_tokens: Maximum tokens to generate. 

113 stream: Whether to stream the response. 

114 response_field: Field to store response in. 

115 """ 

116 self.resource_name = resource_name 

117 self.model = model 

118 self.temperature = temperature 

119 self.max_tokens = max_tokens 

120 self.stream = stream 

121 self.response_field = response_field 

122 

123 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

124 """Transform data by calling LLM. 

125  

126 Args: 

127 data: Input data containing prompt. 

128  

129 Returns: 

130 Data with LLM response. 

131 """ 

132 # Get resource from context 

133 resource = data.get("_resources", {}).get(self.resource_name) 

134 if not resource or not isinstance(resource, LLMResource): 

135 raise TransformFunctionError(f"LLM resource '{self.resource_name}' not found") 

136 

137 # Get prompt 

138 prompt = data.get("prompt") 

139 if not prompt: 

140 raise TransformFunctionError("No prompt found in data") 

141 

142 system_prompt = data.get("system_prompt") 

143 

144 try: 

145 # Call LLM 

146 response = await resource.generate( 

147 prompt=prompt, 

148 system_prompt=system_prompt, 

149 model=self.model, 

150 temperature=self.temperature, 

151 max_tokens=self.max_tokens, 

152 stream=self.stream, 

153 ) 

154 

155 if self.stream: 

156 # For streaming, return an async generator 

157 return { 

158 **data, 

159 self.response_field: response, # Async generator 

160 "is_streaming": True, 

161 } 

162 else: 

163 # For non-streaming, return the full response 

164 return { 

165 **data, 

166 self.response_field: response, 

167 "tokens_used": response.get("usage", {}).get("total_tokens"), 

168 } 

169 

170 except Exception as e: 

171 raise TransformFunctionError(f"LLM call failed: {e}") from e 

172 

173 

174class ResponseValidator(IValidationFunction): 

175 """Validate LLM responses.""" 

176 

177 def __init__( 

178 self, 

179 response_field: str = "llm_response", 

180 format: str | None = None, # "json", "markdown", etc. 

181 schema: Dict[str, Any] | None = None, 

182 min_length: int | None = None, 

183 max_length: int | None = None, 

184 required_fields: List[str] | None = None, 

185 ): 

186 """Initialize the response validator. 

187  

188 Args: 

189 response_field: Field containing LLM response. 

190 format: Expected response format. 

191 schema: JSON schema for validation (if format is JSON). 

192 min_length: Minimum response length. 

193 max_length: Maximum response length. 

194 required_fields: Required fields in parsed response. 

195 """ 

196 self.response_field = response_field 

197 self.format = format 

198 self.schema = schema 

199 self.min_length = min_length 

200 self.max_length = max_length 

201 self.required_fields = required_fields or [] 

202 

203 def validate(self, data: Dict[str, Any]) -> bool: 

204 """Validate LLM response. 

205  

206 Args: 

207 data: Data containing LLM response. 

208  

209 Returns: 

210 True if valid. 

211  

212 Raises: 

213 ValidationFunctionError: If validation fails. 

214 """ 

215 response = data.get(self.response_field) 

216 if response is None: 

217 raise ValidationFunctionError(f"Response field '{self.response_field}' not found") 

218 

219 # Extract text from response object if needed 

220 if isinstance(response, dict): 

221 text = response.get("text", response.get("content", str(response))) 

222 else: 

223 text = str(response) 

224 

225 # Check length constraints 

226 if self.min_length and len(text) < self.min_length: # type: ignore 

227 raise ValidationFunctionError( 

228 f"Response too short: {len(text)} < {self.min_length}" # type: ignore 

229 ) 

230 

231 if self.max_length and len(text) > self.max_length: # type: ignore 

232 raise ValidationFunctionError( 

233 f"Response too long: {len(text)} > {self.max_length}" # type: ignore 

234 ) 

235 

236 # Validate format 

237 if self.format == "json": 

238 try: 

239 parsed = json.loads(text) # type: ignore 

240 

241 # Validate against schema if provided 

242 if self.schema: 

243 from pydantic import create_model, ValidationError 

244 model = create_model("ResponseSchema", **self.schema) 

245 try: 

246 model(**parsed) 

247 except ValidationError as e: 

248 raise ValidationFunctionError(f"Schema validation failed: {e}") from e 

249 

250 # Check required fields 

251 for field in self.required_fields: 

252 if field not in parsed: 

253 raise ValidationFunctionError(f"Required field missing: {field}") 

254 

255 except json.JSONDecodeError as e: 

256 raise ValidationFunctionError(f"Invalid JSON response: {e}") from e 

257 

258 return True 

259 

260 

261class FunctionCaller(ITransformFunction): 

262 """Call functions/tools based on LLM output.""" 

263 

264 def __init__( 

265 self, 

266 response_field: str = "llm_response", 

267 function_registry: Dict[str, Callable] | None = None, 

268 result_field: str = "function_result", 

269 ): 

270 """Initialize the function caller. 

271  

272 Args: 

273 response_field: Field containing LLM response with function call. 

274 function_registry: Registry of available functions. 

275 result_field: Field to store function result. 

276 """ 

277 self.response_field = response_field 

278 self.function_registry = function_registry or {} 

279 self.result_field = result_field 

280 

281 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

282 """Transform data by calling function from LLM response. 

283  

284 Args: 

285 data: Input data containing LLM response with function call. 

286  

287 Returns: 

288 Data with function result. 

289 """ 

290 response = data.get(self.response_field) 

291 if not response: 

292 return data 

293 

294 # Parse function call from response 

295 if isinstance(response, str): 

296 try: 

297 response = json.loads(response) 

298 except json.JSONDecodeError: 

299 # Not a JSON response, no function to call 

300 return data 

301 

302 # Extract function call 

303 function_name = response.get("function") 

304 function_args = response.get("arguments", {}) 

305 

306 if not function_name: 

307 return data 

308 

309 # Look up function 

310 if function_name not in self.function_registry: 

311 raise TransformFunctionError(f"Function not found: {function_name}") 

312 

313 func = self.function_registry[function_name] 

314 

315 try: 

316 # Call function 

317 if asyncio.iscoroutinefunction(func): 

318 result = await func(**function_args) 

319 else: 

320 result = func(**function_args) 

321 

322 return { 

323 **data, 

324 self.result_field: result, 

325 "function_called": function_name, 

326 } 

327 

328 except Exception as e: 

329 raise TransformFunctionError(f"Function call failed: {e}") from e 

330 

331 

332class ConversationManager(ITransformFunction): 

333 """Manage conversation history for multi-turn interactions.""" 

334 

335 def __init__( 

336 self, 

337 max_history: int = 10, 

338 history_field: str = "conversation_history", 

339 role_field: str = "role", 

340 content_field: str = "content", 

341 ): 

342 """Initialize the conversation manager. 

343  

344 Args: 

345 max_history: Maximum number of messages to keep. 

346 history_field: Field to store conversation history. 

347 role_field: Field for message role. 

348 content_field: Field for message content. 

349 """ 

350 self.max_history = max_history 

351 self.history_field = history_field 

352 self.role_field = role_field 

353 self.content_field = content_field 

354 

355 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

356 """Transform data by managing conversation history. 

357  

358 Args: 

359 data: Input data with new message. 

360  

361 Returns: 

362 Data with updated conversation history. 

363 """ 

364 # Get existing history 

365 history = data.get(self.history_field, []) 

366 

367 # Add user message if present 

368 if "prompt" in data: 

369 history.append({ 

370 self.role_field: "user", 

371 self.content_field: data["prompt"], 

372 }) 

373 

374 # Add assistant message if present 

375 if "llm_response" in data: 

376 response = data["llm_response"] 

377 if isinstance(response, dict): 

378 content = response.get("text", response.get("content", str(response))) 

379 else: 

380 content = str(response) 

381 

382 history.append({ 

383 self.role_field: "assistant", 

384 self.content_field: content, 

385 }) 

386 

387 # Trim history if needed 

388 if len(history) > self.max_history: 

389 history = history[-self.max_history:] 

390 

391 return { 

392 **data, 

393 self.history_field: history, 

394 } 

395 

396 

397class EmbeddingGenerator(ITransformFunction): 

398 """Generate embeddings for text using LLM.""" 

399 

400 def __init__( 

401 self, 

402 resource_name: str, 

403 text_field: str = "text", 

404 embedding_field: str = "embedding", 

405 model: str | None = None, 

406 batch_size: int = 100, 

407 ): 

408 """Initialize the embedding generator. 

409  

410 Args: 

411 resource_name: Name of the LLM resource to use. 

412 text_field: Field containing text to embed. 

413 embedding_field: Field to store embeddings. 

414 model: Embedding model to use. 

415 batch_size: Batch size for embedding generation. 

416 """ 

417 self.resource_name = resource_name 

418 self.text_field = text_field 

419 self.embedding_field = embedding_field 

420 self.model = model 

421 self.batch_size = batch_size 

422 

423 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

424 """Transform data by generating embeddings. 

425  

426 Args: 

427 data: Input data containing text. 

428  

429 Returns: 

430 Data with embeddings. 

431 """ 

432 # Get resource from context 

433 resource = data.get("_resources", {}).get(self.resource_name) 

434 if not resource or not isinstance(resource, LLMResource): 

435 raise TransformFunctionError(f"LLM resource '{self.resource_name}' not found") 

436 

437 # Get text to embed 

438 text = data.get(self.text_field) 

439 if not text: 

440 return data 

441 

442 try: 

443 # Generate embedding(s) 

444 if isinstance(text, list): 

445 # Batch processing 

446 embeddings = [] 

447 for i in range(0, len(text), self.batch_size): 

448 batch = text[i:i + self.batch_size] 

449 batch_embeddings = await resource.embed(batch, model=self.model) 

450 embeddings.extend(batch_embeddings) 

451 else: 

452 # Single text 

453 embeddings = await resource.embed(text, model=self.model) 

454 

455 return { 

456 **data, 

457 self.embedding_field: embeddings, 

458 } 

459 

460 except Exception as e: 

461 raise TransformFunctionError(f"Embedding generation failed: {e}") from e 

462 

463 

464# Convenience functions for creating LLM functions 

465def build_prompt(template: str, **kwargs) -> PromptBuilder: 

466 """Create a PromptBuilder.""" 

467 return PromptBuilder(template, **kwargs) 

468 

469 

470def call_llm(resource: str, **kwargs) -> LLMCaller: 

471 """Create an LLMCaller.""" 

472 return LLMCaller(resource, **kwargs) 

473 

474 

475def validate_response(**kwargs) -> ResponseValidator: 

476 """Create a ResponseValidator.""" 

477 return ResponseValidator(**kwargs) 

478 

479 

480def call_function(**kwargs) -> FunctionCaller: 

481 """Create a FunctionCaller.""" 

482 return FunctionCaller(**kwargs) 

483 

484 

485def manage_conversation(**kwargs) -> ConversationManager: 

486 """Create a ConversationManager.""" 

487 return ConversationManager(**kwargs) 

488 

489 

490def generate_embeddings(resource: str, **kwargs) -> EmbeddingGenerator: 

491 """Create an EmbeddingGenerator.""" 

492 return EmbeddingGenerator(resource, **kwargs)