Coverage for aipyapp/aipy/libmcp.py: 20%

213 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-11 12:02 +0200

1import asyncio 

2import contextlib 

3import json 

4import os 

5import re 

6import sys 

7from datetime import timedelta 

8 

9from loguru import logger 

10from mcp import ClientSession, StdioServerParameters 

11from mcp.client.stdio import stdio_client 

12from mcp.client.sse import sse_client 

13from mcp.client.streamable_http import streamablehttp_client 

14from mcp.shared.message import SessionMessage 

15from mcp.types import JSONRPCMessage 

16from .. import T 

17 

18# 猴子补丁:修复第三方库中的 _handle_json_response 方法 

19# MR如果合并后,可以去掉这段代码 https://github.com/modelcontextprotocol/python-sdk/pull/1057 

20async def _patched_handle_json_response( 

21 self, 

22 response, 

23 read_stream_writer, 

24 is_initialization: bool = False, 

25) -> None: 

26 """修复后的 JSON 响应处理方法""" 

27 try: 

28 content = await response.aread() 

29 

30 # Parse JSON first to determine structure 

31 data = json.loads(content) 

32 

33 if isinstance(data, list): 

34 messages = [JSONRPCMessage.model_validate(item) for item in data] # type: ignore 

35 else: 

36 message = JSONRPCMessage.model_validate(data) 

37 messages = [message] 

38 

39 for message in messages: 

40 if is_initialization: 

41 self._maybe_extract_protocol_version_from_message(message) 

42 

43 session_message = SessionMessage(message) 

44 await read_stream_writer.send(session_message) 

45 except Exception as exc: 

46 logger.error(f"Error parsing JSON response: {exc}") 

47 try: 

48 await read_stream_writer.aclose() 

49 except Exception as close_exc: 

50 logger.error(f"Error closing read_stream_writer: {close_exc}") 

51 # 重新抛出异常,让上层调用者处理 

52 raise exc 

53 

54# 应用猴子补丁 

55def _apply_streamable_http_patch(): 

56 """应用 StreamableHTTP 客户端的补丁""" 

57 try: 

58 from mcp.client.streamable_http import StreamableHTTPTransport 

59 # 替换原有方法 

60 StreamableHTTPTransport._handle_json_response = _patched_handle_json_response 

61 logger.debug("Applied StreamableHTTP patch for _handle_json_response") 

62 except ImportError as e: 

63 logger.warning(f"Failed to apply StreamableHTTP patch: {e}") 

64 

65# 在模块加载时应用补丁 

66_apply_streamable_http_patch() 

67 

68 

69# 预编译正则表达式 

70CODE_BLOCK_PATTERN = re.compile(r"```(?:json)?\s*([\s\S]*?)\s*```") 

71JSON_PATTERN = re.compile(r"(\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})") 

72 

73 

74def extra_call_tool_blocks(blocks) -> str: 

75 """ 

76 从代码块列表中提取 MCP call_tool JSON。 

77 

78 Args: 

79 blocks (list): CodeBlock 对象列表 

80 

81 Returns: 

82 str: 找到的 JSON 字符串,如果没找到则返回空字符串 

83 """ 

84 if not blocks: 

85 return "" 

86 

87 for block in blocks: 

88 # 检查代码块是否是 JSON 格式 

89 if hasattr(block, 'lang') and block.lang and block.lang.lower() in ['json', '']: 

90 # 尝试解析代码块内容 

91 content = getattr(block, 'code', '') if hasattr(block, 'code') else str(block) 

92 if content: 

93 content = content.strip() 

94 try: 

95 data = json.loads(content) 

96 # 验证是否是 call_tool 动作 

97 if not isinstance(data, dict): 

98 continue 

99 if "action" not in data or "name" not in data: 

100 continue 

101 if "arguments" in data and not isinstance(data["arguments"], dict): 

102 continue 

103 

104 # 返回 JSON 字符串 

105 return json.dumps(data, ensure_ascii=False) 

106 except json.JSONDecodeError: 

107 continue 

108 

109 return "" 

110 

111 

112def extract_call_tool_str(text) -> str: 

113 """ 

114 Extract MCP call_tool JSON from text. 

115 

116 Args: 

117 text (str): The input text that may contain MCP call_tool JSON. 

118 

119 Returns: 

120 str: The JSON str if found and valid, otherwise empty str. 

121 """ 

122 

123 # 使用预编译的正则模式 

124 code_blocks = CODE_BLOCK_PATTERN.findall(text) 

125 

126 # Potential JSON candidates to check 

127 candidates = code_blocks.copy() 

128 

129 # 使用预编译的正则模式 

130 standalone_jsons = JSON_PATTERN.findall(text) 

131 candidates.extend(standalone_jsons) 

132 

133 # Try to parse each candidate 

134 for candidate in candidates: 

135 candidate = candidate.strip() 

136 try: 

137 data = json.loads(candidate) 

138 # Validate that it's a call_tool action 

139 if not isinstance(data, dict): 

140 continue 

141 if "action" not in data or "name" not in data: 

142 continue 

143 if "arguments" in data and not isinstance(data["arguments"], dict): 

144 continue 

145 

146 # return json string. not dict 

147 return json.dumps(data, ensure_ascii=False) 

148 except json.JSONDecodeError: 

149 continue 

150 

151 return "" 

152 

153 

154class MCPConfigReader: 

155 def __init__(self, config_path, tt_api_key): 

156 self.config_path = config_path 

157 self.tt_api_key = tt_api_key 

158 

159 def _rewrite_config(self, servers): 

160 """rewrite MCP server config""" 

161 if not self.tt_api_key: 

162 return servers 

163 

164 for _, server_config in servers.items(): 

165 # 检查是否是trustoken的URL且transport类型为streamable_http 

166 url = server_config.get("url", "") 

167 transport = server_config.get("transport", {}) 

168 

169 if url.startswith(T("https://sapi.trustoken.ai")) and transport.get("type") == "streamable_http": 

170 if "headers" not in server_config: 

171 server_config["headers"] = {} 

172 

173 server_config["headers"].update({ 

174 "Authorization": f"Bearer {self.tt_api_key}" 

175 }) 

176 

177 return servers 

178 

179 def get_user_mcp(self) -> dict: 

180 """读取 mcp.json 文件并返回 MCP 服务器清单,包括禁用的服务器""" 

181 if not self.config_path: 

182 return {} 

183 try: 

184 with open(self.config_path, "r", encoding="utf-8") as f: 

185 config = json.load(f) 

186 servers = config.get("mcpServers", {}) 

187 return self._rewrite_config(servers) 

188 except FileNotFoundError: 

189 print(f"Config file not found: {self.config_path}") 

190 return {} 

191 except json.JSONDecodeError as e: 

192 print(T("Error decoding MCP config file {}: {}").format(self.config_path, e)) 

193 return {} 

194 

195 

196 def get_sys_mcp(self) -> dict: 

197 """ 

198 获取内部 MCP 服务器配置。 

199 

200 Returns: 

201 dict: 内部 MCP 服务器配置字典。 

202 """ 

203 if not self.tt_api_key: 

204 logger.warning("No Trustoken API key provided, sys_mcp will not be available.") 

205 return {} 

206 

207 return { 

208 "Trustoken-map": { 

209 "url": f"{T('https://sapi.trustoken.ai')}/aio-api/mcp/amap/", 

210 "transport": { 

211 "type": "streamable_http" 

212 }, 

213 "headers": { 

214 "Authorization": f"Bearer {self.tt_api_key}" 

215 } 

216 }, 

217 "Trustoken-search": { 

218 "url": f"{T('https://sapi.trustoken.ai')}/mcp/", 

219 "transport": { 

220 "type": "streamable_http" 

221 }, 

222 "headers": { 

223 "Authorization": f"Bearer {self.tt_api_key}" 

224 } 

225 } 

226 } 

227 

228class MCPClientSync: 

229 def __init__(self, server_config, suppress_output=True): 

230 self.server_config = server_config 

231 self.suppress_output = suppress_output 

232 self.connection_type = self._determine_connection_type() 

233 

234 def _determine_connection_type(self): 

235 """确定连接类型:stdio, sse, 或 streamable_http""" 

236 if "url" in self.server_config: 

237 # 检查是否明确指定为 streamable_http 

238 transport_type = self.server_config.get("transport", {}).get("type") 

239 if transport_type == "streamable_http": 

240 return "streamable_http" 

241 # 默认为 SSE 

242 return "sse" 

243 else: 

244 return "stdio" 

245 

246 @contextlib.contextmanager 

247 def _suppress_stdout_stderr(self): 

248 """上下文管理器:临时重定向 stdout 和 stderr 到空设备""" 

249 if not self.suppress_output: 

250 yield # 如果不需要抑制输出,直接返回 

251 return 

252 

253 # 保存原始的 stdout 和 stderr 

254 original_stdout = sys.stdout 

255 original_stderr = sys.stderr 

256 

257 try: 

258 # 使用 os.devnull - 跨平台解决方案 

259 with open(os.devnull, "w") as devnull: 

260 sys.stdout = devnull 

261 sys.stderr = devnull 

262 yield 

263 finally: 

264 # 恢复原始的 stdout 和 stderr 

265 sys.stdout = original_stdout 

266 sys.stderr = original_stderr 

267 

268 def _run_async(self, coro): 

269 with self._suppress_stdout_stderr(): 

270 try: 

271 return asyncio.run(coro) 

272 except Exception as e: 

273 print(f"Error running async function: {e}") 

274 

275 def list_tools(self) -> list: 

276 return self._run_async(self._list_tools()) or [] 

277 

278 def call_tool(self, tool_name, arguments): 

279 return self._run_async(self._call_tool(tool_name, arguments)) 

280 

281 async def _create_client_session(self): 

282 """根据连接类型创建相应的客户端会话""" 

283 if self.connection_type == "stdio": 

284 # stdio 连接 

285 server_params = StdioServerParameters( 

286 command=self.server_config.get("command"), 

287 args=self.server_config.get("args", []), 

288 env=self.server_config.get("env"), 

289 ) 

290 return stdio_client(server_params) 

291 elif self.connection_type == "sse": 

292 # SSE 连接 

293 kargs = {'url' : self.server_config["url"]} 

294 if "headers" in self.server_config and isinstance(self.server_config["headers"], dict): 

295 kargs['headers'] = self.server_config["headers"] 

296 if "timeout" in self.server_config: 

297 kargs['timeout'] = int(self.server_config["timeout"]) 

298 if "sse_read_timeout" in self.server_config: 

299 kargs['sse_read_timeout'] = int(self.server_config["sse_read_timeout"]) 

300 

301 return sse_client(**kargs) 

302 elif self.connection_type == "streamable_http": 

303 # Streamable HTTP 连接 

304 kargs = {'url' : self.server_config["url"]} 

305 if "headers" in self.server_config and isinstance(self.server_config["headers"], dict): 

306 kargs['headers'] = self.server_config["headers"] 

307 if "timeout" in self.server_config: 

308 kargs['timeout'] = timedelta(seconds=int(self.server_config["timeout"])) 

309 if "sse_read_timeout" in self.server_config: 

310 kargs['sse_read_timeout'] = timedelta(seconds=int(self.server_config["sse_read_timeout"])) 

311 

312 return streamablehttp_client(**kargs) 

313 else: 

314 raise ValueError(f"Unsupported connection type: {self.connection_type}") 

315 

316 async def _execute_with_session(self, operation): 

317 """统一的会话执行方法,处理不同客户端类型的差异""" 

318 client_session = await self._create_client_session() 

319 

320 if self.connection_type == "streamable_http": 

321 async with client_session as (read, write, _): 

322 async with ClientSession(read, write) as session: 

323 await session.initialize() 

324 return await operation(session) 

325 else: 

326 async with client_session as (read, write): 

327 async with ClientSession(read, write) as session: 

328 await session.initialize() 

329 return await operation(session) 

330 

331 async def _list_tools(self): 

332 try: 

333 async def list_operation(session): 

334 server_tools = await session.list_tools() 

335 return server_tools.model_dump().get("tools", []) 

336 

337 tools = await self._execute_with_session(list_operation) 

338 

339 if sys.platform == "win32": 

340 # FIX windows下抛异常的问题 

341 await asyncio.sleep(3) 

342 return tools 

343 except Exception as e: 

344 logger.exception(f"Failed to list tools: {e}") 

345 return [] 

346 

347 async def _call_tool(self, tool_name, arguments): 

348 try: 

349 async def call_operation(session): 

350 result = await session.call_tool(tool_name, arguments=arguments) 

351 return result.model_dump() 

352 

353 ret = await self._execute_with_session(call_operation) 

354 

355 if sys.platform == "win32": 

356 # FIX windows下抛异常的问题 

357 await asyncio.sleep(3) 

358 return ret 

359 except Exception as e: 

360 logger.exception(f"Failed to call tool {tool_name}: {e}") 

361 print(f"Failed to call tool {tool_name}: {e}") 

362 return {"error": str(e), "tool_name": tool_name, "arguments": arguments} 

363