Coverage for src/dataknobs_fsm/functions/manager.py: 28%
257 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
1"""Unified function management for FSM.
3This module provides a central, robust system for managing both sync and async functions
4across all FSM components. It handles function registration, wrapping, resolution, and execution
5in a consistent manner.
6"""
8import asyncio
9import inspect
10from typing import Any, Callable, Dict, Union, Protocol, runtime_checkable
11from enum import Enum
12import logging
14from dataknobs_fsm.functions.base import (
15 IValidationFunction,
16 ITransformFunction,
17 IStateTestFunction,
18 IEndStateTestFunction,
19 ExecutionResult
20)
22logger = logging.getLogger(__name__)
25class FunctionSource(Enum):
26 """Source of a function definition."""
27 REGISTERED = "registered" # Explicitly registered function
28 INLINE = "inline" # Inline code string
29 BUILTIN = "builtin" # Built-in FSM function
30 REFERENCE = "reference" # Reference to registered function
33@runtime_checkable
34class AsyncCallable(Protocol):
35 """Protocol for async callable objects."""
36 async def __call__(self, *args: Any, **kwargs: Any) -> Any:
37 """Call the async function."""
38 ...
41class FunctionWrapper:
42 """Unified wrapper for all function types.
44 This wrapper handles both sync and async functions uniformly,
45 preserving their async nature and providing consistent interfaces.
46 """
48 def __init__(
49 self,
50 func: Callable,
51 name: str,
52 source: FunctionSource = FunctionSource.REGISTERED,
53 interface: type | None = None
54 ):
55 """Initialize function wrapper.
57 Args:
58 func: The actual function (sync or async)
59 name: Function name for identification
60 source: Where the function came from
61 interface: Optional interface the function should implement
62 """
63 self.func = func
64 self.name = name
65 self.source = source
66 self.interface = interface
68 # Determine if function is async
69 self._is_async = self._check_async(func)
71 # Store original function metadata
72 self.__name__ = getattr(func, '__name__', name)
73 self.__doc__ = getattr(func, '__doc__', '')
75 def _check_async(self, func: Callable) -> bool:
76 """Check if a function is async.
78 Args:
79 func: Function to check
81 Returns:
82 True if async, False otherwise
83 """
84 # Direct coroutine function check
85 if asyncio.iscoroutinefunction(func):
86 return True
88 # Check for async __call__ method (for callable objects)
89 # But not for regular functions which also have __call__
90 if callable(func) and not inspect.isfunction(func) and not inspect.ismethod(func):
91 # Check if the __call__ method itself is async
92 try:
93 if asyncio.iscoroutinefunction(func.__call__): # type: ignore[operator]
94 return True
95 except AttributeError:
96 pass
98 return False
100 @property
101 def is_async(self) -> bool:
102 """Check if wrapped function is async."""
103 return self._is_async
105 async def execute_async(self, *args: Any, **kwargs: Any) -> Any:
106 """Execute the function asynchronously.
108 Args:
109 *args: Positional arguments
110 **kwargs: Keyword arguments
112 Returns:
113 Function result
114 """
115 if self._is_async:
116 # Direct async execution
117 result = await self.func(*args, **kwargs)
118 else:
119 # Run sync function in executor to avoid blocking
120 loop = asyncio.get_event_loop()
121 result = await loop.run_in_executor(None, self.func, *args, **kwargs)
123 return result
125 def execute_sync(self, *args: Any, **kwargs: Any) -> Any:
126 """Execute the function synchronously.
128 Args:
129 *args: Positional arguments
130 **kwargs: Keyword arguments
132 Returns:
133 Function result
135 Raises:
136 RuntimeError: If trying to execute async function synchronously
137 """
138 if self._is_async:
139 raise RuntimeError(
140 f"Cannot execute async function '{self.name}' synchronously. "
141 "Use execute_async instead."
142 )
144 return self.func(*args, **kwargs)
146 def __call__(self, *args: Any, **kwargs: Any) -> Any:
147 """Call the wrapped function.
149 This preserves the async nature of the wrapped function.
150 """
151 if self._is_async:
152 # Return coroutine for async functions
153 return self.execute_async(*args, **kwargs)
154 else:
155 # Direct call for sync functions
156 return self.func(*args, **kwargs)
158 # Make wrapper detectable as async when wrapping async functions
159 def __getattr__(self, name):
160 """Forward attribute access to wrapped function."""
161 if name == '_is_coroutine' and self._is_async:
162 # Mark as coroutine function for asyncio detection
163 return asyncio.coroutines._is_coroutine
164 return getattr(self.func, name)
166 def __repr__(self) -> str:
167 """String representation."""
168 return (
169 f"FunctionWrapper(name={self.name}, "
170 f"async={self._is_async}, source={self.source.value})"
171 )
174class InterfaceWrapper:
175 """Wrapper that adapts functions to specific FSM interfaces."""
177 def __init__(self, wrapper: FunctionWrapper, interface: type):
178 """Initialize interface wrapper.
180 Args:
181 wrapper: The function wrapper
182 interface: The interface to implement
183 """
184 self.wrapper = wrapper
185 self.interface = interface
186 self._setup_interface_methods()
188 def _setup_interface_methods(self):
189 """Set up methods based on interface."""
190 if self.interface == ITransformFunction:
191 self.transform = self._create_method('transform')
192 self.get_transform_description = lambda: f"Transform: {self.wrapper.name}"
194 elif self.interface == IValidationFunction:
195 self.validate = self._create_method('validate')
196 self.get_validation_rules = lambda: {"name": self.wrapper.name}
198 elif self.interface == IStateTestFunction:
199 self.test = self._create_test_method()
200 self.get_test_description = lambda: f"Test: {self.wrapper.name}"
202 elif self.interface == IEndStateTestFunction:
203 self.should_end = self._create_test_method()
204 self.get_end_condition = lambda: f"End test: {self.wrapper.name}"
206 def _create_method(self, method_name: str):
207 """Create an interface method that wraps the function.
209 Args:
210 method_name: Name of the interface method
212 Returns:
213 Method that calls the wrapped function
214 """
215 # Check if the function expects a single state argument (common for inline lambdas)
216 import inspect
217 func = self.wrapper.func
218 try:
219 sig = inspect.signature(func)
220 param_count = len(sig.parameters)
221 # If function takes only 1 param, it likely expects a state object
222 expects_state_obj = param_count == 1
223 except Exception:
224 # Can't determine signature, assume standard (data, context)
225 expects_state_obj = False
227 if self.wrapper.is_async:
228 async def async_method(data: Any, context: Dict[str, Any] | None = None) -> Any:
229 if expects_state_obj:
230 # Wrap data for functions expecting state.data pattern
231 from dataknobs_fsm.core.data_wrapper import wrap_for_lambda
232 state_obj = wrap_for_lambda(data)
233 result = await self.wrapper.execute_async(state_obj)
234 else:
235 result = await self.wrapper.execute_async(data, context)
236 if method_name in ['validate', 'transform']:
237 # Wrap in ExecutionResult if needed
238 if not isinstance(result, ExecutionResult):
239 return ExecutionResult.success_result(result)
240 return result
241 return async_method
242 else:
243 def sync_method(data: Any, context: Dict[str, Any] | None = None) -> Any:
244 if expects_state_obj:
245 # Wrap data for functions expecting state.data pattern
246 from dataknobs_fsm.core.data_wrapper import wrap_for_lambda
247 state_obj = wrap_for_lambda(data)
248 result = self.wrapper.execute_sync(state_obj)
249 else:
250 result = self.wrapper.execute_sync(data, context)
251 if method_name in ['validate', 'transform']:
252 # Wrap in ExecutionResult if needed
253 if not isinstance(result, ExecutionResult):
254 return ExecutionResult.success_result(result)
255 return result
256 return sync_method
258 def _create_test_method(self):
259 """Create a test method that returns (bool, reason)."""
260 # Check if the function expects a single state argument (common for inline lambdas)
261 import inspect
262 func = self.wrapper.func
263 try:
264 sig = inspect.signature(func)
265 param_count = len(sig.parameters)
266 # If function takes only 1 param, it likely expects a state object
267 expects_state_obj = param_count == 1
268 except Exception:
269 # Can't determine signature, assume standard (data, context)
270 expects_state_obj = False
272 if self.wrapper.is_async:
273 async def async_test(data: Any, context: Dict[str, Any] | None = None):
274 if expects_state_obj:
275 # Wrap data for functions expecting state.data pattern
276 from dataknobs_fsm.core.data_wrapper import wrap_for_lambda
277 state_obj = wrap_for_lambda(data)
278 result = await self.wrapper.execute_async(state_obj)
279 else:
280 result = await self.wrapper.execute_async(data, context)
281 if isinstance(result, tuple):
282 return result
283 return (bool(result), None)
284 return async_test
285 else:
286 def sync_test(data: Any, context: Dict[str, Any] | None = None):
287 if expects_state_obj:
288 # Wrap data for functions expecting state.data pattern
289 from dataknobs_fsm.core.data_wrapper import wrap_for_lambda
290 state_obj = wrap_for_lambda(data)
291 result = self.wrapper.execute_sync(state_obj)
292 else:
293 result = self.wrapper.execute_sync(data, context)
294 if isinstance(result, tuple):
295 return result
296 return (bool(result), None)
297 return sync_test
299 def __call__(self, *args: Any, **kwargs: Any) -> Any:
300 """Make the wrapper callable."""
301 return self.wrapper(*args, **kwargs)
303 @property
304 def is_async(self) -> bool:
305 """Check if wrapped function is async."""
306 return self.wrapper.is_async
308 @property
309 def __name__(self) -> str:
310 """Get function name."""
311 return self.wrapper.__name__
313 @property
314 def _is_async(self) -> bool:
315 """Expose _is_async for detection."""
316 return self.wrapper.is_async
319class FunctionManager:
320 """Central manager for all FSM functions.
322 This class provides a unified interface for registering, resolving,
323 and managing functions across the entire FSM system.
324 """
326 def __init__(self):
327 """Initialize function manager."""
328 self._functions: Dict[str, FunctionWrapper] = {}
329 self._builtin_functions: Dict[str, FunctionWrapper] = {}
330 self._inline_cache: Dict[str, FunctionWrapper] = {}
332 def register_function(
333 self,
334 name: str,
335 func: Callable,
336 source: FunctionSource = FunctionSource.REGISTERED,
337 interface: type | None = None
338 ) -> FunctionWrapper:
339 """Register a function.
341 Args:
342 name: Function name
343 func: The function to register
344 source: Source of the function
345 interface: Optional interface to implement
347 Returns:
348 FunctionWrapper for the registered function
349 """
350 wrapper = FunctionWrapper(func, name, source, interface)
352 if source == FunctionSource.BUILTIN:
353 self._builtin_functions[name] = wrapper
354 else:
355 self._functions[name] = wrapper
357 logger.debug(
358 f"Registered {'async' if wrapper.is_async else 'sync'} "
359 f"function '{name}' from {source.value}"
360 )
362 return wrapper
364 def register_functions(
365 self,
366 functions: Dict[str, Callable],
367 source: FunctionSource = FunctionSource.REGISTERED
368 ) -> Dict[str, FunctionWrapper]:
369 """Register multiple functions.
371 Args:
372 functions: Dictionary of name -> function
373 source: Source of the functions
375 Returns:
376 Dictionary of name -> wrapper
377 """
378 wrappers = {}
379 for name, func in functions.items():
380 wrappers[name] = self.register_function(name, func, source)
381 return wrappers
383 def resolve_function(
384 self,
385 reference: Union[str, Dict[str, Any], Callable],
386 interface: type | None = None
387 ) -> Union[FunctionWrapper, InterfaceWrapper, None]:
388 """Resolve a function reference to a wrapper.
390 Args:
391 reference: Function reference (name, dict, or callable)
392 interface: Optional interface to adapt to
394 Returns:
395 FunctionWrapper or None if not found
396 """
397 wrapper = None
399 if callable(reference):
400 # Direct callable
401 wrapper = FunctionWrapper(
402 reference,
403 getattr(reference, '__name__', 'anonymous'),
404 FunctionSource.REGISTERED
405 )
407 elif isinstance(reference, str):
408 # String reference - check registered functions first
409 if reference in self._functions:
410 wrapper = self._functions[reference]
411 elif reference in self._builtin_functions:
412 wrapper = self._builtin_functions[reference]
413 else:
414 # Treat as inline code
415 wrapper = self._create_inline_wrapper(reference)
417 elif isinstance(reference, dict):
418 # Dictionary reference
419 ref_type = reference.get('type', 'inline')
421 if ref_type == 'registered':
422 name = reference.get('name')
423 if name:
424 wrapper = self._functions.get(name) or self._builtin_functions.get(name)
426 elif ref_type == 'inline':
427 code = reference.get('code')
428 if code:
429 wrapper = self._create_inline_wrapper(code)
431 # Apply interface if needed
432 if wrapper and interface:
433 return self._adapt_to_interface(wrapper, interface)
435 return wrapper
437 def _create_inline_wrapper(self, code: str) -> FunctionWrapper:
438 """Create a wrapper for inline code.
440 Args:
441 code: Python code string
443 Returns:
444 FunctionWrapper for the inline code
445 """
446 # Check cache first
447 if code in self._inline_cache:
448 return self._inline_cache[code]
450 # Compile and create function
451 try:
452 # Create a namespace for execution with registered functions
453 namespace = {'asyncio': asyncio}
455 # Add all registered functions to namespace so inline code can call them
456 for name, wrapper in self._functions.items():
457 # Add the actual function, not the wrapper
458 namespace[name] = wrapper.func if hasattr(wrapper, 'func') else wrapper
460 # First try to exec the code directly (might be a full function definition)
461 try:
462 # Store the initial set of names
463 initial_names = set(namespace.keys())
465 exec(code, namespace)
467 # Find any newly defined function
468 func = None
469 new_names = set(namespace.keys()) - initial_names
471 # Look through newly defined names for a callable
472 for name in new_names:
473 if callable(namespace[name]):
474 func = namespace[name]
475 break
476 except Exception:
477 func = None
479 if not func:
480 # Check if it's a lambda expression
481 if code.strip().startswith('lambda'):
482 # Evaluate lambda directly
483 func = eval(code, namespace)
484 else:
485 # Treat as function body - check if it needs to be async
486 if 'await' in code:
487 # Create async wrapper
488 func_def = "async def inline_func(data, context=None):\n"
489 else:
490 # Create sync wrapper
491 func_def = "def inline_func(data, context=None):\n"
493 # Add the code as the function body
494 lines = code.split(';') if ';' in code else [code]
496 # Check if this looks like a simple expression (for conditions)
497 # Common patterns: comparisons, boolean ops, method calls that return bool
498 is_expression = (
499 '==' in code or '!=' in code or '<' in code or '>' in code or
500 ' and ' in code or ' or ' in code or ' not ' in code or
501 code.strip().startswith('not ') or
502 '.get(' in code or
503 'in ' in code or
504 code.strip() in ['True', 'False']
505 )
507 if is_expression and 'return' not in code and len(lines) == 1:
508 # For expressions, return the expression result
509 func_def += f" return {code.strip()}\n"
510 else:
511 # For statements, add them as-is
512 for line in lines:
513 stmt = line.strip()
514 if stmt:
515 func_def += f" {stmt}\n"
517 # Ensure we return data if no explicit return (for transforms)
518 if 'return' not in code:
519 func_def += " return data\n"
521 exec(func_def, namespace)
522 func = namespace.get('inline_func')
524 if func is not None and callable(func):
525 wrapper = FunctionWrapper(func, f"inline_{id(code)}", FunctionSource.INLINE)
526 self._inline_cache[code] = wrapper
527 return wrapper
528 else:
529 # Failed to create function
530 raise ValueError(f"Failed to create inline function from code: {code}")
532 except Exception as e:
533 logger.error(f"Failed to create inline function: {e}")
534 # Return a no-op wrapper
535 return FunctionWrapper(
536 lambda data, context=None: data, # noqa: ARG005
537 f"inline_error_{id(code)}",
538 FunctionSource.INLINE
539 )
541 def _adapt_to_interface(
542 self,
543 wrapper: FunctionWrapper,
544 interface: type
545 ) -> Union[InterfaceWrapper, FunctionWrapper]:
546 """Adapt a wrapper to implement a specific interface.
548 Args:
549 wrapper: The function wrapper
550 interface: The interface to implement
552 Returns:
553 InterfaceWrapper that implements the interface
554 """
555 return InterfaceWrapper(wrapper, interface)
557 def get_function(self, name: str) -> FunctionWrapper | None:
558 """Get a registered function by name.
560 Args:
561 name: Function name
563 Returns:
564 FunctionWrapper or None
565 """
566 return self._functions.get(name) or self._builtin_functions.get(name)
568 def has_function(self, name: str) -> bool:
569 """Check if a function is registered.
571 Args:
572 name: Function name
574 Returns:
575 True if registered
576 """
577 return name in self._functions or name in self._builtin_functions
579 def list_functions(self) -> Dict[str, Dict[str, Any]]:
580 """List all registered functions.
582 Returns:
583 Dictionary of function info
584 """
585 result = {}
587 for name, wrapper in self._functions.items():
588 result[name] = {
589 'source': wrapper.source.value,
590 'async': wrapper.is_async,
591 'type': 'registered'
592 }
594 for name, wrapper in self._builtin_functions.items():
595 result[name] = {
596 'source': wrapper.source.value,
597 'async': wrapper.is_async,
598 'type': 'builtin'
599 }
601 return result
603 def clear(self):
604 """Clear all registered functions except builtins."""
605 self._functions.clear()
606 self._inline_cache.clear()
608 def clear_all(self):
609 """Clear all functions including builtins."""
610 self.clear()
611 self._builtin_functions.clear()
614# Global function manager instance
615_global_manager = FunctionManager()
618def get_function_manager() -> FunctionManager:
619 """Get the global function manager instance.
621 Returns:
622 The global FunctionManager
623 """
624 return _global_manager
627def register_function(
628 name: str,
629 func: Callable,
630 source: FunctionSource = FunctionSource.REGISTERED
631) -> FunctionWrapper:
632 """Register a function with the global manager.
634 Args:
635 name: Function name
636 func: The function
637 source: Function source
639 Returns:
640 FunctionWrapper
641 """
642 return _global_manager.register_function(name, func, source)
645def resolve_function(
646 reference: Union[str, Dict[str, Any], Callable],
647 interface: type | None = None
648) -> Union[FunctionWrapper, InterfaceWrapper, None]:
649 """Resolve a function reference.
651 Args:
652 reference: Function reference
653 interface: Optional interface
655 Returns:
656 FunctionWrapper or None
657 """
658 return _global_manager.resolve_function(reference, interface)