Coverage for src/dataknobs_fsm/core/arc.py: 38%

206 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""Arc implementation for FSM state transitions.""" 

2 

3import logging 

4from dataclasses import dataclass, field 

5from enum import Enum 

6from typing import Any, Callable, Dict, TYPE_CHECKING 

7 

8from dataknobs_fsm.core.exceptions import FunctionError, ResourceError 

9from dataknobs_fsm.functions.base import FunctionContext 

10 

11if TYPE_CHECKING: 

12 from dataknobs_fsm.execution.context import ExecutionContext 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17class DataIsolationMode(Enum): 

18 """Data isolation modes for push arcs.""" 

19 COPY = "copy" # Deep copy data when pushing 

20 REFERENCE = "reference" # Pass data by reference 

21 SERIALIZE = "serialize" # Serialize/deserialize for isolation 

22 

23 

24@dataclass 

25class ArcDefinition: 

26 """Definition of an arc between states. 

27 

28 This class defines the static properties of an arc, 

29 including the transition logic and resource requirements. 

30 """ 

31 

32 target_state: str 

33 pre_test: str | None = None 

34 transform: str | None = None 

35 priority: int = 0 # Higher priority arcs are evaluated first 

36 definition_order: int = 0 # Track definition order for stable sorting 

37 metadata: Dict[str, Any] = field(default_factory=dict) 

38 

39 # Resource requirements for this arc 

40 required_resources: Dict[str, str] = field(default_factory=dict) 

41 # e.g., {'database': 'main_db', 'llm': 'gpt4'} 

42 

43 def __hash__(self) -> int: 

44 """Make ArcDefinition hashable.""" 

45 return hash(( 

46 self.target_state, 

47 self.pre_test, 

48 self.transform, 

49 self.priority 

50 )) 

51 

52 

53@dataclass 

54class PushArc(ArcDefinition): 

55 """Arc that pushes to a sub-network. 

56  

57 Push arcs allow hierarchical state machine composition 

58 by pushing execution to a sub-network and returning 

59 when the sub-network completes. 

60 """ 

61 

62 target_network: str = "" # Name of the target network 

63 return_state: str | None = None # State to return to after sub-network 

64 isolation_mode: DataIsolationMode = DataIsolationMode.COPY 

65 pass_context: bool = True # Whether to pass execution context 

66 

67 # Mapping of data from parent to child network 

68 data_mapping: Dict[str, str] = field(default_factory=dict) 

69 # e.g., {'parent_field': 'child_field'} 

70 

71 # Mapping of results from child to parent network 

72 result_mapping: Dict[str, str] = field(default_factory=dict) 

73 # e.g., {'child_result': 'parent_field'} 

74 

75 

76class ArcExecution: 

77 """Handles the execution of arc transitions. 

78  

79 This class manages the runtime execution of arcs, 

80 including resource allocation, streaming support, 

81 and transaction participation. 

82 """ 

83 

84 def __init__( 

85 self, 

86 arc_def: ArcDefinition, 

87 source_state: str, 

88 function_registry 

89 ): 

90 """Initialize arc execution. 

91 

92 Args: 

93 arc_def: Arc definition. 

94 source_state: Source state name. 

95 function_registry: Registry of available functions (FunctionRegistry or dict). 

96 """ 

97 self.arc_def = arc_def 

98 self.source_state = source_state 

99 self.function_registry = function_registry 

100 

101 # Execution statistics 

102 self.execution_count = 0 

103 self.success_count = 0 

104 self.failure_count = 0 

105 self.total_execution_time = 0.0 

106 

107 def _log_warning(self, message: str) -> None: 

108 """Log a warning message. 

109 

110 Args: 

111 message: Warning message to log. 

112 """ 

113 logger.warning(message) 

114 

115 def _log_error(self, message: str) -> None: 

116 """Log an error message. 

117 

118 Args: 

119 message: Error message to log. 

120 """ 

121 logger.error(message) 

122 

123 def can_execute( 

124 self, 

125 context: "ExecutionContext", 

126 data: Any = None 

127 ) -> bool: 

128 """Check if arc can be executed. 

129  

130 This runs the pre-test function if defined. 

131  

132 Args: 

133 context: Execution context. 

134 data: Current data. 

135  

136 Returns: 

137 True if arc can be executed. 

138 """ 

139 if not self.arc_def.pre_test: 

140 return True 

141 

142 # Handle both FunctionRegistry and dict for pre-test function lookup 

143 if hasattr(self.function_registry, 'get_function'): 

144 # FunctionRegistry object 

145 pre_test_func = self.function_registry.get_function(self.arc_def.pre_test) 

146 elif isinstance(self.function_registry, dict): 

147 # Plain dictionary 

148 pre_test_func = self.function_registry.get(self.arc_def.pre_test) 

149 else: 

150 pre_test_func = None 

151 

152 if pre_test_func is None: 

153 raise FunctionError( 

154 f"Pre-test function '{self.arc_def.pre_test}' not found", 

155 from_state=self.source_state, 

156 to_state=self.arc_def.target_state 

157 ) 

158 

159 try: 

160 

161 # Create function context with resources 

162 func_context = self._create_function_context(context) 

163 

164 # Execute pre-test 

165 result = pre_test_func(data, func_context) 

166 

167 # Handle tuple return from InterfaceWrapper (returns (result, error)) 

168 if isinstance(result, tuple) and len(result) == 2: 

169 return bool(result[0]) 

170 return bool(result) 

171 

172 except Exception as e: 

173 raise FunctionError( 

174 f"Pre-test execution failed: {e}", 

175 from_state=self.source_state, 

176 to_state=self.arc_def.target_state 

177 ) from e 

178 

179 def execute( 

180 self, 

181 context: "ExecutionContext", 

182 data: Any = None, 

183 stream_enabled: bool = False 

184 ) -> Any: 

185 """Execute the arc transition. 

186  

187 This runs the transform function if defined and 

188 manages resource allocation. 

189  

190 Args: 

191 context: Execution context. 

192 data: Current data. 

193 stream_enabled: Whether streaming is enabled. 

194  

195 Returns: 

196 Transformed data. 

197 """ 

198 import time 

199 start_time = time.time() 

200 

201 try: 

202 # Get state resources from context if available 

203 state_resources = getattr(context, 'current_state_resources', None) 

204 

205 # Allocate required resources (merging with state resources) 

206 resources = self._allocate_resources(context, state_resources) 

207 

208 # Execute transform if defined 

209 if self.arc_def.transform: 

210 # Handle both FunctionRegistry and dict 

211 if hasattr(self.function_registry, 'get_function'): 

212 # FunctionRegistry object 

213 transform_func = self.function_registry.get_function(self.arc_def.transform) 

214 elif isinstance(self.function_registry, dict): 

215 # Plain dictionary 

216 transform_func = self.function_registry.get(self.arc_def.transform) 

217 else: 

218 transform_func = None 

219 

220 if transform_func is None: 

221 raise FunctionError( 

222 f"Transform function '{self.arc_def.transform}' not found", 

223 from_state=self.source_state, 

224 to_state=self.arc_def.target_state 

225 ) 

226 

227 # Create function context with resources 

228 func_context = self._create_function_context( 

229 context, 

230 resources, 

231 stream_enabled 

232 ) 

233 

234 # Handle streaming vs non-streaming execution 

235 if stream_enabled and hasattr(transform_func, 'stream_capable'): 

236 result = self._execute_streaming( 

237 transform_func, 

238 data, 

239 func_context 

240 ) 

241 else: 

242 # Call the transform function properly 

243 # Check if it has a transform method (wrapped function) 

244 if hasattr(transform_func, 'transform'): 

245 result = transform_func.transform(data, func_context) 

246 elif callable(transform_func): 

247 result = transform_func(data, func_context) 

248 else: 

249 raise ValueError(f"Transform {self.arc_def.transform} is not callable") 

250 

251 # Handle ExecutionResult objects 

252 from dataknobs_fsm.functions.base import ExecutionResult 

253 if isinstance(result, ExecutionResult): 

254 if result.success: 

255 result = result.data 

256 else: 

257 raise FunctionError( 

258 result.error or "Transform failed", 

259 from_state=self.source_state, 

260 to_state=self.arc_def.target_state 

261 ) 

262 else: 

263 # No transform, pass data through 

264 result = data 

265 

266 # Update statistics 

267 self.execution_count += 1 

268 self.success_count += 1 

269 

270 return result 

271 

272 except Exception as e: 

273 self.execution_count += 1 

274 self.failure_count += 1 

275 

276 raise FunctionError( 

277 f"Arc execution failed: {e}", 

278 from_state=self.source_state, 

279 to_state=self.arc_def.target_state 

280 ) from e 

281 finally: 

282 elapsed = time.time() - start_time 

283 self.total_execution_time += elapsed 

284 

285 # Release resources 

286 if 'resources' in locals(): 

287 self._release_resources(context, resources) 

288 

289 def execute_with_transaction( 

290 self, 

291 context: "ExecutionContext", 

292 data: Any = None, 

293 transaction_id: str | None = None 

294 ) -> Any: 

295 """Execute arc within a transaction context. 

296  

297 Args: 

298 context: Execution context. 

299 data: Current data. 

300 transaction_id: Transaction identifier. 

301  

302 Returns: 

303 Transformed data. 

304 """ 

305 # Get or create transaction 

306 if transaction_id is None: 

307 import uuid 

308 transaction_id = str(uuid.uuid4()) 

309 

310 try: 

311 # Begin transaction on required resources 

312 self._begin_transaction(context, transaction_id) 

313 

314 # Execute arc 

315 result = self.execute(context, data) 

316 

317 # Commit transaction 

318 self._commit_transaction(context, transaction_id) 

319 

320 return result 

321 

322 except Exception: 

323 # Rollback transaction 

324 self._rollback_transaction(context, transaction_id) 

325 raise 

326 

327 def execute_push( 

328 self, 

329 push_arc: PushArc, 

330 context: "ExecutionContext", 

331 data: Any = None 

332 ) -> Any: 

333 """Execute a push arc to a sub-network. 

334  

335 Args: 

336 push_arc: Push arc definition. 

337 context: Execution context. 

338 data: Current data. 

339  

340 Returns: 

341 Result from sub-network execution. 

342 """ 

343 # Prepare data for sub-network based on isolation mode 

344 if push_arc.isolation_mode == DataIsolationMode.COPY: 

345 import copy 

346 sub_data = copy.deepcopy(data) 

347 elif push_arc.isolation_mode == DataIsolationMode.SERIALIZE: 

348 import json 

349 serialized = json.dumps(data) 

350 sub_data = json.loads(serialized) 

351 else: 

352 sub_data = data 

353 

354 # Apply data mapping 

355 if push_arc.data_mapping: 

356 mapped_data = {} 

357 for parent_field, child_field in push_arc.data_mapping.items(): 

358 if hasattr(data, parent_field): 

359 mapped_data[child_field] = getattr(data, parent_field) 

360 elif isinstance(data, dict) and parent_field in data: 

361 mapped_data[child_field] = data[parent_field] 

362 sub_data = mapped_data 

363 

364 # Push context to sub-network 

365 context.push_network(push_arc.target_network, push_arc.return_state) 

366 

367 # Execute sub-network (this would be handled by execution engine) 

368 # For now, we just return the data 

369 result = sub_data 

370 

371 # Apply result mapping 

372 if push_arc.result_mapping: 

373 for child_field, parent_field in push_arc.result_mapping.items(): 

374 if isinstance(result, dict) and child_field in result: 

375 if isinstance(data, dict): 

376 data[parent_field] = result[child_field] 

377 elif hasattr(data, parent_field): 

378 setattr(data, parent_field, result[child_field]) 

379 

380 return result 

381 

382 def _create_function_context( 

383 self, 

384 exec_context: "ExecutionContext", 

385 resources: Dict[str, Any] | None = None, 

386 stream_enabled: bool = False 

387 ) -> FunctionContext: 

388 """Create function context for execution. 

389  

390 Args: 

391 exec_context: Execution context. 

392 resources: Allocated resources. 

393 stream_enabled: Whether streaming is enabled. 

394  

395 Returns: 

396 Function context. 

397 """ 

398 return FunctionContext( 

399 state_name=self.source_state, 

400 function_name=self.arc_def.transform or self.arc_def.pre_test, 

401 metadata={ 

402 'source_state': self.source_state, 

403 'target_state': self.arc_def.target_state, 

404 'arc_priority': self.arc_def.priority, 

405 'stream_enabled': stream_enabled 

406 }, 

407 resources=resources or {} 

408 ) 

409 

410 def _allocate_resources( 

411 self, 

412 context: "ExecutionContext", 

413 state_resources: Dict[str, Any] | None = None 

414 ) -> Dict[str, Any]: 

415 """Allocate required resources for arc execution, merging with state resources. 

416 

417 Args: 

418 context: Execution context. 

419 state_resources: Already allocated state resources to merge with. 

420 

421 Returns: 

422 Dictionary of merged resources (state + arc-specific). 

423 """ 

424 # Start with state resources if provided 

425 resources = dict(state_resources) if state_resources else {} 

426 

427 # Get resource manager from context 

428 resource_manager = getattr(context, 'resource_manager', None) 

429 if not resource_manager: 

430 # No resource manager available - return existing resources 

431 return resources 

432 

433 # Generate unique owner ID for this arc execution 

434 # Create an arc identifier from source and target states 

435 arc_identifier = f"{self.source_state}_to_{self.arc_def.target_state}" 

436 owner_id = f"arc_{arc_identifier}_{getattr(context, 'execution_id', 'unknown')}" 

437 

438 for resource_type, resource_name in self.arc_def.required_resources.items(): 

439 # Skip if already have this resource from state 

440 if resource_type in resources: 

441 self._log_warning( 

442 f"Arc resource '{resource_type}' already allocated by state, skipping" 

443 ) 

444 continue 

445 

446 try: 

447 # Acquire arc-specific resource 

448 resource = resource_manager.acquire( 

449 name=resource_name, 

450 owner_id=owner_id, 

451 timeout=30.0 # 30 second timeout 

452 ) 

453 resources[resource_type] = resource 

454 

455 # Track for cleanup (only arc-specific resources) 

456 if not hasattr(context, '_arc_acquired_resources'): 

457 context._arc_acquired_resources = {} 

458 context._arc_acquired_resources[resource_name] = owner_id 

459 

460 except Exception as e: 

461 # Resource acquisition failed - clean up only arc-specific resources 

462 self._release_arc_resources(context, getattr(context, '_arc_acquired_resources', {})) 

463 raise ResourceError( 

464 resource_id=resource_name, 

465 message=f"Failed to acquire arc resource: {e}", 

466 details={"operation": "acquire", "error": str(e)} 

467 ) from e 

468 

469 return resources 

470 

471 def _release_arc_resources( 

472 self, 

473 context: "ExecutionContext", 

474 arc_resources: Dict[str, str] 

475 ) -> None: 

476 """Release only arc-specific resources, not state resources. 

477 

478 Args: 

479 context: Execution context. 

480 arc_resources: Map of resource_name -> owner_id for arc resources only. 

481 """ 

482 if not arc_resources: 

483 return 

484 

485 resource_manager = getattr(context, 'resource_manager', None) 

486 if not resource_manager: 

487 return 

488 

489 for resource_name, owner_id in arc_resources.items(): 

490 try: 

491 resource_manager.release(resource_name, owner_id) 

492 except Exception as e: 

493 self._log_error(f"Failed to release arc resource {resource_name}: {e}") 

494 

495 # Clear arc resources tracking 

496 if hasattr(context, '_arc_acquired_resources'): 

497 context._arc_acquired_resources = {} 

498 

499 def _release_resources( 

500 self, 

501 context: "ExecutionContext", 

502 resources: Dict[str, Any] 

503 ) -> None: 

504 """Release allocated resources. 

505  

506 Args: 

507 context: Execution context. 

508 resources: Resources to release. 

509 """ 

510 # Get resource manager from context 

511 resource_manager = getattr(context, 'resource_manager', None) 

512 if not resource_manager: 

513 return 

514 

515 # Get acquired resources from context if available 

516 acquired_resources = getattr(context, '_acquired_resources', {}) 

517 

518 # Release each resource 

519 for resource_type in resources.keys(): 

520 # Find the resource name for this resource type 

521 resource_name = None 

522 for rtype, rname in self.arc_def.required_resources.items(): 

523 if rtype == resource_type: 

524 resource_name = rname 

525 break 

526 

527 if resource_name and resource_name in acquired_resources: 

528 owner_id = acquired_resources[resource_name] 

529 try: 

530 resource_manager.release(resource_name, owner_id) 

531 # Remove from tracking 

532 del acquired_resources[resource_name] 

533 except Exception: 

534 # Best effort cleanup - don't propagate release errors 

535 pass 

536 

537 def _execute_streaming( 

538 self, 

539 func: Callable, 

540 data: Any, 

541 context: FunctionContext 

542 ) -> Any: 

543 """Execute function with streaming support. 

544  

545 Args: 

546 func: Function to execute. 

547 data: Input data. 

548 context: Function context. 

549  

550 Returns: 

551 Streamed result. 

552 """ 

553 # This would integrate with the streaming system 

554 # For now, we just execute normally 

555 return func(data, context) 

556 

557 def _begin_transaction( 

558 self, 

559 context: "ExecutionContext", 

560 transaction_id: str 

561 ) -> None: 

562 """Begin transaction on required resources. 

563  

564 Args: 

565 context: Execution context. 

566 transaction_id: Transaction ID. 

567 """ 

568 # This would interface with transactional resources 

569 pass 

570 

571 def _commit_transaction( 

572 self, 

573 context: "ExecutionContext", 

574 transaction_id: str 

575 ) -> None: 

576 """Commit transaction on resources. 

577  

578 Args: 

579 context: Execution context. 

580 transaction_id: Transaction ID. 

581 """ 

582 # This would interface with transactional resources 

583 pass 

584 

585 def _rollback_transaction( 

586 self, 

587 context: "ExecutionContext", 

588 transaction_id: str 

589 ) -> None: 

590 """Rollback transaction on resources. 

591  

592 Args: 

593 context: Execution context. 

594 transaction_id: Transaction ID. 

595 """ 

596 # This would interface with transactional resources 

597 pass 

598 

599 def get_statistics(self) -> Dict[str, Any]: 

600 """Get execution statistics. 

601  

602 Returns: 

603 Dictionary of statistics. 

604 """ 

605 avg_time = 0.0 

606 if self.execution_count > 0: 

607 avg_time = self.total_execution_time / self.execution_count 

608 

609 return { 

610 'source_state': self.source_state, 

611 'target_state': self.arc_def.target_state, 

612 'execution_count': self.execution_count, 

613 'success_count': self.success_count, 

614 'failure_count': self.failure_count, 

615 'total_execution_time': self.total_execution_time, 

616 'average_execution_time': avg_time, 

617 'success_rate': ( 

618 self.success_count / self.execution_count 

619 if self.execution_count > 0 else 0.0 

620 ) 

621 }