Coverage for aipyapp/aipy/blocks.py: 19%
206 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
1#! /usr/bin/env python3
2# -*- coding: utf-8 -*-
4import re
5import json
6from pathlib import Path
7from collections import OrderedDict
8from dataclasses import dataclass
9from typing import Optional, Dict, Any
11from loguru import logger
13from .libmcp import extract_call_tool_str, extra_call_tool_blocks
14from ..interface import Trackable
16@dataclass
17class CodeBlock:
18 """代码块对象"""
19 name: str
20 version: int
21 lang: str
22 code: str
23 path: Optional[str] = None
24 deps: Optional[Dict[str, set]] = None
26 def add_dep(self, dep_name: str, dep_value: Any):
27 """添加依赖"""
28 if self.deps is None:
29 self.deps = {}
30 if dep_name not in self.deps:
31 deps = set()
32 self.deps[dep_name] = deps
33 else:
34 deps = self.deps[dep_name]
36 # dep_value 可以是单个值,或者一个可迭代对象
37 if isinstance(dep_value, (list, set, tuple)):
38 deps.update(dep_value)
39 else:
40 deps.add(dep_value)
42 def save(self):
43 """保存代码块到文件"""
44 if not self.path:
45 return False
47 path = Path(self.path)
48 path.parent.mkdir(parents=True, exist_ok=True)
49 path.write_text(self.code, encoding='utf-8')
50 return True
52 @property
53 def abs_path(self):
54 if self.path:
55 return Path(self.path).absolute()
56 return None
58 def get_lang(self):
59 lang = self.lang.lower()
60 return lang
62 def to_dict(self) -> Dict[str, Any]:
63 """转换为字典"""
64 return {
65 '__type__': 'CodeBlock',
66 'name': self.name,
67 'version': self.version,
68 'lang': self.lang,
69 'code': self.code,
70 'path': self.path,
71 'deps': self.deps
72 }
74 @classmethod
75 def from_dict(cls, data: Dict[str, Any]) -> 'CodeBlock':
76 """从字典恢复对象"""
77 return cls(
78 name=data.get('name', ''),
79 version=data.get('version', 1),
80 lang=data.get('lang', ''),
81 code=data.get('code', ''),
82 path=data.get('path'),
83 deps=data.get('deps')
84 )
86 def __repr__(self):
87 return f"<CodeBlock name={self.name}, version={self.version}, lang={self.lang}, path={self.path}>"
89class CodeBlocks(Trackable):
90 def __init__(self):
91 self.history = []
92 self.blocks = OrderedDict()
93 self.code_pattern = re.compile(
94 r'<!--\s*Block-Start:\s*(\{.*?\})\s*-->\s*```(\w+)?\s*\n(.*?)\n```\s*<!--\s*Block-End:\s*(\{.*?\})\s*-->',
95 re.DOTALL
96 )
97 self.line_pattern = re.compile(
98 r'<!--\s*Cmd-(\w+):\s*(\{.*?\})\s*-->'
99 )
100 self.log = logger.bind(src='code_blocks')
102 def __len__(self):
103 return len(self.blocks)
105 def parse(self, markdown_text, parse_mcp=False):
106 blocks = OrderedDict()
107 errors = []
108 for match in self.code_pattern.finditer(markdown_text):
109 start_json, lang, content, end_json = match.groups()
110 try:
111 start_meta = json.loads(start_json)
112 end_meta = json.loads(end_json)
113 except json.JSONDecodeError as e:
114 self.log.exception('Error parsing code block', start_json=start_json, end_json=end_json)
115 error = {'JSONDecodeError': {'Block-Start': start_json, 'Block-End': end_json, 'reason': str(e)}}
116 errors.append(error)
117 continue
119 code_name = start_meta.get("name")
120 if code_name != end_meta.get("name"):
121 self.log.error("Start and end name mismatch", start_name=code_name, end_name=end_meta.get("name"))
122 error = {'Start and end name mismatch': {'start_name': code_name, 'end_name': end_meta.get("name")}}
123 errors.append(error)
124 continue
126 version = start_meta.get("version", 1)
127 if code_name in blocks or code_name in self.blocks:
128 old_block = blocks.get(code_name) or self.blocks.get(code_name)
129 old_version = old_block.version
130 if old_version >= version:
131 self.log.error("Duplicate code name with same or newer version", code_name=code_name, old_version=old_version, version=version)
132 error = {'Duplicate code name with same or newer version': {'code_name': code_name, 'old_version': old_version, 'version': version}}
133 errors.append(error)
134 continue
136 # 创建代码块对象
137 block = CodeBlock(
138 name=code_name,
139 version=version,
140 lang=lang,
141 code=content,
142 path=start_meta.get('path'),
143 )
145 blocks[code_name] = block
146 self.history.append(block)
147 self.log.info("Parsed code block", code_block=block)
149 try:
150 block.save()
151 self.log.info("Saved code block", code_block=block)
152 except Exception as e:
153 self.log.error("Failed to save file", code_block=block, reason=e)
155 self.blocks.update(blocks)
157 commands = [] # 按顺序收集所有命令
158 line_matches = self.line_pattern.findall(markdown_text)
159 for line_match in line_matches:
160 cmd, json_str = line_match
161 try:
162 line_meta = json.loads(json_str)
163 except json.JSONDecodeError as e:
164 self.log.error(f"Invalid JSON in Cmd-{cmd} block", json_str=json_str, reason=e)
165 error = {f'Invalid JSON in Cmd-{cmd} block': {'json_str': json_str, 'reason': str(e)}}
166 errors.append(error)
167 continue
169 error = None
170 if cmd == 'Exec':
171 exec_name = line_meta.get("name")
172 if not exec_name:
173 error = {'Cmd-Exec block without name': {'json_str': json_str}}
174 elif exec_name not in self.blocks:
175 error = {'Cmd-Exec block not found': {'exec_name': exec_name, 'json_str': json_str}}
176 else:
177 commands.append({
178 'type': 'exec',
179 'block_name': exec_name
180 })
181 elif cmd == 'Edit':
182 edit_name = line_meta.get("name")
183 if not edit_name:
184 error = {'Cmd-Edit block without name': {'json_str': json_str}}
185 elif edit_name not in self.blocks:
186 error = {'Cmd-Edit block not found': {'edit_name': edit_name, 'json_str': json_str}}
187 elif not line_meta.get("old"):
188 error = {'Cmd-Edit block without old string': {'json_str': json_str}}
189 elif "new" not in line_meta:
190 error = {'Cmd-Edit block without new string': {'json_str': json_str}}
191 else:
192 commands.append({
193 'type': 'edit',
194 'instruction': {
195 'name': edit_name,
196 'old': line_meta.get("old"),
197 'new': line_meta.get("new"),
198 'replace_all': line_meta.get("replace_all", False),
199 'json_str': json_str
200 }
201 })
202 else:
203 error = {f'Unknown command in Cmd-{cmd} block': {'cmd': cmd}}
205 if error:
206 errors.append(error)
208 ret = {}
209 if errors: ret['errors'] = errors
210 if commands: ret['commands'] = commands
211 if blocks: ret['blocks'] = [v for v in blocks.values()]
213 if parse_mcp:
214 # 首先尝试从代码块中提取 MCP 调用, 然后尝试从markdown文本中提取
215 json_content = extra_call_tool_blocks(list(blocks.values())) or extract_call_tool_str(markdown_text)
217 if json_content:
218 ret['call_tool'] = json_content
219 self.log.info("Parsed MCP call_tool", json_content=json_content)
221 return ret
223 def get_code_by_name(self, code_name):
224 try:
225 return self.blocks[code_name].code
226 except KeyError:
227 self.log.error("Code name not found", code_name=code_name)
228 return None
230 def get_block_by_name(self, code_name):
231 try:
232 return self.blocks[code_name]
233 except KeyError:
234 self.log.error("Code name not found", code_name=code_name)
235 return None
237 def apply_edit_modification(self, edit_instruction):
238 """
239 应用编辑指令到指定代码块,创建新版本而不修改原代码块
241 Args:
242 edit_instruction: 包含name, old, new, replace_all等字段的编辑指令
244 Returns:
245 tuple: (success: bool, message: str, new_block: CodeBlock or None)
246 """
247 name = edit_instruction['name']
248 old_str = edit_instruction['old']
249 new_str = edit_instruction['new']
250 replace_all = edit_instruction.get('replace_all', False)
252 if name not in self.blocks:
253 return False, f"代码块 '{name}' 不存在", None
255 original_block = self.blocks[name]
257 # 检查是否找到匹配的字符串
258 if old_str not in original_block.code:
259 return False, f"未找到匹配的代码片段: {old_str[:50]}...", None
261 # 检查匹配次数
262 match_count = original_block.code.count(old_str)
263 if match_count > 1 and not replace_all:
264 return False, f"代码片段匹配 {match_count} 个位置,请设置 replace_all: true 或提供更具体的上下文", None
266 # 执行替换生成新代码
267 if replace_all:
268 new_code = original_block.code.replace(old_str, new_str)
269 replaced_count = match_count
270 else:
271 # 只替换第一个匹配项
272 new_code = original_block.code.replace(old_str, new_str, 1)
273 replaced_count = 1
275 # 创建新的代码块(版本号+1)
276 new_block = CodeBlock(
277 name=original_block.name,
278 version=original_block.version + 1,
279 lang=original_block.lang,
280 code=new_code,
281 path=original_block.path,
282 deps=original_block.deps.copy() if original_block.deps else None
283 )
285 # 保存新代码块到文件
286 try:
287 new_block.save()
288 self.log.info("Created and saved new block version", code_block=new_block, replaced_count=replaced_count)
289 except Exception as e:
290 self.log.error("Failed to save new block", code_block=new_block, reason=e)
292 # 更新blocks字典为新版本(同名代码块始终指向最新版本)
293 self.blocks[name] = new_block
295 # 添加新版本到历史记录
296 self.history.append(new_block)
298 message = f"成功替换 {replaced_count} 处匹配项,创建版本 v{new_block.version}"
299 return True, message, new_block
301 def to_list(self):
302 """将 CodeBlocks 对象转换为 JSON 字符串
304 Returns:
305 str: JSON 格式的字符串
306 """
307 blocks = [block.to_dict() for block in self.history]
308 return blocks
310 def get_state(self):
311 """获取需要持久化的状态数据"""
312 return self.to_list()
314 def restore_state(self, blocks_data):
315 """从代码块数据恢复状态"""
316 self.history.clear()
317 self.blocks.clear()
319 if blocks_data:
320 for block_data in blocks_data:
321 code_block = CodeBlock(
322 name=block_data['name'],
323 version=block_data['version'],
324 lang=block_data['lang'],
325 code=block_data['code'],
326 path=block_data.get('path'),
327 deps=block_data.get('deps')
328 )
329 self.history.append(code_block)
330 self.blocks[code_block.name] = code_block
333 def clear(self):
334 self.history.clear()
335 self.blocks.clear()
337 # Trackable接口实现
338 def get_checkpoint(self) -> int:
339 """获取当前检查点状态 - 返回history长度"""
340 return len(self.history)
342 def restore_to_checkpoint(self, checkpoint: Optional[int]):
343 """恢复到指定检查点"""
344 if checkpoint is None:
345 # 恢复到初始状态
346 self.clear()
347 else:
348 # 恢复到指定长度
349 if checkpoint < len(self.history):
350 # 获取要删除的代码块
351 deleted_blocks = self.history[checkpoint:]
353 # 从 blocks 字典中删除对应的代码块
354 for block in deleted_blocks:
355 if block.name in self.blocks:
356 del self.blocks[block.name]
358 # 截断 history 到指定长度
359 self.history = self.history[:checkpoint]