Coverage for aipyapp/aipy/prompts.py: 32%
78 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
1#! /usr/bin/env python3
2# -*- coding: utf-8 -*-
4import os
5import locale
6import platform
7import inspect
8import json
9import shutil
10import subprocess
11import re
12from datetime import datetime
14from jinja2 import Environment, FileSystemLoader, select_autoescape
16from .. import __respath__
18def check_commands(commands):
19 """
20 检查多个命令是否存在,并获取其版本号。
21 :param commands: dict,键为命令名,值为获取版本的参数(如 ["--version"])
22 :return: dict,例如 {"node": "v18.17.1", "bash": "5.1.16", ...}
23 """
24 result = {}
26 for cmd, version_args in commands.items():
27 path = shutil.which(cmd)
28 if not path:
29 result[cmd] = None
30 continue
32 try:
33 proc = subprocess.run(
34 [cmd] + version_args,
35 capture_output=True,
36 text=True,
37 timeout=5
38 )
39 # 合并 stdout 和 stderr,然后提取类似 1.2.3 或 v1.2.3 的版本
40 output = (proc.stdout or '') + (proc.stderr or '')
41 version_match = re.search(r"\bv?\d+(\.\d+){1,2}\b", output)
42 version = version_match.group(0) if version_match else output.strip()
43 result[cmd] = version
44 except Exception as e:
45 pass
47 return result
49class Prompts:
50 def __init__(self, template_dir: str = None):
51 if not template_dir:
52 template_dir = __respath__ / 'prompts'
53 self.template_dir = os.path.abspath(template_dir)
54 self.env = Environment(
55 loader=FileSystemLoader(self.template_dir),
56 #autoescape=select_autoescape(['j2'])
57 )
58 self._init_env() # 调用 _init_env 方法注册全局变量
60 def _init_env(self):
61 # 可以在这里注册全局变量或 filter
62 commands_to_check = {
63 "node": ["--version"],
64 "bash": ["--version"],
65 "powershell": ["-Command", "$PSVersionTable.PSVersion.ToString()"],
66 "osascript": ["-e", 'return "AppleScript OK"']
67 }
68 self.env.globals['commands'] = check_commands(commands_to_check)
69 osinfo = {'system': platform.system(), 'platform': platform.platform(), 'locale': locale.getlocale()}
70 self.env.globals['os'] = osinfo
71 self.env.globals['python_version'] = platform.python_version()
72 self.env.filters['tojson'] = lambda x: json.dumps(x, ensure_ascii=False, default=str)
74 def get_prompt(self, template_name: str, **kwargs) -> str:
75 """
76 加载指定模板并用 kwargs 渲染
77 :param template_name: 模板文件名(如 'my_prompt.txt')
78 :param kwargs: 用于模板渲染的关键字参数
79 :return: 渲染后的字符串
80 """
81 template_name = f"{template_name}.j2"
82 try:
83 template = self.env.get_template(template_name)
84 except Exception as e:
85 raise FileNotFoundError(f"Prompt template not found: {template_name} in {self.template_dir}") from e
86 return template.render(**kwargs)
88 def get_default_prompt(self, **kwargs) -> str:
89 """
90 使用 default.jinja 模板,自动补充部分变量后渲染
91 :param kwargs: 用户传入的模板变量
92 :return: 渲染后的字符串
93 """
94 # 自动补充变量
95 extra_vars = {}
96 all_vars = {**extra_vars, **kwargs}
97 return self.get_prompt('default', **all_vars)
99 def get_task_prompt(self, instruction: str, gui: bool = False) -> str:
100 """
101 获取任务提示
102 :param instruction: 用户输入的字符串
103 :param gui: 是否使用 GUI 模式
104 :return: 渲染后的字符串
105 """
106 contexts = {}
107 contexts['Today'] = datetime.now().strftime('%Y-%m-%d')
108 if not gui:
109 contexts['TERM'] = os.environ.get('TERM', 'unknown')
110 constraints = {}
111 return self.get_prompt('task', instruction=instruction, contexts=contexts, constraints=constraints, gui=gui)
113 def get_results_prompt(self, results: dict) -> str:
114 """
115 获取结果提示
116 :param results: 结果字典
117 :return: 渲染后的字符串
118 """
119 return self.get_prompt('results', results=results)
121 def get_edit_results_prompt(self, results: dict) -> str:
122 """
123 获取编辑结果提示
124 :param results: 编辑结果字典
125 :return: 渲染后的字符串
126 """
127 return self.get_prompt('edit_results', results=results)
129 def get_mixed_results_prompt(self, results: dict) -> str:
130 """
131 获取混合结果提示(包含执行和编辑结果)
132 :param results: 混合结果字典
133 :return: 渲染后的字符串
134 """
135 return self.get_prompt('mixed_results', results=results)
137 def get_mcp_result_prompt(self, result: dict) -> str:
138 """
139 获取 MCP 工具调用结果提示
140 :param result: 结果字典
141 :return: 渲染后的字符串
142 """
143 return self.get_prompt('result_mcp', result=result)
145 def get_chat_prompt(self, instruction: str, task: str) -> str:
146 """
147 获取聊天提示
148 :param instruction: 用户输入的字符串
149 :param task: 初始任务
150 :return: 渲染后的字符串
151 """
152 return self.get_prompt('chat', instruction=instruction, initial_task=task)
154 def get_parse_error_prompt(self, errors: list) -> str:
155 """
156 获取消息解析错误提示
157 :param errors: 错误列表
158 :return: 渲染后的字符串
159 """
160 return self.get_prompt('parse_error', errors=errors)
162if __name__ == '__main__':
163 prompts = Prompts()
164 print(prompts.get_default_prompt())
165 func = prompts.get_prompt
166 print(func.__name__)
167 print(inspect.signature(func))
168 print(inspect.getdoc(func))