Coverage for src/dataknobs_fsm/core/network.py: 44%
247 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
1"""State network implementation for FSM."""
3from dataclasses import dataclass, field
4from typing import Any, Dict, List, Set, Tuple
6from dataknobs_fsm.core.state import State
9@dataclass
10class Arc:
11 """Represents an arc (transition) between states.
13 Attributes:
14 source_state: Name of the source state.
15 target_state: Name of the target state.
16 pre_test: Optional pre-test function name.
17 transform: Optional transform function name.
18 metadata: Additional arc metadata.
19 """
20 source_state: str
21 target_state: str
22 pre_test: str | None = None
23 transform: str | None = None
24 metadata: Dict[str, Any] = field(default_factory=dict)
26 def __hash__(self) -> int:
27 """Make Arc hashable for use in sets."""
28 return hash((self.source_state, self.target_state, self.pre_test, self.transform))
30 def __eq__(self, other: object) -> bool:
31 """Check equality."""
32 if not isinstance(other, Arc):
33 return False
34 return (
35 self.source_state == other.source_state and
36 self.target_state == other.target_state and
37 self.pre_test == other.pre_test and
38 self.transform == other.transform
39 )
41 @property
42 def name(self) -> str:
43 """Generate a name for the arc."""
44 # Use metadata name if available, otherwise generate from states
45 if 'name' in self.metadata:
46 return self.metadata['name']
47 return f"{self.source_state}->{self.target_state}"
50@dataclass
51class NetworkResourceRequirements:
52 """Aggregated resource requirements for a network.
54 Attributes:
55 databases: Set of required database resources.
56 filesystems: Set of required filesystem resources.
57 http_services: Set of required HTTP service resources.
58 llms: Set of required LLM resources.
59 custom: Dictionary of custom resource requirements.
60 streaming_enabled: Whether any state requires streaming.
61 estimated_memory_mb: Estimated memory requirement in MB.
62 """
63 databases: Set[str] = field(default_factory=set)
64 filesystems: Set[str] = field(default_factory=set)
65 http_services: Set[str] = field(default_factory=set)
66 llms: Set[str] = field(default_factory=set)
67 custom: Dict[str, Set[str]] = field(default_factory=dict)
68 streaming_enabled: bool = False
69 estimated_memory_mb: int = 0
71 def merge(self, other: "NetworkResourceRequirements") -> None:
72 """Merge another set of requirements into this one.
74 Args:
75 other: Requirements to merge.
76 """
77 self.databases.update(other.databases)
78 self.filesystems.update(other.filesystems)
79 self.http_services.update(other.http_services)
80 self.llms.update(other.llms)
82 for key, values in other.custom.items():
83 if key not in self.custom:
84 self.custom[key] = set()
85 self.custom[key].update(values)
87 self.streaming_enabled = self.streaming_enabled or other.streaming_enabled
88 self.estimated_memory_mb = max(self.estimated_memory_mb, other.estimated_memory_mb)
90 def is_empty(self) -> bool:
91 """Check if there are no resource requirements.
93 Returns:
94 True if no resources are required.
95 """
96 return (
97 not self.databases and
98 not self.filesystems and
99 not self.http_services and
100 not self.llms and
101 not self.custom
102 )
105class StateNetwork:
106 """Represents a network of states and their transitions.
108 A state network is a directed graph where nodes are states
109 and edges are arcs (transitions) between states.
110 """
112 def __init__(self, name: str, description: str | None = None):
113 """Initialize state network.
115 Args:
116 name: Network name/identifier.
117 description: Optional network description.
118 """
119 self.name = name
120 self.description = description
122 # State management
123 self._states: Dict[str, State] = {}
124 self._initial_state: str | None = None
125 self._final_states: Set[str] = set()
127 # Arc management
128 self._arcs: List[Arc] = []
129 self._arc_index: Dict[str, List[Arc]] = {} # source_state -> [arcs]
131 # Resource tracking
132 self._resource_requirements = NetworkResourceRequirements()
133 self._streaming_enabled = False
135 # Validation cache
136 self._validation_cache: Dict[str, Any] | None = None
138 @property
139 def states(self) -> Dict[str, State]:
140 """Get all states in the network."""
141 return self._states
143 @property
144 def arcs(self) -> Dict[str, Any]:
145 """Get all arcs in the network."""
146 # Import here to avoid circular dependency
147 from dataknobs_fsm.core.arc import ArcDefinition
149 # Return arcs as a dict indexed by "source:target"
150 # Convert Arc to ArcDefinition for compatibility
151 arc_dict = {}
152 for arc in self._arcs:
153 key = f"{arc.source_state}:{arc.target_state}"
154 # Create ArcDefinition from Arc
155 arc_def = ArcDefinition(
156 target_state=arc.target_state,
157 pre_test=arc.pre_test,
158 transform=arc.transform
159 )
160 # Copy metadata if it exists
161 if hasattr(arc, 'metadata') and arc.metadata:
162 arc_def.metadata = arc.metadata.copy()
163 arc_dict[key] = arc_def
164 return arc_dict
166 @property
167 def initial_states(self) -> Set[str]:
168 """Get initial states (returns set for compatibility)."""
169 if self._initial_state:
170 return {self._initial_state}
171 return set()
173 @property
174 def final_states(self) -> Set[str]:
175 """Get final states."""
176 return self._final_states.copy()
178 def is_initial_state(self, state_name: str) -> bool:
179 """Check if a state is an initial state.
181 Args:
182 state_name: Name of the state to check
184 Returns:
185 True if the state is an initial state
186 """
187 return self._initial_state == state_name
189 def is_final_state(self, state_name: str) -> bool:
190 """Check if a state is a final state.
192 Args:
193 state_name: Name of the state to check
195 Returns:
196 True if the state is a final state
197 """
198 return state_name in self._final_states
200 @property
201 def resource_requirements(self) -> Dict[str, Any]:
202 """Get resource requirements."""
203 return {
204 'databases': self._resource_requirements.databases,
205 'filesystems': self._resource_requirements.filesystems,
206 'http_services': self._resource_requirements.http_services,
207 'llms': self._resource_requirements.llms,
208 'custom': self._resource_requirements.custom
209 }
211 @property
212 def supports_streaming(self) -> bool:
213 """Check if network supports streaming."""
214 return self._streaming_enabled
216 def add_state(
217 self,
218 state: State,
219 initial: bool = False,
220 final: bool = False
221 ) -> None:
222 """Add a state to the network.
224 Args:
225 state: State to add.
226 initial: Mark as initial state.
227 final: Mark as final state.
229 Raises:
230 ValueError: If state with same name already exists.
231 """
232 if state.name in self._states:
233 raise ValueError(f"State '{state.name}' already exists in network")
235 self._states[state.name] = state
237 if initial:
238 if self._initial_state:
239 raise ValueError(
240 f"Initial state already set to '{self._initial_state}'"
241 )
242 self._initial_state = state.name
244 if final:
245 self._final_states.add(state.name)
247 # Update resource requirements
248 self._update_resource_requirements(state)
250 # Invalidate validation cache
251 self._validation_cache = None
253 def remove_state(self, state_name: str) -> None:
254 """Remove a state from the network.
256 Args:
257 state_name: Name of state to remove.
259 Raises:
260 KeyError: If state doesn't exist.
261 """
262 if state_name not in self._states:
263 raise KeyError(f"State '{state_name}' not found in network")
265 # Remove state
266 del self._states[state_name]
268 # Remove from initial/final if needed
269 if self._initial_state == state_name:
270 self._initial_state = None
271 self._final_states.discard(state_name)
273 # Remove arcs involving this state
274 self._arcs = [
275 arc for arc in self._arcs
276 if arc.source_state != state_name and arc.target_state != state_name
277 ]
279 # Rebuild arc index
280 self._rebuild_arc_index()
282 # Recalculate resource requirements
283 self._recalculate_resource_requirements()
285 # Invalidate validation cache
286 self._validation_cache = None
288 def add_arc(
289 self,
290 source_state: str,
291 target_state: str,
292 pre_test: str | None = None,
293 transform: str | None = None,
294 metadata: Dict[str, Any] | None = None
295 ) -> Arc:
296 """Add an arc between two states.
298 Args:
299 source_state: Source state name.
300 target_state: Target state name.
301 pre_test: Optional pre-test function name.
302 transform: Optional transform function name.
303 metadata: Optional arc metadata.
305 Returns:
306 Created arc.
308 Raises:
309 ValueError: If states don't exist.
310 """
311 if source_state not in self._states:
312 raise ValueError(f"Source state '{source_state}' not found")
313 if target_state not in self._states:
314 raise ValueError(f"Target state '{target_state}' not found")
316 arc = Arc(
317 source_state=source_state,
318 target_state=target_state,
319 pre_test=pre_test,
320 transform=transform,
321 metadata=metadata or {}
322 )
324 self._arcs.append(arc)
326 # Update arc index
327 if source_state not in self._arc_index:
328 self._arc_index[source_state] = []
329 self._arc_index[source_state].append(arc)
331 # Invalidate validation cache
332 self._validation_cache = None
334 return arc
336 def remove_arc(self, arc: Arc) -> None:
337 """Remove an arc from the network.
339 Args:
340 arc: Arc to remove.
342 Raises:
343 ValueError: If arc doesn't exist.
344 """
345 if arc not in self._arcs:
346 raise ValueError("Arc not found in network")
348 self._arcs.remove(arc)
350 # Update arc index
351 if arc.source_state in self._arc_index:
352 self._arc_index[arc.source_state].remove(arc)
353 if not self._arc_index[arc.source_state]:
354 del self._arc_index[arc.source_state]
356 # Invalidate validation cache
357 self._validation_cache = None
359 def get_state(self, name: str) -> State | None:
360 """Get a state by name.
362 Args:
363 name: State name.
365 Returns:
366 State if found, None otherwise.
367 """
368 return self._states.get(name)
370 def get_arcs_from_state(self, state_name: str) -> List[Arc]:
371 """Get all arcs originating from a state.
373 Args:
374 state_name: Source state name.
376 Returns:
377 List of arcs from the state.
378 """
379 return self._arc_index.get(state_name, [])
381 def get_arcs_to_state(self, state_name: str) -> List[Arc]:
382 """Get all arcs targeting a state.
384 Args:
385 state_name: Target state name.
387 Returns:
388 List of arcs to the state.
389 """
390 return [arc for arc in self._arcs if arc.target_state == state_name]
392 def validate(self) -> Tuple[bool, List[str]]:
393 """Validate network consistency.
395 Returns:
396 Tuple of (is_valid, list_of_errors).
397 """
398 # Return cached result if available
399 if self._validation_cache is not None:
400 return self._validation_cache['is_valid'], self._validation_cache['errors']
402 errors = []
404 # Check for initial state
405 if not self._initial_state:
406 errors.append("No initial state defined")
407 elif self._initial_state not in self._states:
408 errors.append(f"Initial state '{self._initial_state}' not found")
410 # Check for final states
411 if not self._final_states:
412 errors.append("No final states defined")
413 else:
414 for final_state in self._final_states:
415 if final_state not in self._states:
416 errors.append(f"Final state '{final_state}' not found")
418 # Check for unreachable states
419 if self._initial_state:
420 reachable = self._find_reachable_states(self._initial_state)
421 unreachable = set(self._states.keys()) - reachable
422 for state in unreachable:
423 errors.append(f"State '{state}' is unreachable from initial state")
425 # Check for states with no outgoing arcs (except final states)
426 for state_name in self._states:
427 if state_name not in self._final_states:
428 if state_name not in self._arc_index or not self._arc_index[state_name]:
429 errors.append(f"Non-final state '{state_name}' has no outgoing arcs")
431 # Check for cycles that don't include final states
432 cycles = self._find_cycles()
433 for cycle in cycles:
434 if not any(state in self._final_states for state in cycle):
435 errors.append(
436 f"Cycle detected without final states: {' -> '.join(cycle)}"
437 )
439 # Cache validation result
440 is_valid = len(errors) == 0
441 self._validation_cache = {
442 'is_valid': is_valid,
443 'errors': errors
444 }
446 return is_valid, errors
448 def get_resource_requirements(self) -> NetworkResourceRequirements:
449 """Get aggregated resource requirements for the network.
451 Returns:
452 Resource requirements.
453 """
454 return self._resource_requirements
456 def is_streaming_enabled(self) -> bool:
457 """Check if any state in the network requires streaming.
459 Returns:
460 True if streaming is required.
461 """
462 return self._streaming_enabled
464 def analyze_dependencies(self) -> Dict[str, Set[str]]:
465 """Analyze resource dependencies between states.
467 Returns:
468 Dictionary mapping resources to dependent states.
469 """
470 dependencies = {}
472 for state_name, state in self._states.items():
473 if hasattr(state, 'resource_requirements'):
474 for resource in state.resource_requirements:
475 if resource not in dependencies:
476 dependencies[resource] = set()
477 dependencies[resource].add(state_name)
479 return dependencies
481 def _update_resource_requirements(self, state: State) -> None:
482 """Update resource requirements based on a state.
484 Args:
485 state: State to analyze.
486 """
487 if hasattr(state, 'resource_requirements'):
488 reqs = state.resource_requirements
490 # Update resource sets based on type
491 if hasattr(reqs, 'databases'):
492 self._resource_requirements.databases.update(reqs.databases)
493 if hasattr(reqs, 'filesystems'):
494 self._resource_requirements.filesystems.update(reqs.filesystems)
495 if hasattr(reqs, 'http_services'):
496 self._resource_requirements.http_services.update(reqs.http_services)
497 if hasattr(reqs, 'llms'):
498 self._resource_requirements.llms.update(reqs.llms)
500 # Update streaming flag
501 if hasattr(reqs, 'streaming_enabled'):
502 self._streaming_enabled = self._streaming_enabled or reqs.streaming_enabled
504 def _recalculate_resource_requirements(self) -> None:
505 """Recalculate all resource requirements from scratch."""
506 self._resource_requirements = NetworkResourceRequirements()
507 self._streaming_enabled = False
509 for state in self._states.values():
510 self._update_resource_requirements(state)
512 def _rebuild_arc_index(self) -> None:
513 """Rebuild the arc index from scratch."""
514 self._arc_index = {}
515 for arc in self._arcs:
516 if arc.source_state not in self._arc_index:
517 self._arc_index[arc.source_state] = []
518 self._arc_index[arc.source_state].append(arc)
520 def _find_reachable_states(self, start_state: str) -> Set[str]:
521 """Find all states reachable from a given state.
523 Args:
524 start_state: Starting state name.
526 Returns:
527 Set of reachable state names.
528 """
529 reachable = set()
530 to_visit = [start_state]
532 while to_visit:
533 current = to_visit.pop()
534 if current in reachable:
535 continue
537 reachable.add(current)
539 # Add target states of outgoing arcs
540 for arc in self.get_arcs_from_state(current):
541 if arc.target_state not in reachable:
542 to_visit.append(arc.target_state)
544 return reachable
546 def _find_cycles(self) -> List[List[str]]:
547 """Find all cycles in the network.
549 Returns:
550 List of cycles (each cycle is a list of state names).
551 """
552 cycles = []
553 visited = set()
554 rec_stack = []
556 def dfs(state: str) -> None:
557 visited.add(state)
558 rec_stack.append(state)
560 for arc in self.get_arcs_from_state(state):
561 if arc.target_state not in visited:
562 dfs(arc.target_state)
563 elif arc.target_state in rec_stack:
564 # Found a cycle
565 cycle_start = rec_stack.index(arc.target_state)
566 cycle = rec_stack[cycle_start:] + [arc.target_state]
567 cycles.append(cycle)
569 rec_stack.pop()
571 # Start DFS from all unvisited states
572 for state_name in self._states:
573 if state_name not in visited:
574 dfs(state_name)
576 return cycles
578 def to_dict(self) -> Dict[str, Any]:
579 """Convert network to dictionary representation.
581 Returns:
582 Dictionary representation.
583 """
584 return {
585 'name': self.name,
586 'description': self.description,
587 'initial_state': self._initial_state,
588 'final_states': list(self._final_states),
589 'states': {
590 name: state.to_dict() if hasattr(state, 'to_dict') else str(state)
591 for name, state in self._states.items()
592 },
593 'arcs': [
594 {
595 'source': arc.source_state,
596 'target': arc.target_state,
597 'pre_test': arc.pre_test,
598 'transform': arc.transform,
599 'metadata': arc.metadata
600 }
601 for arc in self._arcs
602 ],
603 'resource_requirements': {
604 'databases': list(self._resource_requirements.databases),
605 'filesystems': list(self._resource_requirements.filesystems),
606 'http_services': list(self._resource_requirements.http_services),
607 'llms': list(self._resource_requirements.llms),
608 'custom': {
609 k: list(v) for k, v in self._resource_requirements.custom.items()
610 },
611 'streaming_enabled': self._resource_requirements.streaming_enabled,
612 'estimated_memory_mb': self._resource_requirements.estimated_memory_mb
613 },
614 'streaming_enabled': self._streaming_enabled
615 }
617 @classmethod
618 def from_dict(cls, data: Dict[str, Any]) -> "StateNetwork":
619 """Create network from dictionary representation.
621 Args:
622 data: Dictionary representation.
624 Returns:
625 StateNetwork instance.
626 """
627 network = cls(
628 name=data['name'],
629 description=data.get('description')
630 )
632 # Add states
633 for state_name in data.get('states', {}):
634 # Create basic state (can be enhanced with proper State deserialization)
635 state = State(name=state_name)
636 is_initial = state_name == data.get('initial_state')
637 is_final = state_name in data.get('final_states', [])
638 network.add_state(state, initial=is_initial, final=is_final)
640 # Add arcs
641 for arc_data in data.get('arcs', []):
642 network.add_arc(
643 source_state=arc_data['source'],
644 target_state=arc_data['target'],
645 pre_test=arc_data.get('pre_test'),
646 transform=arc_data.get('transform'),
647 metadata=arc_data.get('metadata', {})
648 )
650 return network