Coverage for aipyapp/aipy/libmcp.py: 20%
213 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
1import asyncio
2import contextlib
3import json
4import os
5import re
6import sys
7from datetime import timedelta
9from loguru import logger
10from mcp import ClientSession, StdioServerParameters
11from mcp.client.stdio import stdio_client
12from mcp.client.sse import sse_client
13from mcp.client.streamable_http import streamablehttp_client
14from mcp.shared.message import SessionMessage
15from mcp.types import JSONRPCMessage
16from .. import T
18# 猴子补丁:修复第三方库中的 _handle_json_response 方法
19# MR如果合并后,可以去掉这段代码 https://github.com/modelcontextprotocol/python-sdk/pull/1057
20async def _patched_handle_json_response(
21 self,
22 response,
23 read_stream_writer,
24 is_initialization: bool = False,
25) -> None:
26 """修复后的 JSON 响应处理方法"""
27 try:
28 content = await response.aread()
30 # Parse JSON first to determine structure
31 data = json.loads(content)
33 if isinstance(data, list):
34 messages = [JSONRPCMessage.model_validate(item) for item in data] # type: ignore
35 else:
36 message = JSONRPCMessage.model_validate(data)
37 messages = [message]
39 for message in messages:
40 if is_initialization:
41 self._maybe_extract_protocol_version_from_message(message)
43 session_message = SessionMessage(message)
44 await read_stream_writer.send(session_message)
45 except Exception as exc:
46 logger.error(f"Error parsing JSON response: {exc}")
47 try:
48 await read_stream_writer.aclose()
49 except Exception as close_exc:
50 logger.error(f"Error closing read_stream_writer: {close_exc}")
51 # 重新抛出异常,让上层调用者处理
52 raise exc
54# 应用猴子补丁
55def _apply_streamable_http_patch():
56 """应用 StreamableHTTP 客户端的补丁"""
57 try:
58 from mcp.client.streamable_http import StreamableHTTPTransport
59 # 替换原有方法
60 StreamableHTTPTransport._handle_json_response = _patched_handle_json_response
61 logger.debug("Applied StreamableHTTP patch for _handle_json_response")
62 except ImportError as e:
63 logger.warning(f"Failed to apply StreamableHTTP patch: {e}")
65# 在模块加载时应用补丁
66_apply_streamable_http_patch()
69# 预编译正则表达式
70CODE_BLOCK_PATTERN = re.compile(r"```(?:json)?\s*([\s\S]*?)\s*```")
71JSON_PATTERN = re.compile(r"(\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})")
74def extra_call_tool_blocks(blocks) -> str:
75 """
76 从代码块列表中提取 MCP call_tool JSON。
78 Args:
79 blocks (list): CodeBlock 对象列表
81 Returns:
82 str: 找到的 JSON 字符串,如果没找到则返回空字符串
83 """
84 if not blocks:
85 return ""
87 for block in blocks:
88 # 检查代码块是否是 JSON 格式
89 if hasattr(block, 'lang') and block.lang and block.lang.lower() in ['json', '']:
90 # 尝试解析代码块内容
91 content = getattr(block, 'code', '') if hasattr(block, 'code') else str(block)
92 if content:
93 content = content.strip()
94 try:
95 data = json.loads(content)
96 # 验证是否是 call_tool 动作
97 if not isinstance(data, dict):
98 continue
99 if "action" not in data or "name" not in data:
100 continue
101 if "arguments" in data and not isinstance(data["arguments"], dict):
102 continue
104 # 返回 JSON 字符串
105 return json.dumps(data, ensure_ascii=False)
106 except json.JSONDecodeError:
107 continue
109 return ""
112def extract_call_tool_str(text) -> str:
113 """
114 Extract MCP call_tool JSON from text.
116 Args:
117 text (str): The input text that may contain MCP call_tool JSON.
119 Returns:
120 str: The JSON str if found and valid, otherwise empty str.
121 """
123 # 使用预编译的正则模式
124 code_blocks = CODE_BLOCK_PATTERN.findall(text)
126 # Potential JSON candidates to check
127 candidates = code_blocks.copy()
129 # 使用预编译的正则模式
130 standalone_jsons = JSON_PATTERN.findall(text)
131 candidates.extend(standalone_jsons)
133 # Try to parse each candidate
134 for candidate in candidates:
135 candidate = candidate.strip()
136 try:
137 data = json.loads(candidate)
138 # Validate that it's a call_tool action
139 if not isinstance(data, dict):
140 continue
141 if "action" not in data or "name" not in data:
142 continue
143 if "arguments" in data and not isinstance(data["arguments"], dict):
144 continue
146 # return json string. not dict
147 return json.dumps(data, ensure_ascii=False)
148 except json.JSONDecodeError:
149 continue
151 return ""
154class MCPConfigReader:
155 def __init__(self, config_path, tt_api_key):
156 self.config_path = config_path
157 self.tt_api_key = tt_api_key
159 def _rewrite_config(self, servers):
160 """rewrite MCP server config"""
161 if not self.tt_api_key:
162 return servers
164 for _, server_config in servers.items():
165 # 检查是否是trustoken的URL且transport类型为streamable_http
166 url = server_config.get("url", "")
167 transport = server_config.get("transport", {})
169 if url.startswith(T("https://sapi.trustoken.ai")) and transport.get("type") == "streamable_http":
170 if "headers" not in server_config:
171 server_config["headers"] = {}
173 server_config["headers"].update({
174 "Authorization": f"Bearer {self.tt_api_key}"
175 })
177 return servers
179 def get_user_mcp(self) -> dict:
180 """读取 mcp.json 文件并返回 MCP 服务器清单,包括禁用的服务器"""
181 if not self.config_path:
182 return {}
183 try:
184 with open(self.config_path, "r", encoding="utf-8") as f:
185 config = json.load(f)
186 servers = config.get("mcpServers", {})
187 return self._rewrite_config(servers)
188 except FileNotFoundError:
189 print(f"Config file not found: {self.config_path}")
190 return {}
191 except json.JSONDecodeError as e:
192 print(T("Error decoding MCP config file {}: {}").format(self.config_path, e))
193 return {}
196 def get_sys_mcp(self) -> dict:
197 """
198 获取内部 MCP 服务器配置。
200 Returns:
201 dict: 内部 MCP 服务器配置字典。
202 """
203 if not self.tt_api_key:
204 logger.warning("No Trustoken API key provided, sys_mcp will not be available.")
205 return {}
207 return {
208 "Trustoken-map": {
209 "url": f"{T('https://sapi.trustoken.ai')}/aio-api/mcp/amap/",
210 "transport": {
211 "type": "streamable_http"
212 },
213 "headers": {
214 "Authorization": f"Bearer {self.tt_api_key}"
215 }
216 },
217 "Trustoken-search": {
218 "url": f"{T('https://sapi.trustoken.ai')}/mcp/",
219 "transport": {
220 "type": "streamable_http"
221 },
222 "headers": {
223 "Authorization": f"Bearer {self.tt_api_key}"
224 }
225 }
226 }
228class MCPClientSync:
229 def __init__(self, server_config, suppress_output=True):
230 self.server_config = server_config
231 self.suppress_output = suppress_output
232 self.connection_type = self._determine_connection_type()
234 def _determine_connection_type(self):
235 """确定连接类型:stdio, sse, 或 streamable_http"""
236 if "url" in self.server_config:
237 # 检查是否明确指定为 streamable_http
238 transport_type = self.server_config.get("transport", {}).get("type")
239 if transport_type == "streamable_http":
240 return "streamable_http"
241 # 默认为 SSE
242 return "sse"
243 else:
244 return "stdio"
246 @contextlib.contextmanager
247 def _suppress_stdout_stderr(self):
248 """上下文管理器:临时重定向 stdout 和 stderr 到空设备"""
249 if not self.suppress_output:
250 yield # 如果不需要抑制输出,直接返回
251 return
253 # 保存原始的 stdout 和 stderr
254 original_stdout = sys.stdout
255 original_stderr = sys.stderr
257 try:
258 # 使用 os.devnull - 跨平台解决方案
259 with open(os.devnull, "w") as devnull:
260 sys.stdout = devnull
261 sys.stderr = devnull
262 yield
263 finally:
264 # 恢复原始的 stdout 和 stderr
265 sys.stdout = original_stdout
266 sys.stderr = original_stderr
268 def _run_async(self, coro):
269 with self._suppress_stdout_stderr():
270 try:
271 return asyncio.run(coro)
272 except Exception as e:
273 print(f"Error running async function: {e}")
275 def list_tools(self) -> list:
276 return self._run_async(self._list_tools()) or []
278 def call_tool(self, tool_name, arguments):
279 return self._run_async(self._call_tool(tool_name, arguments))
281 async def _create_client_session(self):
282 """根据连接类型创建相应的客户端会话"""
283 if self.connection_type == "stdio":
284 # stdio 连接
285 server_params = StdioServerParameters(
286 command=self.server_config.get("command"),
287 args=self.server_config.get("args", []),
288 env=self.server_config.get("env"),
289 )
290 return stdio_client(server_params)
291 elif self.connection_type == "sse":
292 # SSE 连接
293 kargs = {'url' : self.server_config["url"]}
294 if "headers" in self.server_config and isinstance(self.server_config["headers"], dict):
295 kargs['headers'] = self.server_config["headers"]
296 if "timeout" in self.server_config:
297 kargs['timeout'] = int(self.server_config["timeout"])
298 if "sse_read_timeout" in self.server_config:
299 kargs['sse_read_timeout'] = int(self.server_config["sse_read_timeout"])
301 return sse_client(**kargs)
302 elif self.connection_type == "streamable_http":
303 # Streamable HTTP 连接
304 kargs = {'url' : self.server_config["url"]}
305 if "headers" in self.server_config and isinstance(self.server_config["headers"], dict):
306 kargs['headers'] = self.server_config["headers"]
307 if "timeout" in self.server_config:
308 kargs['timeout'] = timedelta(seconds=int(self.server_config["timeout"]))
309 if "sse_read_timeout" in self.server_config:
310 kargs['sse_read_timeout'] = timedelta(seconds=int(self.server_config["sse_read_timeout"]))
312 return streamablehttp_client(**kargs)
313 else:
314 raise ValueError(f"Unsupported connection type: {self.connection_type}")
316 async def _execute_with_session(self, operation):
317 """统一的会话执行方法,处理不同客户端类型的差异"""
318 client_session = await self._create_client_session()
320 if self.connection_type == "streamable_http":
321 async with client_session as (read, write, _):
322 async with ClientSession(read, write) as session:
323 await session.initialize()
324 return await operation(session)
325 else:
326 async with client_session as (read, write):
327 async with ClientSession(read, write) as session:
328 await session.initialize()
329 return await operation(session)
331 async def _list_tools(self):
332 try:
333 async def list_operation(session):
334 server_tools = await session.list_tools()
335 return server_tools.model_dump().get("tools", [])
337 tools = await self._execute_with_session(list_operation)
339 if sys.platform == "win32":
340 # FIX windows下抛异常的问题
341 await asyncio.sleep(3)
342 return tools
343 except Exception as e:
344 logger.exception(f"Failed to list tools: {e}")
345 return []
347 async def _call_tool(self, tool_name, arguments):
348 try:
349 async def call_operation(session):
350 result = await session.call_tool(tool_name, arguments=arguments)
351 return result.model_dump()
353 ret = await self._execute_with_session(call_operation)
355 if sys.platform == "win32":
356 # FIX windows下抛异常的问题
357 await asyncio.sleep(3)
358 return ret
359 except Exception as e:
360 logger.exception(f"Failed to call tool {tool_name}: {e}")
361 print(f"Failed to call tool {tool_name}: {e}")
362 return {"error": str(e), "tool_name": tool_name, "arguments": arguments}