Coverage for src/dataknobs_fsm/execution/async_engine.py: 12%

251 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""Asynchronous execution engine for FSM processing.""" 

2 

3import asyncio 

4import time 

5from typing import Any, Dict, List, Tuple 

6 

7from dataknobs_fsm.core.arc import ArcDefinition 

8from dataknobs_fsm.core.fsm import FSM 

9from dataknobs_fsm.core.modes import ProcessingMode 

10from dataknobs_fsm.core.network import StateNetwork 

11from dataknobs_fsm.core.state import StateType 

12from dataknobs_fsm.execution.context import ExecutionContext 

13from dataknobs_fsm.execution.engine import TraversalStrategy 

14from dataknobs_fsm.execution.common import ( 

15 NetworkSelector, 

16 TransitionSelectionMode 

17) 

18from dataknobs_fsm.execution.base_engine import BaseExecutionEngine 

19from dataknobs_fsm.functions.base import FunctionContext 

20from dataknobs_fsm.core.data_wrapper import ensure_dict 

21 

22 

23class AsyncExecutionEngine(BaseExecutionEngine): 

24 """Asynchronous execution engine for FSM. 

25  

26 This engine handles: 

27 - True async execution of state functions 

28 - Parallel arc evaluation 

29 - Async resource management 

30 - Non-blocking state transitions 

31 """ 

32 

33 def __init__( 

34 self, 

35 fsm: FSM, 

36 strategy: TraversalStrategy = TraversalStrategy.DEPTH_FIRST, 

37 selection_mode: TransitionSelectionMode = TransitionSelectionMode.HYBRID 

38 ): 

39 """Initialize async execution engine. 

40 

41 Args: 

42 fsm: FSM to execute. 

43 strategy: Traversal strategy for execution. 

44 selection_mode: Transition selection mode (strategy, scoring, or hybrid). 

45 """ 

46 # Initialize base class (no max_retries/retry_delay needed for async) 

47 super().__init__(fsm, strategy, selection_mode, max_retries=3, retry_delay=1.0) 

48 

49 async def execute( 

50 self, 

51 context: ExecutionContext, 

52 data: Any = None, 

53 max_transitions: int = 1000, 

54 arc_name: str | None = None 

55 ) -> Tuple[bool, Any]: 

56 """Execute the FSM asynchronously with given context. 

57  

58 Args: 

59 context: Execution context. 

60 data: Input data to process. 

61 max_transitions: Maximum transitions before stopping. 

62 arc_name: Optional specific arc name to follow. 

63  

64 Returns: 

65 Tuple of (success, result). 

66 """ 

67 start_time = time.time() 

68 self._execution_count += 1 

69 

70 # Only override context.data if data was explicitly provided 

71 if data is not None: 

72 context.data = data 

73 

74 # Initialize state if needed 

75 if not context.current_state: 

76 initial_state = await self._find_initial_state() 

77 if not initial_state: 

78 return False, "No initial state found" 

79 context.set_state(initial_state) 

80 # Execute transforms for the initial state 

81 await self._execute_state_transforms(context) 

82 

83 try: 

84 # Execute based on data mode 

85 if context.data_mode == ProcessingMode.SINGLE: 

86 result = await self._execute_single(context, max_transitions, arc_name) 

87 elif context.data_mode == ProcessingMode.BATCH: 

88 result = await self._execute_batch(context, max_transitions) 

89 elif context.data_mode == ProcessingMode.STREAM: 

90 result = await self._execute_stream(context, max_transitions) 

91 else: 

92 result = False, f"Unknown data mode: {context.data_mode}" 

93 

94 self._total_execution_time += time.time() - start_time 

95 return result 

96 

97 except Exception as e: 

98 self._error_count += 1 

99 self._total_execution_time += time.time() - start_time 

100 return False, str(e) 

101 

102 async def _execute_single( 

103 self, 

104 context: ExecutionContext, 

105 max_transitions: int, 

106 arc_name: str | None = None 

107 ) -> Tuple[bool, Any]: 

108 """Execute in single record mode asynchronously. 

109  

110 Args: 

111 context: Execution context. 

112 max_transitions: Maximum transitions. 

113 arc_name: Optional specific arc name to follow. 

114  

115 Returns: 

116 Tuple of (success, result). 

117 """ 

118 transitions = 0 

119 

120 while transitions < max_transitions: 

121 # Check if we're in a final state 

122 if await self._is_final_state(context.current_state): 

123 return True, context.data 

124 

125 # Get available transitions 

126 transitions_available = await self._get_available_transitions( 

127 context.current_state, 

128 context, 

129 arc_name 

130 ) 

131 

132 if not transitions_available: 

133 # No valid transitions - check if this is a final state 

134 if await self._is_final_state(context.current_state): 

135 return True, context.data 

136 return False, f"No valid transitions from state: {context.current_state}" 

137 

138 # Choose transition based on strategy 

139 next_transition = await self._choose_transition( 

140 transitions_available, 

141 context 

142 ) 

143 

144 if not next_transition: 

145 return False, "No transition selected" 

146 

147 # Execute transition 

148 success = await self._execute_transition( 

149 next_transition, 

150 context 

151 ) 

152 

153 if not success: 

154 return False, f"Transition failed: {next_transition}" 

155 

156 transitions += 1 

157 self._transition_count += 1 

158 

159 return False, f"Maximum transitions ({max_transitions}) exceeded" 

160 

161 async def _execute_batch( 

162 self, 

163 context: ExecutionContext, 

164 max_transitions: int 

165 ) -> Tuple[bool, Any]: 

166 """Execute in batch mode asynchronously. 

167  

168 Args: 

169 context: Execution context. 

170 max_transitions: Maximum transitions per item. 

171  

172 Returns: 

173 Tuple of (success, results). 

174 """ 

175 if not context.batch_data: 

176 return False, "No batch data to process" 

177 

178 # Process items in parallel 

179 tasks = [] 

180 for i, item in enumerate(context.batch_data): 

181 # Create child context for this item 

182 item_context = context.create_child_context(f"batch_{i}") 

183 item_context.data = item 

184 

185 # Reset to initial state for each item 

186 initial_state = await self._find_initial_state() 

187 if initial_state: 

188 item_context.set_state(initial_state) 

189 

190 # Create task for this item 

191 task = asyncio.create_task( 

192 self._execute_single(item_context, max_transitions) 

193 ) 

194 tasks.append(task) 

195 

196 # Wait for all tasks 

197 results = await asyncio.gather(*tasks, return_exceptions=True) 

198 

199 # Process results 

200 batch_results = [] 

201 batch_errors = [] 

202 for i, result in enumerate(results): 

203 if isinstance(result, Exception): 

204 batch_errors.append((i, result)) 

205 else: 

206 # Result is a tuple[bool, Any] at this point 

207 success, value = result # type: ignore 

208 if success: # success 

209 batch_results.append(value) 

210 else: 

211 batch_errors.append((i, Exception(value))) 

212 

213 return len(batch_errors) == 0, { 

214 'results': batch_results, 

215 'errors': batch_errors 

216 } 

217 

218 async def _execute_stream( 

219 self, 

220 context: ExecutionContext, 

221 max_transitions: int 

222 ) -> Tuple[bool, Any]: 

223 """Execute in stream mode asynchronously. 

224  

225 Args: 

226 context: Execution context. 

227 max_transitions: Maximum transitions per chunk. 

228  

229 Returns: 

230 Tuple of (success, stream_stats). 

231 """ 

232 if not context.stream_context: 

233 return False, "No stream context provided" 

234 

235 chunks_processed = 0 

236 total_records = 0 

237 errors = [] 

238 

239 # Process each chunk 

240 while True: 

241 # Get next chunk from stream 

242 chunk = context.stream_context.get_next_chunk() 

243 if not chunk: 

244 break 

245 

246 context.set_stream_chunk(chunk) 

247 

248 # Process chunk data 

249 for record in chunk.data: 

250 record_context = context.create_child_context( 

251 f"stream_{chunks_processed}_{total_records}" 

252 ) 

253 record_context.data = record 

254 

255 # Reset to initial state 

256 initial_state = await self._find_initial_state() 

257 if initial_state: 

258 record_context.set_state(initial_state) 

259 

260 # Execute for this record 

261 success, result = await self._execute_single( 

262 record_context, 

263 max_transitions 

264 ) 

265 

266 if not success: 

267 errors.append((total_records, result)) 

268 

269 # Merge context 

270 context.merge_child_context( 

271 f"stream_{chunks_processed}_{total_records}" 

272 ) 

273 

274 total_records += 1 

275 

276 chunks_processed += 1 

277 

278 # Check if this was the last chunk 

279 if chunk.is_last: 

280 break 

281 

282 return len(errors) == 0, { 

283 'chunks_processed': chunks_processed, 

284 'records_processed': total_records, 

285 'errors': errors 

286 } 

287 

288 async def _get_available_transitions( 

289 self, 

290 state_name: str, 

291 context: ExecutionContext, 

292 arc_name: str | None = None 

293 ) -> List[ArcDefinition]: 

294 """Get available transitions from current state asynchronously. 

295  

296 This evaluates pre-conditions in parallel. 

297  

298 Args: 

299 state_name: Current state name. 

300 context: Execution context. 

301 arc_name: Optional specific arc name to filter by. 

302  

303 Returns: 

304 List of available arc definitions. 

305 """ 

306 network = await self._get_current_network(context) 

307 if not network or state_name not in network.states: 

308 return [] 

309 

310 state = network.states[state_name] 

311 available = [] 

312 

313 # Filter arcs by name if specified 

314 arcs_to_evaluate = state.outgoing_arcs 

315 if arc_name: 

316 arcs_to_evaluate = [arc for arc in state.outgoing_arcs 

317 if hasattr(arc, 'name') and arc.name == arc_name] 

318 # If no arcs match the specified name, return empty list 

319 if not arcs_to_evaluate: 

320 return [] 

321 

322 # Evaluate all arc pre-conditions in parallel 

323 tasks = [] 

324 for arc in arcs_to_evaluate: 

325 task = asyncio.create_task(self._evaluate_arc(arc, context)) 

326 tasks.append((arc, task)) 

327 

328 # Wait for all evaluations 

329 for arc, task in tasks: 

330 can_execute = await task 

331 if can_execute: 

332 available.append(arc) 

333 

334 # Sort by priority (higher first) 

335 available.sort(key=lambda a: a.priority, reverse=True) 

336 

337 return available 

338 

339 async def _evaluate_arc( 

340 self, 

341 arc: ArcDefinition, 

342 context: ExecutionContext 

343 ) -> bool: 

344 """Evaluate if an arc can be executed. 

345  

346 Args: 

347 arc: Arc definition. 

348 context: Execution context. 

349  

350 Returns: 

351 True if arc can be executed. 

352 """ 

353 if not arc.pre_test: 

354 return True 

355 

356 # Get the function registry 

357 function_registry = getattr(self.fsm, 'function_registry', {}) 

358 if hasattr(function_registry, 'functions'): 

359 functions = function_registry.functions 

360 else: 

361 functions = function_registry 

362 

363 if arc.pre_test not in functions: 

364 return False 

365 

366 # Execute pre-test function 

367 pre_test_func = functions[arc.pre_test] 

368 

369 # Check if it's async 

370 if asyncio.iscoroutinefunction(pre_test_func): 

371 result = await pre_test_func(context.data, context) 

372 else: 

373 # Run sync function in executor 

374 loop = asyncio.get_event_loop() 

375 result = await loop.run_in_executor( 

376 None, 

377 pre_test_func, 

378 context.data, 

379 context 

380 ) 

381 

382 # Handle tuple return from test functions (bool, reason) 

383 if isinstance(result, tuple): 

384 return bool(result[0]) 

385 return bool(result) 

386 

387 async def _choose_transition( 

388 self, 

389 available: List[ArcDefinition], 

390 context: ExecutionContext 

391 ) -> ArcDefinition | None: 

392 """Choose transition using common transition selector. 

393  

394 Args: 

395 available: Available transitions. 

396 context: Execution context. 

397  

398 Returns: 

399 Selected arc or None. 

400 """ 

401 return self.transition_selector.select_transition( 

402 available, 

403 context, 

404 strategy=self.strategy 

405 ) 

406 

407 async def _execute_transition( 

408 self, 

409 arc: ArcDefinition, 

410 context: ExecutionContext 

411 ) -> bool: 

412 """Execute a state transition asynchronously. 

413  

414 Args: 

415 arc: Arc to execute. 

416 context: Execution context. 

417  

418 Returns: 

419 True if successful. 

420 """ 

421 try: 

422 # Execute arc transform if defined 

423 if arc.transform: 

424 function_registry = getattr(self.fsm, 'function_registry', {}) 

425 if hasattr(function_registry, 'functions'): 

426 functions = function_registry.functions 

427 else: 

428 functions = function_registry 

429 

430 if arc.transform in functions: 

431 transform_func = functions[arc.transform] 

432 

433 # Check if it's async - check both the function and its __call__ method 

434 is_async = asyncio.iscoroutinefunction(transform_func) 

435 if not is_async and callable(transform_func) and callable(transform_func): 

436 # Check if the __call__ method is async (for wrapped functions) 

437 is_async = asyncio.iscoroutinefunction(transform_func.__call__) 

438 

439 if is_async: 

440 context.data = await transform_func(context.data, context) 

441 else: 

442 # Run sync function in executor 

443 loop = asyncio.get_event_loop() 

444 context.data = await loop.run_in_executor( 

445 None, 

446 transform_func, 

447 context.data, 

448 context 

449 ) 

450 

451 # Update state (history is automatically tracked by set_state) 

452 context.set_state(arc.target_state) 

453 

454 # Execute state transforms when entering the new state 

455 await self._execute_state_transforms(context) 

456 

457 return True 

458 

459 except Exception: 

460 return False 

461 

462 async def _execute_state_transforms( 

463 self, 

464 context: ExecutionContext 

465 ) -> None: 

466 """Execute state functions (validators and transforms) when in a state. 

467 

468 This should be called before evaluating arc conditions to ensure 

469 that state functions can update the data that conditions depend on. 

470 

471 Args: 

472 context: Execution context. 

473 """ 

474 network = await self._get_current_network(context) 

475 if not network or context.current_state not in network.states: 

476 return 

477 

478 state = network.states[context.current_state] 

479 state_name = context.current_state 

480 

481 # Use base class logic to prepare transforms 

482 transform_functions, state_obj = self.prepare_state_transform(state, context) 

483 

484 # Execute validation functions first (async-specific) 

485 if hasattr(state, 'validation_functions') and state.validation_functions: 

486 for validator in state.validation_functions: 

487 try: 

488 # Handle both async and sync validators 

489 if asyncio.iscoroutinefunction(validator.validate): 

490 # Try with state object first (for inline lambdas) 

491 try: 

492 result = await validator.validate(state_obj) 

493 except (TypeError, AttributeError): 

494 # Fall back to standard signature 

495 result = await validator.validate(ensure_dict(context.data), context) 

496 else: 

497 # Run sync function in executor 

498 loop = asyncio.get_event_loop() 

499 try: 

500 result = await loop.run_in_executor(None, validator.validate, state_obj) 

501 except (TypeError, AttributeError): 

502 # Fall back to standard signature 

503 result = await loop.run_in_executor(None, validator.validate, ensure_dict(context.data), context) 

504 

505 if isinstance(result, dict): 

506 # Merge validation results into context data 

507 context.data.update(result) 

508 except Exception: 

509 # Log but don't fail - validators are optional 

510 pass 

511 

512 # Execute transform functions using base class helpers 

513 import logging 

514 logger = logging.getLogger(__name__) 

515 if transform_functions: 

516 logger.debug(f"Executing {len(transform_functions)} transform functions for state {state_name}") 

517 for transform_func in transform_functions: 

518 try: 

519 # Create function context 

520 func_context = FunctionContext( 

521 state_name=state_name, 

522 function_name=getattr(transform_func, '__name__', 'transform'), 

523 metadata={'state': state_name}, 

524 resources={} 

525 ) 

526 

527 # Handle both async and sync transforms 

528 # For InterfaceWrapper objects, use the transform method 

529 actual_func = transform_func 

530 if hasattr(transform_func, 'transform'): 

531 actual_func = transform_func.transform 

532 

533 # Check if it's async - check both the function and its __call__ method 

534 is_async = asyncio.iscoroutinefunction(actual_func) 

535 if not is_async and callable(actual_func) and callable(actual_func): 

536 # Check if the __call__ method is async (for wrapped functions) 

537 is_async = asyncio.iscoroutinefunction(actual_func.__call__) 

538 

539 # Also check for _is_async attribute (for wrapped functions) 

540 if not is_async and hasattr(transform_func, '_is_async'): 

541 is_async = transform_func._is_async 

542 

543 if is_async: 

544 # Try with state object first (for inline lambdas) 

545 try: 

546 result = await actual_func(state_obj) 

547 except (TypeError, AttributeError): 

548 # Fall back to standard signature 

549 result = await actual_func(ensure_dict(context.data), func_context) 

550 else: 

551 # Run sync function in executor 

552 loop = asyncio.get_event_loop() 

553 try: 

554 result = await loop.run_in_executor(None, actual_func, state_obj) 

555 except (TypeError, AttributeError): 

556 # Fall back to standard signature 

557 result = await loop.run_in_executor(None, actual_func, ensure_dict(context.data), func_context) 

558 

559 # Process result using base class logic 

560 self.process_transform_result(result, context, state_name) 

561 

562 except Exception as e: 

563 # Handle error using base class logic 

564 self.handle_transform_error(e, context, state_name) 

565 

566 async def _find_initial_state(self) -> str | None: 

567 """Find initial state in FSM. 

568 

569 Returns: 

570 Initial state name or None. 

571 """ 

572 # Use base class implementation (it's synchronous but that's fine) 

573 return self.find_initial_state_common() 

574 

575 async def _is_final_state(self, state_name: str | None) -> bool: 

576 """Check if state is a final state. 

577 

578 Args: 

579 state_name: Name of state to check. 

580 

581 Returns: 

582 True if final state. 

583 """ 

584 # Use base class implementation 

585 return self.is_final_state_common(state_name) 

586 

587 async def _is_final_state_legacy(self, state_name: str | None) -> bool: 

588 """Legacy implementation kept for reference.""" 

589 if not state_name: 

590 return False 

591 

592 # Get the main network - could be a string or object 

593 main_network_ref = getattr(self.fsm, 'main_network', None) 

594 

595 if main_network_ref is None: 

596 # If no main network specified, check all networks 

597 for network in self.fsm.networks.values(): 

598 if state_name in network.states: 

599 state = network.states[state_name] 

600 if state.is_end_state() if hasattr(state, 'is_end_state') else state.type == StateType.END: 

601 return True 

602 return False 

603 

604 # Handle case where main_network is already a network object (FSM wrapper) 

605 if hasattr(main_network_ref, 'states'): 

606 main_network = main_network_ref 

607 # Handle case where main_network is a string (core FSM) 

608 elif isinstance(main_network_ref, str) and main_network_ref in self.fsm.networks: 

609 main_network = self.fsm.networks[main_network_ref] 

610 else: 

611 return False 

612 

613 # Check if the state exists and is an end state 

614 if state_name in main_network.states: 

615 state = main_network.states[state_name] 

616 return state.is_end_state() if hasattr(state, 'is_end_state') else state.type == StateType.END 

617 

618 return False 

619 

620 async def _get_current_network( 

621 self, 

622 context: ExecutionContext 

623 ) -> StateNetwork | None: 

624 """Get the current network from context using common network selector. 

625  

626 Args: 

627 context: Execution context. 

628  

629 Returns: 

630 Current network or None. 

631 """ 

632 # Use intelligent selection for async engine by default 

633 return NetworkSelector.get_current_network( 

634 self.fsm, 

635 context, 

636 enable_intelligent_selection=True 

637 ) 

638 

639 def get_statistics(self) -> Dict[str, Any]: 

640 """Get execution statistics. 

641  

642 Returns: 

643 Dictionary of statistics. 

644 """ 

645 return { 

646 'execution_count': self._execution_count, 

647 'transition_count': self._transition_count, 

648 'error_count': self._error_count, 

649 'total_execution_time': self._total_execution_time, 

650 'average_execution_time': ( 

651 self._total_execution_time / self._execution_count 

652 if self._execution_count > 0 else 0.0 

653 ) 

654 }