Coverage for src/dataknobs_fsm/execution/base_engine.py: 56%
87 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
1"""Base execution engine with shared logic for sync and async engines.
3This module provides a base class that contains common logic shared between
4the synchronous (ExecutionEngine) and asynchronous (AsyncExecutionEngine)
5implementations, reducing code duplication and ensuring feature parity.
6"""
8from abc import ABC, abstractmethod
9from typing import Any, Dict, List, Tuple, TYPE_CHECKING
10from types import SimpleNamespace
12from dataknobs_fsm.core.data_wrapper import (
13 ensure_dict,
14 wrap_for_lambda
15)
16from dataknobs_fsm.core.arc import ArcDefinition
17from dataknobs_fsm.core.fsm import FSM
18from dataknobs_fsm.core.network import StateNetwork
19from dataknobs_fsm.core.state import StateType
20from dataknobs_fsm.execution.context import ExecutionContext
21from dataknobs_fsm.functions.base import FunctionContext
22from dataknobs_fsm.execution.common import (
23 NetworkSelector,
24 TransitionSelector,
25 TransitionSelectionMode
26)
28if TYPE_CHECKING:
29 from dataknobs_fsm.execution.engine import TraversalStrategy
32class BaseExecutionEngine(ABC):
33 """Base class for execution engines with shared logic.
35 This class provides common functionality for both sync and async engines:
36 - Initial state finding
37 - Network selection
38 - State transform preparation
39 - Arc evaluation logic
40 - Error handling patterns
41 - Statistics tracking
42 """
44 def __init__(
45 self,
46 fsm: FSM,
47 strategy: 'TraversalStrategy',
48 selection_mode: TransitionSelectionMode = TransitionSelectionMode.HYBRID,
49 max_retries: int = 3,
50 retry_delay: float = 1.0
51 ):
52 """Initialize base execution engine.
54 Args:
55 fsm: FSM instance to execute.
56 strategy: Traversal strategy to use.
57 selection_mode: Transition selection mode.
58 max_retries: Maximum retry attempts for failures.
59 retry_delay: Delay between retries in seconds.
60 """
61 self.fsm = fsm
62 self.strategy = strategy
63 self.selection_mode = selection_mode
64 self.max_retries = max_retries
65 self.retry_delay = retry_delay
67 # Initialize transition selector
68 self.transition_selector = TransitionSelector(
69 mode=selection_mode,
70 default_strategy=strategy
71 )
73 # Execution statistics
74 self._execution_count = 0
75 self._transition_count = 0
76 self._error_count = 0
77 self._total_execution_time = 0.0
79 def find_initial_state_common(self) -> str | None:
80 """Find the initial state in the FSM (common logic).
82 This method contains the shared logic for finding an initial state,
83 used by both sync and async engines.
85 Returns:
86 Name of initial state or None.
87 """
88 # Try to get main_network attribute
89 main_network = getattr(self.fsm, 'main_network', None)
91 # Handle string reference to network
92 if isinstance(main_network, str):
93 if main_network in self.fsm.networks:
94 network = self.fsm.networks[main_network]
95 if hasattr(network, 'initial_states') and network.initial_states:
96 return next(iter(network.initial_states))
97 # Handle direct network object
98 elif main_network and hasattr(main_network, 'initial_states'):
99 if main_network.initial_states:
100 return next(iter(main_network.initial_states))
102 # Fallback to fsm.name for compatibility
103 if hasattr(self.fsm, 'name') and self.fsm.name in self.fsm.networks:
104 network = self.fsm.networks[self.fsm.name]
105 if hasattr(network, 'initial_states') and network.initial_states:
106 return next(iter(network.initial_states))
108 # Last resort: check all networks for any initial state
109 for network in self.fsm.networks.values():
110 if hasattr(network, 'initial_states') and network.initial_states:
111 return next(iter(network.initial_states))
113 return None
115 def is_final_state_common(self, state_name: str | None) -> bool:
116 """Check if state is a final state (common logic).
118 Args:
119 state_name: Name of state to check.
121 Returns:
122 True if state is final.
123 """
124 if not state_name:
125 return False
127 # Check all networks for this state
128 for network in self.fsm.networks.values():
129 if hasattr(network, 'final_states') and state_name in network.final_states:
130 return True
131 # Also check states directly
132 if hasattr(network, 'states') and state_name in network.states:
133 state = network.states[state_name]
134 if hasattr(state, 'type') and state.type == StateType.END:
135 return True
137 return False
139 def get_current_network_common(self, context: ExecutionContext) -> StateNetwork | None:
140 """Get current network using common selection logic.
142 Args:
143 context: Execution context.
145 Returns:
146 Current network or None.
147 """
148 return NetworkSelector.get_current_network(
149 self.fsm,
150 context,
151 enable_intelligent_selection=True
152 )
154 def prepare_state_transform(
155 self,
156 state_def: Any,
157 context: ExecutionContext
158 ) -> Tuple[List[Any], SimpleNamespace]:
159 """Prepare state transform execution (common logic).
161 Args:
162 state_def: State definition.
163 context: Execution context.
165 Returns:
166 Tuple of (transform functions, state object for inline lambdas).
167 """
168 transform_functions = []
170 # Check for transform functions on the state
171 if hasattr(state_def, 'transform_functions') and state_def.transform_functions:
172 transform_functions = state_def.transform_functions
173 # Also check for single transform function
174 elif hasattr(state_def, 'transform_function') and state_def.transform_function:
175 transform_functions = [state_def.transform_function]
177 # Create a wrapper for transforms that expect state.data access pattern
178 # This wrapper provides both dict and attribute access
179 state_obj = wrap_for_lambda(context.data)
181 return transform_functions, state_obj
183 def process_transform_result(
184 self,
185 result: Any,
186 context: ExecutionContext,
187 state_name: str
188 ) -> None:
189 """Process transform result (common logic).
191 Args:
192 result: Result from transform function.
193 context: Execution context.
194 state_name: Name of current state.
195 """
196 if result is not None:
197 # Handle ExecutionResult objects from unified function manager
198 from dataknobs_fsm.functions.base import ExecutionResult
199 if isinstance(result, ExecutionResult):
200 if result.success:
201 # Ensure we store plain dict data
202 context.data = ensure_dict(result.data)
203 else:
204 # Transform failed - handle the error
205 self.handle_transform_error(
206 Exception(result.error or "Transform failed"),
207 context,
208 state_name
209 )
210 else:
211 # Ensure we always store plain dict data, not wrappers
212 context.data = ensure_dict(result)
214 def handle_transform_error(
215 self,
216 error: Exception,
217 context: ExecutionContext,
218 state_name: str
219 ) -> None:
220 """Handle transform error (common logic).
222 State transforms failing doesn't stop the FSM, but marks the state as failed.
224 Args:
225 error: Exception that occurred.
226 context: Execution context.
227 state_name: Name of current state.
228 """
229 if not hasattr(context, 'failed_states'):
230 context.failed_states = set()
231 context.failed_states.add(state_name)
233 def evaluate_arc_condition_common(
234 self,
235 arc: ArcDefinition,
236 context: ExecutionContext
237 ) -> bool:
238 """Evaluate arc condition (common logic).
240 Args:
241 arc: Arc definition.
242 context: Execution context.
244 Returns:
245 True if arc condition is met.
246 """
247 # If arc has no condition, it's always valid
248 if not hasattr(arc, 'condition') or not arc.condition:
249 return True
251 # Evaluate the condition function
252 try:
253 # Create function context
254 func_context = FunctionContext(
255 state_name=context.current_state or "",
256 function_name="arc_condition",
257 metadata={'arc': arc.name if hasattr(arc, 'name') else None},
258 resources={}
259 )
261 # Try different function signatures
262 try:
263 # Try with data and context
264 return bool(arc.condition(context.data, func_context))
265 except TypeError:
266 # Try with just data
267 return bool(arc.condition(context.data))
268 except Exception:
269 # Condition evaluation failed - arc is not valid
270 return False
272 def get_execution_statistics(self) -> Dict[str, Any]:
273 """Get execution statistics (common implementation).
275 Returns:
276 Dictionary of execution statistics.
277 """
278 return {
279 'execution_count': self._execution_count,
280 'transition_count': self._transition_count,
281 'error_count': self._error_count,
282 'total_execution_time': self._total_execution_time,
283 'average_execution_time': (
284 self._total_execution_time / self._execution_count
285 if self._execution_count > 0 else 0
286 )
287 }
289 @abstractmethod
290 def execute(self, context: ExecutionContext, data: Any = None,
291 max_transitions: int = 1000, arc_name: str | None = None) -> Tuple[bool, Any]:
292 """Execute the FSM with given context.
294 This method must be implemented by sync and async engines.
296 Args:
297 context: Execution context.
298 data: Input data to process.
299 max_transitions: Maximum transitions before stopping.
300 arc_name: Optional specific arc name to follow.
302 Returns:
303 Tuple of (success, result).
304 """
305 pass
307 @abstractmethod
308 def _execute_single(self, context: ExecutionContext,
309 max_transitions: int, arc_name: str | None = None) -> Any:
310 """Execute single mode processing.
312 Must be implemented by subclasses.
313 """
314 pass
316 @abstractmethod
317 def _execute_batch(self, context: ExecutionContext, max_transitions: int) -> Any:
318 """Execute batch mode processing.
320 Must be implemented by subclasses.
321 """
322 pass
324 @abstractmethod
325 def _execute_stream(self, context: ExecutionContext, max_transitions: int) -> Any:
326 """Execute stream mode processing.
328 Must be implemented by subclasses.
329 """
330 pass