Coverage for src/dataknobs_fsm/core/fsm.py: 28%
248 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
1"""Core FSM class for managing state machines."""
3from typing import Any, Dict, List, Set, Tuple, Optional
5from dataknobs_fsm.core.modes import ProcessingMode, TransactionMode
6from dataknobs_fsm.core.network import StateNetwork
7from dataknobs_fsm.core.state import StateDefinition, StateInstance, StateType
8from dataknobs_fsm.functions.base import FunctionRegistry
11class FSM:
12 """Finite State Machine core class.
14 This class manages:
15 - Multiple state networks
16 - Function registry
17 - Data and transaction modes
18 - Resource requirements
19 - Configuration
20 """
22 def __init__(
23 self,
24 name: str,
25 data_mode: ProcessingMode = ProcessingMode.SINGLE,
26 transaction_mode: TransactionMode = TransactionMode.NONE,
27 description: str | None = None,
28 resource_manager: Any | None = None,
29 transaction_manager: Any | None = None
30 ):
31 """Initialize FSM.
33 Args:
34 name: Name of the FSM.
35 data_mode: Data processing mode.
36 transaction_mode: Transaction handling mode.
37 description: Optional FSM description.
38 """
39 self.name = name
40 self.data_mode = data_mode
41 self.transaction_mode = transaction_mode
42 self.description = description
44 # Networks
45 self.networks: Dict[str, StateNetwork] = {}
46 self.main_network_name: str | None = None
48 # Function registry
49 self.function_registry = FunctionRegistry()
51 # Resource requirements
52 self.resource_requirements: Dict[str, Any] = {}
54 # Configuration
55 self.config: Dict[str, Any] = {}
57 # Metadata
58 self.metadata: Dict[str, Any] = {}
59 self.version: str = "1.0.0"
60 self.created_at: float | None = None
61 self.updated_at: float | None = None
63 # Execution support (from builder FSM wrapper)
64 self.resource_manager = resource_manager
65 self.transaction_manager = transaction_manager
66 self._engine: Any | None = None # ExecutionEngine
67 self._async_engine: Any | None = None # AsyncExecutionEngine
69 def add_network(
70 self,
71 network: StateNetwork,
72 is_main: bool = False
73 ) -> None:
74 """Add a network to the FSM.
76 Args:
77 network: Network to add.
78 is_main: Whether this is the main network.
79 """
80 self.networks[network.name] = network
82 if is_main or self.main_network_name is None:
83 self.main_network_name = network.name
85 # Aggregate resource requirements
86 for resource_type, requirements in network.resource_requirements.items():
87 if resource_type not in self.resource_requirements:
88 self.resource_requirements[resource_type] = set()
89 self.resource_requirements[resource_type].update(requirements)
91 def remove_network(self, network_name: str) -> bool:
92 """Remove a network from the FSM.
94 Args:
95 network_name: Name of network to remove.
97 Returns:
98 True if removed successfully.
99 """
100 if network_name in self.networks:
101 del self.networks[network_name]
103 # Update main network if needed
104 if self.main_network_name == network_name:
105 if self.networks:
106 self.main_network_name = next(iter(self.networks.keys()))
107 else:
108 self.main_network_name = None
110 return True
111 return False
113 def get_network(self, network_name: str | None = None) -> StateNetwork | None:
114 """Get a network by name.
116 Args:
117 network_name: Name of network (None for main network).
119 Returns:
120 Network or None if not found.
121 """
122 if network_name is None:
123 network_name = self.main_network_name
125 if network_name:
126 return self.networks.get(network_name)
127 return None
129 def validate(self) -> Tuple[bool, List[str]]:
130 """Validate the FSM.
132 Returns:
133 Tuple of (valid, list of errors).
134 """
135 errors = []
137 # Check for at least one network
138 if not self.networks:
139 errors.append("FSM has no networks")
141 # Check main network exists
142 if self.main_network_name and self.main_network_name not in self.networks:
143 errors.append(f"Main network '{self.main_network_name}' not found")
145 # Validate each network
146 for network_name, network in self.networks.items():
147 valid, network_errors = network.validate()
148 if not valid:
149 for error in network_errors:
150 errors.append(f"Network '{network_name}': {error}")
152 # Check function references
153 all_functions = self._get_all_function_references()
154 for func_name in all_functions:
155 if not self.function_registry.get_function(func_name):
156 errors.append(f"Function '{func_name}' not registered")
158 return len(errors) == 0, errors
160 def _get_all_function_references(self) -> Set[str]:
161 """Get all function references from all networks.
163 Returns:
164 Set of function names referenced.
165 """
166 functions = set()
168 for network in self.networks.values():
169 for arc in network.arcs.values():
170 if hasattr(arc, 'pre_test') and arc.pre_test:
171 functions.add(arc.pre_test)
172 if hasattr(arc, 'transform') and arc.transform:
173 functions.add(arc.transform)
175 return functions
177 def get_all_states(self) -> Dict[str, List[str]]:
178 """Get all states from all networks.
180 Returns:
181 Dictionary of network_name -> list of state names.
182 """
183 all_states = {}
185 for network_name, network in self.networks.items():
186 all_states[network_name] = list(network.states.keys())
188 return all_states
190 def get_all_arcs(self) -> Dict[str, List[str]]:
191 """Get all arcs from all networks.
193 Returns:
194 Dictionary of network_name -> list of arc IDs.
195 """
196 all_arcs = {}
198 for network_name, network in self.networks.items():
199 all_arcs[network_name] = list(network.arcs.keys())
201 return all_arcs
203 def supports_streaming(self) -> bool:
204 """Check if FSM supports streaming.
206 Returns:
207 True if any network supports streaming.
208 """
209 return any(network.supports_streaming for network in self.networks.values())
211 def get_resource_summary(self) -> Dict[str, Any]:
212 """Get resource requirements summary.
214 Returns:
215 Resource requirements summary.
216 """
217 summary = {
218 'total_networks': len(self.networks),
219 'total_states': sum(len(n.states) for n in self.networks.values()),
220 'total_arcs': sum(len(n.arcs) for n in self.networks.values()),
221 'resource_types': list(self.resource_requirements.keys()),
222 'supports_streaming': self.supports_streaming(),
223 'data_mode': self.data_mode.value,
224 'transaction_mode': self.transaction_mode.value
225 }
227 # Add resource counts
228 for resource_type, requirements in self.resource_requirements.items():
229 summary[f'{resource_type}_count'] = len(requirements)
231 return summary
233 def clone(self) -> 'FSM':
234 """Create a clone of this FSM.
236 Returns:
237 Cloned FSM.
238 """
239 clone = FSM(
240 name=f"{self.name}_clone",
241 data_mode=self.data_mode,
242 transaction_mode=self.transaction_mode,
243 description=self.description
244 )
246 # Clone networks
247 for network_name, network in self.networks.items():
248 # Note: This is a shallow copy - for deep clone would need to implement network.clone()
249 clone.networks[network_name] = network
251 clone.main_network_name = self.main_network_name
252 clone.function_registry = self.function_registry
253 clone.resource_requirements = self.resource_requirements.copy()
254 clone.config = self.config.copy()
255 clone.metadata = self.metadata.copy()
256 clone.version = self.version
258 return clone
260 def to_dict(self) -> Dict[str, Any]:
261 """Convert FSM to dictionary representation.
263 Returns:
264 Dictionary representation.
265 """
266 return {
267 'name': self.name,
268 'description': self.description,
269 'data_mode': self.data_mode.value,
270 'transaction_mode': self.transaction_mode.value,
271 'main_network': self.main_network_name,
272 'networks': list(self.networks.keys()),
273 'resource_requirements': {
274 k: list(v) if isinstance(v, set) else v
275 for k, v in self.resource_requirements.items()
276 },
277 'config': self.config,
278 'metadata': self.metadata,
279 'version': self.version
280 }
282 @classmethod
283 def from_dict(cls, data: Dict[str, Any]) -> 'FSM':
284 """Create FSM from dictionary representation.
286 Args:
287 data: Dictionary with FSM data.
289 Returns:
290 New FSM instance.
291 """
292 fsm = cls(
293 name=data['name'],
294 data_mode=ProcessingMode(data.get('data_mode', 'single')),
295 transaction_mode=TransactionMode(data.get('transaction_mode', 'none')),
296 description=data.get('description')
297 )
299 fsm.main_network_name = data.get('main_network')
300 fsm.config = data.get('config', {})
301 fsm.metadata = data.get('metadata', {})
302 fsm.version = data.get('version', '1.0.0')
304 # Resource requirements
305 for resource_type, requirements in data.get('resource_requirements', {}).items():
306 fsm.resource_requirements[resource_type] = set(requirements)
308 return fsm
310 def find_state_definition(self, state_name: str, network_name: str | None = None) -> StateDefinition | None:
311 """Find a state definition by name.
313 Args:
314 state_name: Name of the state to find
315 network_name: Optional specific network to search in
317 Returns:
318 StateDefinition if found, None otherwise
319 """
320 if network_name:
321 # Search specific network
322 network = self.networks.get(network_name)
323 if network and hasattr(network, 'states'):
324 return network.states.get(state_name)
325 else:
326 # Search all networks
327 for network in self.networks.values():
328 if hasattr(network, 'states') and state_name in network.states:
329 return network.states[state_name]
331 return None
333 def create_state_instance(self, state_name: str, data: Dict[str, Any] | None = None, network_name: str | None = None) -> StateInstance:
334 """Create a state instance from a state name.
336 Args:
337 state_name: Name of the state
338 data: Optional initial data for the state
339 network_name: Optional specific network to search in
341 Returns:
342 StateInstance object
343 """
344 # Try to find existing state definition
345 state_def = self.find_state_definition(state_name, network_name)
347 if not state_def:
348 # Create minimal state definition
349 state_def = StateDefinition(
350 name=state_name,
351 type=StateType.START if state_name in ['start', 'Start', 'START'] else StateType.NORMAL
352 )
354 # Create and return state instance
355 return StateInstance(
356 definition=state_def,
357 data=data or {}
358 )
360 def get_state(self, state_name: str, network_name: str | None = None) -> StateDefinition | None:
361 """Get a state definition by name.
363 This is an alias for find_state_definition for compatibility.
365 Args:
366 state_name: Name of the state
367 network_name: Optional specific network to search in
369 Returns:
370 StateDefinition if found, None otherwise
371 """
372 return self.find_state_definition(state_name, network_name)
374 def is_start_state(self, state_name: str, network_name: str | None = None) -> bool:
375 """Check if a state is a start state.
377 Args:
378 state_name: Name of the state
379 network_name: Optional specific network to check in (defaults to main network)
381 Returns:
382 True if the state is a start state
383 """
384 network_name = network_name or self.main_network_name
385 if network_name:
386 network = self.networks.get(network_name)
387 if network:
388 return network.is_initial_state(state_name)
389 return False
391 def is_end_state(self, state_name: str, network_name: str | None = None) -> bool:
392 """Check if a state is an end state.
394 Args:
395 state_name: Name of the state
396 network_name: Optional specific network to check in (defaults to main network)
398 Returns:
399 True if the state is an end state
400 """
401 network_name = network_name or self.main_network_name
402 if network_name:
403 network = self.networks.get(network_name)
404 if network:
405 return network.is_final_state(state_name)
406 return False
408 def get_start_state(self, network_name: str | None = None) -> StateDefinition | None:
409 """Get the start state definition.
411 Args:
412 network_name: Optional specific network to search in
414 Returns:
415 Start state definition if found, None otherwise
416 """
417 # If network specified, search that network
418 if network_name:
419 network = self.networks.get(network_name)
420 if network and hasattr(network, 'states'):
421 for state in network.states.values():
422 if (hasattr(state, 'is_start_state') and state.is_start_state()) or (hasattr(state, 'type') and state.type == StateType.START):
423 return state
424 else:
425 # Search main network first
426 if self.main_network_name:
427 start_state = self.get_start_state(self.main_network_name)
428 if start_state:
429 return start_state
431 # Search all networks
432 for network in self.networks.values():
433 if hasattr(network, 'states'):
434 for state in network.states.values():
435 if (hasattr(state, 'is_start_state') and state.is_start_state()) or (hasattr(state, 'type') and state.type == StateType.START):
436 return state
438 # Fallback: look for state named 'start'
439 return self.find_state_definition('start', network_name)
441 @property
442 def main_network(self) -> Optional['StateNetwork']:
443 """Get the main network object.
445 Returns:
446 The main StateNetwork object or None if not set
447 """
448 if self.main_network_name:
449 return self.networks.get(self.main_network_name)
450 return None
452 @property
453 def states(self) -> Dict[str, StateDefinition]:
454 """Get all states from the main network.
456 Returns:
457 Dictionary of state_name -> state_definition for the main network
458 """
459 if not self.main_network_name:
460 return {}
462 network = self.get_network(self.main_network_name)
463 if network and hasattr(network, 'states'):
464 return network.states
465 return {}
467 def get_all_states_dict(self) -> Dict[str, Dict[str, StateDefinition]]:
468 """Get all states from all networks.
470 Returns:
471 Dictionary of network_name -> {state_name -> state_definition}
472 """
473 all_states = {}
474 for network_name, network in self.networks.items():
475 if hasattr(network, 'states'):
476 all_states[network_name] = network.states
477 return all_states
479 def get_outgoing_arcs(self, state_name: str, network_name: str | None = None) -> List[Any]:
480 """Get outgoing arcs from a state.
482 Args:
483 state_name: Name of the state
484 network_name: Optional network name (uses main network if None)
486 Returns:
487 List of outgoing arcs from the state
488 """
489 network_name = network_name or self.main_network_name
490 if not network_name:
491 return []
493 network = self.get_network(network_name)
494 if network:
495 return network.get_arcs_from_state(state_name)
496 return []
498 def get_engine(self, strategy: str | None = None):
499 """Get or create the execution engine.
501 Args:
502 strategy: Optional execution strategy override
504 Returns:
505 ExecutionEngine instance.
506 """
507 if self._engine is None:
508 from dataknobs_fsm.execution.engine import ExecutionEngine, TraversalStrategy
510 # Map strategy strings to enum
511 strategy_map = {
512 "depth_first": TraversalStrategy.DEPTH_FIRST,
513 "breadth_first": TraversalStrategy.BREADTH_FIRST,
514 "resource_optimized": TraversalStrategy.RESOURCE_OPTIMIZED,
515 "stream_optimized": TraversalStrategy.STREAM_OPTIMIZED,
516 }
518 strat = TraversalStrategy.DEPTH_FIRST # Default
519 if strategy and strategy in strategy_map:
520 strat = strategy_map[strategy]
522 self._engine = ExecutionEngine(
523 fsm=self,
524 strategy=strat,
525 )
527 return self._engine
529 def get_async_engine(self, strategy: str | None = None):
530 """Get or create the async execution engine.
532 Args:
533 strategy: Optional execution strategy override
535 Returns:
536 AsyncExecutionEngine instance.
537 """
538 if self._async_engine is None:
539 from dataknobs_fsm.execution.async_engine import AsyncExecutionEngine
541 self._async_engine = AsyncExecutionEngine(fsm=self)
543 return self._async_engine
545 def _prepare_execution_context(self, initial_data: Dict[str, Any] | None = None):
546 """Prepare execution context for FSM execution.
548 Args:
549 initial_data: Initial data for execution.
551 Returns:
552 Configured ExecutionContext instance.
553 """
554 from dataknobs_fsm.execution.context import ExecutionContext
555 from dataknobs_fsm.streaming.core import StreamContext, StreamConfig
557 # Create execution context
558 context = ExecutionContext(
559 data_mode=self.data_mode,
560 transaction_mode=self.transaction_mode
561 )
563 # Set resource and transaction managers if available
564 if self.resource_manager:
565 context.resource_manager = self.resource_manager
566 if self.transaction_manager:
567 context.transaction_manager = self.transaction_manager
569 # Set up context based on data mode
570 if self.data_mode == ProcessingMode.BATCH:
571 # For batch mode, treat input as batch data
572 if initial_data is not None:
573 # If it's not already a list, make it one
574 if not isinstance(initial_data, list): # type: ignore[unreachable]
575 context.batch_data = [initial_data]
576 else:
577 context.batch_data = initial_data # type: ignore[unreachable]
578 else:
579 context.batch_data = []
580 elif self.data_mode == ProcessingMode.STREAM:
581 # For stream mode, create a stream context
582 stream_config = StreamConfig()
583 context.stream_context = StreamContext(config=stream_config)
585 # Add initial data as a chunk if provided
586 if initial_data is not None:
587 # Add the data as a single chunk to the stream
588 context.stream_context.add_data(initial_data, is_last=True)
589 # Also set context.data for compatibility
590 context.data = initial_data
591 else:
592 # Single mode - data passed normally
593 pass
595 return context
597 def _format_execution_result(self, success: bool, result: Any, context: Any,
598 duration: float, initial_data: Any = None,
599 error: str | None = None) -> Dict[str, Any]:
600 """Format the execution result in a standard format.
602 Args:
603 success: Whether execution succeeded.
604 result: The execution result data.
605 context: The execution context.
606 duration: Time taken for execution.
607 initial_data: Original input data.
608 error: Error message if execution failed.
610 Returns:
611 Formatted result dictionary.
612 """
613 if error:
614 return {
615 "status": "error",
616 "error": error,
617 "data": initial_data,
618 "execution_id": None,
619 "transitions": 0,
620 "duration": None
621 }
623 return {
624 "status": "completed" if success else "failed",
625 "data": result,
626 "execution_id": getattr(context, 'execution_id', None),
627 "transitions": getattr(context, 'transition_count', 0),
628 "duration": duration
629 }
631 async def execute_async(self, initial_data: Dict[str, Any] | None = None) -> Any:
632 """Execute the FSM asynchronously with initial data.
634 Args:
635 initial_data: Initial data for execution.
637 Returns:
638 Execution result.
639 """
640 import time
642 try:
643 # Get the async execution engine
644 engine = self.get_async_engine()
646 # Prepare execution context
647 context = self._prepare_execution_context(initial_data)
649 # Track execution time
650 start_time = time.time()
652 # Execute the FSM
653 success, result = await engine.execute(
654 context,
655 initial_data if self.data_mode == ProcessingMode.SINGLE else None
656 )
658 # Calculate duration
659 duration = time.time() - start_time
661 return self._format_execution_result(success, result, context, duration)
663 except Exception as e:
664 # Handle any exception that occurs during execution
665 return self._format_execution_result(
666 False, None, None, 0.0, initial_data, str(e)
667 )
669 def execute(self, initial_data: Dict[str, Any] | None = None) -> Any:
670 """Execute the FSM synchronously with initial data.
672 This is a simplified API for running the FSM.
674 Args:
675 initial_data: Initial data for execution.
677 Returns:
678 Execution result.
679 """
680 import time
682 try:
683 # Get the execution engine
684 engine = self.get_engine()
686 # Prepare execution context
687 context = self._prepare_execution_context(initial_data)
689 # Track execution time
690 start_time = time.time()
692 # Execute the FSM
693 success, result = engine.execute(
694 context,
695 initial_data if self.data_mode == ProcessingMode.SINGLE else None
696 )
698 # Calculate duration
699 duration = time.time() - start_time
701 return self._format_execution_result(success, result, context, duration)
703 except Exception as e:
704 # Handle any exception that occurs during execution
705 return self._format_execution_result(
706 False, None, None, 0.0, initial_data, str(e)
707 )