Coverage for aipyapp/aipy/task.py: 16%

371 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-11 13:03 +0200

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import os 

5import json 

6import uuid 

7import time 

8import re 

9import yaml 

10from typing import Dict, Tuple 

11from datetime import datetime 

12from collections import namedtuple, OrderedDict 

13from importlib.resources import read_text 

14 

15import requests 

16from loguru import logger 

17 

18from .. import T, __respkg__ 

19from ..exec import BlockExecutor 

20from .runtime import CliPythonRuntime 

21from .utils import get_safe_filename 

22from .blocks import CodeBlocks, CodeBlock 

23from ..interface import Stoppable, EventBus 

24from .step_manager import StepManager 

25from .multimodal import MMContent, LLMContext 

26from .context_manager import ContextManager, ContextConfig 

27from .event_recorder import EventRecorder 

28from .task_state import TaskState 

29 

30CONSOLE_WHITE_HTML = read_text(__respkg__, "console_white.html") 

31CONSOLE_CODE_HTML = read_text(__respkg__, "console_code.html") 

32 

33class TaskError(Exception): 

34 """Task 异常""" 

35 pass 

36 

37class TaskInputError(TaskError): 

38 """Task 输入异常""" 

39 def __init__(self, message: str, original_error: Exception = None): 

40 self.message = message 

41 self.original_error = original_error 

42 super().__init__(self.message) 

43 

44class TastStateError(TaskError): 

45 """Task 状态异常""" 

46 def __init__(self, message: str, **kwargs): 

47 self.message = message 

48 self.data = kwargs 

49 super().__init__(self.message) 

50 

51class Task(Stoppable, EventBus): 

52 MAX_ROUNDS = 16 

53 

54 def __init__(self, context): 

55 super().__init__() 

56 self.task_id = uuid.uuid4().hex 

57 self.log = logger.bind(src='task', id=self.task_id) 

58 

59 self.context = context 

60 self.settings = context.settings 

61 self.prompts = context.prompts 

62 

63 self.start_time = None 

64 self.done_time = None 

65 self.instruction = None 

66 self.saved = None 

67 

68 #TODO: 移除 gui 参数 

69 self.gui = self.settings.gui 

70 

71 self.cwd = context.cwd / self.task_id 

72 self.max_rounds = self.settings.get('max_rounds', self.MAX_ROUNDS) 

73 

74 self.mcp = context.mcp 

75 self.display = None 

76 self.plugins = {} 

77 

78 # 创建上下文管理器(从Client移到Task) 

79 context_settings = self.settings.get('context_manager', {}) 

80 self.context_manager = ContextManager(ContextConfig.from_dict(context_settings)) 

81 

82 # 创建Client时传入context_manager 

83 self.client = context.client_manager.Client(self, self.context_manager) 

84 self.role = context.role_manager.current_role 

85 self.code_blocks = CodeBlocks() 

86 self.runtime = CliPythonRuntime(self) 

87 self.runner = BlockExecutor() 

88 self.runner.set_python_runtime(self.runtime) 

89 

90 # 注册所有可追踪对象到步骤管理器 

91 self.step_manager = StepManager() 

92 self.step_manager.register_trackable('messages', self.context_manager) 

93 self.step_manager.register_trackable('runner', self.runner) 

94 self.step_manager.register_trackable('blocks', self.code_blocks) 

95 

96 # 初始化事件记录器 

97 enable_replay = self.settings.get('enable_replay_recording', True) 

98 if enable_replay: 

99 self.event_recorder = EventRecorder(enabled=True) 

100 # 注册事件记录器到步骤管理器 

101 self.step_manager.register_trackable('events', self.event_recorder) 

102 else: 

103 self.event_recorder = None 

104 

105 self.init_plugins() 

106 

107 def emit(self, event_name: str, **kwargs): 

108 """重写broadcast方法以记录事件""" 

109 # 记录事件到事件记录器 

110 if self.event_recorder: 

111 self.event_recorder.record_event(event_name, kwargs.copy()) 

112 

113 # 调用父类的broadcast方法 

114 super().emit(event_name, **kwargs) 

115 

116 def restore_state(self, task_data): 

117 """从任务状态加载任务 

118  

119 Args: 

120 task_data: 任务状态数据(字典格式)或 TaskState 对象 

121  

122 Returns: 

123 Task: 加载的任务对象 

124 """ 

125 # 支持传入字典或 TaskState 对象 

126 if isinstance(task_data, dict): 

127 task_state = TaskState.from_dict(task_data) 

128 elif isinstance(task_data, TaskState): 

129 task_state = task_data 

130 else: 

131 raise TastStateError('Invalid task_data type, expected dict or TaskState') 

132 

133 # 使用 TaskState 恢复状态 

134 task_state.restore_to_task(self) 

135 

136 def get_status(self): 

137 return { 

138 'llm': self.client.name, 

139 'blocks': len(self.code_blocks), 

140 'steps': len(self.step_manager), 

141 } 

142 

143 def init_plugins(self): 

144 """初始化插件""" 

145 plugin_manager = self.context.plugin_manager 

146 for plugin_name, plugin_data in self.role.plugins.items(): 

147 plugin = plugin_manager.create_task_plugin(plugin_name, plugin_data) 

148 if not plugin: 

149 self.log.warning(f"Create task plugin {plugin_name} failed") 

150 continue 

151 self.add_listener(plugin) 

152 self.runtime.register_plugin(plugin) 

153 self.plugins[plugin_name] = plugin 

154 

155 # 注册显示效果插件 

156 if self.context.display_manager: 

157 self.display = self.context.display_manager.create_display_plugin() 

158 self.add_listener(self.display) 

159 

160 def to_record(self): 

161 TaskRecord = namedtuple('TaskRecord', ['task_id', 'start_time', 'done_time', 'instruction']) 

162 start_time = datetime.fromtimestamp(self.start_time).strftime('%H:%M:%S') if self.start_time else '-' 

163 done_time = datetime.fromtimestamp(self.done_time).strftime('%H:%M:%S') if self.done_time else '-' 

164 return TaskRecord( 

165 task_id=self.task_id, 

166 start_time=start_time, 

167 done_time=done_time, 

168 instruction=self.title[:32] if self.title else '-' 

169 ) 

170 

171 def use(self, name): 

172 ret = self.client.use(name) 

173 return ret 

174 

175 def save(self, path): 

176 self.display.save(path, clear=False, code_format=CONSOLE_WHITE_HTML) 

177 

178 def save_html(self, path, task): 

179 if 'chats' in task and isinstance(task['chats'], list) and len(task['chats']) > 0: 

180 if task['chats'][0]['role'] == 'system': 

181 task['chats'].pop(0) 

182 

183 task_json = json.dumps(task, ensure_ascii=False, default=str) 

184 html_content = CONSOLE_CODE_HTML.replace('{{code}}', task_json) 

185 try: 

186 with open(path, 'w', encoding='utf-8') as f: 

187 f.write(html_content) 

188 except Exception as e: 

189 self.log.exception('Error saving html') 

190 self.emit('exception', msg='save_html', exception=e) 

191 

192 def _auto_save(self): 

193 """自动保存任务状态""" 

194 # 如果任务目录不存在,则不保存 

195 if not self.cwd.exists(): 

196 self.log.warning('Task directory not found, skipping save') 

197 return 

198 

199 self.done_time = time.time() 

200 try: 

201 # 创建 TaskState 对象并保存 

202 task_state = TaskState(self) 

203 task_state.save_to_file(self.cwd / "task.json") 

204 

205 # 保存 HTML 控制台 

206 filename = self.cwd / "console.html" 

207 self.save(filename) 

208 

209 self.saved = True 

210 self.log.info('Task auto saved') 

211 except Exception as e: 

212 self.log.exception('Error saving task') 

213 self.emit('exception', msg='save_task', exception=e) 

214 

215 def done(self): 

216 if not self.instruction or not self.start_time: 

217 self.log.warning('Task not started, skipping save') 

218 return 

219 

220 os.chdir(self.context.cwd) # Change back to the original working directory 

221 curname = self.task_id 

222 if os.path.exists(curname): 

223 if not self.saved: 

224 self.log.warning('Task not saved, trying to save') 

225 self._auto_save() 

226 

227 newname = get_safe_filename(self.instruction, extension=None) 

228 if newname: 

229 try: 

230 os.rename(curname, newname) 

231 except Exception as e: 

232 self.log.exception('Error renaming task directory', curname=curname, newname=newname) 

233 else: 

234 newname = None 

235 self.log.warning('Task directory not found') 

236 

237 self.log.info('Task done', path=newname) 

238 self.emit('task_end', path=newname) 

239 self.context.diagnose.report_code_error(self.runner.history) 

240 if self.settings.get('share_result'): 

241 self.sync_to_cloud() 

242 

243 def process_reply(self, markdown): 

244 ret = self.code_blocks.parse(markdown, parse_mcp=self.mcp) 

245 self.emit('parse_reply', result=ret) 

246 if not ret: 

247 return None 

248 

249 if 'call_tool' in ret: 

250 return self.process_mcp_reply(ret['call_tool']) 

251 

252 errors = ret.get('errors') 

253 if errors: 

254 prompt = self.prompts.get_parse_error_prompt(errors) 

255 return self.chat(prompt) 

256 

257 commands = ret.get('commands', []) 

258 if commands: 

259 return self.process_commands(commands) 

260 

261 return None 

262 

263 def run_code_block(self, block): 

264 """运行代码块""" 

265 self.emit('exec', block=block) 

266 result = self.runner(block) 

267 self.emit('exec_result', result=result, block=block) 

268 return result 

269 

270 def process_code_reply(self, exec_blocks): 

271 results = OrderedDict() 

272 for block in exec_blocks: 

273 result = self.run_code_block(block) 

274 results[block.name] = result 

275 

276 msg = self.prompts.get_results_prompt(results) 

277 return self.chat(msg) 

278 

279 def process_commands(self, commands): 

280 """按顺序处理混合指令,智能错误处理""" 

281 all_results = OrderedDict() 

282 failed_blocks = set() # 记录编辑失败的代码块 

283 

284 for command in commands: 

285 cmd_type = command['type'] 

286 

287 if cmd_type == 'exec': 

288 block_name = command['block_name'] 

289 

290 # 如果这个代码块之前编辑失败,跳过执行 

291 if block_name in failed_blocks: 

292 self.log.warning(f'Skipping execution of {block_name} due to previous edit failure') 

293 all_results[f"exec_{block_name}"] = { 

294 'type': 'exec', 

295 'block_name': block_name, 

296 'result': { 

297 'error': f'Execution skipped: previous edit of {block_name} failed', 

298 'skipped': True 

299 } 

300 } 

301 continue 

302 

303 # 动态获取最新版本的代码块 

304 block = self.code_blocks.blocks[block_name] 

305 result = self.run_code_block(block) 

306 all_results[f"exec_{block_name}"] = { 

307 'type': 'exec', 

308 'block_name': block_name, 

309 'result': result 

310 } 

311 

312 elif cmd_type == 'edit': 

313 edit_instruction = command['instruction'] 

314 block_name = edit_instruction['name'] 

315 

316 self.emit('edit_start', instruction=edit_instruction) 

317 success, message, modified_block = self.code_blocks.apply_edit_modification(edit_instruction) 

318 

319 # 编辑失败时标记这个代码块 

320 if not success: 

321 failed_blocks.add(block_name) 

322 self.log.warning(f'Edit failed for {block_name}: {message}') 

323 

324 result = { 

325 'type': 'edit', 

326 'success': success, 

327 'message': message, 

328 'block_name': block_name, 

329 'old_str': edit_instruction['old'][:100] + '...' if len(edit_instruction['old']) > 100 else edit_instruction['old'], 

330 'new_str': edit_instruction['new'][:100] + '...' if len(edit_instruction['new']) > 100 else edit_instruction['new'], 

331 'replace_all': edit_instruction.get('replace_all', False) 

332 } 

333 

334 if modified_block: 

335 result['new_version'] = modified_block.version 

336 

337 all_results[f"edit_{block_name}"] = result 

338 self.emit('edit_result', result=result, instruction=edit_instruction) 

339 

340 # 生成混合结果的prompt 

341 msg = self.prompts.get_mixed_results_prompt(all_results) 

342 return self.chat(msg) 

343 

344 def process_edit_reply(self, edit_instructions): 

345 """处理编辑指令""" 

346 results = OrderedDict() 

347 for edit_instruction in edit_instructions: 

348 self.emit('edit_start', instruction=edit_instruction) 

349 success, message, modified_block = self.code_blocks.apply_edit_modification(edit_instruction) 

350 

351 result = { 

352 'success': success, 

353 'message': message, 

354 'block_name': edit_instruction['name'], 

355 'old_str': edit_instruction['old'][:100] + '...' if len(edit_instruction['old']) > 100 else edit_instruction['old'], 

356 'new_str': edit_instruction['new'][:100] + '...' if len(edit_instruction['new']) > 100 else edit_instruction['new'], 

357 'replace_all': edit_instruction.get('replace_all', False) 

358 } 

359 

360 if modified_block: 

361 result['new_version'] = modified_block.version 

362 

363 results[edit_instruction['name']] = result 

364 self.emit('edit_result', result=result, instruction=edit_instruction) 

365 

366 msg = self.prompts.get_edit_results_prompt(results) 

367 return self.chat(msg) 

368 

369 def process_mcp_reply(self, json_content): 

370 """处理 MCP 工具调用的回复""" 

371 block = {'content': json_content, 'language': 'json'} 

372 self.emit('mcp_call', block=block) 

373 

374 call_tool = json.loads(json_content) 

375 result = self.mcp.call_tool(call_tool['name'], call_tool.get('arguments', {})) 

376 code_block = CodeBlock( 

377 code=json_content, 

378 lang='json', 

379 name=call_tool.get('name', 'MCP Tool Call'), 

380 version=1, 

381 ) 

382 self.emit('mcp_result', block=code_block, result=result) 

383 msg = self.prompts.get_mcp_result_prompt(result) 

384 return self.chat(msg) 

385 

386 def _get_summary(self, detail=False): 

387 data = {} 

388 context_manager = self.context_manager 

389 if detail: 

390 data['usages'] = context_manager.get_usage() 

391 

392 summary = context_manager.get_summary() 

393 summary['elapsed_time'] = time.time() - self.start_time 

394 summarys = "{rounds} | {time:.3f}s/{elapsed_time:.3f}s | Tokens: {input_tokens}/{output_tokens}/{total_tokens}".format(**summary) 

395 data['summary'] = summarys 

396 return data 

397 

398 def chat(self, context: LLMContext, *, system_prompt=None): 

399 self.emit('query_start', llm=self.client.name) 

400 msg = self.client(context, system_prompt=system_prompt) 

401 self.emit('response_complete', llm=self.client.name, msg=msg) 

402 return msg.content if msg else None 

403 

404 def _get_system_prompt(self): 

405 params = {} 

406 if self.mcp: 

407 params['mcp_tools'] = self.mcp.get_tools_prompt() 

408 params['util_functions'] = self.runtime.get_builtin_functions() 

409 params['tool_functions'] = self.runtime.get_plugin_functions() 

410 params['role'] = self.role 

411 return self.prompts.get_default_prompt(**params) 

412 

413 def _parse_front_matter(self, md_text: str) -> Tuple[Dict, str]: 

414 """ 

415 解析 Markdown 字符串,提取 YAML front matter 和正文内容。 

416 

417 参数: 

418 md_text: 包含 YAML front matter 和 Markdown 内容的字符串 

419 

420 返回: 

421 (yaml_dict, content): 

422 - yaml_dict 是解析后的 YAML 字典,若无 front matter 则为空字典 

423 - content 是去除 front matter 后的 Markdown 正文字符串 

424 """ 

425 front_matter_pattern = r"^\s*---\s*\n(.*?)\n---\s*" 

426 match = re.match(front_matter_pattern, md_text, re.DOTALL) 

427 if match: 

428 yaml_str = match.group(1) 

429 try: 

430 yaml_dict = yaml.safe_load(yaml_str) or {} 

431 except yaml.YAMLError: 

432 yaml_dict = {} 

433 self.log.error('Invalid front matter', yaml_str=yaml_str) 

434 self.log.info('Front matter', yaml_dict=yaml_dict) 

435 content = md_text[match.end():] 

436 else: 

437 yaml_dict = {} 

438 content = md_text 

439 

440 return yaml_dict, content 

441 

442 def run(self, instruction: str, title: str | None = None): 

443 """ 

444 执行自动处理循环,直到 LLM 不再返回代码消息 

445 instruction: 用户输入的字符串(可包含@file等多模态标记) 

446 """ 

447 mmc = MMContent(instruction, base_path=self.context.cwd) 

448 try: 

449 content = mmc.content 

450 except Exception as e: 

451 raise TaskInputError(T("Invalid input"), e) from e 

452 

453 if not self.client.has_capability(content): 

454 raise TaskInputError(T("Current model does not support this content")) 

455 

456 user_prompt = content 

457 if not self.start_time: 

458 self.start_time = time.time() 

459 self.instruction = instruction 

460 self.title = title or instruction 

461 # 开始事件记录 

462 self.event_recorder.start_recording() 

463 if isinstance(content, str): 

464 user_prompt = self.prompts.get_task_prompt(content, gui=self.gui) 

465 system_prompt = self._get_system_prompt() 

466 self.emit('task_start', instruction=instruction, task_id=self.task_id, title=title) 

467 else: 

468 system_prompt = None 

469 title = title or instruction 

470 if isinstance(content, str): 

471 user_prompt = self.prompts.get_chat_prompt(content, self.instruction) 

472 # 记录轮次开始事件 

473 self.emit('round_start', instruction=instruction, step=len(self.step_manager) + 1, title=title) 

474 

475 self.cwd.mkdir(exist_ok=True) 

476 os.chdir(self.cwd) 

477 

478 rounds = 1 

479 max_rounds = self.max_rounds 

480 self.saved = False 

481 

482 response = self.chat(user_prompt, system_prompt=system_prompt) 

483 if not response: 

484 self.log.error('No response from LLM') 

485 # 使用新的步骤管理器记录失败的步骤 

486 self.step_manager.create_checkpoint(instruction, 0, '') 

487 return 

488 

489 while rounds <= max_rounds: 

490 status, content = self._parse_front_matter(response) 

491 if status: 

492 response = content 

493 self.log.info('Task status', status=status) 

494 self.emit('task_status', status=status) 

495 prev_response = response 

496 response = self.process_reply(response) 

497 rounds += 1 

498 if self.is_stopped(): 

499 self.log.info('Task stopped') 

500 break 

501 if not response: 

502 response = prev_response 

503 break 

504 

505 # 使用新的步骤管理器记录步骤 

506 self.step_manager.create_checkpoint(instruction, rounds, response) 

507 summary = self._get_summary() 

508 self.emit('round_end', summary=summary, response=response) 

509 self._auto_save() 

510 self.log.info('Round done', rounds=rounds) 

511 

512 def sync_to_cloud(self): 

513 """ Sync result 

514 """ 

515 url = T("https://store.aipy.app/api/work") 

516 

517 trustoken_apikey = self.settings.get('llm', {}).get('Trustoken', {}).get('api_key') 

518 if not trustoken_apikey: 

519 trustoken_apikey = self.settings.get('llm', {}).get('trustoken', {}).get('api_key') 

520 if not trustoken_apikey: 

521 return False 

522 self.log.info('Uploading result to cloud') 

523 try: 

524 # Serialize twice to remove the non-compliant JSON type. 

525 # First, use the json.dumps() `default` to convert the non-compliant JSON type to str. 

526 # However, NaN/Infinity will remain. 

527 # Second, use the json.loads() 'parse_constant' to convert NaN/Infinity to str. 

528 data = json.loads( 

529 json.dumps({ 

530 'apikey': trustoken_apikey, 

531 'author': os.getlogin(), 

532 'instruction': self.instruction, 

533 'llm': self.context_manager.json(), 

534 'runner': self.runner.history, 

535 }, ensure_ascii=False, default=str), 

536 parse_constant=str) 

537 response = requests.post(url, json=data, verify=True, timeout=30) 

538 except Exception as e: 

539 self.emit('exception', msg='sync_to_cloud', exception=e) 

540 return False 

541 

542 url = None 

543 status_code = response.status_code 

544 if status_code in (200, 201): 

545 data = response.json() 

546 url = data.get('url', '') 

547 

548 self.emit('upload_result', status_code=status_code, url=url) 

549 return True 

550 

551 def delete_step(self, index: int): 

552 """删除指定索引的步骤 - 使用新的步骤管理器""" 

553 return self.step_manager.delete_step(index) 

554 

555 def clear_steps(self): 

556 """清空所有步骤 - 使用新的步骤管理器""" 

557 self.step_manager.clear_all() 

558 return True 

559 

560 def list_steps(self): 

561 """列出所有步骤 - 使用新的步骤管理器""" 

562 return self.step_manager.list_steps() 

563 

564 def list_code_blocks(self): 

565 """列出所有代码块""" 

566 BlockRecord = namedtuple('BlockRecord', ['Index', 'Name', 'Version', 'Language', 'Path', 'Size']) 

567 

568 rows = [] 

569 for index, block in enumerate(self.code_blocks.history): 

570 # 计算代码大小 

571 code_size = len(block.code) if block.code else 0 

572 size_str = f"{code_size} chars" 

573 

574 # 处理路径显示 

575 path_str = block.path if block.path else '-' 

576 if path_str != '-' and len(path_str) > 40: 

577 path_str = '...' + path_str[-37:] 

578 

579 rows.append(BlockRecord( 

580 Index=index, 

581 Name=block.name, 

582 Version=f"v{block.version}", 

583 Language=block.lang, 

584 Path=path_str, 

585 Size=size_str 

586 )) 

587 

588 return rows 

589 

590 def get_code_block(self, index): 

591 """获取指定索引的代码块""" 

592 if index < 0 or index >= len(self.code_blocks.history): 

593 return None 

594 return self.code_blocks.history[index]