Coverage for aipyapp/aipy/blocks.py: 19%

206 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-11 12:02 +0200

1#! /usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3 

4import re 

5import json 

6from pathlib import Path 

7from collections import OrderedDict 

8from dataclasses import dataclass 

9from typing import Optional, Dict, Any 

10 

11from loguru import logger 

12 

13from .libmcp import extract_call_tool_str, extra_call_tool_blocks 

14from ..interface import Trackable 

15 

16@dataclass 

17class CodeBlock: 

18 """代码块对象""" 

19 name: str 

20 version: int 

21 lang: str 

22 code: str 

23 path: Optional[str] = None 

24 deps: Optional[Dict[str, set]] = None 

25 

26 def add_dep(self, dep_name: str, dep_value: Any): 

27 """添加依赖""" 

28 if self.deps is None: 

29 self.deps = {} 

30 if dep_name not in self.deps: 

31 deps = set() 

32 self.deps[dep_name] = deps 

33 else: 

34 deps = self.deps[dep_name] 

35 

36 # dep_value 可以是单个值,或者一个可迭代对象 

37 if isinstance(dep_value, (list, set, tuple)): 

38 deps.update(dep_value) 

39 else: 

40 deps.add(dep_value) 

41 

42 def save(self): 

43 """保存代码块到文件""" 

44 if not self.path: 

45 return False 

46 

47 path = Path(self.path) 

48 path.parent.mkdir(parents=True, exist_ok=True) 

49 path.write_text(self.code, encoding='utf-8') 

50 return True 

51 

52 @property 

53 def abs_path(self): 

54 if self.path: 

55 return Path(self.path).absolute() 

56 return None 

57 

58 def get_lang(self): 

59 lang = self.lang.lower() 

60 return lang 

61 

62 def to_dict(self) -> Dict[str, Any]: 

63 """转换为字典""" 

64 return { 

65 '__type__': 'CodeBlock', 

66 'name': self.name, 

67 'version': self.version, 

68 'lang': self.lang, 

69 'code': self.code, 

70 'path': self.path, 

71 'deps': self.deps 

72 } 

73 

74 @classmethod 

75 def from_dict(cls, data: Dict[str, Any]) -> 'CodeBlock': 

76 """从字典恢复对象""" 

77 return cls( 

78 name=data.get('name', ''), 

79 version=data.get('version', 1), 

80 lang=data.get('lang', ''), 

81 code=data.get('code', ''), 

82 path=data.get('path'), 

83 deps=data.get('deps') 

84 ) 

85 

86 def __repr__(self): 

87 return f"<CodeBlock name={self.name}, version={self.version}, lang={self.lang}, path={self.path}>" 

88 

89class CodeBlocks(Trackable): 

90 def __init__(self): 

91 self.history = [] 

92 self.blocks = OrderedDict() 

93 self.code_pattern = re.compile( 

94 r'<!--\s*Block-Start:\s*(\{.*?\})\s*-->\s*```(\w+)?\s*\n(.*?)\n```\s*<!--\s*Block-End:\s*(\{.*?\})\s*-->', 

95 re.DOTALL 

96 ) 

97 self.line_pattern = re.compile( 

98 r'<!--\s*Cmd-(\w+):\s*(\{.*?\})\s*-->' 

99 ) 

100 self.log = logger.bind(src='code_blocks') 

101 

102 def __len__(self): 

103 return len(self.blocks) 

104 

105 def parse(self, markdown_text, parse_mcp=False): 

106 blocks = OrderedDict() 

107 errors = [] 

108 for match in self.code_pattern.finditer(markdown_text): 

109 start_json, lang, content, end_json = match.groups() 

110 try: 

111 start_meta = json.loads(start_json) 

112 end_meta = json.loads(end_json) 

113 except json.JSONDecodeError as e: 

114 self.log.exception('Error parsing code block', start_json=start_json, end_json=end_json) 

115 error = {'JSONDecodeError': {'Block-Start': start_json, 'Block-End': end_json, 'reason': str(e)}} 

116 errors.append(error) 

117 continue 

118 

119 code_name = start_meta.get("name") 

120 if code_name != end_meta.get("name"): 

121 self.log.error("Start and end name mismatch", start_name=code_name, end_name=end_meta.get("name")) 

122 error = {'Start and end name mismatch': {'start_name': code_name, 'end_name': end_meta.get("name")}} 

123 errors.append(error) 

124 continue 

125 

126 version = start_meta.get("version", 1) 

127 if code_name in blocks or code_name in self.blocks: 

128 old_block = blocks.get(code_name) or self.blocks.get(code_name) 

129 old_version = old_block.version 

130 if old_version >= version: 

131 self.log.error("Duplicate code name with same or newer version", code_name=code_name, old_version=old_version, version=version) 

132 error = {'Duplicate code name with same or newer version': {'code_name': code_name, 'old_version': old_version, 'version': version}} 

133 errors.append(error) 

134 continue 

135 

136 # 创建代码块对象 

137 block = CodeBlock( 

138 name=code_name, 

139 version=version, 

140 lang=lang, 

141 code=content, 

142 path=start_meta.get('path'), 

143 ) 

144 

145 blocks[code_name] = block 

146 self.history.append(block) 

147 self.log.info("Parsed code block", code_block=block) 

148 

149 try: 

150 block.save() 

151 self.log.info("Saved code block", code_block=block) 

152 except Exception as e: 

153 self.log.error("Failed to save file", code_block=block, reason=e) 

154 

155 self.blocks.update(blocks) 

156 

157 commands = [] # 按顺序收集所有命令 

158 line_matches = self.line_pattern.findall(markdown_text) 

159 for line_match in line_matches: 

160 cmd, json_str = line_match 

161 try: 

162 line_meta = json.loads(json_str) 

163 except json.JSONDecodeError as e: 

164 self.log.error(f"Invalid JSON in Cmd-{cmd} block", json_str=json_str, reason=e) 

165 error = {f'Invalid JSON in Cmd-{cmd} block': {'json_str': json_str, 'reason': str(e)}} 

166 errors.append(error) 

167 continue 

168 

169 error = None 

170 if cmd == 'Exec': 

171 exec_name = line_meta.get("name") 

172 if not exec_name: 

173 error = {'Cmd-Exec block without name': {'json_str': json_str}} 

174 elif exec_name not in self.blocks: 

175 error = {'Cmd-Exec block not found': {'exec_name': exec_name, 'json_str': json_str}} 

176 else: 

177 commands.append({ 

178 'type': 'exec', 

179 'block_name': exec_name 

180 }) 

181 elif cmd == 'Edit': 

182 edit_name = line_meta.get("name") 

183 if not edit_name: 

184 error = {'Cmd-Edit block without name': {'json_str': json_str}} 

185 elif edit_name not in self.blocks: 

186 error = {'Cmd-Edit block not found': {'edit_name': edit_name, 'json_str': json_str}} 

187 elif not line_meta.get("old"): 

188 error = {'Cmd-Edit block without old string': {'json_str': json_str}} 

189 elif "new" not in line_meta: 

190 error = {'Cmd-Edit block without new string': {'json_str': json_str}} 

191 else: 

192 commands.append({ 

193 'type': 'edit', 

194 'instruction': { 

195 'name': edit_name, 

196 'old': line_meta.get("old"), 

197 'new': line_meta.get("new"), 

198 'replace_all': line_meta.get("replace_all", False), 

199 'json_str': json_str 

200 } 

201 }) 

202 else: 

203 error = {f'Unknown command in Cmd-{cmd} block': {'cmd': cmd}} 

204 

205 if error: 

206 errors.append(error) 

207 

208 ret = {} 

209 if errors: ret['errors'] = errors 

210 if commands: ret['commands'] = commands 

211 if blocks: ret['blocks'] = [v for v in blocks.values()] 

212 

213 if parse_mcp: 

214 # 首先尝试从代码块中提取 MCP 调用, 然后尝试从markdown文本中提取 

215 json_content = extra_call_tool_blocks(list(blocks.values())) or extract_call_tool_str(markdown_text) 

216 

217 if json_content: 

218 ret['call_tool'] = json_content 

219 self.log.info("Parsed MCP call_tool", json_content=json_content) 

220 

221 return ret 

222 

223 def get_code_by_name(self, code_name): 

224 try: 

225 return self.blocks[code_name].code 

226 except KeyError: 

227 self.log.error("Code name not found", code_name=code_name) 

228 return None 

229 

230 def get_block_by_name(self, code_name): 

231 try: 

232 return self.blocks[code_name] 

233 except KeyError: 

234 self.log.error("Code name not found", code_name=code_name) 

235 return None 

236 

237 def apply_edit_modification(self, edit_instruction): 

238 """ 

239 应用编辑指令到指定代码块,创建新版本而不修改原代码块 

240  

241 Args: 

242 edit_instruction: 包含name, old, new, replace_all等字段的编辑指令 

243  

244 Returns: 

245 tuple: (success: bool, message: str, new_block: CodeBlock or None) 

246 """ 

247 name = edit_instruction['name'] 

248 old_str = edit_instruction['old'] 

249 new_str = edit_instruction['new'] 

250 replace_all = edit_instruction.get('replace_all', False) 

251 

252 if name not in self.blocks: 

253 return False, f"代码块 '{name}' 不存在", None 

254 

255 original_block = self.blocks[name] 

256 

257 # 检查是否找到匹配的字符串 

258 if old_str not in original_block.code: 

259 return False, f"未找到匹配的代码片段: {old_str[:50]}...", None 

260 

261 # 检查匹配次数 

262 match_count = original_block.code.count(old_str) 

263 if match_count > 1 and not replace_all: 

264 return False, f"代码片段匹配 {match_count} 个位置,请设置 replace_all: true 或提供更具体的上下文", None 

265 

266 # 执行替换生成新代码 

267 if replace_all: 

268 new_code = original_block.code.replace(old_str, new_str) 

269 replaced_count = match_count 

270 else: 

271 # 只替换第一个匹配项 

272 new_code = original_block.code.replace(old_str, new_str, 1) 

273 replaced_count = 1 

274 

275 # 创建新的代码块(版本号+1) 

276 new_block = CodeBlock( 

277 name=original_block.name, 

278 version=original_block.version + 1, 

279 lang=original_block.lang, 

280 code=new_code, 

281 path=original_block.path, 

282 deps=original_block.deps.copy() if original_block.deps else None 

283 ) 

284 

285 # 保存新代码块到文件 

286 try: 

287 new_block.save() 

288 self.log.info("Created and saved new block version", code_block=new_block, replaced_count=replaced_count) 

289 except Exception as e: 

290 self.log.error("Failed to save new block", code_block=new_block, reason=e) 

291 

292 # 更新blocks字典为新版本(同名代码块始终指向最新版本) 

293 self.blocks[name] = new_block 

294 

295 # 添加新版本到历史记录 

296 self.history.append(new_block) 

297 

298 message = f"成功替换 {replaced_count} 处匹配项,创建版本 v{new_block.version}" 

299 return True, message, new_block 

300 

301 def to_list(self): 

302 """将 CodeBlocks 对象转换为 JSON 字符串 

303  

304 Returns: 

305 str: JSON 格式的字符串 

306 """ 

307 blocks = [block.to_dict() for block in self.history] 

308 return blocks 

309 

310 def get_state(self): 

311 """获取需要持久化的状态数据""" 

312 return self.to_list() 

313 

314 def restore_state(self, blocks_data): 

315 """从代码块数据恢复状态""" 

316 self.history.clear() 

317 self.blocks.clear() 

318 

319 if blocks_data: 

320 for block_data in blocks_data: 

321 code_block = CodeBlock( 

322 name=block_data['name'], 

323 version=block_data['version'], 

324 lang=block_data['lang'], 

325 code=block_data['code'], 

326 path=block_data.get('path'), 

327 deps=block_data.get('deps') 

328 ) 

329 self.history.append(code_block) 

330 self.blocks[code_block.name] = code_block 

331 

332 

333 def clear(self): 

334 self.history.clear() 

335 self.blocks.clear() 

336 

337 # Trackable接口实现 

338 def get_checkpoint(self) -> int: 

339 """获取当前检查点状态 - 返回history长度""" 

340 return len(self.history) 

341 

342 def restore_to_checkpoint(self, checkpoint: Optional[int]): 

343 """恢复到指定检查点""" 

344 if checkpoint is None: 

345 # 恢复到初始状态 

346 self.clear() 

347 else: 

348 # 恢复到指定长度 

349 if checkpoint < len(self.history): 

350 # 获取要删除的代码块 

351 deleted_blocks = self.history[checkpoint:] 

352 

353 # 从 blocks 字典中删除对应的代码块 

354 for block in deleted_blocks: 

355 if block.name in self.blocks: 

356 del self.blocks[block.name] 

357 

358 # 截断 history 到指定长度 

359 self.history = self.history[:checkpoint]