Coverage for aipyapp/aipy/mcp_tool.py: 12%

217 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-11 12:02 +0200

1import json 

2import re 

3import hashlib 

4from collections import namedtuple 

5from . import cache 

6from .libmcp import MCPClientSync, MCPConfigReader 

7from .. import T 

8 

9def build_function_call_tool_name(server_name: str, tool_name: str) -> str: 

10 """ 

11 构建函数调用工具名称 

12 

13 Args: 

14 server_name: 服务器名称 

15 tool_name: 工具名称 

16 

17 Returns: 

18 处理后的工具名称 

19 """ 

20 

21 sanitized_server = server_name.strip() 

22 if not sanitized_server.isascii(): 

23 # 生成MD5并取前4位 

24 sanitized_server = hashlib.md5(sanitized_server.encode('utf-8')).hexdigest()[:4] 

25 else: 

26 # 用下划线替换无效字符,保留 a-z, A-Z, 0-9, 下划线和短横线 

27 sanitized_server = re.sub(r'[^a-zA-Z0-9_-]', '_', sanitized_server) 

28 

29 sanitized_tool = tool_name.strip().replace('-', '_') 

30 

31 # 合并服务器名称和工具名称 

32 name = sanitized_tool 

33 if sanitized_server[:16] not in sanitized_tool: 

34 name = f"{sanitized_server[:16] or ''}.{sanitized_tool or ''}" 

35 

36 # 确保名称以字母或下划线开头 

37 if not re.match(r'^[a-zA-Z]', name): 

38 name = f"tool-{name}" 

39 

40 # 移除连续的下划线/短横线 

41 name = re.sub(r'[_-]{2,}', '_', name) 

42 

43 # 最大截断为 63 个字符 

44 if len(name) > 63: 

45 name = name[:63] 

46 

47 # 处理边缘情况:确保截断后仍有有效名称 

48 if name.endswith('_') or name.endswith('-'): 

49 name = name[:-1] 

50 

51 return name 

52 

53 

54class MCPToolManager: 

55 def __init__(self, config_path, tt_api_key): 

56 self.config_path = config_path 

57 self.config_reader = MCPConfigReader(config_path, tt_api_key=tt_api_key) 

58 self.user_mcp = self.config_reader.get_user_mcp() 

59 self.sys_mcp = self.config_reader.get_sys_mcp() 

60 self.mcp_servers = self.sys_mcp | self.user_mcp 

61 self._tools_dict = {} # 缓存已获取的工具列表 

62 self._inited = False 

63 

64 # 全局启用/禁用用户MCP标志,默认禁用 

65 self._user_mcp_enabled = False 

66 self._sys_mcp_enabled = False 

67 

68 # 服务器状态缓存,记录每个服务器的启用/禁用状态 

69 self._server_status = self._init_server_status() 

70 

71 def _init_server_status(self): 

72 """初始化服务器状态,从配置文件中读取初始状态,包括禁用的服务器,同时会被命令行更新,维护在内存中""" 

73 

74 server_status = {} 

75 for server_name, server_config in self.mcp_servers.items(): 

76 # sys_mcp 服务器默认禁用,user_mcp 服务器默认启用 

77 if server_name in self.sys_mcp: 

78 # sys_mcp 服务器默认禁用 

79 is_enabled = False 

80 else: 

81 # user_mcp 服务器默认启用,除非配置中明确设置为禁用 

82 is_enabled = not ( 

83 server_config.get("disabled", False) 

84 or server_config.get("enabled", True) is False 

85 ) 

86 server_status[server_name] = is_enabled 

87 return server_status 

88 

89 def list_tools(self, mcp_type="user", force_load=False): 

90 """返回所有MCP服务器的工具列表 

91 [{'description': 'Get weather alerts for a US state.\n' 

92 '\n' 

93 ' Args:\n' 

94 ' state: Two-letter US state code (e.g. CA, ' 

95 'NY)\n' 

96 ' ', 

97 'inputSchema': {'properties': {'state': {'title': 'State', 

98 'type': 'string'}}, 

99 'required': ['state'], 

100 'title': 'get_alertsArguments', 

101 'type': 'object'}, 

102 'name': 'get_alerts', 

103 'server': 'server1' 

104 }, 

105 ] 

106 """ 

107 if mcp_type == "user": 

108 if not self._user_mcp_enabled and not force_load: 

109 return [] 

110 mcp_servers = self.user_mcp 

111 if self._user_mcp_enabled: 

112 print( 

113 T( 

114 "Initializing MCP server, this may take a while if it's " 

115 "the first load, please wait patiently..." 

116 ) 

117 ) 

118 else: 

119 mcp_servers = self.sys_mcp 

120 

121 all_tools = [] 

122 for server_name, server_config in mcp_servers.items(): 

123 if server_name not in self._tools_dict: 

124 try: 

125 print("+ Loading MCP", server_name) 

126 key = f"mcp_tool:{server_name}:{cache.cache_key(server_config)}" 

127 tools = cache.get_cache(key) 

128 if tools is not None: 

129 # 如果缓存中有工具列表,直接使用 

130 self._tools_dict[server_name] = tools 

131 continue 

132 

133 client = MCPClientSync(server_config) 

134 tools = client.list_tools() 

135 

136 # 为每个工具添加服务器标识 

137 for tool in tools: 

138 tool["server"] = server_name 

139 tool["id"] = build_function_call_tool_name( 

140 server_name, tool.get("name", "") 

141 ) 

142 

143 if tools: 

144 cache.set_cache(key, tools, ttl=60 * 60 * 24 * 2) 

145 

146 self._tools_dict[server_name] = tools 

147 except Exception as e: 

148 print(f"Error listing tools for server {server_name}: {e}") 

149 self._tools_dict[server_name] = [] 

150 

151 # 添加到总工具列表 

152 all_tools.extend(self._tools_dict[server_name]) 

153 self._inited = True 

154 

155 return all_tools 

156 

157 def get_tools_prompt(self): 

158 """获取工具列表并转换为 Markdown 格式""" 

159 tools = self.get_available_tools() # 获取启用的工具 

160 if not tools: 

161 return "" 

162 

163 ret = [] 

164 for tool in tools: 

165 # 构建工具信息的 JSON 对象 

166 # 去掉 inputSchema里的additionalProperties、$schema 

167 input_schema = tool.get("inputSchema", {}).copy() 

168 if "additionalProperties" in input_schema: 

169 del input_schema["additionalProperties"] 

170 if "$schema" in input_schema: 

171 del input_schema["$schema"] 

172 

173 ret.append( 

174 { 

175 "name": tool.get("id", ""), 

176 "description": tool.get("description", ""), 

177 "arguments": input_schema, 

178 } 

179 ) 

180 

181 # 转换为 JSON 字符串并添加到 Markdown 中 

182 json_str = json.dumps(ret, ensure_ascii=False) 

183 return f"```json\n{json_str}\n```\n" 

184 

185 def get_available_tools(self): 

186 """返回已经启用的工具列表""" 

187 

188 if not self._inited: 

189 self.list_tools() 

190 

191 all_tools = [] 

192 for server_name, tools in self._tools_dict.items(): 

193 # 检查全局启用状态和单个服务器状态 

194 is_user_server = server_name in self.user_mcp 

195 

196 # user_mcp服务器需要同时满足全局启用和服务器启用 

197 # sys_mcp服务器只需要服务器启用 

198 server_enabled = self._server_status.get(server_name, True) 

199 if is_user_server: 

200 server_enabled = server_enabled and self._user_mcp_enabled 

201 

202 if server_enabled: 

203 for tool in tools: 

204 tool["server"] = server_name 

205 all_tools.append(tool) 

206 return all_tools 

207 

208 def get_server_info(self, mcp_type="user") -> dict: 

209 """返回所有服务器的列表及其启用状态""" 

210 if not self._inited: 

211 # 强制加载工具信息,即使全局禁用也要获取服务器和工具数量信息 

212 self.list_tools(mcp_type=mcp_type, force_load=True) 

213 

214 if mcp_type == "user": 

215 mcp_servers = self.user_mcp 

216 else: 

217 mcp_servers = self.sys_mcp 

218 # 返回服务器列表及其启用状态 

219 servers_info = {} 

220 for server_name, _ in mcp_servers.items(): 

221 is_enabled = self._server_status.get(server_name, True) 

222 tools = self._tools_dict.get(server_name, []) 

223 servers_info[server_name] = { 

224 "enabled": is_enabled, 

225 "tools_count": len(tools), 

226 } 

227 

228 return servers_info 

229 

230 def call_tool(self, tool_name, arguments): 

231 """调用指定名称的工具,自动选择最匹配的服务器""" 

232 # 获取所有工具 

233 all_tools = self.get_available_tools() 

234 if not all_tools: 

235 return { 

236 "isError": True, 

237 "content": [{ 

238 "type": "text", 

239 "text": "No tools available to call." 

240 }] 

241 } 

242 

243 # 查找匹配的工具,根据id查找 

244 matching_tools = [t for t in all_tools if t["id"] == tool_name] 

245 if not matching_tools: 

246 return { 

247 "isError": True, 

248 "content": [{ 

249 "type": "text", 

250 "text": f"No tool found with name: {tool_name}" 

251 }] 

252 } 

253 

254 # 选择参数匹配度最高的工具 

255 best_match = None 

256 best_score = -1 

257 

258 for tool in matching_tools: 

259 score = 0 

260 required_params = [] 

261 

262 # 检查工具的输入模式 

263 if "inputSchema" in tool and "properties" in tool["inputSchema"]: 

264 properties = tool["inputSchema"]["properties"] 

265 required_params = tool["inputSchema"].get("required", []) 

266 

267 # 检查所有必需参数是否提供 

268 required_provided = all(param in arguments for param in required_params) 

269 if not required_provided: 

270 continue 

271 

272 # 计算匹配的参数数量 

273 matching_params = sum(1 for param in arguments if param in properties) 

274 extra_params = len(arguments) - matching_params 

275 

276 # 评分:匹配参数越多越好,额外参数越少越好 

277 score = matching_params - 0.1 * extra_params 

278 

279 if score > best_score: 

280 best_score = score 

281 best_match = tool 

282 

283 if not best_match: 

284 # 返回错误信息而不是抛出异常 

285 error_msg = f"No suitable tool found for {tool_name} with given arguments" 

286 return { 

287 "isError": True, 

288 "content": [{ 

289 "type": "text", 

290 "text": error_msg 

291 }] 

292 } 

293 

294 # 获取服务器配置 

295 server_name = best_match["server"] 

296 real_tool_name = best_match["name"] 

297 server_config = self.mcp_servers[server_name] 

298 

299 try: 

300 # 创建客户端并调用工具 

301 client = MCPClientSync(server_config) 

302 ret = client.call_tool(real_tool_name, arguments) 

303 # ret需要是字典,并且如果同时包含content和structuredContent字段, 

304 # 则丢弃content 

305 if (isinstance(ret, dict) and ret.get("content") 

306 and ret.get("structuredContent")): 

307 del ret["content"] 

308 return ret 

309 except Exception as e: 

310 # 捕获工具调用异常,返回错误信息给LLM 

311 error_msg = f"Tool call failed: {str(e)}" 

312 return { 

313 "isError": True, 

314 "content": [{ 

315 "type": "text", 

316 "text": error_msg 

317 }] 

318 } 

319 

320 @property 

321 def is_mcp_enabled(self): 

322 """是否需要启用MCP处理 

323 1. 如果系统MCP全局启用,则返回True 

324 2. 如果用户MCP全局禁用,则返回False 

325 3. 如果用户MCP全局启用,则返回用户MCP服务器中是否至少有一个服务器启用 

326 """ 

327 if self._sys_mcp_enabled: 

328 return True 

329 

330 if not self._user_mcp_enabled: 

331 return False 

332 

333 user_server_status = [self._server_status.get(server_name, False) for server_name in self.user_mcp] 

334 return any(user_server_status) 

335 

336 @property 

337 def is_sys_mcp_enabled(self): 

338 return self._sys_mcp_enabled 

339 

340 @property 

341 def is_user_mcp_enabled(self): 

342 return self._user_mcp_enabled 

343 

344 def enable_user_mcp(self, enable=True): 

345 """全局启用/禁用用户MCP""" 

346 self._user_mcp_enabled = enable 

347 if enable: 

348 self.list_tools(mcp_type="user") 

349 return True 

350 

351 def enable_sys_mcp(self, enable=True): 

352 """全局启用/禁用系统MCP""" 

353 self._sys_mcp_enabled = enable 

354 for server_name in self.sys_mcp: 

355 self._server_status[server_name] = enable 

356 if enable: 

357 self.list_tools(mcp_type="sys") 

358 return True 

359 

360 def enable_user_server(self, server_name, enable=True): 

361 """启用/禁用指定用户MCP服务器""" 

362 if server_name == "*" or not server_name: 

363 for srv_name in self.user_mcp.keys(): 

364 self._server_status[srv_name] = enable 

365 ret = True 

366 elif server_name in self.user_mcp: 

367 self._server_status[server_name] = enable 

368 ret = True 

369 else: 

370 return False 

371 

372 self.list_tools(mcp_type="user") 

373 return ret 

374 

375 def list_user_servers(self): 

376 """返回所有用户MCP服务器列表""" 

377 status = self.get_server_info(mcp_type="user") 

378 ServerRecord = namedtuple("ServerRecord", ["Name", "Enabled", "ToolsCount"]) 

379 return [ServerRecord(name, info['enabled'], info['tools_count']) for name, info in status.items()] 

380 

381 def get_status(self): 

382 """返回MCP状态信息供状态栏使用""" 

383 total_tools = 0 

384 enabled_tools = 0 

385 enabled_servers = 0 

386 total_servers = 0 

387 

388 sys_mcp_info = self.get_server_info(mcp_type="sys") 

389 for server_name, server_info in sys_mcp_info.items(): 

390 total_servers += 1 

391 if server_info['enabled']: 

392 enabled_servers += 1 

393 enabled_tools += server_info['tools_count'] 

394 total_tools += server_info['tools_count'] 

395 if self.is_user_mcp_enabled: 

396 user_mcp_info = self.get_server_info(mcp_type="user") 

397 for server_name, server_info in user_mcp_info.items(): 

398 total_servers += 1 

399 if server_info['enabled']: 

400 enabled_servers += 1 

401 enabled_tools += server_info['tools_count'] 

402 total_tools += server_info['tools_count'] 

403 return { 

404 'sys_mcp_enabled': self._sys_mcp_enabled, 

405 'user_mcp_enabled': self._user_mcp_enabled, 

406 'total_servers': total_servers, 

407 'total_tools': total_tools, 

408 'enabled_servers': enabled_servers, 

409 'enabled_tools': enabled_tools, 

410 }