Coverage for aipyapp/aipy/mcp_tool.py: 12%
217 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
1import json
2import re
3import hashlib
4from collections import namedtuple
5from . import cache
6from .libmcp import MCPClientSync, MCPConfigReader
7from .. import T
9def build_function_call_tool_name(server_name: str, tool_name: str) -> str:
10 """
11 构建函数调用工具名称
13 Args:
14 server_name: 服务器名称
15 tool_name: 工具名称
17 Returns:
18 处理后的工具名称
19 """
21 sanitized_server = server_name.strip()
22 if not sanitized_server.isascii():
23 # 生成MD5并取前4位
24 sanitized_server = hashlib.md5(sanitized_server.encode('utf-8')).hexdigest()[:4]
25 else:
26 # 用下划线替换无效字符,保留 a-z, A-Z, 0-9, 下划线和短横线
27 sanitized_server = re.sub(r'[^a-zA-Z0-9_-]', '_', sanitized_server)
29 sanitized_tool = tool_name.strip().replace('-', '_')
31 # 合并服务器名称和工具名称
32 name = sanitized_tool
33 if sanitized_server[:16] not in sanitized_tool:
34 name = f"{sanitized_server[:16] or ''}.{sanitized_tool or ''}"
36 # 确保名称以字母或下划线开头
37 if not re.match(r'^[a-zA-Z]', name):
38 name = f"tool-{name}"
40 # 移除连续的下划线/短横线
41 name = re.sub(r'[_-]{2,}', '_', name)
43 # 最大截断为 63 个字符
44 if len(name) > 63:
45 name = name[:63]
47 # 处理边缘情况:确保截断后仍有有效名称
48 if name.endswith('_') or name.endswith('-'):
49 name = name[:-1]
51 return name
54class MCPToolManager:
55 def __init__(self, config_path, tt_api_key):
56 self.config_path = config_path
57 self.config_reader = MCPConfigReader(config_path, tt_api_key=tt_api_key)
58 self.user_mcp = self.config_reader.get_user_mcp()
59 self.sys_mcp = self.config_reader.get_sys_mcp()
60 self.mcp_servers = self.sys_mcp | self.user_mcp
61 self._tools_dict = {} # 缓存已获取的工具列表
62 self._inited = False
64 # 全局启用/禁用用户MCP标志,默认禁用
65 self._user_mcp_enabled = False
66 self._sys_mcp_enabled = False
68 # 服务器状态缓存,记录每个服务器的启用/禁用状态
69 self._server_status = self._init_server_status()
71 def _init_server_status(self):
72 """初始化服务器状态,从配置文件中读取初始状态,包括禁用的服务器,同时会被命令行更新,维护在内存中"""
74 server_status = {}
75 for server_name, server_config in self.mcp_servers.items():
76 # sys_mcp 服务器默认禁用,user_mcp 服务器默认启用
77 if server_name in self.sys_mcp:
78 # sys_mcp 服务器默认禁用
79 is_enabled = False
80 else:
81 # user_mcp 服务器默认启用,除非配置中明确设置为禁用
82 is_enabled = not (
83 server_config.get("disabled", False)
84 or server_config.get("enabled", True) is False
85 )
86 server_status[server_name] = is_enabled
87 return server_status
89 def list_tools(self, mcp_type="user", force_load=False):
90 """返回所有MCP服务器的工具列表
91 [{'description': 'Get weather alerts for a US state.\n'
92 '\n'
93 ' Args:\n'
94 ' state: Two-letter US state code (e.g. CA, '
95 'NY)\n'
96 ' ',
97 'inputSchema': {'properties': {'state': {'title': 'State',
98 'type': 'string'}},
99 'required': ['state'],
100 'title': 'get_alertsArguments',
101 'type': 'object'},
102 'name': 'get_alerts',
103 'server': 'server1'
104 },
105 ]
106 """
107 if mcp_type == "user":
108 if not self._user_mcp_enabled and not force_load:
109 return []
110 mcp_servers = self.user_mcp
111 if self._user_mcp_enabled:
112 print(
113 T(
114 "Initializing MCP server, this may take a while if it's "
115 "the first load, please wait patiently..."
116 )
117 )
118 else:
119 mcp_servers = self.sys_mcp
121 all_tools = []
122 for server_name, server_config in mcp_servers.items():
123 if server_name not in self._tools_dict:
124 try:
125 print("+ Loading MCP", server_name)
126 key = f"mcp_tool:{server_name}:{cache.cache_key(server_config)}"
127 tools = cache.get_cache(key)
128 if tools is not None:
129 # 如果缓存中有工具列表,直接使用
130 self._tools_dict[server_name] = tools
131 continue
133 client = MCPClientSync(server_config)
134 tools = client.list_tools()
136 # 为每个工具添加服务器标识
137 for tool in tools:
138 tool["server"] = server_name
139 tool["id"] = build_function_call_tool_name(
140 server_name, tool.get("name", "")
141 )
143 if tools:
144 cache.set_cache(key, tools, ttl=60 * 60 * 24 * 2)
146 self._tools_dict[server_name] = tools
147 except Exception as e:
148 print(f"Error listing tools for server {server_name}: {e}")
149 self._tools_dict[server_name] = []
151 # 添加到总工具列表
152 all_tools.extend(self._tools_dict[server_name])
153 self._inited = True
155 return all_tools
157 def get_tools_prompt(self):
158 """获取工具列表并转换为 Markdown 格式"""
159 tools = self.get_available_tools() # 获取启用的工具
160 if not tools:
161 return ""
163 ret = []
164 for tool in tools:
165 # 构建工具信息的 JSON 对象
166 # 去掉 inputSchema里的additionalProperties、$schema
167 input_schema = tool.get("inputSchema", {}).copy()
168 if "additionalProperties" in input_schema:
169 del input_schema["additionalProperties"]
170 if "$schema" in input_schema:
171 del input_schema["$schema"]
173 ret.append(
174 {
175 "name": tool.get("id", ""),
176 "description": tool.get("description", ""),
177 "arguments": input_schema,
178 }
179 )
181 # 转换为 JSON 字符串并添加到 Markdown 中
182 json_str = json.dumps(ret, ensure_ascii=False)
183 return f"```json\n{json_str}\n```\n"
185 def get_available_tools(self):
186 """返回已经启用的工具列表"""
188 if not self._inited:
189 self.list_tools()
191 all_tools = []
192 for server_name, tools in self._tools_dict.items():
193 # 检查全局启用状态和单个服务器状态
194 is_user_server = server_name in self.user_mcp
196 # user_mcp服务器需要同时满足全局启用和服务器启用
197 # sys_mcp服务器只需要服务器启用
198 server_enabled = self._server_status.get(server_name, True)
199 if is_user_server:
200 server_enabled = server_enabled and self._user_mcp_enabled
202 if server_enabled:
203 for tool in tools:
204 tool["server"] = server_name
205 all_tools.append(tool)
206 return all_tools
208 def get_server_info(self, mcp_type="user") -> dict:
209 """返回所有服务器的列表及其启用状态"""
210 if not self._inited:
211 # 强制加载工具信息,即使全局禁用也要获取服务器和工具数量信息
212 self.list_tools(mcp_type=mcp_type, force_load=True)
214 if mcp_type == "user":
215 mcp_servers = self.user_mcp
216 else:
217 mcp_servers = self.sys_mcp
218 # 返回服务器列表及其启用状态
219 servers_info = {}
220 for server_name, _ in mcp_servers.items():
221 is_enabled = self._server_status.get(server_name, True)
222 tools = self._tools_dict.get(server_name, [])
223 servers_info[server_name] = {
224 "enabled": is_enabled,
225 "tools_count": len(tools),
226 }
228 return servers_info
230 def call_tool(self, tool_name, arguments):
231 """调用指定名称的工具,自动选择最匹配的服务器"""
232 # 获取所有工具
233 all_tools = self.get_available_tools()
234 if not all_tools:
235 return {
236 "isError": True,
237 "content": [{
238 "type": "text",
239 "text": "No tools available to call."
240 }]
241 }
243 # 查找匹配的工具,根据id查找
244 matching_tools = [t for t in all_tools if t["id"] == tool_name]
245 if not matching_tools:
246 return {
247 "isError": True,
248 "content": [{
249 "type": "text",
250 "text": f"No tool found with name: {tool_name}"
251 }]
252 }
254 # 选择参数匹配度最高的工具
255 best_match = None
256 best_score = -1
258 for tool in matching_tools:
259 score = 0
260 required_params = []
262 # 检查工具的输入模式
263 if "inputSchema" in tool and "properties" in tool["inputSchema"]:
264 properties = tool["inputSchema"]["properties"]
265 required_params = tool["inputSchema"].get("required", [])
267 # 检查所有必需参数是否提供
268 required_provided = all(param in arguments for param in required_params)
269 if not required_provided:
270 continue
272 # 计算匹配的参数数量
273 matching_params = sum(1 for param in arguments if param in properties)
274 extra_params = len(arguments) - matching_params
276 # 评分:匹配参数越多越好,额外参数越少越好
277 score = matching_params - 0.1 * extra_params
279 if score > best_score:
280 best_score = score
281 best_match = tool
283 if not best_match:
284 # 返回错误信息而不是抛出异常
285 error_msg = f"No suitable tool found for {tool_name} with given arguments"
286 return {
287 "isError": True,
288 "content": [{
289 "type": "text",
290 "text": error_msg
291 }]
292 }
294 # 获取服务器配置
295 server_name = best_match["server"]
296 real_tool_name = best_match["name"]
297 server_config = self.mcp_servers[server_name]
299 try:
300 # 创建客户端并调用工具
301 client = MCPClientSync(server_config)
302 ret = client.call_tool(real_tool_name, arguments)
303 # ret需要是字典,并且如果同时包含content和structuredContent字段,
304 # 则丢弃content
305 if (isinstance(ret, dict) and ret.get("content")
306 and ret.get("structuredContent")):
307 del ret["content"]
308 return ret
309 except Exception as e:
310 # 捕获工具调用异常,返回错误信息给LLM
311 error_msg = f"Tool call failed: {str(e)}"
312 return {
313 "isError": True,
314 "content": [{
315 "type": "text",
316 "text": error_msg
317 }]
318 }
320 @property
321 def is_mcp_enabled(self):
322 """是否需要启用MCP处理
323 1. 如果系统MCP全局启用,则返回True
324 2. 如果用户MCP全局禁用,则返回False
325 3. 如果用户MCP全局启用,则返回用户MCP服务器中是否至少有一个服务器启用
326 """
327 if self._sys_mcp_enabled:
328 return True
330 if not self._user_mcp_enabled:
331 return False
333 user_server_status = [self._server_status.get(server_name, False) for server_name in self.user_mcp]
334 return any(user_server_status)
336 @property
337 def is_sys_mcp_enabled(self):
338 return self._sys_mcp_enabled
340 @property
341 def is_user_mcp_enabled(self):
342 return self._user_mcp_enabled
344 def enable_user_mcp(self, enable=True):
345 """全局启用/禁用用户MCP"""
346 self._user_mcp_enabled = enable
347 if enable:
348 self.list_tools(mcp_type="user")
349 return True
351 def enable_sys_mcp(self, enable=True):
352 """全局启用/禁用系统MCP"""
353 self._sys_mcp_enabled = enable
354 for server_name in self.sys_mcp:
355 self._server_status[server_name] = enable
356 if enable:
357 self.list_tools(mcp_type="sys")
358 return True
360 def enable_user_server(self, server_name, enable=True):
361 """启用/禁用指定用户MCP服务器"""
362 if server_name == "*" or not server_name:
363 for srv_name in self.user_mcp.keys():
364 self._server_status[srv_name] = enable
365 ret = True
366 elif server_name in self.user_mcp:
367 self._server_status[server_name] = enable
368 ret = True
369 else:
370 return False
372 self.list_tools(mcp_type="user")
373 return ret
375 def list_user_servers(self):
376 """返回所有用户MCP服务器列表"""
377 status = self.get_server_info(mcp_type="user")
378 ServerRecord = namedtuple("ServerRecord", ["Name", "Enabled", "ToolsCount"])
379 return [ServerRecord(name, info['enabled'], info['tools_count']) for name, info in status.items()]
381 def get_status(self):
382 """返回MCP状态信息供状态栏使用"""
383 total_tools = 0
384 enabled_tools = 0
385 enabled_servers = 0
386 total_servers = 0
388 sys_mcp_info = self.get_server_info(mcp_type="sys")
389 for server_name, server_info in sys_mcp_info.items():
390 total_servers += 1
391 if server_info['enabled']:
392 enabled_servers += 1
393 enabled_tools += server_info['tools_count']
394 total_tools += server_info['tools_count']
395 if self.is_user_mcp_enabled:
396 user_mcp_info = self.get_server_info(mcp_type="user")
397 for server_name, server_info in user_mcp_info.items():
398 total_servers += 1
399 if server_info['enabled']:
400 enabled_servers += 1
401 enabled_tools += server_info['tools_count']
402 total_tools += server_info['tools_count']
403 return {
404 'sys_mcp_enabled': self._sys_mcp_enabled,
405 'user_mcp_enabled': self._user_mcp_enabled,
406 'total_servers': total_servers,
407 'total_tools': total_tools,
408 'enabled_servers': enabled_servers,
409 'enabled_tools': enabled_tools,
410 }