Coverage for src/dataknobs_fsm/core/arc.py: 38%
206 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
1"""Arc implementation for FSM state transitions."""
3import logging
4from dataclasses import dataclass, field
5from enum import Enum
6from typing import Any, Callable, Dict, TYPE_CHECKING
8from dataknobs_fsm.core.exceptions import FunctionError, ResourceError
9from dataknobs_fsm.functions.base import FunctionContext
11if TYPE_CHECKING:
12 from dataknobs_fsm.execution.context import ExecutionContext
14logger = logging.getLogger(__name__)
17class DataIsolationMode(Enum):
18 """Data isolation modes for push arcs."""
19 COPY = "copy" # Deep copy data when pushing
20 REFERENCE = "reference" # Pass data by reference
21 SERIALIZE = "serialize" # Serialize/deserialize for isolation
24@dataclass
25class ArcDefinition:
26 """Definition of an arc between states.
28 This class defines the static properties of an arc,
29 including the transition logic and resource requirements.
30 """
32 target_state: str
33 pre_test: str | None = None
34 transform: str | None = None
35 priority: int = 0 # Higher priority arcs are evaluated first
36 definition_order: int = 0 # Track definition order for stable sorting
37 metadata: Dict[str, Any] = field(default_factory=dict)
39 # Resource requirements for this arc
40 required_resources: Dict[str, str] = field(default_factory=dict)
41 # e.g., {'database': 'main_db', 'llm': 'gpt4'}
43 def __hash__(self) -> int:
44 """Make ArcDefinition hashable."""
45 return hash((
46 self.target_state,
47 self.pre_test,
48 self.transform,
49 self.priority
50 ))
53@dataclass
54class PushArc(ArcDefinition):
55 """Arc that pushes to a sub-network.
57 Push arcs allow hierarchical state machine composition
58 by pushing execution to a sub-network and returning
59 when the sub-network completes.
60 """
62 target_network: str = "" # Name of the target network
63 return_state: str | None = None # State to return to after sub-network
64 isolation_mode: DataIsolationMode = DataIsolationMode.COPY
65 pass_context: bool = True # Whether to pass execution context
67 # Mapping of data from parent to child network
68 data_mapping: Dict[str, str] = field(default_factory=dict)
69 # e.g., {'parent_field': 'child_field'}
71 # Mapping of results from child to parent network
72 result_mapping: Dict[str, str] = field(default_factory=dict)
73 # e.g., {'child_result': 'parent_field'}
76class ArcExecution:
77 """Handles the execution of arc transitions.
79 This class manages the runtime execution of arcs,
80 including resource allocation, streaming support,
81 and transaction participation.
82 """
84 def __init__(
85 self,
86 arc_def: ArcDefinition,
87 source_state: str,
88 function_registry
89 ):
90 """Initialize arc execution.
92 Args:
93 arc_def: Arc definition.
94 source_state: Source state name.
95 function_registry: Registry of available functions (FunctionRegistry or dict).
96 """
97 self.arc_def = arc_def
98 self.source_state = source_state
99 self.function_registry = function_registry
101 # Execution statistics
102 self.execution_count = 0
103 self.success_count = 0
104 self.failure_count = 0
105 self.total_execution_time = 0.0
107 def _log_warning(self, message: str) -> None:
108 """Log a warning message.
110 Args:
111 message: Warning message to log.
112 """
113 logger.warning(message)
115 def _log_error(self, message: str) -> None:
116 """Log an error message.
118 Args:
119 message: Error message to log.
120 """
121 logger.error(message)
123 def can_execute(
124 self,
125 context: "ExecutionContext",
126 data: Any = None
127 ) -> bool:
128 """Check if arc can be executed.
130 This runs the pre-test function if defined.
132 Args:
133 context: Execution context.
134 data: Current data.
136 Returns:
137 True if arc can be executed.
138 """
139 if not self.arc_def.pre_test:
140 return True
142 # Handle both FunctionRegistry and dict for pre-test function lookup
143 if hasattr(self.function_registry, 'get_function'):
144 # FunctionRegistry object
145 pre_test_func = self.function_registry.get_function(self.arc_def.pre_test)
146 elif isinstance(self.function_registry, dict):
147 # Plain dictionary
148 pre_test_func = self.function_registry.get(self.arc_def.pre_test)
149 else:
150 pre_test_func = None
152 if pre_test_func is None:
153 raise FunctionError(
154 f"Pre-test function '{self.arc_def.pre_test}' not found",
155 from_state=self.source_state,
156 to_state=self.arc_def.target_state
157 )
159 try:
161 # Create function context with resources
162 func_context = self._create_function_context(context)
164 # Execute pre-test
165 result = pre_test_func(data, func_context)
167 # Handle tuple return from InterfaceWrapper (returns (result, error))
168 if isinstance(result, tuple) and len(result) == 2:
169 return bool(result[0])
170 return bool(result)
172 except Exception as e:
173 raise FunctionError(
174 f"Pre-test execution failed: {e}",
175 from_state=self.source_state,
176 to_state=self.arc_def.target_state
177 ) from e
179 def execute(
180 self,
181 context: "ExecutionContext",
182 data: Any = None,
183 stream_enabled: bool = False
184 ) -> Any:
185 """Execute the arc transition.
187 This runs the transform function if defined and
188 manages resource allocation.
190 Args:
191 context: Execution context.
192 data: Current data.
193 stream_enabled: Whether streaming is enabled.
195 Returns:
196 Transformed data.
197 """
198 import time
199 start_time = time.time()
201 try:
202 # Get state resources from context if available
203 state_resources = getattr(context, 'current_state_resources', None)
205 # Allocate required resources (merging with state resources)
206 resources = self._allocate_resources(context, state_resources)
208 # Execute transform if defined
209 if self.arc_def.transform:
210 # Handle both FunctionRegistry and dict
211 if hasattr(self.function_registry, 'get_function'):
212 # FunctionRegistry object
213 transform_func = self.function_registry.get_function(self.arc_def.transform)
214 elif isinstance(self.function_registry, dict):
215 # Plain dictionary
216 transform_func = self.function_registry.get(self.arc_def.transform)
217 else:
218 transform_func = None
220 if transform_func is None:
221 raise FunctionError(
222 f"Transform function '{self.arc_def.transform}' not found",
223 from_state=self.source_state,
224 to_state=self.arc_def.target_state
225 )
227 # Create function context with resources
228 func_context = self._create_function_context(
229 context,
230 resources,
231 stream_enabled
232 )
234 # Handle streaming vs non-streaming execution
235 if stream_enabled and hasattr(transform_func, 'stream_capable'):
236 result = self._execute_streaming(
237 transform_func,
238 data,
239 func_context
240 )
241 else:
242 # Call the transform function properly
243 # Check if it has a transform method (wrapped function)
244 if hasattr(transform_func, 'transform'):
245 result = transform_func.transform(data, func_context)
246 elif callable(transform_func):
247 result = transform_func(data, func_context)
248 else:
249 raise ValueError(f"Transform {self.arc_def.transform} is not callable")
251 # Handle ExecutionResult objects
252 from dataknobs_fsm.functions.base import ExecutionResult
253 if isinstance(result, ExecutionResult):
254 if result.success:
255 result = result.data
256 else:
257 raise FunctionError(
258 result.error or "Transform failed",
259 from_state=self.source_state,
260 to_state=self.arc_def.target_state
261 )
262 else:
263 # No transform, pass data through
264 result = data
266 # Update statistics
267 self.execution_count += 1
268 self.success_count += 1
270 return result
272 except Exception as e:
273 self.execution_count += 1
274 self.failure_count += 1
276 raise FunctionError(
277 f"Arc execution failed: {e}",
278 from_state=self.source_state,
279 to_state=self.arc_def.target_state
280 ) from e
281 finally:
282 elapsed = time.time() - start_time
283 self.total_execution_time += elapsed
285 # Release resources
286 if 'resources' in locals():
287 self._release_resources(context, resources)
289 def execute_with_transaction(
290 self,
291 context: "ExecutionContext",
292 data: Any = None,
293 transaction_id: str | None = None
294 ) -> Any:
295 """Execute arc within a transaction context.
297 Args:
298 context: Execution context.
299 data: Current data.
300 transaction_id: Transaction identifier.
302 Returns:
303 Transformed data.
304 """
305 # Get or create transaction
306 if transaction_id is None:
307 import uuid
308 transaction_id = str(uuid.uuid4())
310 try:
311 # Begin transaction on required resources
312 self._begin_transaction(context, transaction_id)
314 # Execute arc
315 result = self.execute(context, data)
317 # Commit transaction
318 self._commit_transaction(context, transaction_id)
320 return result
322 except Exception:
323 # Rollback transaction
324 self._rollback_transaction(context, transaction_id)
325 raise
327 def execute_push(
328 self,
329 push_arc: PushArc,
330 context: "ExecutionContext",
331 data: Any = None
332 ) -> Any:
333 """Execute a push arc to a sub-network.
335 Args:
336 push_arc: Push arc definition.
337 context: Execution context.
338 data: Current data.
340 Returns:
341 Result from sub-network execution.
342 """
343 # Prepare data for sub-network based on isolation mode
344 if push_arc.isolation_mode == DataIsolationMode.COPY:
345 import copy
346 sub_data = copy.deepcopy(data)
347 elif push_arc.isolation_mode == DataIsolationMode.SERIALIZE:
348 import json
349 serialized = json.dumps(data)
350 sub_data = json.loads(serialized)
351 else:
352 sub_data = data
354 # Apply data mapping
355 if push_arc.data_mapping:
356 mapped_data = {}
357 for parent_field, child_field in push_arc.data_mapping.items():
358 if hasattr(data, parent_field):
359 mapped_data[child_field] = getattr(data, parent_field)
360 elif isinstance(data, dict) and parent_field in data:
361 mapped_data[child_field] = data[parent_field]
362 sub_data = mapped_data
364 # Push context to sub-network
365 context.push_network(push_arc.target_network, push_arc.return_state)
367 # Execute sub-network (this would be handled by execution engine)
368 # For now, we just return the data
369 result = sub_data
371 # Apply result mapping
372 if push_arc.result_mapping:
373 for child_field, parent_field in push_arc.result_mapping.items():
374 if isinstance(result, dict) and child_field in result:
375 if isinstance(data, dict):
376 data[parent_field] = result[child_field]
377 elif hasattr(data, parent_field):
378 setattr(data, parent_field, result[child_field])
380 return result
382 def _create_function_context(
383 self,
384 exec_context: "ExecutionContext",
385 resources: Dict[str, Any] | None = None,
386 stream_enabled: bool = False
387 ) -> FunctionContext:
388 """Create function context for execution.
390 Args:
391 exec_context: Execution context.
392 resources: Allocated resources.
393 stream_enabled: Whether streaming is enabled.
395 Returns:
396 Function context.
397 """
398 return FunctionContext(
399 state_name=self.source_state,
400 function_name=self.arc_def.transform or self.arc_def.pre_test,
401 metadata={
402 'source_state': self.source_state,
403 'target_state': self.arc_def.target_state,
404 'arc_priority': self.arc_def.priority,
405 'stream_enabled': stream_enabled
406 },
407 resources=resources or {}
408 )
410 def _allocate_resources(
411 self,
412 context: "ExecutionContext",
413 state_resources: Dict[str, Any] | None = None
414 ) -> Dict[str, Any]:
415 """Allocate required resources for arc execution, merging with state resources.
417 Args:
418 context: Execution context.
419 state_resources: Already allocated state resources to merge with.
421 Returns:
422 Dictionary of merged resources (state + arc-specific).
423 """
424 # Start with state resources if provided
425 resources = dict(state_resources) if state_resources else {}
427 # Get resource manager from context
428 resource_manager = getattr(context, 'resource_manager', None)
429 if not resource_manager:
430 # No resource manager available - return existing resources
431 return resources
433 # Generate unique owner ID for this arc execution
434 # Create an arc identifier from source and target states
435 arc_identifier = f"{self.source_state}_to_{self.arc_def.target_state}"
436 owner_id = f"arc_{arc_identifier}_{getattr(context, 'execution_id', 'unknown')}"
438 for resource_type, resource_name in self.arc_def.required_resources.items():
439 # Skip if already have this resource from state
440 if resource_type in resources:
441 self._log_warning(
442 f"Arc resource '{resource_type}' already allocated by state, skipping"
443 )
444 continue
446 try:
447 # Acquire arc-specific resource
448 resource = resource_manager.acquire(
449 name=resource_name,
450 owner_id=owner_id,
451 timeout=30.0 # 30 second timeout
452 )
453 resources[resource_type] = resource
455 # Track for cleanup (only arc-specific resources)
456 if not hasattr(context, '_arc_acquired_resources'):
457 context._arc_acquired_resources = {}
458 context._arc_acquired_resources[resource_name] = owner_id
460 except Exception as e:
461 # Resource acquisition failed - clean up only arc-specific resources
462 self._release_arc_resources(context, getattr(context, '_arc_acquired_resources', {}))
463 raise ResourceError(
464 resource_id=resource_name,
465 message=f"Failed to acquire arc resource: {e}",
466 details={"operation": "acquire", "error": str(e)}
467 ) from e
469 return resources
471 def _release_arc_resources(
472 self,
473 context: "ExecutionContext",
474 arc_resources: Dict[str, str]
475 ) -> None:
476 """Release only arc-specific resources, not state resources.
478 Args:
479 context: Execution context.
480 arc_resources: Map of resource_name -> owner_id for arc resources only.
481 """
482 if not arc_resources:
483 return
485 resource_manager = getattr(context, 'resource_manager', None)
486 if not resource_manager:
487 return
489 for resource_name, owner_id in arc_resources.items():
490 try:
491 resource_manager.release(resource_name, owner_id)
492 except Exception as e:
493 self._log_error(f"Failed to release arc resource {resource_name}: {e}")
495 # Clear arc resources tracking
496 if hasattr(context, '_arc_acquired_resources'):
497 context._arc_acquired_resources = {}
499 def _release_resources(
500 self,
501 context: "ExecutionContext",
502 resources: Dict[str, Any]
503 ) -> None:
504 """Release allocated resources.
506 Args:
507 context: Execution context.
508 resources: Resources to release.
509 """
510 # Get resource manager from context
511 resource_manager = getattr(context, 'resource_manager', None)
512 if not resource_manager:
513 return
515 # Get acquired resources from context if available
516 acquired_resources = getattr(context, '_acquired_resources', {})
518 # Release each resource
519 for resource_type in resources.keys():
520 # Find the resource name for this resource type
521 resource_name = None
522 for rtype, rname in self.arc_def.required_resources.items():
523 if rtype == resource_type:
524 resource_name = rname
525 break
527 if resource_name and resource_name in acquired_resources:
528 owner_id = acquired_resources[resource_name]
529 try:
530 resource_manager.release(resource_name, owner_id)
531 # Remove from tracking
532 del acquired_resources[resource_name]
533 except Exception:
534 # Best effort cleanup - don't propagate release errors
535 pass
537 def _execute_streaming(
538 self,
539 func: Callable,
540 data: Any,
541 context: FunctionContext
542 ) -> Any:
543 """Execute function with streaming support.
545 Args:
546 func: Function to execute.
547 data: Input data.
548 context: Function context.
550 Returns:
551 Streamed result.
552 """
553 # This would integrate with the streaming system
554 # For now, we just execute normally
555 return func(data, context)
557 def _begin_transaction(
558 self,
559 context: "ExecutionContext",
560 transaction_id: str
561 ) -> None:
562 """Begin transaction on required resources.
564 Args:
565 context: Execution context.
566 transaction_id: Transaction ID.
567 """
568 # This would interface with transactional resources
569 pass
571 def _commit_transaction(
572 self,
573 context: "ExecutionContext",
574 transaction_id: str
575 ) -> None:
576 """Commit transaction on resources.
578 Args:
579 context: Execution context.
580 transaction_id: Transaction ID.
581 """
582 # This would interface with transactional resources
583 pass
585 def _rollback_transaction(
586 self,
587 context: "ExecutionContext",
588 transaction_id: str
589 ) -> None:
590 """Rollback transaction on resources.
592 Args:
593 context: Execution context.
594 transaction_id: Transaction ID.
595 """
596 # This would interface with transactional resources
597 pass
599 def get_statistics(self) -> Dict[str, Any]:
600 """Get execution statistics.
602 Returns:
603 Dictionary of statistics.
604 """
605 avg_time = 0.0
606 if self.execution_count > 0:
607 avg_time = self.total_execution_time / self.execution_count
609 return {
610 'source_state': self.source_state,
611 'target_state': self.arc_def.target_state,
612 'execution_count': self.execution_count,
613 'success_count': self.success_count,
614 'failure_count': self.failure_count,
615 'total_execution_time': self.total_execution_time,
616 'average_execution_time': avg_time,
617 'success_rate': (
618 self.success_count / self.execution_count
619 if self.execution_count > 0 else 0.0
620 )
621 }