Coverage for aipyapp/aipy/functions.py: 26%
72 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
1import inspect
2from typing import Any, Callable, Dict, Optional, List
4from loguru import logger
5from pydantic import create_model, ValidationError
8class FunctionError(Exception):
9 """Base exception for function calls"""
10 pass
12class FunctionNotFoundError(FunctionError):
13 """Function/method not found exception"""
14 pass
16class ParameterValidationError(FunctionError):
17 """Parameter validation exception"""
18 pass
20class FunctionManager:
21 """Function manager - manage all callable functions"""
23 def __init__(self):
24 """Initialize function manager"""
25 self.function_registry: Dict[str, Dict[str, Any]] = {}
26 self.logger = logger.bind(src=self.__class__.__name__)
28 def register_function(self, func: Callable, name: str = None) -> bool:
29 """Register function
31 Args:
32 func: Function to register
33 name: Custom function name, default is the function name
35 Returns:
36 Whether the function is registered successfully
37 """
38 func_name = name or func.__name__
39 sig = inspect.signature(func)
41 fields = {}
42 for param_name, param in sig.parameters.items():
43 annotation = param.annotation if param.annotation != inspect.Parameter.empty else Any
44 default = param.default if param.default != inspect.Parameter.empty else ...
45 fields[param_name] = (annotation, default)
47 ParamModel = create_model(f"{func_name}_Params", **fields)
49 self.function_registry[func_name] = {
50 "func": func,
51 "signature": str(sig),
52 "doc": inspect.getdoc(func) or "",
53 "param_model": ParamModel
54 }
56 self.logger.debug(f"Registered function: {func_name}")
57 return True
59 def register_functions(self, functions: Dict[str, Callable]) -> int:
60 """Batch register object instance methods
62 Args:
63 functions: Function name list
64 """
65 success_count = 0
66 for func_name, func in functions.items():
67 if self.register_function(func, func_name):
68 success_count += 1
69 self.logger.info(f"Registered {success_count}/{len(functions)} functions")
70 return success_count
72 def call(self, func_name: str, **kwargs) -> Any:
73 """Call the registered function
75 Args:
76 func_name: Function name
77 **kwargs: Function parameters
79 Returns:
80 Function execution result
82 Raises:
83 FunctionNotFoundError: Function not found
84 ParameterValidationError: Parameter validation failed
85 """
86 entry = self.function_registry.get(func_name)
87 if not entry:
88 self.logger.error(f"Function '{func_name}' not found")
89 raise FunctionNotFoundError(f"Function '{func_name}' not found")
91 fn = entry["func"]
92 ParamModel = entry.get("param_model")
94 try:
95 parsed = ParamModel(**kwargs)
96 self.logger.info(f"Calling function {func_name}, parameters: {parsed.model_dump()}")
97 result = fn(**parsed.model_dump())
98 return result
100 except ValidationError as e:
101 self.logger.error(f"Function '{func_name}' parameter validation failed: {e}")
102 raise ParameterValidationError(f"Parameter validation failed: {e}")
103 except Exception as e:
104 self.logger.error(f"Function '{func_name}' execution failed: {e}")
105 raise
107 def get_functions(self) -> Dict[str, Dict[str, str]]:
108 """Get all functions
110 Returns:
111 All functions
112 """
113 functions = {}
114 for name, meta in self.function_registry.items():
115 functions[name] = {
116 "signature": meta["signature"],
117 "docstring": meta["doc"],
118 }
119 return functions
121 def unregister_function(self, func_name: str) -> bool:
122 """Unregister function
124 Args:
125 func_name: Function name
127 Returns:
128 Whether the function is unregistered successfully
129 """
130 if func_name in self.function_registry:
131 del self.function_registry[func_name]
132 self.logger.debug(f"Unregistered function: {func_name}")
133 return True
134 return False
136 def clear_registry(self, group: Optional[str] = None):
137 """Clear the registry
139 Args:
140 group: Specify the group, None means clear all
141 """
142 if group is None:
143 self.function_registry.clear()
144 self.logger.info("Cleared all registered functions")
145 else:
146 to_remove = [
147 name for name, meta in self.function_registry.items()
148 if meta["group"] == group
149 ]
150 for name in to_remove:
151 del self.function_registry[name]
152 self.logger.info(f"Cleared functions in group {group}")
154 def get_registry(self) -> Dict[str, Dict[str, Any]]:
155 """Get the complete registry (read-only)
157 Returns:
158 Copy of the registry
159 """
160 return self.function_registry.copy()