Coverage for src/dataknobs_fsm/execution/network.py: 10%
189 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
1"""Network executor for running state networks."""
3from typing import Any, Dict, List, Tuple
5from dataknobs_fsm.core.arc import PushArc, DataIsolationMode
6from dataknobs_fsm.core.fsm import FSM
7from dataknobs_fsm.core.network import StateNetwork
8from dataknobs_fsm.execution.context import ExecutionContext
9from dataknobs_fsm.execution.engine import ExecutionEngine
10from dataknobs_fsm.functions.base import StateTransitionError
13class NetworkExecutor:
14 """Executor for running state networks with hierarchical support.
16 This executor manages:
17 - Network execution with context isolation
18 - Hierarchical network push/pop operations
19 - Data passing between networks
20 - Resource management across networks
21 - Parallel network execution
22 """
24 def __init__(
25 self,
26 fsm: FSM,
27 enable_parallel: bool = False,
28 max_depth: int = 10
29 ):
30 """Initialize network executor.
32 Args:
33 fsm: FSM containing networks to execute.
34 enable_parallel: Enable parallel network execution.
35 max_depth: Maximum network push depth.
36 """
37 self.fsm = fsm
38 self.enable_parallel = enable_parallel
39 self.max_depth = max_depth
41 # Create execution engine
42 self.engine = ExecutionEngine(fsm)
44 # Track active networks
45 self._active_networks: Dict[str, ExecutionContext] = {}
47 def execute_network(
48 self,
49 network_name: str,
50 context: ExecutionContext | None = None,
51 data: Any = None,
52 max_transitions: int = 1000,
53 initial_state: str | None = None
54 ) -> Tuple[bool, Any]:
55 """Execute a specific network.
57 Args:
58 network_name: Name of network to execute.
59 context: Execution context (created if None).
60 data: Input data for network.
61 max_transitions: Maximum transitions allowed.
63 Returns:
64 Tuple of (success, result).
65 """
66 # Get network
67 network = self.fsm.networks.get(network_name)
68 if not network:
69 return False, f"Network not found: {network_name}"
71 # Create context if needed
72 if context is None:
73 context = ExecutionContext()
75 # Set initial data
76 if data is not None:
77 context.data = data
79 # Track this network
80 self._active_networks[network_name] = context
82 try:
83 # For subnetworks, we always need to set the initial state
84 # regardless of what was in the parent context
85 if initial_state:
86 # Use the provided initial state override
87 # First verify it exists in the network
88 if initial_state not in network.states:
89 return False, f"State '{initial_state}' not found in network '{network_name}'"
90 state_to_enter = initial_state
91 elif network.initial_states:
92 # Use the network's default initial state
93 state_to_enter = next(iter(network.initial_states))
94 else:
95 return False, f"No initial state in network: {network_name}"
97 # Clear any previous state and set the subnetwork's initial state
98 context.current_state = None # Clear first
99 # Use the engine's public enter_state method for consistent state entry
100 if not self.engine.enter_state(context, state_to_enter, run_validators=False):
101 return False, f"Failed to enter initial state: {state_to_enter}"
103 # Execute the network
104 result = self._execute_network_internal(
105 network,
106 context,
107 max_transitions
108 )
110 return result
112 finally:
113 # Clean up tracking
114 if network_name in self._active_networks:
115 del self._active_networks[network_name]
117 def _execute_network_internal(
118 self,
119 network: StateNetwork,
120 context: ExecutionContext,
121 max_transitions: int
122 ) -> Tuple[bool, Any]:
123 """Internal network execution.
125 Args:
126 network: Network to execute.
127 context: Execution context.
128 max_transitions: Maximum transitions.
130 Returns:
131 Tuple of (success, result).
132 """
133 transitions = 0
135 while transitions < max_transitions:
136 # Check if in final state
137 if context.current_state in network.final_states:
138 return True, context.data
140 # Get available arcs from current state
141 available_arcs = self._get_available_arcs(
142 network,
143 context.current_state
144 )
146 if not available_arcs:
147 # No transitions available
148 if context.current_state in network.final_states:
149 return True, context.data
150 return False, f"No valid transitions from: {context.current_state}"
152 # Process each arc
153 transition_made = False
154 for _arc_id, arc in available_arcs:
155 # Evaluate arc condition/pre_test first (for all arc types including PushArcs)
156 if hasattr(arc, 'pre_test') and arc.pre_test:
157 # Arc has a condition (stored as pre_test)
158 # Need to evaluate it using the function registry
159 from dataknobs_fsm.core.arc import ArcExecution
160 arc_exec = ArcExecution(
161 arc,
162 context.current_state or "",
163 self.fsm.function_registry
164 )
165 if not arc_exec.can_execute(context, context.data):
166 continue # Skip this arc if condition is not met
168 # Check if this is a push arc
169 if isinstance(arc, PushArc):
170 # Debug: print push arc detection
171 import logging
172 logging.debug(f"Detected PushArc from {context.current_state} to network {arc.target_network}")
173 success = self._handle_push_arc(
174 arc,
175 context
176 )
177 elif hasattr(arc, 'metadata') and 'push_arc' in arc.metadata:
178 # Arc with push_arc in metadata
179 push_arc = arc.metadata['push_arc']
180 if isinstance(push_arc, PushArc):
181 success = self._handle_push_arc(
182 push_arc,
183 context
184 )
185 else:
186 # Regular transition
187 success = self.engine._execute_transition(
188 context,
189 arc
190 )
191 else:
192 # Regular transition
193 success = self.engine._execute_transition(
194 context,
195 arc
196 )
198 if success:
199 transition_made = True
200 transitions += 1
201 break
203 if not transition_made:
204 return False, "No valid transition could be made"
206 # Check for network pop
207 if context.current_state in network.final_states:
208 if context.network_stack:
209 self._handle_network_return(context)
211 return False, f"Maximum transitions ({max_transitions}) exceeded"
213 def _handle_push_arc(
214 self,
215 arc: PushArc,
216 context: ExecutionContext
217 ) -> bool:
218 """Handle a push arc to another network.
220 Args:
221 arc: Push arc to execute.
222 context: Execution context.
224 Returns:
225 True if successful.
226 """
227 # Check depth limit
228 if len(context.network_stack) >= self.max_depth:
229 raise StateTransitionError(
230 from_state=context.current_state or "unknown",
231 to_state=arc.target_network,
232 message="Maximum network depth exceeded"
233 )
235 # Save parent state resources before pushing
236 parent_state_resources = getattr(context, 'current_state_resources', None)
238 # Parse target network and optional initial state
239 # Using Syntax: "network_name" or "network_name:initial_state"
240 if ':' in arc.target_network:
241 network_name, initial_state = arc.target_network.split(':', 1)
242 override_initial_state = initial_state.strip()
243 else:
244 network_name = arc.target_network
245 override_initial_state = None
247 # Push current network
248 context.push_network(
249 network_name,
250 arc.return_state
251 )
253 # Get target network
254 target_network = self.fsm.networks.get(network_name)
255 if not target_network:
256 context.pop_network()
257 return False
259 # Create isolated context if requested
260 if hasattr(arc, 'isolation_mode') and arc.isolation_mode == DataIsolationMode.COPY:
261 # Full isolation - new context
262 sub_context = ExecutionContext(
263 data_mode=context.data_mode,
264 transaction_mode=context.transaction_mode,
265 resources=context.resource_limits
266 )
267 sub_context.data = context.data
268 sub_context.variables = context.variables # Share variables for tracking
269 # Preserve resource manager in new context
270 if hasattr(context, 'resource_manager'):
271 sub_context.resource_manager = context.resource_manager
272 # Preserve parent state resources in new context
273 if parent_state_resources:
274 sub_context.parent_state_resources = parent_state_resources
275 elif hasattr(arc, 'data_isolation_mode') and arc.data_isolation_mode == 'partial':
276 # Partial isolation - clone context
277 sub_context = context.clone()
278 # Preserve parent state resources - this needs to be accessible to all subnetwork states
279 if parent_state_resources:
280 sub_context.parent_state_resources = parent_state_resources
281 # Also ensure resource_manager is available
282 if hasattr(context, 'resource_manager'):
283 sub_context.resource_manager = context.resource_manager
284 else:
285 # No isolation - use same context
286 sub_context = context
287 if parent_state_resources:
288 context.parent_state_resources = parent_state_resources
289 # Ensure resource_manager is available in subcontext
290 if hasattr(context, 'resource_manager') and not hasattr(sub_context, 'resource_manager'):
291 sub_context.resource_manager = context.resource_manager
293 # Execute target network (which will handle initial state and transforms)
294 import logging
295 logging.debug(f"Executing sub-network {network_name} with context type {type(sub_context)}")
296 # Pass the override initial state if specified
297 # Don't pass data parameter - sub_context already has the correctly transformed data
298 success, result = self.execute_network(
299 network_name,
300 sub_context,
301 initial_state=override_initial_state
302 )
303 logging.debug(f"Sub-network execution result: success={success}, result={result}")
305 if success:
306 # Update main context with result
307 context.data = result
309 # Return to specified state and execute its entry logic
310 if arc.return_state:
311 # Use the engine's public enter_state method to properly enter the return state
312 if not self.engine.enter_state(context, arc.return_state, run_validators=False):
313 return False
315 return True
317 return False
319 def _handle_network_return(
320 self,
321 context: ExecutionContext
322 ) -> None:
323 """Handle returning from a pushed network.
325 Args:
326 context: Execution context.
327 """
328 if context.network_stack:
329 _network_name, return_state = context.pop_network()
331 # Clean up parent_state_resources attribute if it was added
332 if hasattr(context, 'parent_state_resources'):
333 delattr(context, 'parent_state_resources')
335 if return_state:
336 context.set_state(return_state)
338 def _get_available_arcs(
339 self,
340 network: StateNetwork,
341 state_name: str | None
342 ) -> List[Tuple[str, Any]]:
343 """Get available arcs from a state.
345 Args:
346 network: Network containing arcs.
347 state_name: Current state name.
349 Returns:
350 List of (arc_id, arc) tuples.
351 """
352 if not state_name:
353 return []
355 # Get the state definition to access actual arc objects (including PushArcs)
356 state_def = network.get_state(state_name)
357 if not state_def:
358 return []
360 available = []
361 # Use the state's outgoing_arcs which have the proper arc types
362 for i, arc in enumerate(state_def.outgoing_arcs):
363 # Create an arc_id for tracking
364 arc_id = f"{state_name}:{arc.target_state}:{i}"
365 available.append((arc_id, arc))
367 return available
369 def execute_parallel_networks(
370 self,
371 network_configs: List[Dict[str, Any]],
372 base_context: ExecutionContext | None = None
373 ) -> List[Tuple[bool, Any]]:
374 """Execute multiple networks in parallel.
376 Args:
377 network_configs: List of network configurations.
378 Each config should have:
379 - 'network_name': Name of network
380 - 'data': Input data
381 - 'max_transitions': Max transitions (optional)
382 base_context: Base context to clone for each network.
384 Returns:
385 List of (success, result) tuples in the same order as configs.
386 """
387 if not self.enable_parallel:
388 # Execute sequentially if parallel disabled
389 results = []
390 for config in network_configs:
391 network_name = config['network_name']
392 data = config.get('data')
393 max_transitions = config.get('max_transitions', 1000)
395 # Clone context for each network
396 if base_context:
397 context = base_context.clone()
398 else:
399 context = ExecutionContext()
401 success, result = self.execute_network(
402 network_name,
403 context,
404 data,
405 max_transitions
406 )
408 results.append((success, result))
410 return results
412 # Parallel execution using asyncio
413 import asyncio
415 async def execute_async(config):
416 network_name = config['network_name']
417 data = config.get('data')
418 max_transitions = config.get('max_transitions', 1000)
420 # Clone context
421 if base_context:
422 context = base_context.clone()
423 else:
424 context = ExecutionContext()
426 # Execute in thread pool
427 loop = asyncio.get_event_loop()
428 success, result = await loop.run_in_executor(
429 None,
430 self.execute_network,
431 network_name,
432 context,
433 data,
434 max_transitions
435 )
437 return (success, result)
439 # Run all networks in parallel
440 loop = asyncio.new_event_loop()
441 asyncio.set_event_loop(loop)
443 try:
444 tasks = [execute_async(config) for config in network_configs]
445 results = loop.run_until_complete(
446 asyncio.gather(*tasks)
447 )
449 return results
451 finally:
452 loop.close()
454 def validate_all_networks(self) -> Dict[str, Tuple[bool, List[str]]]:
455 """Validate all networks in the FSM.
457 Returns:
458 Dictionary of network_name -> (valid, errors).
459 """
460 results = {}
462 for network_name, network in self.fsm.networks.items():
463 valid, errors = network.validate()
464 results[network_name] = (valid, errors)
466 return results
468 def get_network_stats(self, network_name: str) -> Dict[str, Any]:
469 """Get statistics for a network.
471 Args:
472 network_name: Name of network.
474 Returns:
475 Network statistics.
476 """
477 network = self.fsm.networks.get(network_name)
478 if not network:
479 return {}
481 # Count various elements
482 state_count = len(network.states)
483 arc_count = len(network.arcs)
484 initial_count = len(network.initial_states)
485 final_count = len(network.final_states)
487 # Check connectivity
488 valid, errors = network.validate()
490 # Get resource requirements
491 total_resources = {}
492 for resource_type, requirements in network.resource_requirements.items():
493 total_resources[resource_type] = len(requirements)
495 return {
496 'states': state_count,
497 'arcs': arc_count,
498 'initial_states': initial_count,
499 'final_states': final_count,
500 'is_valid': valid,
501 'validation_errors': errors,
502 'resource_requirements': total_resources,
503 'supports_streaming': network.supports_streaming
504 }
506 def get_active_networks(self) -> List[str]:
507 """Get list of currently active networks.
509 Returns:
510 List of active network names.
511 """
512 return list(self._active_networks.keys())