Coverage for aipyapp/aipy/functions.py: 26%

72 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-11 12:02 +0200

1import inspect 

2from typing import Any, Callable, Dict, Optional, List 

3 

4from loguru import logger 

5from pydantic import create_model, ValidationError 

6 

7 

8class FunctionError(Exception): 

9 """Base exception for function calls""" 

10 pass 

11 

12class FunctionNotFoundError(FunctionError): 

13 """Function/method not found exception""" 

14 pass 

15 

16class ParameterValidationError(FunctionError): 

17 """Parameter validation exception""" 

18 pass 

19 

20class FunctionManager: 

21 """Function manager - manage all callable functions""" 

22 

23 def __init__(self): 

24 """Initialize function manager""" 

25 self.function_registry: Dict[str, Dict[str, Any]] = {} 

26 self.logger = logger.bind(src=self.__class__.__name__) 

27 

28 def register_function(self, func: Callable, name: str = None) -> bool: 

29 """Register function 

30  

31 Args: 

32 func: Function to register 

33 name: Custom function name, default is the function name 

34  

35 Returns: 

36 Whether the function is registered successfully 

37 """ 

38 func_name = name or func.__name__ 

39 sig = inspect.signature(func) 

40 

41 fields = {} 

42 for param_name, param in sig.parameters.items(): 

43 annotation = param.annotation if param.annotation != inspect.Parameter.empty else Any 

44 default = param.default if param.default != inspect.Parameter.empty else ... 

45 fields[param_name] = (annotation, default) 

46 

47 ParamModel = create_model(f"{func_name}_Params", **fields) 

48 

49 self.function_registry[func_name] = { 

50 "func": func, 

51 "signature": str(sig), 

52 "doc": inspect.getdoc(func) or "", 

53 "param_model": ParamModel 

54 } 

55 

56 self.logger.debug(f"Registered function: {func_name}") 

57 return True 

58 

59 def register_functions(self, functions: Dict[str, Callable]) -> int: 

60 """Batch register object instance methods 

61  

62 Args: 

63 functions: Function name list 

64 """ 

65 success_count = 0 

66 for func_name, func in functions.items(): 

67 if self.register_function(func, func_name): 

68 success_count += 1 

69 self.logger.info(f"Registered {success_count}/{len(functions)} functions") 

70 return success_count 

71 

72 def call(self, func_name: str, **kwargs) -> Any: 

73 """Call the registered function 

74 

75 Args: 

76 func_name: Function name 

77 **kwargs: Function parameters 

78  

79 Returns: 

80 Function execution result 

81  

82 Raises: 

83 FunctionNotFoundError: Function not found 

84 ParameterValidationError: Parameter validation failed 

85 """ 

86 entry = self.function_registry.get(func_name) 

87 if not entry: 

88 self.logger.error(f"Function '{func_name}' not found") 

89 raise FunctionNotFoundError(f"Function '{func_name}' not found") 

90 

91 fn = entry["func"] 

92 ParamModel = entry.get("param_model") 

93 

94 try: 

95 parsed = ParamModel(**kwargs) 

96 self.logger.info(f"Calling function {func_name}, parameters: {parsed.model_dump()}") 

97 result = fn(**parsed.model_dump()) 

98 return result 

99 

100 except ValidationError as e: 

101 self.logger.error(f"Function '{func_name}' parameter validation failed: {e}") 

102 raise ParameterValidationError(f"Parameter validation failed: {e}") 

103 except Exception as e: 

104 self.logger.error(f"Function '{func_name}' execution failed: {e}") 

105 raise 

106 

107 def get_functions(self) -> Dict[str, Dict[str, str]]: 

108 """Get all functions 

109  

110 Returns: 

111 All functions 

112 """ 

113 functions = {} 

114 for name, meta in self.function_registry.items(): 

115 functions[name] = { 

116 "signature": meta["signature"], 

117 "docstring": meta["doc"], 

118 } 

119 return functions 

120 

121 def unregister_function(self, func_name: str) -> bool: 

122 """Unregister function 

123  

124 Args: 

125 func_name: Function name 

126  

127 Returns: 

128 Whether the function is unregistered successfully 

129 """ 

130 if func_name in self.function_registry: 

131 del self.function_registry[func_name] 

132 self.logger.debug(f"Unregistered function: {func_name}") 

133 return True 

134 return False 

135 

136 def clear_registry(self, group: Optional[str] = None): 

137 """Clear the registry 

138  

139 Args: 

140 group: Specify the group, None means clear all 

141 """ 

142 if group is None: 

143 self.function_registry.clear() 

144 self.logger.info("Cleared all registered functions") 

145 else: 

146 to_remove = [ 

147 name for name, meta in self.function_registry.items() 

148 if meta["group"] == group 

149 ] 

150 for name in to_remove: 

151 del self.function_registry[name] 

152 self.logger.info(f"Cleared functions in group {group}") 

153 

154 def get_registry(self) -> Dict[str, Dict[str, Any]]: 

155 """Get the complete registry (read-only) 

156  

157 Returns: 

158 Copy of the registry 

159 """ 

160 return self.function_registry.copy()