Coverage for src/dataknobs_fsm/core/network.py: 44%

247 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""State network implementation for FSM.""" 

2 

3from dataclasses import dataclass, field 

4from typing import Any, Dict, List, Set, Tuple 

5 

6from dataknobs_fsm.core.state import State 

7 

8 

9@dataclass 

10class Arc: 

11 """Represents an arc (transition) between states. 

12  

13 Attributes: 

14 source_state: Name of the source state. 

15 target_state: Name of the target state. 

16 pre_test: Optional pre-test function name. 

17 transform: Optional transform function name. 

18 metadata: Additional arc metadata. 

19 """ 

20 source_state: str 

21 target_state: str 

22 pre_test: str | None = None 

23 transform: str | None = None 

24 metadata: Dict[str, Any] = field(default_factory=dict) 

25 

26 def __hash__(self) -> int: 

27 """Make Arc hashable for use in sets.""" 

28 return hash((self.source_state, self.target_state, self.pre_test, self.transform)) 

29 

30 def __eq__(self, other: object) -> bool: 

31 """Check equality.""" 

32 if not isinstance(other, Arc): 

33 return False 

34 return ( 

35 self.source_state == other.source_state and 

36 self.target_state == other.target_state and 

37 self.pre_test == other.pre_test and 

38 self.transform == other.transform 

39 ) 

40 

41 @property 

42 def name(self) -> str: 

43 """Generate a name for the arc.""" 

44 # Use metadata name if available, otherwise generate from states 

45 if 'name' in self.metadata: 

46 return self.metadata['name'] 

47 return f"{self.source_state}->{self.target_state}" 

48 

49 

50@dataclass 

51class NetworkResourceRequirements: 

52 """Aggregated resource requirements for a network. 

53  

54 Attributes: 

55 databases: Set of required database resources. 

56 filesystems: Set of required filesystem resources. 

57 http_services: Set of required HTTP service resources. 

58 llms: Set of required LLM resources. 

59 custom: Dictionary of custom resource requirements. 

60 streaming_enabled: Whether any state requires streaming. 

61 estimated_memory_mb: Estimated memory requirement in MB. 

62 """ 

63 databases: Set[str] = field(default_factory=set) 

64 filesystems: Set[str] = field(default_factory=set) 

65 http_services: Set[str] = field(default_factory=set) 

66 llms: Set[str] = field(default_factory=set) 

67 custom: Dict[str, Set[str]] = field(default_factory=dict) 

68 streaming_enabled: bool = False 

69 estimated_memory_mb: int = 0 

70 

71 def merge(self, other: "NetworkResourceRequirements") -> None: 

72 """Merge another set of requirements into this one. 

73  

74 Args: 

75 other: Requirements to merge. 

76 """ 

77 self.databases.update(other.databases) 

78 self.filesystems.update(other.filesystems) 

79 self.http_services.update(other.http_services) 

80 self.llms.update(other.llms) 

81 

82 for key, values in other.custom.items(): 

83 if key not in self.custom: 

84 self.custom[key] = set() 

85 self.custom[key].update(values) 

86 

87 self.streaming_enabled = self.streaming_enabled or other.streaming_enabled 

88 self.estimated_memory_mb = max(self.estimated_memory_mb, other.estimated_memory_mb) 

89 

90 def is_empty(self) -> bool: 

91 """Check if there are no resource requirements. 

92  

93 Returns: 

94 True if no resources are required. 

95 """ 

96 return ( 

97 not self.databases and 

98 not self.filesystems and 

99 not self.http_services and 

100 not self.llms and 

101 not self.custom 

102 ) 

103 

104 

105class StateNetwork: 

106 """Represents a network of states and their transitions. 

107  

108 A state network is a directed graph where nodes are states 

109 and edges are arcs (transitions) between states. 

110 """ 

111 

112 def __init__(self, name: str, description: str | None = None): 

113 """Initialize state network. 

114  

115 Args: 

116 name: Network name/identifier. 

117 description: Optional network description. 

118 """ 

119 self.name = name 

120 self.description = description 

121 

122 # State management 

123 self._states: Dict[str, State] = {} 

124 self._initial_state: str | None = None 

125 self._final_states: Set[str] = set() 

126 

127 # Arc management 

128 self._arcs: List[Arc] = [] 

129 self._arc_index: Dict[str, List[Arc]] = {} # source_state -> [arcs] 

130 

131 # Resource tracking 

132 self._resource_requirements = NetworkResourceRequirements() 

133 self._streaming_enabled = False 

134 

135 # Validation cache 

136 self._validation_cache: Dict[str, Any] | None = None 

137 

138 @property 

139 def states(self) -> Dict[str, State]: 

140 """Get all states in the network.""" 

141 return self._states 

142 

143 @property 

144 def arcs(self) -> Dict[str, Any]: 

145 """Get all arcs in the network.""" 

146 # Import here to avoid circular dependency 

147 from dataknobs_fsm.core.arc import ArcDefinition 

148 

149 # Return arcs as a dict indexed by "source:target"  

150 # Convert Arc to ArcDefinition for compatibility 

151 arc_dict = {} 

152 for arc in self._arcs: 

153 key = f"{arc.source_state}:{arc.target_state}" 

154 # Create ArcDefinition from Arc 

155 arc_def = ArcDefinition( 

156 target_state=arc.target_state, 

157 pre_test=arc.pre_test, 

158 transform=arc.transform 

159 ) 

160 # Copy metadata if it exists 

161 if hasattr(arc, 'metadata') and arc.metadata: 

162 arc_def.metadata = arc.metadata.copy() 

163 arc_dict[key] = arc_def 

164 return arc_dict 

165 

166 @property 

167 def initial_states(self) -> Set[str]: 

168 """Get initial states (returns set for compatibility).""" 

169 if self._initial_state: 

170 return {self._initial_state} 

171 return set() 

172 

173 @property 

174 def final_states(self) -> Set[str]: 

175 """Get final states.""" 

176 return self._final_states.copy() 

177 

178 def is_initial_state(self, state_name: str) -> bool: 

179 """Check if a state is an initial state. 

180  

181 Args: 

182 state_name: Name of the state to check 

183  

184 Returns: 

185 True if the state is an initial state 

186 """ 

187 return self._initial_state == state_name 

188 

189 def is_final_state(self, state_name: str) -> bool: 

190 """Check if a state is a final state. 

191  

192 Args: 

193 state_name: Name of the state to check 

194  

195 Returns: 

196 True if the state is a final state 

197 """ 

198 return state_name in self._final_states 

199 

200 @property 

201 def resource_requirements(self) -> Dict[str, Any]: 

202 """Get resource requirements.""" 

203 return { 

204 'databases': self._resource_requirements.databases, 

205 'filesystems': self._resource_requirements.filesystems, 

206 'http_services': self._resource_requirements.http_services, 

207 'llms': self._resource_requirements.llms, 

208 'custom': self._resource_requirements.custom 

209 } 

210 

211 @property 

212 def supports_streaming(self) -> bool: 

213 """Check if network supports streaming.""" 

214 return self._streaming_enabled 

215 

216 def add_state( 

217 self, 

218 state: State, 

219 initial: bool = False, 

220 final: bool = False 

221 ) -> None: 

222 """Add a state to the network. 

223  

224 Args: 

225 state: State to add. 

226 initial: Mark as initial state. 

227 final: Mark as final state. 

228  

229 Raises: 

230 ValueError: If state with same name already exists. 

231 """ 

232 if state.name in self._states: 

233 raise ValueError(f"State '{state.name}' already exists in network") 

234 

235 self._states[state.name] = state 

236 

237 if initial: 

238 if self._initial_state: 

239 raise ValueError( 

240 f"Initial state already set to '{self._initial_state}'" 

241 ) 

242 self._initial_state = state.name 

243 

244 if final: 

245 self._final_states.add(state.name) 

246 

247 # Update resource requirements 

248 self._update_resource_requirements(state) 

249 

250 # Invalidate validation cache 

251 self._validation_cache = None 

252 

253 def remove_state(self, state_name: str) -> None: 

254 """Remove a state from the network. 

255  

256 Args: 

257 state_name: Name of state to remove. 

258  

259 Raises: 

260 KeyError: If state doesn't exist. 

261 """ 

262 if state_name not in self._states: 

263 raise KeyError(f"State '{state_name}' not found in network") 

264 

265 # Remove state 

266 del self._states[state_name] 

267 

268 # Remove from initial/final if needed 

269 if self._initial_state == state_name: 

270 self._initial_state = None 

271 self._final_states.discard(state_name) 

272 

273 # Remove arcs involving this state 

274 self._arcs = [ 

275 arc for arc in self._arcs 

276 if arc.source_state != state_name and arc.target_state != state_name 

277 ] 

278 

279 # Rebuild arc index 

280 self._rebuild_arc_index() 

281 

282 # Recalculate resource requirements 

283 self._recalculate_resource_requirements() 

284 

285 # Invalidate validation cache 

286 self._validation_cache = None 

287 

288 def add_arc( 

289 self, 

290 source_state: str, 

291 target_state: str, 

292 pre_test: str | None = None, 

293 transform: str | None = None, 

294 metadata: Dict[str, Any] | None = None 

295 ) -> Arc: 

296 """Add an arc between two states. 

297  

298 Args: 

299 source_state: Source state name. 

300 target_state: Target state name. 

301 pre_test: Optional pre-test function name. 

302 transform: Optional transform function name. 

303 metadata: Optional arc metadata. 

304  

305 Returns: 

306 Created arc. 

307  

308 Raises: 

309 ValueError: If states don't exist. 

310 """ 

311 if source_state not in self._states: 

312 raise ValueError(f"Source state '{source_state}' not found") 

313 if target_state not in self._states: 

314 raise ValueError(f"Target state '{target_state}' not found") 

315 

316 arc = Arc( 

317 source_state=source_state, 

318 target_state=target_state, 

319 pre_test=pre_test, 

320 transform=transform, 

321 metadata=metadata or {} 

322 ) 

323 

324 self._arcs.append(arc) 

325 

326 # Update arc index 

327 if source_state not in self._arc_index: 

328 self._arc_index[source_state] = [] 

329 self._arc_index[source_state].append(arc) 

330 

331 # Invalidate validation cache 

332 self._validation_cache = None 

333 

334 return arc 

335 

336 def remove_arc(self, arc: Arc) -> None: 

337 """Remove an arc from the network. 

338  

339 Args: 

340 arc: Arc to remove. 

341  

342 Raises: 

343 ValueError: If arc doesn't exist. 

344 """ 

345 if arc not in self._arcs: 

346 raise ValueError("Arc not found in network") 

347 

348 self._arcs.remove(arc) 

349 

350 # Update arc index 

351 if arc.source_state in self._arc_index: 

352 self._arc_index[arc.source_state].remove(arc) 

353 if not self._arc_index[arc.source_state]: 

354 del self._arc_index[arc.source_state] 

355 

356 # Invalidate validation cache 

357 self._validation_cache = None 

358 

359 def get_state(self, name: str) -> State | None: 

360 """Get a state by name. 

361  

362 Args: 

363 name: State name. 

364  

365 Returns: 

366 State if found, None otherwise. 

367 """ 

368 return self._states.get(name) 

369 

370 def get_arcs_from_state(self, state_name: str) -> List[Arc]: 

371 """Get all arcs originating from a state. 

372  

373 Args: 

374 state_name: Source state name. 

375  

376 Returns: 

377 List of arcs from the state. 

378 """ 

379 return self._arc_index.get(state_name, []) 

380 

381 def get_arcs_to_state(self, state_name: str) -> List[Arc]: 

382 """Get all arcs targeting a state. 

383  

384 Args: 

385 state_name: Target state name. 

386  

387 Returns: 

388 List of arcs to the state. 

389 """ 

390 return [arc for arc in self._arcs if arc.target_state == state_name] 

391 

392 def validate(self) -> Tuple[bool, List[str]]: 

393 """Validate network consistency. 

394  

395 Returns: 

396 Tuple of (is_valid, list_of_errors). 

397 """ 

398 # Return cached result if available 

399 if self._validation_cache is not None: 

400 return self._validation_cache['is_valid'], self._validation_cache['errors'] 

401 

402 errors = [] 

403 

404 # Check for initial state 

405 if not self._initial_state: 

406 errors.append("No initial state defined") 

407 elif self._initial_state not in self._states: 

408 errors.append(f"Initial state '{self._initial_state}' not found") 

409 

410 # Check for final states 

411 if not self._final_states: 

412 errors.append("No final states defined") 

413 else: 

414 for final_state in self._final_states: 

415 if final_state not in self._states: 

416 errors.append(f"Final state '{final_state}' not found") 

417 

418 # Check for unreachable states 

419 if self._initial_state: 

420 reachable = self._find_reachable_states(self._initial_state) 

421 unreachable = set(self._states.keys()) - reachable 

422 for state in unreachable: 

423 errors.append(f"State '{state}' is unreachable from initial state") 

424 

425 # Check for states with no outgoing arcs (except final states) 

426 for state_name in self._states: 

427 if state_name not in self._final_states: 

428 if state_name not in self._arc_index or not self._arc_index[state_name]: 

429 errors.append(f"Non-final state '{state_name}' has no outgoing arcs") 

430 

431 # Check for cycles that don't include final states 

432 cycles = self._find_cycles() 

433 for cycle in cycles: 

434 if not any(state in self._final_states for state in cycle): 

435 errors.append( 

436 f"Cycle detected without final states: {' -> '.join(cycle)}" 

437 ) 

438 

439 # Cache validation result 

440 is_valid = len(errors) == 0 

441 self._validation_cache = { 

442 'is_valid': is_valid, 

443 'errors': errors 

444 } 

445 

446 return is_valid, errors 

447 

448 def get_resource_requirements(self) -> NetworkResourceRequirements: 

449 """Get aggregated resource requirements for the network. 

450  

451 Returns: 

452 Resource requirements. 

453 """ 

454 return self._resource_requirements 

455 

456 def is_streaming_enabled(self) -> bool: 

457 """Check if any state in the network requires streaming. 

458  

459 Returns: 

460 True if streaming is required. 

461 """ 

462 return self._streaming_enabled 

463 

464 def analyze_dependencies(self) -> Dict[str, Set[str]]: 

465 """Analyze resource dependencies between states. 

466  

467 Returns: 

468 Dictionary mapping resources to dependent states. 

469 """ 

470 dependencies = {} 

471 

472 for state_name, state in self._states.items(): 

473 if hasattr(state, 'resource_requirements'): 

474 for resource in state.resource_requirements: 

475 if resource not in dependencies: 

476 dependencies[resource] = set() 

477 dependencies[resource].add(state_name) 

478 

479 return dependencies 

480 

481 def _update_resource_requirements(self, state: State) -> None: 

482 """Update resource requirements based on a state. 

483  

484 Args: 

485 state: State to analyze. 

486 """ 

487 if hasattr(state, 'resource_requirements'): 

488 reqs = state.resource_requirements 

489 

490 # Update resource sets based on type 

491 if hasattr(reqs, 'databases'): 

492 self._resource_requirements.databases.update(reqs.databases) 

493 if hasattr(reqs, 'filesystems'): 

494 self._resource_requirements.filesystems.update(reqs.filesystems) 

495 if hasattr(reqs, 'http_services'): 

496 self._resource_requirements.http_services.update(reqs.http_services) 

497 if hasattr(reqs, 'llms'): 

498 self._resource_requirements.llms.update(reqs.llms) 

499 

500 # Update streaming flag 

501 if hasattr(reqs, 'streaming_enabled'): 

502 self._streaming_enabled = self._streaming_enabled or reqs.streaming_enabled 

503 

504 def _recalculate_resource_requirements(self) -> None: 

505 """Recalculate all resource requirements from scratch.""" 

506 self._resource_requirements = NetworkResourceRequirements() 

507 self._streaming_enabled = False 

508 

509 for state in self._states.values(): 

510 self._update_resource_requirements(state) 

511 

512 def _rebuild_arc_index(self) -> None: 

513 """Rebuild the arc index from scratch.""" 

514 self._arc_index = {} 

515 for arc in self._arcs: 

516 if arc.source_state not in self._arc_index: 

517 self._arc_index[arc.source_state] = [] 

518 self._arc_index[arc.source_state].append(arc) 

519 

520 def _find_reachable_states(self, start_state: str) -> Set[str]: 

521 """Find all states reachable from a given state. 

522  

523 Args: 

524 start_state: Starting state name. 

525  

526 Returns: 

527 Set of reachable state names. 

528 """ 

529 reachable = set() 

530 to_visit = [start_state] 

531 

532 while to_visit: 

533 current = to_visit.pop() 

534 if current in reachable: 

535 continue 

536 

537 reachable.add(current) 

538 

539 # Add target states of outgoing arcs 

540 for arc in self.get_arcs_from_state(current): 

541 if arc.target_state not in reachable: 

542 to_visit.append(arc.target_state) 

543 

544 return reachable 

545 

546 def _find_cycles(self) -> List[List[str]]: 

547 """Find all cycles in the network. 

548  

549 Returns: 

550 List of cycles (each cycle is a list of state names). 

551 """ 

552 cycles = [] 

553 visited = set() 

554 rec_stack = [] 

555 

556 def dfs(state: str) -> None: 

557 visited.add(state) 

558 rec_stack.append(state) 

559 

560 for arc in self.get_arcs_from_state(state): 

561 if arc.target_state not in visited: 

562 dfs(arc.target_state) 

563 elif arc.target_state in rec_stack: 

564 # Found a cycle 

565 cycle_start = rec_stack.index(arc.target_state) 

566 cycle = rec_stack[cycle_start:] + [arc.target_state] 

567 cycles.append(cycle) 

568 

569 rec_stack.pop() 

570 

571 # Start DFS from all unvisited states 

572 for state_name in self._states: 

573 if state_name not in visited: 

574 dfs(state_name) 

575 

576 return cycles 

577 

578 def to_dict(self) -> Dict[str, Any]: 

579 """Convert network to dictionary representation. 

580  

581 Returns: 

582 Dictionary representation. 

583 """ 

584 return { 

585 'name': self.name, 

586 'description': self.description, 

587 'initial_state': self._initial_state, 

588 'final_states': list(self._final_states), 

589 'states': { 

590 name: state.to_dict() if hasattr(state, 'to_dict') else str(state) 

591 for name, state in self._states.items() 

592 }, 

593 'arcs': [ 

594 { 

595 'source': arc.source_state, 

596 'target': arc.target_state, 

597 'pre_test': arc.pre_test, 

598 'transform': arc.transform, 

599 'metadata': arc.metadata 

600 } 

601 for arc in self._arcs 

602 ], 

603 'resource_requirements': { 

604 'databases': list(self._resource_requirements.databases), 

605 'filesystems': list(self._resource_requirements.filesystems), 

606 'http_services': list(self._resource_requirements.http_services), 

607 'llms': list(self._resource_requirements.llms), 

608 'custom': { 

609 k: list(v) for k, v in self._resource_requirements.custom.items() 

610 }, 

611 'streaming_enabled': self._resource_requirements.streaming_enabled, 

612 'estimated_memory_mb': self._resource_requirements.estimated_memory_mb 

613 }, 

614 'streaming_enabled': self._streaming_enabled 

615 } 

616 

617 @classmethod 

618 def from_dict(cls, data: Dict[str, Any]) -> "StateNetwork": 

619 """Create network from dictionary representation. 

620  

621 Args: 

622 data: Dictionary representation. 

623  

624 Returns: 

625 StateNetwork instance. 

626 """ 

627 network = cls( 

628 name=data['name'], 

629 description=data.get('description') 

630 ) 

631 

632 # Add states 

633 for state_name in data.get('states', {}): 

634 # Create basic state (can be enhanced with proper State deserialization) 

635 state = State(name=state_name) 

636 is_initial = state_name == data.get('initial_state') 

637 is_final = state_name in data.get('final_states', []) 

638 network.add_state(state, initial=is_initial, final=is_final) 

639 

640 # Add arcs 

641 for arc_data in data.get('arcs', []): 

642 network.add_arc( 

643 source_state=arc_data['source'], 

644 target_state=arc_data['target'], 

645 pre_test=arc_data.get('pre_test'), 

646 transform=arc_data.get('transform'), 

647 metadata=arc_data.get('metadata', {}) 

648 ) 

649 

650 return network