Coverage for src/dataknobs_fsm/functions/manager.py: 28%

257 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""Unified function management for FSM. 

2 

3This module provides a central, robust system for managing both sync and async functions 

4across all FSM components. It handles function registration, wrapping, resolution, and execution 

5in a consistent manner. 

6""" 

7 

8import asyncio 

9import inspect 

10from typing import Any, Callable, Dict, Union, Protocol, runtime_checkable 

11from enum import Enum 

12import logging 

13 

14from dataknobs_fsm.functions.base import ( 

15 IValidationFunction, 

16 ITransformFunction, 

17 IStateTestFunction, 

18 IEndStateTestFunction, 

19 ExecutionResult 

20) 

21 

22logger = logging.getLogger(__name__) 

23 

24 

25class FunctionSource(Enum): 

26 """Source of a function definition.""" 

27 REGISTERED = "registered" # Explicitly registered function 

28 INLINE = "inline" # Inline code string 

29 BUILTIN = "builtin" # Built-in FSM function 

30 REFERENCE = "reference" # Reference to registered function 

31 

32 

33@runtime_checkable 

34class AsyncCallable(Protocol): 

35 """Protocol for async callable objects.""" 

36 async def __call__(self, *args: Any, **kwargs: Any) -> Any: 

37 """Call the async function.""" 

38 ... 

39 

40 

41class FunctionWrapper: 

42 """Unified wrapper for all function types. 

43 

44 This wrapper handles both sync and async functions uniformly, 

45 preserving their async nature and providing consistent interfaces. 

46 """ 

47 

48 def __init__( 

49 self, 

50 func: Callable, 

51 name: str, 

52 source: FunctionSource = FunctionSource.REGISTERED, 

53 interface: type | None = None 

54 ): 

55 """Initialize function wrapper. 

56 

57 Args: 

58 func: The actual function (sync or async) 

59 name: Function name for identification 

60 source: Where the function came from 

61 interface: Optional interface the function should implement 

62 """ 

63 self.func = func 

64 self.name = name 

65 self.source = source 

66 self.interface = interface 

67 

68 # Determine if function is async 

69 self._is_async = self._check_async(func) 

70 

71 # Store original function metadata 

72 self.__name__ = getattr(func, '__name__', name) 

73 self.__doc__ = getattr(func, '__doc__', '') 

74 

75 def _check_async(self, func: Callable) -> bool: 

76 """Check if a function is async. 

77 

78 Args: 

79 func: Function to check 

80 

81 Returns: 

82 True if async, False otherwise 

83 """ 

84 # Direct coroutine function check 

85 if asyncio.iscoroutinefunction(func): 

86 return True 

87 

88 # Check for async __call__ method (for callable objects) 

89 # But not for regular functions which also have __call__ 

90 if callable(func) and not inspect.isfunction(func) and not inspect.ismethod(func): 

91 # Check if the __call__ method itself is async 

92 try: 

93 if asyncio.iscoroutinefunction(func.__call__): # type: ignore[operator] 

94 return True 

95 except AttributeError: 

96 pass 

97 

98 return False 

99 

100 @property 

101 def is_async(self) -> bool: 

102 """Check if wrapped function is async.""" 

103 return self._is_async 

104 

105 async def execute_async(self, *args: Any, **kwargs: Any) -> Any: 

106 """Execute the function asynchronously. 

107 

108 Args: 

109 *args: Positional arguments 

110 **kwargs: Keyword arguments 

111 

112 Returns: 

113 Function result 

114 """ 

115 if self._is_async: 

116 # Direct async execution 

117 result = await self.func(*args, **kwargs) 

118 else: 

119 # Run sync function in executor to avoid blocking 

120 loop = asyncio.get_event_loop() 

121 result = await loop.run_in_executor(None, self.func, *args, **kwargs) 

122 

123 return result 

124 

125 def execute_sync(self, *args: Any, **kwargs: Any) -> Any: 

126 """Execute the function synchronously. 

127 

128 Args: 

129 *args: Positional arguments 

130 **kwargs: Keyword arguments 

131 

132 Returns: 

133 Function result 

134 

135 Raises: 

136 RuntimeError: If trying to execute async function synchronously 

137 """ 

138 if self._is_async: 

139 raise RuntimeError( 

140 f"Cannot execute async function '{self.name}' synchronously. " 

141 "Use execute_async instead." 

142 ) 

143 

144 return self.func(*args, **kwargs) 

145 

146 def __call__(self, *args: Any, **kwargs: Any) -> Any: 

147 """Call the wrapped function. 

148 

149 This preserves the async nature of the wrapped function. 

150 """ 

151 if self._is_async: 

152 # Return coroutine for async functions 

153 return self.execute_async(*args, **kwargs) 

154 else: 

155 # Direct call for sync functions 

156 return self.func(*args, **kwargs) 

157 

158 # Make wrapper detectable as async when wrapping async functions 

159 def __getattr__(self, name): 

160 """Forward attribute access to wrapped function.""" 

161 if name == '_is_coroutine' and self._is_async: 

162 # Mark as coroutine function for asyncio detection 

163 return asyncio.coroutines._is_coroutine 

164 return getattr(self.func, name) 

165 

166 def __repr__(self) -> str: 

167 """String representation.""" 

168 return ( 

169 f"FunctionWrapper(name={self.name}, " 

170 f"async={self._is_async}, source={self.source.value})" 

171 ) 

172 

173 

174class InterfaceWrapper: 

175 """Wrapper that adapts functions to specific FSM interfaces.""" 

176 

177 def __init__(self, wrapper: FunctionWrapper, interface: type): 

178 """Initialize interface wrapper. 

179 

180 Args: 

181 wrapper: The function wrapper 

182 interface: The interface to implement 

183 """ 

184 self.wrapper = wrapper 

185 self.interface = interface 

186 self._setup_interface_methods() 

187 

188 def _setup_interface_methods(self): 

189 """Set up methods based on interface.""" 

190 if self.interface == ITransformFunction: 

191 self.transform = self._create_method('transform') 

192 self.get_transform_description = lambda: f"Transform: {self.wrapper.name}" 

193 

194 elif self.interface == IValidationFunction: 

195 self.validate = self._create_method('validate') 

196 self.get_validation_rules = lambda: {"name": self.wrapper.name} 

197 

198 elif self.interface == IStateTestFunction: 

199 self.test = self._create_test_method() 

200 self.get_test_description = lambda: f"Test: {self.wrapper.name}" 

201 

202 elif self.interface == IEndStateTestFunction: 

203 self.should_end = self._create_test_method() 

204 self.get_end_condition = lambda: f"End test: {self.wrapper.name}" 

205 

206 def _create_method(self, method_name: str): 

207 """Create an interface method that wraps the function. 

208 

209 Args: 

210 method_name: Name of the interface method 

211 

212 Returns: 

213 Method that calls the wrapped function 

214 """ 

215 # Check if the function expects a single state argument (common for inline lambdas) 

216 import inspect 

217 func = self.wrapper.func 

218 try: 

219 sig = inspect.signature(func) 

220 param_count = len(sig.parameters) 

221 # If function takes only 1 param, it likely expects a state object 

222 expects_state_obj = param_count == 1 

223 except Exception: 

224 # Can't determine signature, assume standard (data, context) 

225 expects_state_obj = False 

226 

227 if self.wrapper.is_async: 

228 async def async_method(data: Any, context: Dict[str, Any] | None = None) -> Any: 

229 if expects_state_obj: 

230 # Wrap data for functions expecting state.data pattern 

231 from dataknobs_fsm.core.data_wrapper import wrap_for_lambda 

232 state_obj = wrap_for_lambda(data) 

233 result = await self.wrapper.execute_async(state_obj) 

234 else: 

235 result = await self.wrapper.execute_async(data, context) 

236 if method_name in ['validate', 'transform']: 

237 # Wrap in ExecutionResult if needed 

238 if not isinstance(result, ExecutionResult): 

239 return ExecutionResult.success_result(result) 

240 return result 

241 return async_method 

242 else: 

243 def sync_method(data: Any, context: Dict[str, Any] | None = None) -> Any: 

244 if expects_state_obj: 

245 # Wrap data for functions expecting state.data pattern 

246 from dataknobs_fsm.core.data_wrapper import wrap_for_lambda 

247 state_obj = wrap_for_lambda(data) 

248 result = self.wrapper.execute_sync(state_obj) 

249 else: 

250 result = self.wrapper.execute_sync(data, context) 

251 if method_name in ['validate', 'transform']: 

252 # Wrap in ExecutionResult if needed 

253 if not isinstance(result, ExecutionResult): 

254 return ExecutionResult.success_result(result) 

255 return result 

256 return sync_method 

257 

258 def _create_test_method(self): 

259 """Create a test method that returns (bool, reason).""" 

260 # Check if the function expects a single state argument (common for inline lambdas) 

261 import inspect 

262 func = self.wrapper.func 

263 try: 

264 sig = inspect.signature(func) 

265 param_count = len(sig.parameters) 

266 # If function takes only 1 param, it likely expects a state object 

267 expects_state_obj = param_count == 1 

268 except Exception: 

269 # Can't determine signature, assume standard (data, context) 

270 expects_state_obj = False 

271 

272 if self.wrapper.is_async: 

273 async def async_test(data: Any, context: Dict[str, Any] | None = None): 

274 if expects_state_obj: 

275 # Wrap data for functions expecting state.data pattern 

276 from dataknobs_fsm.core.data_wrapper import wrap_for_lambda 

277 state_obj = wrap_for_lambda(data) 

278 result = await self.wrapper.execute_async(state_obj) 

279 else: 

280 result = await self.wrapper.execute_async(data, context) 

281 if isinstance(result, tuple): 

282 return result 

283 return (bool(result), None) 

284 return async_test 

285 else: 

286 def sync_test(data: Any, context: Dict[str, Any] | None = None): 

287 if expects_state_obj: 

288 # Wrap data for functions expecting state.data pattern 

289 from dataknobs_fsm.core.data_wrapper import wrap_for_lambda 

290 state_obj = wrap_for_lambda(data) 

291 result = self.wrapper.execute_sync(state_obj) 

292 else: 

293 result = self.wrapper.execute_sync(data, context) 

294 if isinstance(result, tuple): 

295 return result 

296 return (bool(result), None) 

297 return sync_test 

298 

299 def __call__(self, *args: Any, **kwargs: Any) -> Any: 

300 """Make the wrapper callable.""" 

301 return self.wrapper(*args, **kwargs) 

302 

303 @property 

304 def is_async(self) -> bool: 

305 """Check if wrapped function is async.""" 

306 return self.wrapper.is_async 

307 

308 @property 

309 def __name__(self) -> str: 

310 """Get function name.""" 

311 return self.wrapper.__name__ 

312 

313 @property 

314 def _is_async(self) -> bool: 

315 """Expose _is_async for detection.""" 

316 return self.wrapper.is_async 

317 

318 

319class FunctionManager: 

320 """Central manager for all FSM functions. 

321 

322 This class provides a unified interface for registering, resolving, 

323 and managing functions across the entire FSM system. 

324 """ 

325 

326 def __init__(self): 

327 """Initialize function manager.""" 

328 self._functions: Dict[str, FunctionWrapper] = {} 

329 self._builtin_functions: Dict[str, FunctionWrapper] = {} 

330 self._inline_cache: Dict[str, FunctionWrapper] = {} 

331 

332 def register_function( 

333 self, 

334 name: str, 

335 func: Callable, 

336 source: FunctionSource = FunctionSource.REGISTERED, 

337 interface: type | None = None 

338 ) -> FunctionWrapper: 

339 """Register a function. 

340 

341 Args: 

342 name: Function name 

343 func: The function to register 

344 source: Source of the function 

345 interface: Optional interface to implement 

346 

347 Returns: 

348 FunctionWrapper for the registered function 

349 """ 

350 wrapper = FunctionWrapper(func, name, source, interface) 

351 

352 if source == FunctionSource.BUILTIN: 

353 self._builtin_functions[name] = wrapper 

354 else: 

355 self._functions[name] = wrapper 

356 

357 logger.debug( 

358 f"Registered {'async' if wrapper.is_async else 'sync'} " 

359 f"function '{name}' from {source.value}" 

360 ) 

361 

362 return wrapper 

363 

364 def register_functions( 

365 self, 

366 functions: Dict[str, Callable], 

367 source: FunctionSource = FunctionSource.REGISTERED 

368 ) -> Dict[str, FunctionWrapper]: 

369 """Register multiple functions. 

370 

371 Args: 

372 functions: Dictionary of name -> function 

373 source: Source of the functions 

374 

375 Returns: 

376 Dictionary of name -> wrapper 

377 """ 

378 wrappers = {} 

379 for name, func in functions.items(): 

380 wrappers[name] = self.register_function(name, func, source) 

381 return wrappers 

382 

383 def resolve_function( 

384 self, 

385 reference: Union[str, Dict[str, Any], Callable], 

386 interface: type | None = None 

387 ) -> Union[FunctionWrapper, InterfaceWrapper, None]: 

388 """Resolve a function reference to a wrapper. 

389 

390 Args: 

391 reference: Function reference (name, dict, or callable) 

392 interface: Optional interface to adapt to 

393 

394 Returns: 

395 FunctionWrapper or None if not found 

396 """ 

397 wrapper = None 

398 

399 if callable(reference): 

400 # Direct callable 

401 wrapper = FunctionWrapper( 

402 reference, 

403 getattr(reference, '__name__', 'anonymous'), 

404 FunctionSource.REGISTERED 

405 ) 

406 

407 elif isinstance(reference, str): 

408 # String reference - check registered functions first 

409 if reference in self._functions: 

410 wrapper = self._functions[reference] 

411 elif reference in self._builtin_functions: 

412 wrapper = self._builtin_functions[reference] 

413 else: 

414 # Treat as inline code 

415 wrapper = self._create_inline_wrapper(reference) 

416 

417 elif isinstance(reference, dict): 

418 # Dictionary reference 

419 ref_type = reference.get('type', 'inline') 

420 

421 if ref_type == 'registered': 

422 name = reference.get('name') 

423 if name: 

424 wrapper = self._functions.get(name) or self._builtin_functions.get(name) 

425 

426 elif ref_type == 'inline': 

427 code = reference.get('code') 

428 if code: 

429 wrapper = self._create_inline_wrapper(code) 

430 

431 # Apply interface if needed 

432 if wrapper and interface: 

433 return self._adapt_to_interface(wrapper, interface) 

434 

435 return wrapper 

436 

437 def _create_inline_wrapper(self, code: str) -> FunctionWrapper: 

438 """Create a wrapper for inline code. 

439 

440 Args: 

441 code: Python code string 

442 

443 Returns: 

444 FunctionWrapper for the inline code 

445 """ 

446 # Check cache first 

447 if code in self._inline_cache: 

448 return self._inline_cache[code] 

449 

450 # Compile and create function 

451 try: 

452 # Create a namespace for execution with registered functions 

453 namespace = {'asyncio': asyncio} 

454 

455 # Add all registered functions to namespace so inline code can call them 

456 for name, wrapper in self._functions.items(): 

457 # Add the actual function, not the wrapper 

458 namespace[name] = wrapper.func if hasattr(wrapper, 'func') else wrapper 

459 

460 # First try to exec the code directly (might be a full function definition) 

461 try: 

462 # Store the initial set of names 

463 initial_names = set(namespace.keys()) 

464 

465 exec(code, namespace) 

466 

467 # Find any newly defined function 

468 func = None 

469 new_names = set(namespace.keys()) - initial_names 

470 

471 # Look through newly defined names for a callable 

472 for name in new_names: 

473 if callable(namespace[name]): 

474 func = namespace[name] 

475 break 

476 except Exception: 

477 func = None 

478 

479 if not func: 

480 # Check if it's a lambda expression 

481 if code.strip().startswith('lambda'): 

482 # Evaluate lambda directly 

483 func = eval(code, namespace) 

484 else: 

485 # Treat as function body - check if it needs to be async 

486 if 'await' in code: 

487 # Create async wrapper 

488 func_def = "async def inline_func(data, context=None):\n" 

489 else: 

490 # Create sync wrapper 

491 func_def = "def inline_func(data, context=None):\n" 

492 

493 # Add the code as the function body 

494 lines = code.split(';') if ';' in code else [code] 

495 

496 # Check if this looks like a simple expression (for conditions) 

497 # Common patterns: comparisons, boolean ops, method calls that return bool 

498 is_expression = ( 

499 '==' in code or '!=' in code or '<' in code or '>' in code or 

500 ' and ' in code or ' or ' in code or ' not ' in code or 

501 code.strip().startswith('not ') or 

502 '.get(' in code or 

503 'in ' in code or 

504 code.strip() in ['True', 'False'] 

505 ) 

506 

507 if is_expression and 'return' not in code and len(lines) == 1: 

508 # For expressions, return the expression result 

509 func_def += f" return {code.strip()}\n" 

510 else: 

511 # For statements, add them as-is 

512 for line in lines: 

513 stmt = line.strip() 

514 if stmt: 

515 func_def += f" {stmt}\n" 

516 

517 # Ensure we return data if no explicit return (for transforms) 

518 if 'return' not in code: 

519 func_def += " return data\n" 

520 

521 exec(func_def, namespace) 

522 func = namespace.get('inline_func') 

523 

524 if func is not None and callable(func): 

525 wrapper = FunctionWrapper(func, f"inline_{id(code)}", FunctionSource.INLINE) 

526 self._inline_cache[code] = wrapper 

527 return wrapper 

528 else: 

529 # Failed to create function 

530 raise ValueError(f"Failed to create inline function from code: {code}") 

531 

532 except Exception as e: 

533 logger.error(f"Failed to create inline function: {e}") 

534 # Return a no-op wrapper 

535 return FunctionWrapper( 

536 lambda data, context=None: data, # noqa: ARG005 

537 f"inline_error_{id(code)}", 

538 FunctionSource.INLINE 

539 ) 

540 

541 def _adapt_to_interface( 

542 self, 

543 wrapper: FunctionWrapper, 

544 interface: type 

545 ) -> Union[InterfaceWrapper, FunctionWrapper]: 

546 """Adapt a wrapper to implement a specific interface. 

547 

548 Args: 

549 wrapper: The function wrapper 

550 interface: The interface to implement 

551 

552 Returns: 

553 InterfaceWrapper that implements the interface 

554 """ 

555 return InterfaceWrapper(wrapper, interface) 

556 

557 def get_function(self, name: str) -> FunctionWrapper | None: 

558 """Get a registered function by name. 

559 

560 Args: 

561 name: Function name 

562 

563 Returns: 

564 FunctionWrapper or None 

565 """ 

566 return self._functions.get(name) or self._builtin_functions.get(name) 

567 

568 def has_function(self, name: str) -> bool: 

569 """Check if a function is registered. 

570 

571 Args: 

572 name: Function name 

573 

574 Returns: 

575 True if registered 

576 """ 

577 return name in self._functions or name in self._builtin_functions 

578 

579 def list_functions(self) -> Dict[str, Dict[str, Any]]: 

580 """List all registered functions. 

581 

582 Returns: 

583 Dictionary of function info 

584 """ 

585 result = {} 

586 

587 for name, wrapper in self._functions.items(): 

588 result[name] = { 

589 'source': wrapper.source.value, 

590 'async': wrapper.is_async, 

591 'type': 'registered' 

592 } 

593 

594 for name, wrapper in self._builtin_functions.items(): 

595 result[name] = { 

596 'source': wrapper.source.value, 

597 'async': wrapper.is_async, 

598 'type': 'builtin' 

599 } 

600 

601 return result 

602 

603 def clear(self): 

604 """Clear all registered functions except builtins.""" 

605 self._functions.clear() 

606 self._inline_cache.clear() 

607 

608 def clear_all(self): 

609 """Clear all functions including builtins.""" 

610 self.clear() 

611 self._builtin_functions.clear() 

612 

613 

614# Global function manager instance 

615_global_manager = FunctionManager() 

616 

617 

618def get_function_manager() -> FunctionManager: 

619 """Get the global function manager instance. 

620 

621 Returns: 

622 The global FunctionManager 

623 """ 

624 return _global_manager 

625 

626 

627def register_function( 

628 name: str, 

629 func: Callable, 

630 source: FunctionSource = FunctionSource.REGISTERED 

631) -> FunctionWrapper: 

632 """Register a function with the global manager. 

633 

634 Args: 

635 name: Function name 

636 func: The function 

637 source: Function source 

638 

639 Returns: 

640 FunctionWrapper 

641 """ 

642 return _global_manager.register_function(name, func, source) 

643 

644 

645def resolve_function( 

646 reference: Union[str, Dict[str, Any], Callable], 

647 interface: type | None = None 

648) -> Union[FunctionWrapper, InterfaceWrapper, None]: 

649 """Resolve a function reference. 

650 

651 Args: 

652 reference: Function reference 

653 interface: Optional interface 

654 

655 Returns: 

656 FunctionWrapper or None 

657 """ 

658 return _global_manager.resolve_function(reference, interface)