Coverage for aipyapp/aipy/task.py: 16%
371 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 13:03 +0200
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 13:03 +0200
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4import os
5import json
6import uuid
7import time
8import re
9import yaml
10from typing import Dict, Tuple
11from datetime import datetime
12from collections import namedtuple, OrderedDict
13from importlib.resources import read_text
15import requests
16from loguru import logger
18from .. import T, __respkg__
19from ..exec import BlockExecutor
20from .runtime import CliPythonRuntime
21from .utils import get_safe_filename
22from .blocks import CodeBlocks, CodeBlock
23from ..interface import Stoppable, EventBus
24from .step_manager import StepManager
25from .multimodal import MMContent, LLMContext
26from .context_manager import ContextManager, ContextConfig
27from .event_recorder import EventRecorder
28from .task_state import TaskState
30CONSOLE_WHITE_HTML = read_text(__respkg__, "console_white.html")
31CONSOLE_CODE_HTML = read_text(__respkg__, "console_code.html")
33class TaskError(Exception):
34 """Task 异常"""
35 pass
37class TaskInputError(TaskError):
38 """Task 输入异常"""
39 def __init__(self, message: str, original_error: Exception = None):
40 self.message = message
41 self.original_error = original_error
42 super().__init__(self.message)
44class TastStateError(TaskError):
45 """Task 状态异常"""
46 def __init__(self, message: str, **kwargs):
47 self.message = message
48 self.data = kwargs
49 super().__init__(self.message)
51class Task(Stoppable, EventBus):
52 MAX_ROUNDS = 16
54 def __init__(self, context):
55 super().__init__()
56 self.task_id = uuid.uuid4().hex
57 self.log = logger.bind(src='task', id=self.task_id)
59 self.context = context
60 self.settings = context.settings
61 self.prompts = context.prompts
63 self.start_time = None
64 self.done_time = None
65 self.instruction = None
66 self.saved = None
68 #TODO: 移除 gui 参数
69 self.gui = self.settings.gui
71 self.cwd = context.cwd / self.task_id
72 self.max_rounds = self.settings.get('max_rounds', self.MAX_ROUNDS)
74 self.mcp = context.mcp
75 self.display = None
76 self.plugins = {}
78 # 创建上下文管理器(从Client移到Task)
79 context_settings = self.settings.get('context_manager', {})
80 self.context_manager = ContextManager(ContextConfig.from_dict(context_settings))
82 # 创建Client时传入context_manager
83 self.client = context.client_manager.Client(self, self.context_manager)
84 self.role = context.role_manager.current_role
85 self.code_blocks = CodeBlocks()
86 self.runtime = CliPythonRuntime(self)
87 self.runner = BlockExecutor()
88 self.runner.set_python_runtime(self.runtime)
90 # 注册所有可追踪对象到步骤管理器
91 self.step_manager = StepManager()
92 self.step_manager.register_trackable('messages', self.context_manager)
93 self.step_manager.register_trackable('runner', self.runner)
94 self.step_manager.register_trackable('blocks', self.code_blocks)
96 # 初始化事件记录器
97 enable_replay = self.settings.get('enable_replay_recording', True)
98 if enable_replay:
99 self.event_recorder = EventRecorder(enabled=True)
100 # 注册事件记录器到步骤管理器
101 self.step_manager.register_trackable('events', self.event_recorder)
102 else:
103 self.event_recorder = None
105 self.init_plugins()
107 def emit(self, event_name: str, **kwargs):
108 """重写broadcast方法以记录事件"""
109 # 记录事件到事件记录器
110 if self.event_recorder:
111 self.event_recorder.record_event(event_name, kwargs.copy())
113 # 调用父类的broadcast方法
114 super().emit(event_name, **kwargs)
116 def restore_state(self, task_data):
117 """从任务状态加载任务
119 Args:
120 task_data: 任务状态数据(字典格式)或 TaskState 对象
122 Returns:
123 Task: 加载的任务对象
124 """
125 # 支持传入字典或 TaskState 对象
126 if isinstance(task_data, dict):
127 task_state = TaskState.from_dict(task_data)
128 elif isinstance(task_data, TaskState):
129 task_state = task_data
130 else:
131 raise TastStateError('Invalid task_data type, expected dict or TaskState')
133 # 使用 TaskState 恢复状态
134 task_state.restore_to_task(self)
136 def get_status(self):
137 return {
138 'llm': self.client.name,
139 'blocks': len(self.code_blocks),
140 'steps': len(self.step_manager),
141 }
143 def init_plugins(self):
144 """初始化插件"""
145 plugin_manager = self.context.plugin_manager
146 for plugin_name, plugin_data in self.role.plugins.items():
147 plugin = plugin_manager.create_task_plugin(plugin_name, plugin_data)
148 if not plugin:
149 self.log.warning(f"Create task plugin {plugin_name} failed")
150 continue
151 self.add_listener(plugin)
152 self.runtime.register_plugin(plugin)
153 self.plugins[plugin_name] = plugin
155 # 注册显示效果插件
156 if self.context.display_manager:
157 self.display = self.context.display_manager.create_display_plugin()
158 self.add_listener(self.display)
160 def to_record(self):
161 TaskRecord = namedtuple('TaskRecord', ['task_id', 'start_time', 'done_time', 'instruction'])
162 start_time = datetime.fromtimestamp(self.start_time).strftime('%H:%M:%S') if self.start_time else '-'
163 done_time = datetime.fromtimestamp(self.done_time).strftime('%H:%M:%S') if self.done_time else '-'
164 return TaskRecord(
165 task_id=self.task_id,
166 start_time=start_time,
167 done_time=done_time,
168 instruction=self.title[:32] if self.title else '-'
169 )
171 def use(self, name):
172 ret = self.client.use(name)
173 return ret
175 def save(self, path):
176 self.display.save(path, clear=False, code_format=CONSOLE_WHITE_HTML)
178 def save_html(self, path, task):
179 if 'chats' in task and isinstance(task['chats'], list) and len(task['chats']) > 0:
180 if task['chats'][0]['role'] == 'system':
181 task['chats'].pop(0)
183 task_json = json.dumps(task, ensure_ascii=False, default=str)
184 html_content = CONSOLE_CODE_HTML.replace('{{code}}', task_json)
185 try:
186 with open(path, 'w', encoding='utf-8') as f:
187 f.write(html_content)
188 except Exception as e:
189 self.log.exception('Error saving html')
190 self.emit('exception', msg='save_html', exception=e)
192 def _auto_save(self):
193 """自动保存任务状态"""
194 # 如果任务目录不存在,则不保存
195 if not self.cwd.exists():
196 self.log.warning('Task directory not found, skipping save')
197 return
199 self.done_time = time.time()
200 try:
201 # 创建 TaskState 对象并保存
202 task_state = TaskState(self)
203 task_state.save_to_file(self.cwd / "task.json")
205 # 保存 HTML 控制台
206 filename = self.cwd / "console.html"
207 self.save(filename)
209 self.saved = True
210 self.log.info('Task auto saved')
211 except Exception as e:
212 self.log.exception('Error saving task')
213 self.emit('exception', msg='save_task', exception=e)
215 def done(self):
216 if not self.instruction or not self.start_time:
217 self.log.warning('Task not started, skipping save')
218 return
220 os.chdir(self.context.cwd) # Change back to the original working directory
221 curname = self.task_id
222 if os.path.exists(curname):
223 if not self.saved:
224 self.log.warning('Task not saved, trying to save')
225 self._auto_save()
227 newname = get_safe_filename(self.instruction, extension=None)
228 if newname:
229 try:
230 os.rename(curname, newname)
231 except Exception as e:
232 self.log.exception('Error renaming task directory', curname=curname, newname=newname)
233 else:
234 newname = None
235 self.log.warning('Task directory not found')
237 self.log.info('Task done', path=newname)
238 self.emit('task_end', path=newname)
239 self.context.diagnose.report_code_error(self.runner.history)
240 if self.settings.get('share_result'):
241 self.sync_to_cloud()
243 def process_reply(self, markdown):
244 ret = self.code_blocks.parse(markdown, parse_mcp=self.mcp)
245 self.emit('parse_reply', result=ret)
246 if not ret:
247 return None
249 if 'call_tool' in ret:
250 return self.process_mcp_reply(ret['call_tool'])
252 errors = ret.get('errors')
253 if errors:
254 prompt = self.prompts.get_parse_error_prompt(errors)
255 return self.chat(prompt)
257 commands = ret.get('commands', [])
258 if commands:
259 return self.process_commands(commands)
261 return None
263 def run_code_block(self, block):
264 """运行代码块"""
265 self.emit('exec', block=block)
266 result = self.runner(block)
267 self.emit('exec_result', result=result, block=block)
268 return result
270 def process_code_reply(self, exec_blocks):
271 results = OrderedDict()
272 for block in exec_blocks:
273 result = self.run_code_block(block)
274 results[block.name] = result
276 msg = self.prompts.get_results_prompt(results)
277 return self.chat(msg)
279 def process_commands(self, commands):
280 """按顺序处理混合指令,智能错误处理"""
281 all_results = OrderedDict()
282 failed_blocks = set() # 记录编辑失败的代码块
284 for command in commands:
285 cmd_type = command['type']
287 if cmd_type == 'exec':
288 block_name = command['block_name']
290 # 如果这个代码块之前编辑失败,跳过执行
291 if block_name in failed_blocks:
292 self.log.warning(f'Skipping execution of {block_name} due to previous edit failure')
293 all_results[f"exec_{block_name}"] = {
294 'type': 'exec',
295 'block_name': block_name,
296 'result': {
297 'error': f'Execution skipped: previous edit of {block_name} failed',
298 'skipped': True
299 }
300 }
301 continue
303 # 动态获取最新版本的代码块
304 block = self.code_blocks.blocks[block_name]
305 result = self.run_code_block(block)
306 all_results[f"exec_{block_name}"] = {
307 'type': 'exec',
308 'block_name': block_name,
309 'result': result
310 }
312 elif cmd_type == 'edit':
313 edit_instruction = command['instruction']
314 block_name = edit_instruction['name']
316 self.emit('edit_start', instruction=edit_instruction)
317 success, message, modified_block = self.code_blocks.apply_edit_modification(edit_instruction)
319 # 编辑失败时标记这个代码块
320 if not success:
321 failed_blocks.add(block_name)
322 self.log.warning(f'Edit failed for {block_name}: {message}')
324 result = {
325 'type': 'edit',
326 'success': success,
327 'message': message,
328 'block_name': block_name,
329 'old_str': edit_instruction['old'][:100] + '...' if len(edit_instruction['old']) > 100 else edit_instruction['old'],
330 'new_str': edit_instruction['new'][:100] + '...' if len(edit_instruction['new']) > 100 else edit_instruction['new'],
331 'replace_all': edit_instruction.get('replace_all', False)
332 }
334 if modified_block:
335 result['new_version'] = modified_block.version
337 all_results[f"edit_{block_name}"] = result
338 self.emit('edit_result', result=result, instruction=edit_instruction)
340 # 生成混合结果的prompt
341 msg = self.prompts.get_mixed_results_prompt(all_results)
342 return self.chat(msg)
344 def process_edit_reply(self, edit_instructions):
345 """处理编辑指令"""
346 results = OrderedDict()
347 for edit_instruction in edit_instructions:
348 self.emit('edit_start', instruction=edit_instruction)
349 success, message, modified_block = self.code_blocks.apply_edit_modification(edit_instruction)
351 result = {
352 'success': success,
353 'message': message,
354 'block_name': edit_instruction['name'],
355 'old_str': edit_instruction['old'][:100] + '...' if len(edit_instruction['old']) > 100 else edit_instruction['old'],
356 'new_str': edit_instruction['new'][:100] + '...' if len(edit_instruction['new']) > 100 else edit_instruction['new'],
357 'replace_all': edit_instruction.get('replace_all', False)
358 }
360 if modified_block:
361 result['new_version'] = modified_block.version
363 results[edit_instruction['name']] = result
364 self.emit('edit_result', result=result, instruction=edit_instruction)
366 msg = self.prompts.get_edit_results_prompt(results)
367 return self.chat(msg)
369 def process_mcp_reply(self, json_content):
370 """处理 MCP 工具调用的回复"""
371 block = {'content': json_content, 'language': 'json'}
372 self.emit('mcp_call', block=block)
374 call_tool = json.loads(json_content)
375 result = self.mcp.call_tool(call_tool['name'], call_tool.get('arguments', {}))
376 code_block = CodeBlock(
377 code=json_content,
378 lang='json',
379 name=call_tool.get('name', 'MCP Tool Call'),
380 version=1,
381 )
382 self.emit('mcp_result', block=code_block, result=result)
383 msg = self.prompts.get_mcp_result_prompt(result)
384 return self.chat(msg)
386 def _get_summary(self, detail=False):
387 data = {}
388 context_manager = self.context_manager
389 if detail:
390 data['usages'] = context_manager.get_usage()
392 summary = context_manager.get_summary()
393 summary['elapsed_time'] = time.time() - self.start_time
394 summarys = "{rounds} | {time:.3f}s/{elapsed_time:.3f}s | Tokens: {input_tokens}/{output_tokens}/{total_tokens}".format(**summary)
395 data['summary'] = summarys
396 return data
398 def chat(self, context: LLMContext, *, system_prompt=None):
399 self.emit('query_start', llm=self.client.name)
400 msg = self.client(context, system_prompt=system_prompt)
401 self.emit('response_complete', llm=self.client.name, msg=msg)
402 return msg.content if msg else None
404 def _get_system_prompt(self):
405 params = {}
406 if self.mcp:
407 params['mcp_tools'] = self.mcp.get_tools_prompt()
408 params['util_functions'] = self.runtime.get_builtin_functions()
409 params['tool_functions'] = self.runtime.get_plugin_functions()
410 params['role'] = self.role
411 return self.prompts.get_default_prompt(**params)
413 def _parse_front_matter(self, md_text: str) -> Tuple[Dict, str]:
414 """
415 解析 Markdown 字符串,提取 YAML front matter 和正文内容。
417 参数:
418 md_text: 包含 YAML front matter 和 Markdown 内容的字符串
420 返回:
421 (yaml_dict, content):
422 - yaml_dict 是解析后的 YAML 字典,若无 front matter 则为空字典
423 - content 是去除 front matter 后的 Markdown 正文字符串
424 """
425 front_matter_pattern = r"^\s*---\s*\n(.*?)\n---\s*"
426 match = re.match(front_matter_pattern, md_text, re.DOTALL)
427 if match:
428 yaml_str = match.group(1)
429 try:
430 yaml_dict = yaml.safe_load(yaml_str) or {}
431 except yaml.YAMLError:
432 yaml_dict = {}
433 self.log.error('Invalid front matter', yaml_str=yaml_str)
434 self.log.info('Front matter', yaml_dict=yaml_dict)
435 content = md_text[match.end():]
436 else:
437 yaml_dict = {}
438 content = md_text
440 return yaml_dict, content
442 def run(self, instruction: str, title: str | None = None):
443 """
444 执行自动处理循环,直到 LLM 不再返回代码消息
445 instruction: 用户输入的字符串(可包含@file等多模态标记)
446 """
447 mmc = MMContent(instruction, base_path=self.context.cwd)
448 try:
449 content = mmc.content
450 except Exception as e:
451 raise TaskInputError(T("Invalid input"), e) from e
453 if not self.client.has_capability(content):
454 raise TaskInputError(T("Current model does not support this content"))
456 user_prompt = content
457 if not self.start_time:
458 self.start_time = time.time()
459 self.instruction = instruction
460 self.title = title or instruction
461 # 开始事件记录
462 self.event_recorder.start_recording()
463 if isinstance(content, str):
464 user_prompt = self.prompts.get_task_prompt(content, gui=self.gui)
465 system_prompt = self._get_system_prompt()
466 self.emit('task_start', instruction=instruction, task_id=self.task_id, title=title)
467 else:
468 system_prompt = None
469 title = title or instruction
470 if isinstance(content, str):
471 user_prompt = self.prompts.get_chat_prompt(content, self.instruction)
472 # 记录轮次开始事件
473 self.emit('round_start', instruction=instruction, step=len(self.step_manager) + 1, title=title)
475 self.cwd.mkdir(exist_ok=True)
476 os.chdir(self.cwd)
478 rounds = 1
479 max_rounds = self.max_rounds
480 self.saved = False
482 response = self.chat(user_prompt, system_prompt=system_prompt)
483 if not response:
484 self.log.error('No response from LLM')
485 # 使用新的步骤管理器记录失败的步骤
486 self.step_manager.create_checkpoint(instruction, 0, '')
487 return
489 while rounds <= max_rounds:
490 status, content = self._parse_front_matter(response)
491 if status:
492 response = content
493 self.log.info('Task status', status=status)
494 self.emit('task_status', status=status)
495 prev_response = response
496 response = self.process_reply(response)
497 rounds += 1
498 if self.is_stopped():
499 self.log.info('Task stopped')
500 break
501 if not response:
502 response = prev_response
503 break
505 # 使用新的步骤管理器记录步骤
506 self.step_manager.create_checkpoint(instruction, rounds, response)
507 summary = self._get_summary()
508 self.emit('round_end', summary=summary, response=response)
509 self._auto_save()
510 self.log.info('Round done', rounds=rounds)
512 def sync_to_cloud(self):
513 """ Sync result
514 """
515 url = T("https://store.aipy.app/api/work")
517 trustoken_apikey = self.settings.get('llm', {}).get('Trustoken', {}).get('api_key')
518 if not trustoken_apikey:
519 trustoken_apikey = self.settings.get('llm', {}).get('trustoken', {}).get('api_key')
520 if not trustoken_apikey:
521 return False
522 self.log.info('Uploading result to cloud')
523 try:
524 # Serialize twice to remove the non-compliant JSON type.
525 # First, use the json.dumps() `default` to convert the non-compliant JSON type to str.
526 # However, NaN/Infinity will remain.
527 # Second, use the json.loads() 'parse_constant' to convert NaN/Infinity to str.
528 data = json.loads(
529 json.dumps({
530 'apikey': trustoken_apikey,
531 'author': os.getlogin(),
532 'instruction': self.instruction,
533 'llm': self.context_manager.json(),
534 'runner': self.runner.history,
535 }, ensure_ascii=False, default=str),
536 parse_constant=str)
537 response = requests.post(url, json=data, verify=True, timeout=30)
538 except Exception as e:
539 self.emit('exception', msg='sync_to_cloud', exception=e)
540 return False
542 url = None
543 status_code = response.status_code
544 if status_code in (200, 201):
545 data = response.json()
546 url = data.get('url', '')
548 self.emit('upload_result', status_code=status_code, url=url)
549 return True
551 def delete_step(self, index: int):
552 """删除指定索引的步骤 - 使用新的步骤管理器"""
553 return self.step_manager.delete_step(index)
555 def clear_steps(self):
556 """清空所有步骤 - 使用新的步骤管理器"""
557 self.step_manager.clear_all()
558 return True
560 def list_steps(self):
561 """列出所有步骤 - 使用新的步骤管理器"""
562 return self.step_manager.list_steps()
564 def list_code_blocks(self):
565 """列出所有代码块"""
566 BlockRecord = namedtuple('BlockRecord', ['Index', 'Name', 'Version', 'Language', 'Path', 'Size'])
568 rows = []
569 for index, block in enumerate(self.code_blocks.history):
570 # 计算代码大小
571 code_size = len(block.code) if block.code else 0
572 size_str = f"{code_size} chars"
574 # 处理路径显示
575 path_str = block.path if block.path else '-'
576 if path_str != '-' and len(path_str) > 40:
577 path_str = '...' + path_str[-37:]
579 rows.append(BlockRecord(
580 Index=index,
581 Name=block.name,
582 Version=f"v{block.version}",
583 Language=block.lang,
584 Path=path_str,
585 Size=size_str
586 ))
588 return rows
590 def get_code_block(self, index):
591 """获取指定索引的代码块"""
592 if index < 0 or index >= len(self.code_blocks.history):
593 return None
594 return self.code_blocks.history[index]