Coverage for aipyapp/aipy/prompts.py: 32%

78 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-11 12:02 +0200

1#! /usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3 

4import os 

5import locale 

6import platform 

7import inspect 

8import json 

9import shutil 

10import subprocess 

11import re 

12from datetime import datetime 

13 

14from jinja2 import Environment, FileSystemLoader, select_autoescape 

15 

16from .. import __respath__ 

17 

18def check_commands(commands): 

19 """ 

20 检查多个命令是否存在,并获取其版本号。 

21 :param commands: dict,键为命令名,值为获取版本的参数(如 ["--version"]) 

22 :return: dict,例如 {"node": "v18.17.1", "bash": "5.1.16", ...} 

23 """ 

24 result = {} 

25 

26 for cmd, version_args in commands.items(): 

27 path = shutil.which(cmd) 

28 if not path: 

29 result[cmd] = None 

30 continue 

31 

32 try: 

33 proc = subprocess.run( 

34 [cmd] + version_args, 

35 capture_output=True, 

36 text=True, 

37 timeout=5 

38 ) 

39 # 合并 stdout 和 stderr,然后提取类似 1.2.3 或 v1.2.3 的版本 

40 output = (proc.stdout or '') + (proc.stderr or '') 

41 version_match = re.search(r"\bv?\d+(\.\d+){1,2}\b", output) 

42 version = version_match.group(0) if version_match else output.strip() 

43 result[cmd] = version 

44 except Exception as e: 

45 pass 

46 

47 return result 

48 

49class Prompts: 

50 def __init__(self, template_dir: str = None): 

51 if not template_dir: 

52 template_dir = __respath__ / 'prompts' 

53 self.template_dir = os.path.abspath(template_dir) 

54 self.env = Environment( 

55 loader=FileSystemLoader(self.template_dir), 

56 #autoescape=select_autoescape(['j2']) 

57 ) 

58 self._init_env() # 调用 _init_env 方法注册全局变量 

59 

60 def _init_env(self): 

61 # 可以在这里注册全局变量或 filter 

62 commands_to_check = { 

63 "node": ["--version"], 

64 "bash": ["--version"], 

65 "powershell": ["-Command", "$PSVersionTable.PSVersion.ToString()"], 

66 "osascript": ["-e", 'return "AppleScript OK"'] 

67 } 

68 self.env.globals['commands'] = check_commands(commands_to_check) 

69 osinfo = {'system': platform.system(), 'platform': platform.platform(), 'locale': locale.getlocale()} 

70 self.env.globals['os'] = osinfo 

71 self.env.globals['python_version'] = platform.python_version() 

72 self.env.filters['tojson'] = lambda x: json.dumps(x, ensure_ascii=False, default=str) 

73 

74 def get_prompt(self, template_name: str, **kwargs) -> str: 

75 """ 

76 加载指定模板并用 kwargs 渲染 

77 :param template_name: 模板文件名(如 'my_prompt.txt') 

78 :param kwargs: 用于模板渲染的关键字参数 

79 :return: 渲染后的字符串 

80 """ 

81 template_name = f"{template_name}.j2" 

82 try: 

83 template = self.env.get_template(template_name) 

84 except Exception as e: 

85 raise FileNotFoundError(f"Prompt template not found: {template_name} in {self.template_dir}") from e 

86 return template.render(**kwargs) 

87 

88 def get_default_prompt(self, **kwargs) -> str: 

89 """ 

90 使用 default.jinja 模板,自动补充部分变量后渲染 

91 :param kwargs: 用户传入的模板变量 

92 :return: 渲染后的字符串 

93 """ 

94 # 自动补充变量 

95 extra_vars = {} 

96 all_vars = {**extra_vars, **kwargs} 

97 return self.get_prompt('default', **all_vars) 

98 

99 def get_task_prompt(self, instruction: str, gui: bool = False) -> str: 

100 """ 

101 获取任务提示 

102 :param instruction: 用户输入的字符串 

103 :param gui: 是否使用 GUI 模式 

104 :return: 渲染后的字符串 

105 """ 

106 contexts = {} 

107 contexts['Today'] = datetime.now().strftime('%Y-%m-%d') 

108 if not gui: 

109 contexts['TERM'] = os.environ.get('TERM', 'unknown') 

110 constraints = {} 

111 return self.get_prompt('task', instruction=instruction, contexts=contexts, constraints=constraints, gui=gui) 

112 

113 def get_results_prompt(self, results: dict) -> str: 

114 """ 

115 获取结果提示 

116 :param results: 结果字典 

117 :return: 渲染后的字符串 

118 """ 

119 return self.get_prompt('results', results=results) 

120 

121 def get_edit_results_prompt(self, results: dict) -> str: 

122 """ 

123 获取编辑结果提示 

124 :param results: 编辑结果字典 

125 :return: 渲染后的字符串 

126 """ 

127 return self.get_prompt('edit_results', results=results) 

128 

129 def get_mixed_results_prompt(self, results: dict) -> str: 

130 """ 

131 获取混合结果提示(包含执行和编辑结果) 

132 :param results: 混合结果字典 

133 :return: 渲染后的字符串 

134 """ 

135 return self.get_prompt('mixed_results', results=results) 

136 

137 def get_mcp_result_prompt(self, result: dict) -> str: 

138 """ 

139 获取 MCP 工具调用结果提示 

140 :param result: 结果字典 

141 :return: 渲染后的字符串 

142 """ 

143 return self.get_prompt('result_mcp', result=result) 

144 

145 def get_chat_prompt(self, instruction: str, task: str) -> str: 

146 """ 

147 获取聊天提示 

148 :param instruction: 用户输入的字符串 

149 :param task: 初始任务 

150 :return: 渲染后的字符串 

151 """ 

152 return self.get_prompt('chat', instruction=instruction, initial_task=task) 

153 

154 def get_parse_error_prompt(self, errors: list) -> str: 

155 """ 

156 获取消息解析错误提示 

157 :param errors: 错误列表 

158 :return: 渲染后的字符串 

159 """ 

160 return self.get_prompt('parse_error', errors=errors) 

161 

162if __name__ == '__main__': 

163 prompts = Prompts() 

164 print(prompts.get_default_prompt()) 

165 func = prompts.get_prompt 

166 print(func.__name__) 

167 print(inspect.signature(func)) 

168 print(inspect.getdoc(func))