Coverage for src/dataknobs_fsm/core/fsm.py: 28%

248 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""Core FSM class for managing state machines.""" 

2 

3from typing import Any, Dict, List, Set, Tuple, Optional 

4 

5from dataknobs_fsm.core.modes import ProcessingMode, TransactionMode 

6from dataknobs_fsm.core.network import StateNetwork 

7from dataknobs_fsm.core.state import StateDefinition, StateInstance, StateType 

8from dataknobs_fsm.functions.base import FunctionRegistry 

9 

10 

11class FSM: 

12 """Finite State Machine core class. 

13  

14 This class manages: 

15 - Multiple state networks 

16 - Function registry 

17 - Data and transaction modes 

18 - Resource requirements 

19 - Configuration 

20 """ 

21 

22 def __init__( 

23 self, 

24 name: str, 

25 data_mode: ProcessingMode = ProcessingMode.SINGLE, 

26 transaction_mode: TransactionMode = TransactionMode.NONE, 

27 description: str | None = None, 

28 resource_manager: Any | None = None, 

29 transaction_manager: Any | None = None 

30 ): 

31 """Initialize FSM. 

32  

33 Args: 

34 name: Name of the FSM. 

35 data_mode: Data processing mode. 

36 transaction_mode: Transaction handling mode. 

37 description: Optional FSM description. 

38 """ 

39 self.name = name 

40 self.data_mode = data_mode 

41 self.transaction_mode = transaction_mode 

42 self.description = description 

43 

44 # Networks 

45 self.networks: Dict[str, StateNetwork] = {} 

46 self.main_network_name: str | None = None 

47 

48 # Function registry 

49 self.function_registry = FunctionRegistry() 

50 

51 # Resource requirements 

52 self.resource_requirements: Dict[str, Any] = {} 

53 

54 # Configuration 

55 self.config: Dict[str, Any] = {} 

56 

57 # Metadata 

58 self.metadata: Dict[str, Any] = {} 

59 self.version: str = "1.0.0" 

60 self.created_at: float | None = None 

61 self.updated_at: float | None = None 

62 

63 # Execution support (from builder FSM wrapper) 

64 self.resource_manager = resource_manager 

65 self.transaction_manager = transaction_manager 

66 self._engine: Any | None = None # ExecutionEngine 

67 self._async_engine: Any | None = None # AsyncExecutionEngine 

68 

69 def add_network( 

70 self, 

71 network: StateNetwork, 

72 is_main: bool = False 

73 ) -> None: 

74 """Add a network to the FSM. 

75  

76 Args: 

77 network: Network to add. 

78 is_main: Whether this is the main network. 

79 """ 

80 self.networks[network.name] = network 

81 

82 if is_main or self.main_network_name is None: 

83 self.main_network_name = network.name 

84 

85 # Aggregate resource requirements 

86 for resource_type, requirements in network.resource_requirements.items(): 

87 if resource_type not in self.resource_requirements: 

88 self.resource_requirements[resource_type] = set() 

89 self.resource_requirements[resource_type].update(requirements) 

90 

91 def remove_network(self, network_name: str) -> bool: 

92 """Remove a network from the FSM. 

93  

94 Args: 

95 network_name: Name of network to remove. 

96  

97 Returns: 

98 True if removed successfully. 

99 """ 

100 if network_name in self.networks: 

101 del self.networks[network_name] 

102 

103 # Update main network if needed 

104 if self.main_network_name == network_name: 

105 if self.networks: 

106 self.main_network_name = next(iter(self.networks.keys())) 

107 else: 

108 self.main_network_name = None 

109 

110 return True 

111 return False 

112 

113 def get_network(self, network_name: str | None = None) -> StateNetwork | None: 

114 """Get a network by name. 

115  

116 Args: 

117 network_name: Name of network (None for main network). 

118  

119 Returns: 

120 Network or None if not found. 

121 """ 

122 if network_name is None: 

123 network_name = self.main_network_name 

124 

125 if network_name: 

126 return self.networks.get(network_name) 

127 return None 

128 

129 def validate(self) -> Tuple[bool, List[str]]: 

130 """Validate the FSM. 

131  

132 Returns: 

133 Tuple of (valid, list of errors). 

134 """ 

135 errors = [] 

136 

137 # Check for at least one network 

138 if not self.networks: 

139 errors.append("FSM has no networks") 

140 

141 # Check main network exists 

142 if self.main_network_name and self.main_network_name not in self.networks: 

143 errors.append(f"Main network '{self.main_network_name}' not found") 

144 

145 # Validate each network 

146 for network_name, network in self.networks.items(): 

147 valid, network_errors = network.validate() 

148 if not valid: 

149 for error in network_errors: 

150 errors.append(f"Network '{network_name}': {error}") 

151 

152 # Check function references 

153 all_functions = self._get_all_function_references() 

154 for func_name in all_functions: 

155 if not self.function_registry.get_function(func_name): 

156 errors.append(f"Function '{func_name}' not registered") 

157 

158 return len(errors) == 0, errors 

159 

160 def _get_all_function_references(self) -> Set[str]: 

161 """Get all function references from all networks. 

162  

163 Returns: 

164 Set of function names referenced. 

165 """ 

166 functions = set() 

167 

168 for network in self.networks.values(): 

169 for arc in network.arcs.values(): 

170 if hasattr(arc, 'pre_test') and arc.pre_test: 

171 functions.add(arc.pre_test) 

172 if hasattr(arc, 'transform') and arc.transform: 

173 functions.add(arc.transform) 

174 

175 return functions 

176 

177 def get_all_states(self) -> Dict[str, List[str]]: 

178 """Get all states from all networks. 

179  

180 Returns: 

181 Dictionary of network_name -> list of state names. 

182 """ 

183 all_states = {} 

184 

185 for network_name, network in self.networks.items(): 

186 all_states[network_name] = list(network.states.keys()) 

187 

188 return all_states 

189 

190 def get_all_arcs(self) -> Dict[str, List[str]]: 

191 """Get all arcs from all networks. 

192  

193 Returns: 

194 Dictionary of network_name -> list of arc IDs. 

195 """ 

196 all_arcs = {} 

197 

198 for network_name, network in self.networks.items(): 

199 all_arcs[network_name] = list(network.arcs.keys()) 

200 

201 return all_arcs 

202 

203 def supports_streaming(self) -> bool: 

204 """Check if FSM supports streaming. 

205  

206 Returns: 

207 True if any network supports streaming. 

208 """ 

209 return any(network.supports_streaming for network in self.networks.values()) 

210 

211 def get_resource_summary(self) -> Dict[str, Any]: 

212 """Get resource requirements summary. 

213  

214 Returns: 

215 Resource requirements summary. 

216 """ 

217 summary = { 

218 'total_networks': len(self.networks), 

219 'total_states': sum(len(n.states) for n in self.networks.values()), 

220 'total_arcs': sum(len(n.arcs) for n in self.networks.values()), 

221 'resource_types': list(self.resource_requirements.keys()), 

222 'supports_streaming': self.supports_streaming(), 

223 'data_mode': self.data_mode.value, 

224 'transaction_mode': self.transaction_mode.value 

225 } 

226 

227 # Add resource counts 

228 for resource_type, requirements in self.resource_requirements.items(): 

229 summary[f'{resource_type}_count'] = len(requirements) 

230 

231 return summary 

232 

233 def clone(self) -> 'FSM': 

234 """Create a clone of this FSM. 

235  

236 Returns: 

237 Cloned FSM. 

238 """ 

239 clone = FSM( 

240 name=f"{self.name}_clone", 

241 data_mode=self.data_mode, 

242 transaction_mode=self.transaction_mode, 

243 description=self.description 

244 ) 

245 

246 # Clone networks 

247 for network_name, network in self.networks.items(): 

248 # Note: This is a shallow copy - for deep clone would need to implement network.clone() 

249 clone.networks[network_name] = network 

250 

251 clone.main_network_name = self.main_network_name 

252 clone.function_registry = self.function_registry 

253 clone.resource_requirements = self.resource_requirements.copy() 

254 clone.config = self.config.copy() 

255 clone.metadata = self.metadata.copy() 

256 clone.version = self.version 

257 

258 return clone 

259 

260 def to_dict(self) -> Dict[str, Any]: 

261 """Convert FSM to dictionary representation. 

262  

263 Returns: 

264 Dictionary representation. 

265 """ 

266 return { 

267 'name': self.name, 

268 'description': self.description, 

269 'data_mode': self.data_mode.value, 

270 'transaction_mode': self.transaction_mode.value, 

271 'main_network': self.main_network_name, 

272 'networks': list(self.networks.keys()), 

273 'resource_requirements': { 

274 k: list(v) if isinstance(v, set) else v 

275 for k, v in self.resource_requirements.items() 

276 }, 

277 'config': self.config, 

278 'metadata': self.metadata, 

279 'version': self.version 

280 } 

281 

282 @classmethod 

283 def from_dict(cls, data: Dict[str, Any]) -> 'FSM': 

284 """Create FSM from dictionary representation. 

285  

286 Args: 

287 data: Dictionary with FSM data. 

288  

289 Returns: 

290 New FSM instance. 

291 """ 

292 fsm = cls( 

293 name=data['name'], 

294 data_mode=ProcessingMode(data.get('data_mode', 'single')), 

295 transaction_mode=TransactionMode(data.get('transaction_mode', 'none')), 

296 description=data.get('description') 

297 ) 

298 

299 fsm.main_network_name = data.get('main_network') 

300 fsm.config = data.get('config', {}) 

301 fsm.metadata = data.get('metadata', {}) 

302 fsm.version = data.get('version', '1.0.0') 

303 

304 # Resource requirements 

305 for resource_type, requirements in data.get('resource_requirements', {}).items(): 

306 fsm.resource_requirements[resource_type] = set(requirements) 

307 

308 return fsm 

309 

310 def find_state_definition(self, state_name: str, network_name: str | None = None) -> StateDefinition | None: 

311 """Find a state definition by name. 

312  

313 Args: 

314 state_name: Name of the state to find 

315 network_name: Optional specific network to search in 

316  

317 Returns: 

318 StateDefinition if found, None otherwise 

319 """ 

320 if network_name: 

321 # Search specific network 

322 network = self.networks.get(network_name) 

323 if network and hasattr(network, 'states'): 

324 return network.states.get(state_name) 

325 else: 

326 # Search all networks 

327 for network in self.networks.values(): 

328 if hasattr(network, 'states') and state_name in network.states: 

329 return network.states[state_name] 

330 

331 return None 

332 

333 def create_state_instance(self, state_name: str, data: Dict[str, Any] | None = None, network_name: str | None = None) -> StateInstance: 

334 """Create a state instance from a state name. 

335  

336 Args: 

337 state_name: Name of the state 

338 data: Optional initial data for the state 

339 network_name: Optional specific network to search in 

340  

341 Returns: 

342 StateInstance object 

343 """ 

344 # Try to find existing state definition 

345 state_def = self.find_state_definition(state_name, network_name) 

346 

347 if not state_def: 

348 # Create minimal state definition 

349 state_def = StateDefinition( 

350 name=state_name, 

351 type=StateType.START if state_name in ['start', 'Start', 'START'] else StateType.NORMAL 

352 ) 

353 

354 # Create and return state instance 

355 return StateInstance( 

356 definition=state_def, 

357 data=data or {} 

358 ) 

359 

360 def get_state(self, state_name: str, network_name: str | None = None) -> StateDefinition | None: 

361 """Get a state definition by name. 

362  

363 This is an alias for find_state_definition for compatibility. 

364  

365 Args: 

366 state_name: Name of the state 

367 network_name: Optional specific network to search in 

368  

369 Returns: 

370 StateDefinition if found, None otherwise 

371 """ 

372 return self.find_state_definition(state_name, network_name) 

373 

374 def is_start_state(self, state_name: str, network_name: str | None = None) -> bool: 

375 """Check if a state is a start state. 

376  

377 Args: 

378 state_name: Name of the state 

379 network_name: Optional specific network to check in (defaults to main network) 

380  

381 Returns: 

382 True if the state is a start state 

383 """ 

384 network_name = network_name or self.main_network_name 

385 if network_name: 

386 network = self.networks.get(network_name) 

387 if network: 

388 return network.is_initial_state(state_name) 

389 return False 

390 

391 def is_end_state(self, state_name: str, network_name: str | None = None) -> bool: 

392 """Check if a state is an end state. 

393  

394 Args: 

395 state_name: Name of the state 

396 network_name: Optional specific network to check in (defaults to main network) 

397  

398 Returns: 

399 True if the state is an end state 

400 """ 

401 network_name = network_name or self.main_network_name 

402 if network_name: 

403 network = self.networks.get(network_name) 

404 if network: 

405 return network.is_final_state(state_name) 

406 return False 

407 

408 def get_start_state(self, network_name: str | None = None) -> StateDefinition | None: 

409 """Get the start state definition. 

410  

411 Args: 

412 network_name: Optional specific network to search in 

413  

414 Returns: 

415 Start state definition if found, None otherwise 

416 """ 

417 # If network specified, search that network 

418 if network_name: 

419 network = self.networks.get(network_name) 

420 if network and hasattr(network, 'states'): 

421 for state in network.states.values(): 

422 if (hasattr(state, 'is_start_state') and state.is_start_state()) or (hasattr(state, 'type') and state.type == StateType.START): 

423 return state 

424 else: 

425 # Search main network first 

426 if self.main_network_name: 

427 start_state = self.get_start_state(self.main_network_name) 

428 if start_state: 

429 return start_state 

430 

431 # Search all networks 

432 for network in self.networks.values(): 

433 if hasattr(network, 'states'): 

434 for state in network.states.values(): 

435 if (hasattr(state, 'is_start_state') and state.is_start_state()) or (hasattr(state, 'type') and state.type == StateType.START): 

436 return state 

437 

438 # Fallback: look for state named 'start' 

439 return self.find_state_definition('start', network_name) 

440 

441 @property 

442 def main_network(self) -> Optional['StateNetwork']: 

443 """Get the main network object. 

444  

445 Returns: 

446 The main StateNetwork object or None if not set 

447 """ 

448 if self.main_network_name: 

449 return self.networks.get(self.main_network_name) 

450 return None 

451 

452 @property 

453 def states(self) -> Dict[str, StateDefinition]: 

454 """Get all states from the main network. 

455  

456 Returns: 

457 Dictionary of state_name -> state_definition for the main network 

458 """ 

459 if not self.main_network_name: 

460 return {} 

461 

462 network = self.get_network(self.main_network_name) 

463 if network and hasattr(network, 'states'): 

464 return network.states 

465 return {} 

466 

467 def get_all_states_dict(self) -> Dict[str, Dict[str, StateDefinition]]: 

468 """Get all states from all networks. 

469  

470 Returns: 

471 Dictionary of network_name -> {state_name -> state_definition} 

472 """ 

473 all_states = {} 

474 for network_name, network in self.networks.items(): 

475 if hasattr(network, 'states'): 

476 all_states[network_name] = network.states 

477 return all_states 

478 

479 def get_outgoing_arcs(self, state_name: str, network_name: str | None = None) -> List[Any]: 

480 """Get outgoing arcs from a state. 

481  

482 Args: 

483 state_name: Name of the state 

484 network_name: Optional network name (uses main network if None) 

485  

486 Returns: 

487 List of outgoing arcs from the state 

488 """ 

489 network_name = network_name or self.main_network_name 

490 if not network_name: 

491 return [] 

492 

493 network = self.get_network(network_name) 

494 if network: 

495 return network.get_arcs_from_state(state_name) 

496 return [] 

497 

498 def get_engine(self, strategy: str | None = None): 

499 """Get or create the execution engine. 

500  

501 Args: 

502 strategy: Optional execution strategy override 

503  

504 Returns: 

505 ExecutionEngine instance. 

506 """ 

507 if self._engine is None: 

508 from dataknobs_fsm.execution.engine import ExecutionEngine, TraversalStrategy 

509 

510 # Map strategy strings to enum 

511 strategy_map = { 

512 "depth_first": TraversalStrategy.DEPTH_FIRST, 

513 "breadth_first": TraversalStrategy.BREADTH_FIRST, 

514 "resource_optimized": TraversalStrategy.RESOURCE_OPTIMIZED, 

515 "stream_optimized": TraversalStrategy.STREAM_OPTIMIZED, 

516 } 

517 

518 strat = TraversalStrategy.DEPTH_FIRST # Default 

519 if strategy and strategy in strategy_map: 

520 strat = strategy_map[strategy] 

521 

522 self._engine = ExecutionEngine( 

523 fsm=self, 

524 strategy=strat, 

525 ) 

526 

527 return self._engine 

528 

529 def get_async_engine(self, strategy: str | None = None): 

530 """Get or create the async execution engine. 

531  

532 Args: 

533 strategy: Optional execution strategy override 

534  

535 Returns: 

536 AsyncExecutionEngine instance. 

537 """ 

538 if self._async_engine is None: 

539 from dataknobs_fsm.execution.async_engine import AsyncExecutionEngine 

540 

541 self._async_engine = AsyncExecutionEngine(fsm=self) 

542 

543 return self._async_engine 

544 

545 def _prepare_execution_context(self, initial_data: Dict[str, Any] | None = None): 

546 """Prepare execution context for FSM execution. 

547  

548 Args: 

549 initial_data: Initial data for execution. 

550  

551 Returns: 

552 Configured ExecutionContext instance. 

553 """ 

554 from dataknobs_fsm.execution.context import ExecutionContext 

555 from dataknobs_fsm.streaming.core import StreamContext, StreamConfig 

556 

557 # Create execution context 

558 context = ExecutionContext( 

559 data_mode=self.data_mode, 

560 transaction_mode=self.transaction_mode 

561 ) 

562 

563 # Set resource and transaction managers if available 

564 if self.resource_manager: 

565 context.resource_manager = self.resource_manager 

566 if self.transaction_manager: 

567 context.transaction_manager = self.transaction_manager 

568 

569 # Set up context based on data mode 

570 if self.data_mode == ProcessingMode.BATCH: 

571 # For batch mode, treat input as batch data 

572 if initial_data is not None: 

573 # If it's not already a list, make it one 

574 if not isinstance(initial_data, list): # type: ignore[unreachable] 

575 context.batch_data = [initial_data] 

576 else: 

577 context.batch_data = initial_data # type: ignore[unreachable] 

578 else: 

579 context.batch_data = [] 

580 elif self.data_mode == ProcessingMode.STREAM: 

581 # For stream mode, create a stream context 

582 stream_config = StreamConfig() 

583 context.stream_context = StreamContext(config=stream_config) 

584 

585 # Add initial data as a chunk if provided 

586 if initial_data is not None: 

587 # Add the data as a single chunk to the stream 

588 context.stream_context.add_data(initial_data, is_last=True) 

589 # Also set context.data for compatibility 

590 context.data = initial_data 

591 else: 

592 # Single mode - data passed normally 

593 pass 

594 

595 return context 

596 

597 def _format_execution_result(self, success: bool, result: Any, context: Any, 

598 duration: float, initial_data: Any = None, 

599 error: str | None = None) -> Dict[str, Any]: 

600 """Format the execution result in a standard format. 

601  

602 Args: 

603 success: Whether execution succeeded. 

604 result: The execution result data. 

605 context: The execution context. 

606 duration: Time taken for execution. 

607 initial_data: Original input data. 

608 error: Error message if execution failed. 

609  

610 Returns: 

611 Formatted result dictionary. 

612 """ 

613 if error: 

614 return { 

615 "status": "error", 

616 "error": error, 

617 "data": initial_data, 

618 "execution_id": None, 

619 "transitions": 0, 

620 "duration": None 

621 } 

622 

623 return { 

624 "status": "completed" if success else "failed", 

625 "data": result, 

626 "execution_id": getattr(context, 'execution_id', None), 

627 "transitions": getattr(context, 'transition_count', 0), 

628 "duration": duration 

629 } 

630 

631 async def execute_async(self, initial_data: Dict[str, Any] | None = None) -> Any: 

632 """Execute the FSM asynchronously with initial data. 

633  

634 Args: 

635 initial_data: Initial data for execution. 

636  

637 Returns: 

638 Execution result. 

639 """ 

640 import time 

641 

642 try: 

643 # Get the async execution engine 

644 engine = self.get_async_engine() 

645 

646 # Prepare execution context 

647 context = self._prepare_execution_context(initial_data) 

648 

649 # Track execution time 

650 start_time = time.time() 

651 

652 # Execute the FSM 

653 success, result = await engine.execute( 

654 context, 

655 initial_data if self.data_mode == ProcessingMode.SINGLE else None 

656 ) 

657 

658 # Calculate duration 

659 duration = time.time() - start_time 

660 

661 return self._format_execution_result(success, result, context, duration) 

662 

663 except Exception as e: 

664 # Handle any exception that occurs during execution 

665 return self._format_execution_result( 

666 False, None, None, 0.0, initial_data, str(e) 

667 ) 

668 

669 def execute(self, initial_data: Dict[str, Any] | None = None) -> Any: 

670 """Execute the FSM synchronously with initial data. 

671  

672 This is a simplified API for running the FSM. 

673  

674 Args: 

675 initial_data: Initial data for execution. 

676  

677 Returns: 

678 Execution result. 

679 """ 

680 import time 

681 

682 try: 

683 # Get the execution engine 

684 engine = self.get_engine() 

685 

686 # Prepare execution context 

687 context = self._prepare_execution_context(initial_data) 

688 

689 # Track execution time 

690 start_time = time.time() 

691 

692 # Execute the FSM 

693 success, result = engine.execute( 

694 context, 

695 initial_data if self.data_mode == ProcessingMode.SINGLE else None 

696 ) 

697 

698 # Calculate duration 

699 duration = time.time() - start_time 

700 

701 return self._format_execution_result(success, result, context, duration) 

702 

703 except Exception as e: 

704 # Handle any exception that occurs during execution 

705 return self._format_execution_result( 

706 False, None, None, 0.0, initial_data, str(e) 

707 )