Coverage for src/dataknobs_fsm/patterns/llm_workflow.py: 0%
320 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
1"""LLM workflow pattern implementation.
3This module provides pre-configured FSM patterns for LLM-based workflows,
4including RAG pipelines, chain-of-thought reasoning, and multi-agent systems.
5"""
7from typing import Any, Dict, List, Union, Callable
8from dataclasses import dataclass
9from enum import Enum
10import asyncio
12from ..api.simple import SimpleFSM
13from ..core.data_modes import DataHandlingMode
14from ..llm.base import LLMConfig, LLMMessage, LLMResponse
15from ..llm.providers import create_llm_provider
16from ..llm.utils import (
17 PromptTemplate, MessageBuilder, ResponseParser
18)
21class WorkflowType(Enum):
22 """LLM workflow types."""
23 SIMPLE = "simple" # Single LLM call
24 CHAIN = "chain" # Sequential chain of LLM calls
25 RAG = "rag" # Retrieval-augmented generation
26 COT = "cot" # Chain-of-thought reasoning
27 TREE = "tree" # Tree-of-thought reasoning
28 AGENT = "agent" # Agent with tools
29 MULTI_AGENT = "multi_agent" # Multiple cooperating agents
32@dataclass
33class LLMStep:
34 """Single step in LLM workflow."""
35 name: str
36 prompt_template: PromptTemplate
37 model_config: LLMConfig | None = None # Override default
39 # Processing
40 pre_processor: Callable[[Any], Any] | None = None
41 post_processor: Callable[[LLMResponse], Any] | None = None
43 # Validation
44 validator: Callable[[Any], bool] | None = None
45 retry_on_failure: bool = True
46 max_retries: int = 3
48 # Dependencies
49 depends_on: List[str] | None = None
50 pass_context: bool = True # Pass previous results
52 # Output
53 output_key: str | None = None # Key in results dict
54 parse_json: bool = False
55 extract_code: bool = False
58@dataclass
59class RAGConfig:
60 """Configuration for RAG (Retrieval-Augmented Generation)."""
61 retriever_type: str # 'vector', 'keyword', 'hybrid'
62 index_path: str | None = None
63 embedding_model: str | None = None
65 # Retrieval settings
66 top_k: int = 5
67 similarity_threshold: float = 0.7
68 rerank: bool = False
69 rerank_model: str | None = None
71 # Context settings
72 max_context_length: int = 2000
73 context_template: PromptTemplate | None = None
75 # Chunking settings
76 chunk_size: int = 500
77 chunk_overlap: int = 50
80@dataclass
81class AgentConfig:
82 """Configuration for agent-based workflows."""
83 agent_name: str
84 role: str
85 capabilities: List[str]
87 # Tools
88 tools: List[Dict[str, Any]] | None = None
89 tool_descriptions: str | None = None
91 # Memory
92 memory_type: str | None = None # 'buffer', 'summary', 'vector'
93 memory_size: int = 10
95 # Planning
96 planning_enabled: bool = False
97 planning_steps: int = 5
99 # Reflection
100 reflection_enabled: bool = False
101 reflection_prompt: PromptTemplate | None = None
104@dataclass
105class LLMWorkflowConfig:
106 """Configuration for LLM workflow."""
107 workflow_type: WorkflowType
108 steps: List[LLMStep]
109 default_model_config: LLMConfig
111 # Workflow settings
112 max_iterations: int = 10
113 early_stop_condition: Callable[[Dict[str, Any]], bool] | None = None
115 # RAG settings (if applicable)
116 rag_config: RAGConfig | None = None
118 # Agent settings (if applicable)
119 agent_configs: List[AgentConfig] | None = None
121 # Memory and context
122 maintain_history: bool = True
123 max_history_length: int = 20
124 context_window: int = 4000
126 # Output settings
127 aggregate_outputs: bool = False
128 output_formatter: Callable[[Dict[str, Any]], Any] | None = None
130 # Error handling
131 error_handler: Callable[[Exception, str], Any] | None = None
132 fallback_response: str | None = None
134 # Monitoring
135 log_prompts: bool = False
136 log_responses: bool = False
137 track_tokens: bool = True
138 track_cost: bool = False
141class VectorRetriever:
142 """Simple vector-based retriever for RAG."""
144 def __init__(self, config: RAGConfig):
145 self.config = config
146 self.documents = []
147 self.embeddings = []
149 async def index_documents(self, documents: List[str]) -> None:
150 """Index documents for retrieval.
152 Generates embeddings for documents using the configured LLM provider.
153 In production, these would be stored in a vector database.
155 Args:
156 documents: List of documents to index
157 """
158 from dataknobs_fsm.llm.providers import get_provider
160 self.documents = documents
162 # Try to use a real embedding provider if available
163 if self.config.provider_config:
164 try:
165 provider = get_provider(self.config.provider_config)
167 # Check if provider supports embeddings
168 if hasattr(provider, 'embed'):
169 # Generate embeddings for all documents
170 self.embeddings = await provider.embed(documents)
172 # Normalize embeddings for cosine similarity
173 self.embeddings = [
174 self._normalize_embedding(emb) for emb in self.embeddings
175 ]
176 else:
177 # Fallback to mock embeddings if provider doesn't support them
178 self.embeddings = self._generate_mock_embeddings(documents)
179 except Exception as e:
180 # Log error and fallback to mock embeddings
181 import logging
182 logger = logging.getLogger(__name__)
183 logger.warning(f"Failed to generate real embeddings: {e}. Using mock embeddings.")
184 self.embeddings = self._generate_mock_embeddings(documents)
185 else:
186 # No provider configured, use mock embeddings
187 self.embeddings = self._generate_mock_embeddings(documents)
189 def _normalize_embedding(self, embedding: List[float]) -> List[float]:
190 """Normalize an embedding vector for cosine similarity.
192 Args:
193 embedding: Raw embedding vector
195 Returns:
196 Normalized embedding vector
197 """
198 import math
200 norm = math.sqrt(sum(x * x for x in embedding))
201 if norm == 0:
202 return embedding
203 return [x / norm for x in embedding]
205 def _generate_mock_embeddings(self, documents: List[str]) -> List[List[float]]:
206 """Generate mock embeddings for testing.
208 Args:
209 documents: Documents to generate embeddings for
211 Returns:
212 Mock embedding vectors
213 """
214 import hashlib
216 embeddings = []
217 for doc in documents:
218 # Create deterministic mock embedding based on document content
219 doc_hash = hashlib.sha256(doc.encode()).digest()
220 # Convert hash to 768-dimensional embedding (standard size)
221 embedding = []
222 for i in range(96): # 768 / 8 = 96
223 # Take 8 bytes at a time and convert to float
224 if i * 8 < len(doc_hash):
225 byte_chunk = doc_hash[i*8:(i+1)*8]
226 value = sum(b for b in byte_chunk) / 2040.0 # Normalize to ~[0, 1]
227 else:
228 # Pad with deterministic values if needed
229 value = (i % 10) / 10.0
231 # Expand to 8 dimensions
232 for j in range(8):
233 embedding.append(value * (1 + j * 0.1))
235 embeddings.append(self._normalize_embedding(embedding))
237 return embeddings
239 async def retrieve(self, query: str, top_k: int = None) -> List[str]:
240 """Retrieve relevant documents using semantic similarity.
242 Args:
243 query: Query string
244 top_k: Number of documents to retrieve
246 Returns:
247 List of most relevant documents
248 """
249 from dataknobs_fsm.llm.providers import get_provider
251 top_k = top_k or self.config.top_k
253 if not self.documents:
254 return []
256 # Generate embedding for query
257 query_embedding = None
259 if self.config.provider_config:
260 try:
261 provider = get_provider(self.config.provider_config)
262 if hasattr(provider, 'embed'):
263 query_embedding = await provider.embed(query)
264 query_embedding = self._normalize_embedding(query_embedding)
265 except Exception:
266 pass
268 if query_embedding is None:
269 # Fallback to mock embedding
270 query_embedding = self._generate_mock_embeddings([query])[0]
272 # Calculate cosine similarities
273 similarities = []
274 for i, doc_embedding in enumerate(self.embeddings):
275 similarity = self._cosine_similarity(query_embedding, doc_embedding)
276 similarities.append((similarity, i))
278 # Sort by similarity and return top-k documents
279 similarities.sort(reverse=True)
280 top_indices = [idx for _, idx in similarities[:top_k]]
282 return [self.documents[idx] for idx in top_indices]
284 def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
285 """Calculate cosine similarity between two vectors.
287 Args:
288 vec1: First vector
289 vec2: Second vector
291 Returns:
292 Cosine similarity score
293 """
294 if len(vec1) != len(vec2):
295 # Handle dimension mismatch by padding or truncating
296 min_len = min(len(vec1), len(vec2))
297 vec1 = vec1[:min_len]
298 vec2 = vec2[:min_len]
300 dot_product = sum(a * b for a, b in zip(vec1, vec2, strict=False))
301 return dot_product # Already normalized
304class LLMWorkflow:
305 """LLM workflow orchestrator using FSM pattern."""
307 def __init__(self, config: LLMWorkflowConfig):
308 """Initialize LLM workflow.
310 Args:
311 config: Workflow configuration
312 """
313 self.config = config
314 self._fsm = self._build_fsm()
315 self._providers = {}
316 self._history = []
317 self._context = {}
318 self._retriever = None
320 # Initialize retriever if RAG
321 if config.workflow_type == WorkflowType.RAG and config.rag_config:
322 self._retriever = VectorRetriever(config.rag_config)
324 def _build_fsm(self) -> SimpleFSM:
325 """Build FSM for LLM workflow."""
326 # Add start state
327 states = [{'name': 'start', 'type': 'initial', 'is_start': True}]
328 arcs = []
330 if self.config.workflow_type == WorkflowType.SIMPLE:
331 # Single LLM call
332 states.append({'name': 'llm_call', 'type': 'task'})
333 arcs.append({'from': 'start', 'to': 'llm_call', 'name': 'init'})
334 arcs.append({'from': 'llm_call', 'to': 'end', 'name': 'complete'})
336 elif self.config.workflow_type == WorkflowType.CHAIN:
337 # Sequential chain
338 for i, step in enumerate(self.config.steps):
339 state_name = f"step_{step.name}"
340 states.append({'name': state_name, 'type': 'task'})
342 if i == 0:
343 arcs.append({'from': 'start', 'to': state_name, 'name': f'init_{step.name}'})
344 else:
345 prev_state = f"step_{self.config.steps[i-1].name}"
346 arcs.append({
347 'from': prev_state,
348 'to': state_name,
349 'name': f'{self.config.steps[i-1].name}_to_{step.name}'
350 })
352 if i == len(self.config.steps) - 1:
353 arcs.append({'from': state_name, 'to': 'end', 'name': f'{step.name}_complete'})
355 elif self.config.workflow_type == WorkflowType.RAG:
356 # RAG pipeline
357 states.extend([
358 {'name': 'retrieve', 'type': 'task'},
359 {'name': 'augment', 'type': 'task'},
360 {'name': 'generate', 'type': 'task'}
361 ])
363 arcs.extend([
364 {'from': 'start', 'to': 'retrieve', 'name': 'init_retrieval'},
365 {'from': 'retrieve', 'to': 'augment', 'name': 'retrieve_to_augment'},
366 {'from': 'augment', 'to': 'generate', 'name': 'augment_to_generate'},
367 {'from': 'generate', 'to': 'end', 'name': 'generation_complete'}
368 ])
370 elif self.config.workflow_type == WorkflowType.COT:
371 # Chain-of-thought reasoning
372 states.extend([
373 {'name': 'decompose', 'type': 'task'},
374 {'name': 'reason', 'type': 'task'},
375 {'name': 'synthesize', 'type': 'task'}
376 ])
378 arcs.extend([
379 {'from': 'start', 'to': 'decompose', 'name': 'init_decompose'},
380 {'from': 'decompose', 'to': 'reason', 'name': 'decompose_to_reason'},
381 {'from': 'reason', 'to': 'synthesize', 'name': 'reason_to_synthesize'},
382 {'from': 'synthesize', 'to': 'end', 'name': 'synthesis_complete'}
383 ])
385 # Add end state
386 states.append({
387 'name': 'end',
388 'type': 'terminal'
389 })
391 # Build FSM configuration
392 fsm_config = {
393 'name': 'LLM_Workflow',
394 'data_mode': DataHandlingMode.REFERENCE.value,
395 'states': states,
396 'arcs': arcs,
397 'resources': []
398 }
400 return SimpleFSM(fsm_config)
402 async def _get_provider(self, step: LLMStep | None = None):
403 """Get LLM provider for step."""
404 config = step.model_config if step and step.model_config else self.config.default_model_config
406 key = f"{config.provider}_{config.model}"
407 if key not in self._providers:
408 self._providers[key] = create_llm_provider(config, is_async=True)
409 await self._providers[key].initialize()
411 return self._providers[key]
413 async def _execute_step(
414 self,
415 step: LLMStep,
416 input_data: Dict[str, Any]
417 ) -> Any:
418 """Execute a single workflow step.
420 Args:
421 step: Workflow step
422 input_data: Input data with template variables
424 Returns:
425 Step output
426 """
427 # Pre-process input
428 if step.pre_processor:
429 input_data = step.pre_processor(input_data)
431 # Format prompt
432 prompt = step.prompt_template.format(**input_data)
434 # Build messages
435 builder = MessageBuilder()
436 if self.config.default_model_config.system_prompt:
437 builder.system(self.config.default_model_config.system_prompt)
439 # Add history if maintaining
440 if self.config.maintain_history and self._history:
441 for msg in self._history[-self.config.max_history_length:]:
442 builder.messages.append(msg)
444 builder.user(prompt)
445 messages = builder.build()
447 # Get provider and generate
448 provider = await self._get_provider(step)
450 retry_count = 0
451 while retry_count <= step.max_retries:
452 try:
453 # Generate response
454 if self.config.default_model_config.stream:
455 response_text = ""
456 async for chunk in provider.stream_complete(messages):
457 response_text += chunk.delta
458 if self.config.default_model_config.stream_callback:
459 self.config.default_model_config.stream_callback(chunk)
460 response = LLMResponse(content=response_text, model=provider.config.model)
461 else:
462 response = await provider.complete(messages)
464 # Validate response
465 if step.validator and not step.validator(response):
466 if not step.retry_on_failure or retry_count >= step.max_retries:
467 raise ValueError(f"Validation failed for step {step.name}")
468 retry_count += 1
469 continue
471 # Parse response if needed
472 result = response.content
473 if step.parse_json:
474 result = ResponseParser.extract_json(response)
475 elif step.extract_code:
476 result = ResponseParser.extract_code(response)
478 # Post-process
479 if step.post_processor:
480 result = step.post_processor(result) # type: ignore
482 # Update history
483 if self.config.maintain_history:
484 self._history.append(LLMMessage(role='user', content=prompt))
485 self._history.append(LLMMessage(role='assistant', content=response.content))
487 # Track tokens and cost
488 if self.config.track_tokens and response.usage:
489 self._context['total_tokens'] = self._context.get('total_tokens', 0) + response.usage.get('total_tokens', 0)
491 return result
493 except Exception as e:
494 if retry_count >= step.max_retries:
495 if self.config.error_handler:
496 return self.config.error_handler(e, step.name)
497 raise
498 retry_count += 1
499 await asyncio.sleep(1.0 * retry_count) # Exponential backoff
501 async def _execute_rag(self, query: str) -> str:
502 """Execute RAG workflow.
504 Args:
505 query: User query
507 Returns:
508 Generated response
509 """
510 if not self._retriever:
511 raise ValueError("RAG configuration not provided")
513 # Retrieve relevant documents
514 documents = await self._retriever.retrieve(query)
516 # Build augmented prompt
517 context = "\n\n".join(documents)
518 if self.config.rag_config.context_template:
519 augmented_prompt = self.config.rag_config.context_template.format(
520 context=context,
521 query=query
522 )
523 else:
524 augmented_prompt = f"""Context:
525{context}
527Question: {query}
529Answer based on the context provided:"""
531 # Generate response
532 provider = await self._get_provider()
533 response = await provider.complete(augmented_prompt)
535 return response.content
537 async def _execute_cot(self, problem: str) -> str:
538 """Execute chain-of-thought reasoning.
540 Args:
541 problem: Problem to solve
543 Returns:
544 Solution
545 """
546 provider = await self._get_provider()
548 # Step 1: Decompose problem
549 decompose_prompt = f"""Break down this problem into smaller steps:
550{problem}
552List the steps needed to solve this:"""
554 decompose_response = await provider.complete(decompose_prompt)
555 steps = ResponseParser.extract_list(decompose_response)
557 # Step 2: Reason through each step
558 reasoning = []
559 for i, step in enumerate(steps, 1):
560 reason_prompt = f"""Problem: {problem}
561Step {i}: {step}
563Explain how to complete this step:"""
565 reason_response = await provider.complete(reason_prompt)
566 reasoning.append(f"Step {i}: {step}\n{reason_response.content}")
568 # Step 3: Synthesize solution
569 synthesis_prompt = f"""Problem: {problem}
571Reasoning:
572{chr(10).join(reasoning)}
574Based on the reasoning above, provide the final solution:"""
576 synthesis_response = await provider.complete(synthesis_prompt)
578 return synthesis_response.content
580 async def execute(
581 self,
582 input_data: Union[str, Dict[str, Any]]
583 ) -> Dict[str, Any]:
584 """Execute LLM workflow.
586 Args:
587 input_data: Input data or query
589 Returns:
590 Workflow results
591 """
592 # Normalize input
593 if isinstance(input_data, str):
594 input_data = {'query': input_data}
596 results = {}
598 if self.config.workflow_type == WorkflowType.SIMPLE:
599 # Single step execution
600 if self.config.steps:
601 output = await self._execute_step(self.config.steps[0], input_data)
602 results[self.config.steps[0].output_key or 'output'] = output
603 else:
604 # Direct LLM call
605 provider = await self._get_provider()
606 response = await provider.complete(input_data.get('query', ''))
607 results['output'] = response.content
609 elif self.config.workflow_type == WorkflowType.CHAIN:
610 # Sequential chain execution
611 current_context = input_data.copy()
613 for step in self.config.steps:
614 # Add dependencies to context
615 if step.depends_on:
616 for dep in step.depends_on:
617 if dep in results:
618 current_context[dep] = results[dep]
620 # Execute step
621 output = await self._execute_step(step, current_context)
623 # Store result
624 output_key = step.output_key or step.name
625 results[output_key] = output
627 # Update context if passing
628 if step.pass_context:
629 current_context[output_key] = output
631 elif self.config.workflow_type == WorkflowType.RAG:
632 # RAG pipeline
633 output = await self._execute_rag(input_data.get('query', ''))
634 results['output'] = output
636 elif self.config.workflow_type == WorkflowType.COT:
637 # Chain-of-thought
638 output = await self._execute_cot(input_data.get('problem', input_data.get('query', '')))
639 results['output'] = output
641 # Format output if configured
642 if self.config.output_formatter:
643 results = self.config.output_formatter(results)
645 # Add metadata
646 if self.config.track_tokens:
647 results['_tokens'] = self._context.get('total_tokens', 0)
649 return results
651 async def index_documents(self, documents: List[str]) -> None:
652 """Index documents for RAG.
654 Args:
655 documents: Documents to index
656 """
657 if not self._retriever:
658 raise ValueError("RAG configuration not provided")
659 await self._retriever.index_documents(documents)
661 async def close(self) -> None:
662 """Close all providers."""
663 for provider in self._providers.values():
664 await provider.close()
667def create_simple_llm_workflow(
668 prompt_template: str,
669 model: str = 'gpt-3.5-turbo',
670 provider: str = 'openai',
671 **kwargs
672) -> LLMWorkflow:
673 """Create simple LLM workflow.
675 Args:
676 prompt_template: Prompt template string
677 model: Model name
678 provider: Provider name
679 **kwargs: Additional configuration
681 Returns:
682 Configured LLM workflow
683 """
684 template = PromptTemplate(prompt_template)
686 config = LLMWorkflowConfig(
687 workflow_type=WorkflowType.SIMPLE,
688 steps=[
689 LLMStep(
690 name='generate',
691 prompt_template=template
692 )
693 ],
694 default_model_config=LLMConfig(
695 provider=provider,
696 model=model,
697 **kwargs
698 )
699 )
701 return LLMWorkflow(config)
704def create_rag_workflow(
705 model: str = 'gpt-3.5-turbo',
706 provider: str = 'openai',
707 retriever_type: str = 'vector',
708 top_k: int = 5,
709 **kwargs
710) -> LLMWorkflow:
711 """Create RAG workflow.
713 Args:
714 model: Model name
715 provider: Provider name
716 retriever_type: Type of retriever
717 top_k: Number of documents to retrieve
718 **kwargs: Additional configuration
720 Returns:
721 Configured RAG workflow
722 """
723 config = LLMWorkflowConfig(
724 workflow_type=WorkflowType.RAG,
725 steps=[],
726 default_model_config=LLMConfig(
727 provider=provider,
728 model=model,
729 **kwargs
730 ),
731 rag_config=RAGConfig(
732 retriever_type=retriever_type,
733 top_k=top_k
734 )
735 )
737 return LLMWorkflow(config)
740def create_chain_workflow(
741 steps: List[Dict[str, Any]],
742 model: str = 'gpt-3.5-turbo',
743 provider: str = 'openai',
744 **kwargs
745) -> LLMWorkflow:
746 """Create chain workflow.
748 Args:
749 steps: List of step configurations
750 model: Model name
751 provider: Provider name
752 **kwargs: Additional configuration
754 Returns:
755 Configured chain workflow
756 """
757 llm_steps = []
758 for step_config in steps:
759 llm_steps.append(LLMStep(
760 name=step_config['name'],
761 prompt_template=PromptTemplate(step_config['prompt']),
762 output_key=step_config.get('output_key'),
763 parse_json=step_config.get('parse_json', False),
764 depends_on=step_config.get('depends_on')
765 ))
767 config = LLMWorkflowConfig(
768 workflow_type=WorkflowType.CHAIN,
769 steps=llm_steps,
770 default_model_config=LLMConfig(
771 provider=provider,
772 model=model,
773 **kwargs
774 )
775 )
777 return LLMWorkflow(config)