Coverage for src/dataknobs_fsm/execution/network.py: 10%

189 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""Network executor for running state networks.""" 

2 

3from typing import Any, Dict, List, Tuple 

4 

5from dataknobs_fsm.core.arc import PushArc, DataIsolationMode 

6from dataknobs_fsm.core.fsm import FSM 

7from dataknobs_fsm.core.network import StateNetwork 

8from dataknobs_fsm.execution.context import ExecutionContext 

9from dataknobs_fsm.execution.engine import ExecutionEngine 

10from dataknobs_fsm.functions.base import StateTransitionError 

11 

12 

13class NetworkExecutor: 

14 """Executor for running state networks with hierarchical support. 

15  

16 This executor manages: 

17 - Network execution with context isolation 

18 - Hierarchical network push/pop operations 

19 - Data passing between networks 

20 - Resource management across networks 

21 - Parallel network execution 

22 """ 

23 

24 def __init__( 

25 self, 

26 fsm: FSM, 

27 enable_parallel: bool = False, 

28 max_depth: int = 10 

29 ): 

30 """Initialize network executor. 

31  

32 Args: 

33 fsm: FSM containing networks to execute. 

34 enable_parallel: Enable parallel network execution. 

35 max_depth: Maximum network push depth. 

36 """ 

37 self.fsm = fsm 

38 self.enable_parallel = enable_parallel 

39 self.max_depth = max_depth 

40 

41 # Create execution engine 

42 self.engine = ExecutionEngine(fsm) 

43 

44 # Track active networks 

45 self._active_networks: Dict[str, ExecutionContext] = {} 

46 

47 def execute_network( 

48 self, 

49 network_name: str, 

50 context: ExecutionContext | None = None, 

51 data: Any = None, 

52 max_transitions: int = 1000, 

53 initial_state: str | None = None 

54 ) -> Tuple[bool, Any]: 

55 """Execute a specific network. 

56  

57 Args: 

58 network_name: Name of network to execute. 

59 context: Execution context (created if None). 

60 data: Input data for network. 

61 max_transitions: Maximum transitions allowed. 

62  

63 Returns: 

64 Tuple of (success, result). 

65 """ 

66 # Get network 

67 network = self.fsm.networks.get(network_name) 

68 if not network: 

69 return False, f"Network not found: {network_name}" 

70 

71 # Create context if needed 

72 if context is None: 

73 context = ExecutionContext() 

74 

75 # Set initial data 

76 if data is not None: 

77 context.data = data 

78 

79 # Track this network 

80 self._active_networks[network_name] = context 

81 

82 try: 

83 # For subnetworks, we always need to set the initial state 

84 # regardless of what was in the parent context 

85 if initial_state: 

86 # Use the provided initial state override 

87 # First verify it exists in the network 

88 if initial_state not in network.states: 

89 return False, f"State '{initial_state}' not found in network '{network_name}'" 

90 state_to_enter = initial_state 

91 elif network.initial_states: 

92 # Use the network's default initial state 

93 state_to_enter = next(iter(network.initial_states)) 

94 else: 

95 return False, f"No initial state in network: {network_name}" 

96 

97 # Clear any previous state and set the subnetwork's initial state 

98 context.current_state = None # Clear first 

99 # Use the engine's public enter_state method for consistent state entry 

100 if not self.engine.enter_state(context, state_to_enter, run_validators=False): 

101 return False, f"Failed to enter initial state: {state_to_enter}" 

102 

103 # Execute the network 

104 result = self._execute_network_internal( 

105 network, 

106 context, 

107 max_transitions 

108 ) 

109 

110 return result 

111 

112 finally: 

113 # Clean up tracking 

114 if network_name in self._active_networks: 

115 del self._active_networks[network_name] 

116 

117 def _execute_network_internal( 

118 self, 

119 network: StateNetwork, 

120 context: ExecutionContext, 

121 max_transitions: int 

122 ) -> Tuple[bool, Any]: 

123 """Internal network execution. 

124  

125 Args: 

126 network: Network to execute. 

127 context: Execution context. 

128 max_transitions: Maximum transitions. 

129  

130 Returns: 

131 Tuple of (success, result). 

132 """ 

133 transitions = 0 

134 

135 while transitions < max_transitions: 

136 # Check if in final state 

137 if context.current_state in network.final_states: 

138 return True, context.data 

139 

140 # Get available arcs from current state 

141 available_arcs = self._get_available_arcs( 

142 network, 

143 context.current_state 

144 ) 

145 

146 if not available_arcs: 

147 # No transitions available 

148 if context.current_state in network.final_states: 

149 return True, context.data 

150 return False, f"No valid transitions from: {context.current_state}" 

151 

152 # Process each arc 

153 transition_made = False 

154 for _arc_id, arc in available_arcs: 

155 # Evaluate arc condition/pre_test first (for all arc types including PushArcs) 

156 if hasattr(arc, 'pre_test') and arc.pre_test: 

157 # Arc has a condition (stored as pre_test) 

158 # Need to evaluate it using the function registry 

159 from dataknobs_fsm.core.arc import ArcExecution 

160 arc_exec = ArcExecution( 

161 arc, 

162 context.current_state or "", 

163 self.fsm.function_registry 

164 ) 

165 if not arc_exec.can_execute(context, context.data): 

166 continue # Skip this arc if condition is not met 

167 

168 # Check if this is a push arc 

169 if isinstance(arc, PushArc): 

170 # Debug: print push arc detection 

171 import logging 

172 logging.debug(f"Detected PushArc from {context.current_state} to network {arc.target_network}") 

173 success = self._handle_push_arc( 

174 arc, 

175 context 

176 ) 

177 elif hasattr(arc, 'metadata') and 'push_arc' in arc.metadata: 

178 # Arc with push_arc in metadata 

179 push_arc = arc.metadata['push_arc'] 

180 if isinstance(push_arc, PushArc): 

181 success = self._handle_push_arc( 

182 push_arc, 

183 context 

184 ) 

185 else: 

186 # Regular transition 

187 success = self.engine._execute_transition( 

188 context, 

189 arc 

190 ) 

191 else: 

192 # Regular transition 

193 success = self.engine._execute_transition( 

194 context, 

195 arc 

196 ) 

197 

198 if success: 

199 transition_made = True 

200 transitions += 1 

201 break 

202 

203 if not transition_made: 

204 return False, "No valid transition could be made" 

205 

206 # Check for network pop 

207 if context.current_state in network.final_states: 

208 if context.network_stack: 

209 self._handle_network_return(context) 

210 

211 return False, f"Maximum transitions ({max_transitions}) exceeded" 

212 

213 def _handle_push_arc( 

214 self, 

215 arc: PushArc, 

216 context: ExecutionContext 

217 ) -> bool: 

218 """Handle a push arc to another network. 

219 

220 Args: 

221 arc: Push arc to execute. 

222 context: Execution context. 

223 

224 Returns: 

225 True if successful. 

226 """ 

227 # Check depth limit 

228 if len(context.network_stack) >= self.max_depth: 

229 raise StateTransitionError( 

230 from_state=context.current_state or "unknown", 

231 to_state=arc.target_network, 

232 message="Maximum network depth exceeded" 

233 ) 

234 

235 # Save parent state resources before pushing 

236 parent_state_resources = getattr(context, 'current_state_resources', None) 

237 

238 # Parse target network and optional initial state 

239 # Using Syntax: "network_name" or "network_name:initial_state" 

240 if ':' in arc.target_network: 

241 network_name, initial_state = arc.target_network.split(':', 1) 

242 override_initial_state = initial_state.strip() 

243 else: 

244 network_name = arc.target_network 

245 override_initial_state = None 

246 

247 # Push current network 

248 context.push_network( 

249 network_name, 

250 arc.return_state 

251 ) 

252 

253 # Get target network 

254 target_network = self.fsm.networks.get(network_name) 

255 if not target_network: 

256 context.pop_network() 

257 return False 

258 

259 # Create isolated context if requested 

260 if hasattr(arc, 'isolation_mode') and arc.isolation_mode == DataIsolationMode.COPY: 

261 # Full isolation - new context 

262 sub_context = ExecutionContext( 

263 data_mode=context.data_mode, 

264 transaction_mode=context.transaction_mode, 

265 resources=context.resource_limits 

266 ) 

267 sub_context.data = context.data 

268 sub_context.variables = context.variables # Share variables for tracking 

269 # Preserve resource manager in new context 

270 if hasattr(context, 'resource_manager'): 

271 sub_context.resource_manager = context.resource_manager 

272 # Preserve parent state resources in new context 

273 if parent_state_resources: 

274 sub_context.parent_state_resources = parent_state_resources 

275 elif hasattr(arc, 'data_isolation_mode') and arc.data_isolation_mode == 'partial': 

276 # Partial isolation - clone context 

277 sub_context = context.clone() 

278 # Preserve parent state resources - this needs to be accessible to all subnetwork states 

279 if parent_state_resources: 

280 sub_context.parent_state_resources = parent_state_resources 

281 # Also ensure resource_manager is available 

282 if hasattr(context, 'resource_manager'): 

283 sub_context.resource_manager = context.resource_manager 

284 else: 

285 # No isolation - use same context 

286 sub_context = context 

287 if parent_state_resources: 

288 context.parent_state_resources = parent_state_resources 

289 # Ensure resource_manager is available in subcontext 

290 if hasattr(context, 'resource_manager') and not hasattr(sub_context, 'resource_manager'): 

291 sub_context.resource_manager = context.resource_manager 

292 

293 # Execute target network (which will handle initial state and transforms) 

294 import logging 

295 logging.debug(f"Executing sub-network {network_name} with context type {type(sub_context)}") 

296 # Pass the override initial state if specified 

297 # Don't pass data parameter - sub_context already has the correctly transformed data 

298 success, result = self.execute_network( 

299 network_name, 

300 sub_context, 

301 initial_state=override_initial_state 

302 ) 

303 logging.debug(f"Sub-network execution result: success={success}, result={result}") 

304 

305 if success: 

306 # Update main context with result 

307 context.data = result 

308 

309 # Return to specified state and execute its entry logic 

310 if arc.return_state: 

311 # Use the engine's public enter_state method to properly enter the return state 

312 if not self.engine.enter_state(context, arc.return_state, run_validators=False): 

313 return False 

314 

315 return True 

316 

317 return False 

318 

319 def _handle_network_return( 

320 self, 

321 context: ExecutionContext 

322 ) -> None: 

323 """Handle returning from a pushed network. 

324 

325 Args: 

326 context: Execution context. 

327 """ 

328 if context.network_stack: 

329 _network_name, return_state = context.pop_network() 

330 

331 # Clean up parent_state_resources attribute if it was added 

332 if hasattr(context, 'parent_state_resources'): 

333 delattr(context, 'parent_state_resources') 

334 

335 if return_state: 

336 context.set_state(return_state) 

337 

338 def _get_available_arcs( 

339 self, 

340 network: StateNetwork, 

341 state_name: str | None 

342 ) -> List[Tuple[str, Any]]: 

343 """Get available arcs from a state. 

344 

345 Args: 

346 network: Network containing arcs. 

347 state_name: Current state name. 

348 

349 Returns: 

350 List of (arc_id, arc) tuples. 

351 """ 

352 if not state_name: 

353 return [] 

354 

355 # Get the state definition to access actual arc objects (including PushArcs) 

356 state_def = network.get_state(state_name) 

357 if not state_def: 

358 return [] 

359 

360 available = [] 

361 # Use the state's outgoing_arcs which have the proper arc types 

362 for i, arc in enumerate(state_def.outgoing_arcs): 

363 # Create an arc_id for tracking 

364 arc_id = f"{state_name}:{arc.target_state}:{i}" 

365 available.append((arc_id, arc)) 

366 

367 return available 

368 

369 def execute_parallel_networks( 

370 self, 

371 network_configs: List[Dict[str, Any]], 

372 base_context: ExecutionContext | None = None 

373 ) -> List[Tuple[bool, Any]]: 

374 """Execute multiple networks in parallel. 

375  

376 Args: 

377 network_configs: List of network configurations. 

378 Each config should have: 

379 - 'network_name': Name of network 

380 - 'data': Input data 

381 - 'max_transitions': Max transitions (optional) 

382 base_context: Base context to clone for each network. 

383  

384 Returns: 

385 List of (success, result) tuples in the same order as configs. 

386 """ 

387 if not self.enable_parallel: 

388 # Execute sequentially if parallel disabled 

389 results = [] 

390 for config in network_configs: 

391 network_name = config['network_name'] 

392 data = config.get('data') 

393 max_transitions = config.get('max_transitions', 1000) 

394 

395 # Clone context for each network 

396 if base_context: 

397 context = base_context.clone() 

398 else: 

399 context = ExecutionContext() 

400 

401 success, result = self.execute_network( 

402 network_name, 

403 context, 

404 data, 

405 max_transitions 

406 ) 

407 

408 results.append((success, result)) 

409 

410 return results 

411 

412 # Parallel execution using asyncio 

413 import asyncio 

414 

415 async def execute_async(config): 

416 network_name = config['network_name'] 

417 data = config.get('data') 

418 max_transitions = config.get('max_transitions', 1000) 

419 

420 # Clone context 

421 if base_context: 

422 context = base_context.clone() 

423 else: 

424 context = ExecutionContext() 

425 

426 # Execute in thread pool 

427 loop = asyncio.get_event_loop() 

428 success, result = await loop.run_in_executor( 

429 None, 

430 self.execute_network, 

431 network_name, 

432 context, 

433 data, 

434 max_transitions 

435 ) 

436 

437 return (success, result) 

438 

439 # Run all networks in parallel 

440 loop = asyncio.new_event_loop() 

441 asyncio.set_event_loop(loop) 

442 

443 try: 

444 tasks = [execute_async(config) for config in network_configs] 

445 results = loop.run_until_complete( 

446 asyncio.gather(*tasks) 

447 ) 

448 

449 return results 

450 

451 finally: 

452 loop.close() 

453 

454 def validate_all_networks(self) -> Dict[str, Tuple[bool, List[str]]]: 

455 """Validate all networks in the FSM. 

456  

457 Returns: 

458 Dictionary of network_name -> (valid, errors). 

459 """ 

460 results = {} 

461 

462 for network_name, network in self.fsm.networks.items(): 

463 valid, errors = network.validate() 

464 results[network_name] = (valid, errors) 

465 

466 return results 

467 

468 def get_network_stats(self, network_name: str) -> Dict[str, Any]: 

469 """Get statistics for a network. 

470  

471 Args: 

472 network_name: Name of network. 

473  

474 Returns: 

475 Network statistics. 

476 """ 

477 network = self.fsm.networks.get(network_name) 

478 if not network: 

479 return {} 

480 

481 # Count various elements 

482 state_count = len(network.states) 

483 arc_count = len(network.arcs) 

484 initial_count = len(network.initial_states) 

485 final_count = len(network.final_states) 

486 

487 # Check connectivity 

488 valid, errors = network.validate() 

489 

490 # Get resource requirements 

491 total_resources = {} 

492 for resource_type, requirements in network.resource_requirements.items(): 

493 total_resources[resource_type] = len(requirements) 

494 

495 return { 

496 'states': state_count, 

497 'arcs': arc_count, 

498 'initial_states': initial_count, 

499 'final_states': final_count, 

500 'is_valid': valid, 

501 'validation_errors': errors, 

502 'resource_requirements': total_resources, 

503 'supports_streaming': network.supports_streaming 

504 } 

505 

506 def get_active_networks(self) -> List[str]: 

507 """Get list of currently active networks. 

508  

509 Returns: 

510 List of active network names. 

511 """ 

512 return list(self._active_networks.keys())