Coverage for src/dataknobs_fsm/execution/async_engine.py: 12%
251 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
1"""Asynchronous execution engine for FSM processing."""
3import asyncio
4import time
5from typing import Any, Dict, List, Tuple
7from dataknobs_fsm.core.arc import ArcDefinition
8from dataknobs_fsm.core.fsm import FSM
9from dataknobs_fsm.core.modes import ProcessingMode
10from dataknobs_fsm.core.network import StateNetwork
11from dataknobs_fsm.core.state import StateType
12from dataknobs_fsm.execution.context import ExecutionContext
13from dataknobs_fsm.execution.engine import TraversalStrategy
14from dataknobs_fsm.execution.common import (
15 NetworkSelector,
16 TransitionSelectionMode
17)
18from dataknobs_fsm.execution.base_engine import BaseExecutionEngine
19from dataknobs_fsm.functions.base import FunctionContext
20from dataknobs_fsm.core.data_wrapper import ensure_dict
23class AsyncExecutionEngine(BaseExecutionEngine):
24 """Asynchronous execution engine for FSM.
26 This engine handles:
27 - True async execution of state functions
28 - Parallel arc evaluation
29 - Async resource management
30 - Non-blocking state transitions
31 """
33 def __init__(
34 self,
35 fsm: FSM,
36 strategy: TraversalStrategy = TraversalStrategy.DEPTH_FIRST,
37 selection_mode: TransitionSelectionMode = TransitionSelectionMode.HYBRID
38 ):
39 """Initialize async execution engine.
41 Args:
42 fsm: FSM to execute.
43 strategy: Traversal strategy for execution.
44 selection_mode: Transition selection mode (strategy, scoring, or hybrid).
45 """
46 # Initialize base class (no max_retries/retry_delay needed for async)
47 super().__init__(fsm, strategy, selection_mode, max_retries=3, retry_delay=1.0)
49 async def execute(
50 self,
51 context: ExecutionContext,
52 data: Any = None,
53 max_transitions: int = 1000,
54 arc_name: str | None = None
55 ) -> Tuple[bool, Any]:
56 """Execute the FSM asynchronously with given context.
58 Args:
59 context: Execution context.
60 data: Input data to process.
61 max_transitions: Maximum transitions before stopping.
62 arc_name: Optional specific arc name to follow.
64 Returns:
65 Tuple of (success, result).
66 """
67 start_time = time.time()
68 self._execution_count += 1
70 # Only override context.data if data was explicitly provided
71 if data is not None:
72 context.data = data
74 # Initialize state if needed
75 if not context.current_state:
76 initial_state = await self._find_initial_state()
77 if not initial_state:
78 return False, "No initial state found"
79 context.set_state(initial_state)
80 # Execute transforms for the initial state
81 await self._execute_state_transforms(context)
83 try:
84 # Execute based on data mode
85 if context.data_mode == ProcessingMode.SINGLE:
86 result = await self._execute_single(context, max_transitions, arc_name)
87 elif context.data_mode == ProcessingMode.BATCH:
88 result = await self._execute_batch(context, max_transitions)
89 elif context.data_mode == ProcessingMode.STREAM:
90 result = await self._execute_stream(context, max_transitions)
91 else:
92 result = False, f"Unknown data mode: {context.data_mode}"
94 self._total_execution_time += time.time() - start_time
95 return result
97 except Exception as e:
98 self._error_count += 1
99 self._total_execution_time += time.time() - start_time
100 return False, str(e)
102 async def _execute_single(
103 self,
104 context: ExecutionContext,
105 max_transitions: int,
106 arc_name: str | None = None
107 ) -> Tuple[bool, Any]:
108 """Execute in single record mode asynchronously.
110 Args:
111 context: Execution context.
112 max_transitions: Maximum transitions.
113 arc_name: Optional specific arc name to follow.
115 Returns:
116 Tuple of (success, result).
117 """
118 transitions = 0
120 while transitions < max_transitions:
121 # Check if we're in a final state
122 if await self._is_final_state(context.current_state):
123 return True, context.data
125 # Get available transitions
126 transitions_available = await self._get_available_transitions(
127 context.current_state,
128 context,
129 arc_name
130 )
132 if not transitions_available:
133 # No valid transitions - check if this is a final state
134 if await self._is_final_state(context.current_state):
135 return True, context.data
136 return False, f"No valid transitions from state: {context.current_state}"
138 # Choose transition based on strategy
139 next_transition = await self._choose_transition(
140 transitions_available,
141 context
142 )
144 if not next_transition:
145 return False, "No transition selected"
147 # Execute transition
148 success = await self._execute_transition(
149 next_transition,
150 context
151 )
153 if not success:
154 return False, f"Transition failed: {next_transition}"
156 transitions += 1
157 self._transition_count += 1
159 return False, f"Maximum transitions ({max_transitions}) exceeded"
161 async def _execute_batch(
162 self,
163 context: ExecutionContext,
164 max_transitions: int
165 ) -> Tuple[bool, Any]:
166 """Execute in batch mode asynchronously.
168 Args:
169 context: Execution context.
170 max_transitions: Maximum transitions per item.
172 Returns:
173 Tuple of (success, results).
174 """
175 if not context.batch_data:
176 return False, "No batch data to process"
178 # Process items in parallel
179 tasks = []
180 for i, item in enumerate(context.batch_data):
181 # Create child context for this item
182 item_context = context.create_child_context(f"batch_{i}")
183 item_context.data = item
185 # Reset to initial state for each item
186 initial_state = await self._find_initial_state()
187 if initial_state:
188 item_context.set_state(initial_state)
190 # Create task for this item
191 task = asyncio.create_task(
192 self._execute_single(item_context, max_transitions)
193 )
194 tasks.append(task)
196 # Wait for all tasks
197 results = await asyncio.gather(*tasks, return_exceptions=True)
199 # Process results
200 batch_results = []
201 batch_errors = []
202 for i, result in enumerate(results):
203 if isinstance(result, Exception):
204 batch_errors.append((i, result))
205 else:
206 # Result is a tuple[bool, Any] at this point
207 success, value = result # type: ignore
208 if success: # success
209 batch_results.append(value)
210 else:
211 batch_errors.append((i, Exception(value)))
213 return len(batch_errors) == 0, {
214 'results': batch_results,
215 'errors': batch_errors
216 }
218 async def _execute_stream(
219 self,
220 context: ExecutionContext,
221 max_transitions: int
222 ) -> Tuple[bool, Any]:
223 """Execute in stream mode asynchronously.
225 Args:
226 context: Execution context.
227 max_transitions: Maximum transitions per chunk.
229 Returns:
230 Tuple of (success, stream_stats).
231 """
232 if not context.stream_context:
233 return False, "No stream context provided"
235 chunks_processed = 0
236 total_records = 0
237 errors = []
239 # Process each chunk
240 while True:
241 # Get next chunk from stream
242 chunk = context.stream_context.get_next_chunk()
243 if not chunk:
244 break
246 context.set_stream_chunk(chunk)
248 # Process chunk data
249 for record in chunk.data:
250 record_context = context.create_child_context(
251 f"stream_{chunks_processed}_{total_records}"
252 )
253 record_context.data = record
255 # Reset to initial state
256 initial_state = await self._find_initial_state()
257 if initial_state:
258 record_context.set_state(initial_state)
260 # Execute for this record
261 success, result = await self._execute_single(
262 record_context,
263 max_transitions
264 )
266 if not success:
267 errors.append((total_records, result))
269 # Merge context
270 context.merge_child_context(
271 f"stream_{chunks_processed}_{total_records}"
272 )
274 total_records += 1
276 chunks_processed += 1
278 # Check if this was the last chunk
279 if chunk.is_last:
280 break
282 return len(errors) == 0, {
283 'chunks_processed': chunks_processed,
284 'records_processed': total_records,
285 'errors': errors
286 }
288 async def _get_available_transitions(
289 self,
290 state_name: str,
291 context: ExecutionContext,
292 arc_name: str | None = None
293 ) -> List[ArcDefinition]:
294 """Get available transitions from current state asynchronously.
296 This evaluates pre-conditions in parallel.
298 Args:
299 state_name: Current state name.
300 context: Execution context.
301 arc_name: Optional specific arc name to filter by.
303 Returns:
304 List of available arc definitions.
305 """
306 network = await self._get_current_network(context)
307 if not network or state_name not in network.states:
308 return []
310 state = network.states[state_name]
311 available = []
313 # Filter arcs by name if specified
314 arcs_to_evaluate = state.outgoing_arcs
315 if arc_name:
316 arcs_to_evaluate = [arc for arc in state.outgoing_arcs
317 if hasattr(arc, 'name') and arc.name == arc_name]
318 # If no arcs match the specified name, return empty list
319 if not arcs_to_evaluate:
320 return []
322 # Evaluate all arc pre-conditions in parallel
323 tasks = []
324 for arc in arcs_to_evaluate:
325 task = asyncio.create_task(self._evaluate_arc(arc, context))
326 tasks.append((arc, task))
328 # Wait for all evaluations
329 for arc, task in tasks:
330 can_execute = await task
331 if can_execute:
332 available.append(arc)
334 # Sort by priority (higher first)
335 available.sort(key=lambda a: a.priority, reverse=True)
337 return available
339 async def _evaluate_arc(
340 self,
341 arc: ArcDefinition,
342 context: ExecutionContext
343 ) -> bool:
344 """Evaluate if an arc can be executed.
346 Args:
347 arc: Arc definition.
348 context: Execution context.
350 Returns:
351 True if arc can be executed.
352 """
353 if not arc.pre_test:
354 return True
356 # Get the function registry
357 function_registry = getattr(self.fsm, 'function_registry', {})
358 if hasattr(function_registry, 'functions'):
359 functions = function_registry.functions
360 else:
361 functions = function_registry
363 if arc.pre_test not in functions:
364 return False
366 # Execute pre-test function
367 pre_test_func = functions[arc.pre_test]
369 # Check if it's async
370 if asyncio.iscoroutinefunction(pre_test_func):
371 result = await pre_test_func(context.data, context)
372 else:
373 # Run sync function in executor
374 loop = asyncio.get_event_loop()
375 result = await loop.run_in_executor(
376 None,
377 pre_test_func,
378 context.data,
379 context
380 )
382 # Handle tuple return from test functions (bool, reason)
383 if isinstance(result, tuple):
384 return bool(result[0])
385 return bool(result)
387 async def _choose_transition(
388 self,
389 available: List[ArcDefinition],
390 context: ExecutionContext
391 ) -> ArcDefinition | None:
392 """Choose transition using common transition selector.
394 Args:
395 available: Available transitions.
396 context: Execution context.
398 Returns:
399 Selected arc or None.
400 """
401 return self.transition_selector.select_transition(
402 available,
403 context,
404 strategy=self.strategy
405 )
407 async def _execute_transition(
408 self,
409 arc: ArcDefinition,
410 context: ExecutionContext
411 ) -> bool:
412 """Execute a state transition asynchronously.
414 Args:
415 arc: Arc to execute.
416 context: Execution context.
418 Returns:
419 True if successful.
420 """
421 try:
422 # Execute arc transform if defined
423 if arc.transform:
424 function_registry = getattr(self.fsm, 'function_registry', {})
425 if hasattr(function_registry, 'functions'):
426 functions = function_registry.functions
427 else:
428 functions = function_registry
430 if arc.transform in functions:
431 transform_func = functions[arc.transform]
433 # Check if it's async - check both the function and its __call__ method
434 is_async = asyncio.iscoroutinefunction(transform_func)
435 if not is_async and callable(transform_func) and callable(transform_func):
436 # Check if the __call__ method is async (for wrapped functions)
437 is_async = asyncio.iscoroutinefunction(transform_func.__call__)
439 if is_async:
440 context.data = await transform_func(context.data, context)
441 else:
442 # Run sync function in executor
443 loop = asyncio.get_event_loop()
444 context.data = await loop.run_in_executor(
445 None,
446 transform_func,
447 context.data,
448 context
449 )
451 # Update state (history is automatically tracked by set_state)
452 context.set_state(arc.target_state)
454 # Execute state transforms when entering the new state
455 await self._execute_state_transforms(context)
457 return True
459 except Exception:
460 return False
462 async def _execute_state_transforms(
463 self,
464 context: ExecutionContext
465 ) -> None:
466 """Execute state functions (validators and transforms) when in a state.
468 This should be called before evaluating arc conditions to ensure
469 that state functions can update the data that conditions depend on.
471 Args:
472 context: Execution context.
473 """
474 network = await self._get_current_network(context)
475 if not network or context.current_state not in network.states:
476 return
478 state = network.states[context.current_state]
479 state_name = context.current_state
481 # Use base class logic to prepare transforms
482 transform_functions, state_obj = self.prepare_state_transform(state, context)
484 # Execute validation functions first (async-specific)
485 if hasattr(state, 'validation_functions') and state.validation_functions:
486 for validator in state.validation_functions:
487 try:
488 # Handle both async and sync validators
489 if asyncio.iscoroutinefunction(validator.validate):
490 # Try with state object first (for inline lambdas)
491 try:
492 result = await validator.validate(state_obj)
493 except (TypeError, AttributeError):
494 # Fall back to standard signature
495 result = await validator.validate(ensure_dict(context.data), context)
496 else:
497 # Run sync function in executor
498 loop = asyncio.get_event_loop()
499 try:
500 result = await loop.run_in_executor(None, validator.validate, state_obj)
501 except (TypeError, AttributeError):
502 # Fall back to standard signature
503 result = await loop.run_in_executor(None, validator.validate, ensure_dict(context.data), context)
505 if isinstance(result, dict):
506 # Merge validation results into context data
507 context.data.update(result)
508 except Exception:
509 # Log but don't fail - validators are optional
510 pass
512 # Execute transform functions using base class helpers
513 import logging
514 logger = logging.getLogger(__name__)
515 if transform_functions:
516 logger.debug(f"Executing {len(transform_functions)} transform functions for state {state_name}")
517 for transform_func in transform_functions:
518 try:
519 # Create function context
520 func_context = FunctionContext(
521 state_name=state_name,
522 function_name=getattr(transform_func, '__name__', 'transform'),
523 metadata={'state': state_name},
524 resources={}
525 )
527 # Handle both async and sync transforms
528 # For InterfaceWrapper objects, use the transform method
529 actual_func = transform_func
530 if hasattr(transform_func, 'transform'):
531 actual_func = transform_func.transform
533 # Check if it's async - check both the function and its __call__ method
534 is_async = asyncio.iscoroutinefunction(actual_func)
535 if not is_async and callable(actual_func) and callable(actual_func):
536 # Check if the __call__ method is async (for wrapped functions)
537 is_async = asyncio.iscoroutinefunction(actual_func.__call__)
539 # Also check for _is_async attribute (for wrapped functions)
540 if not is_async and hasattr(transform_func, '_is_async'):
541 is_async = transform_func._is_async
543 if is_async:
544 # Try with state object first (for inline lambdas)
545 try:
546 result = await actual_func(state_obj)
547 except (TypeError, AttributeError):
548 # Fall back to standard signature
549 result = await actual_func(ensure_dict(context.data), func_context)
550 else:
551 # Run sync function in executor
552 loop = asyncio.get_event_loop()
553 try:
554 result = await loop.run_in_executor(None, actual_func, state_obj)
555 except (TypeError, AttributeError):
556 # Fall back to standard signature
557 result = await loop.run_in_executor(None, actual_func, ensure_dict(context.data), func_context)
559 # Process result using base class logic
560 self.process_transform_result(result, context, state_name)
562 except Exception as e:
563 # Handle error using base class logic
564 self.handle_transform_error(e, context, state_name)
566 async def _find_initial_state(self) -> str | None:
567 """Find initial state in FSM.
569 Returns:
570 Initial state name or None.
571 """
572 # Use base class implementation (it's synchronous but that's fine)
573 return self.find_initial_state_common()
575 async def _is_final_state(self, state_name: str | None) -> bool:
576 """Check if state is a final state.
578 Args:
579 state_name: Name of state to check.
581 Returns:
582 True if final state.
583 """
584 # Use base class implementation
585 return self.is_final_state_common(state_name)
587 async def _is_final_state_legacy(self, state_name: str | None) -> bool:
588 """Legacy implementation kept for reference."""
589 if not state_name:
590 return False
592 # Get the main network - could be a string or object
593 main_network_ref = getattr(self.fsm, 'main_network', None)
595 if main_network_ref is None:
596 # If no main network specified, check all networks
597 for network in self.fsm.networks.values():
598 if state_name in network.states:
599 state = network.states[state_name]
600 if state.is_end_state() if hasattr(state, 'is_end_state') else state.type == StateType.END:
601 return True
602 return False
604 # Handle case where main_network is already a network object (FSM wrapper)
605 if hasattr(main_network_ref, 'states'):
606 main_network = main_network_ref
607 # Handle case where main_network is a string (core FSM)
608 elif isinstance(main_network_ref, str) and main_network_ref in self.fsm.networks:
609 main_network = self.fsm.networks[main_network_ref]
610 else:
611 return False
613 # Check if the state exists and is an end state
614 if state_name in main_network.states:
615 state = main_network.states[state_name]
616 return state.is_end_state() if hasattr(state, 'is_end_state') else state.type == StateType.END
618 return False
620 async def _get_current_network(
621 self,
622 context: ExecutionContext
623 ) -> StateNetwork | None:
624 """Get the current network from context using common network selector.
626 Args:
627 context: Execution context.
629 Returns:
630 Current network or None.
631 """
632 # Use intelligent selection for async engine by default
633 return NetworkSelector.get_current_network(
634 self.fsm,
635 context,
636 enable_intelligent_selection=True
637 )
639 def get_statistics(self) -> Dict[str, Any]:
640 """Get execution statistics.
642 Returns:
643 Dictionary of statistics.
644 """
645 return {
646 'execution_count': self._execution_count,
647 'transition_count': self._transition_count,
648 'error_count': self._error_count,
649 'total_execution_time': self._total_execution_time,
650 'average_execution_time': (
651 self._total_execution_time / self._execution_count
652 if self._execution_count > 0 else 0.0
653 )
654 }