Coverage for aipyapp/cli/command/cmd_task.py: 0%
134 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
1import time
2from pathlib import Path
3import json
4import os
5import shlex
7from rich.panel import Panel
9from ... import T, EventBus
10from ...aipy.event_serializer import EventSerializer
11from ...aipy.task_state import TaskState
12from .base import Completable
13from .base_parser import ParserCommand
14from .result import TaskModeResult
15from .utils import print_records
18class TaskCommand(ParserCommand):
19 name = 'task'
20 description = T('Task operations')
22 def add_subcommands(self, subparsers):
23 subparsers.add_parser('list', help=T('List recent tasks'))
24 parser = subparsers.add_parser('use', help=T('Load a recent task by task id'))
25 parser.add_argument('tid', type=str, help=T('Task ID'))
26 parser = subparsers.add_parser('resume', help=T('Load task from task.json file'))
27 parser.add_argument('path', type=str, help=T('Path to task.json file'))
28 parser = subparsers.add_parser('replay', help=T('Replay task from task.json file'))
29 parser.add_argument('path', type=str, help=T('Path to task.json file'))
30 parser.add_argument('--speed', type=float, default=1.0, help=T('Replay speed multiplier (default: 1.0)'))
32 def cmd_list(self, args, ctx):
33 rows = ctx.tm.list_tasks()
34 print_records(rows)
36 def get_arg_values(self, arg, subcommand=None, partial_value=''):
37 if subcommand == 'use' and arg.name == 'tid':
38 tasks = self.manager.tm.get_tasks()
39 return [Completable(task.task_id, task.instruction[:32]) for task in tasks]
40 elif subcommand in ('resume', 'replay') and arg.name == 'path':
41 return self._get_path_completions(partial_value)
42 return super().get_arg_values(arg, subcommand)
44 def _get_path_completions(self, partial_path=''):
45 """获取文件路径补齐选项,优先显示 .json 文件"""
46 completions = []
48 # 处理可能包含引号的路径输入
49 try:
50 # 尝试解析引号,如果失败则使用原始输入
51 unquoted_path = shlex.split(partial_path)[0] if partial_path else ''
52 except ValueError:
53 # 如果引号不匹配,使用原始输入
54 unquoted_path = partial_path
56 # 如果是空输入或相对路径,从当前目录开始
57 if not unquoted_path or not os.path.isabs(unquoted_path):
58 search_dir = os.getcwd()
59 if unquoted_path:
60 # 处理相对路径
61 if os.sep in unquoted_path:
62 search_dir = os.path.join(search_dir, os.path.dirname(unquoted_path))
63 prefix = os.path.basename(unquoted_path)
64 else:
65 prefix = unquoted_path
66 else:
67 prefix = ''
68 else:
69 # 绝对路径
70 search_dir = os.path.dirname(unquoted_path)
71 prefix = os.path.basename(unquoted_path)
73 try:
74 if os.path.isdir(search_dir):
75 items = os.listdir(search_dir)
77 # 分别收集文件和目录
78 json_files = []
79 other_files = []
80 directories = []
82 for item in items:
83 if not item.startswith('.') and item.startswith(prefix):
84 full_path = os.path.join(search_dir, item)
85 # 处理包含空格的文件名,使用引号包装
86 display_name = shlex.quote(item) if ' ' in item else item
88 if os.path.isdir(full_path):
89 # 对于目录,在引号内添加 / 后缀
90 if ' ' in item:
91 display_name = shlex.quote(item + '/')
92 else:
93 display_name = item + '/'
94 directories.append(Completable(display_name, f"Directory"))
95 elif item.endswith('.json'):
96 json_files.append(Completable(display_name, f"JSON file"))
97 else:
98 other_files.append(Completable(display_name, f"File"))
100 # 优先显示 JSON 文件,然后目录,最后其他文件
101 completions = json_files + directories + other_files
102 except (OSError, PermissionError):
103 # 如果无法访问目录,返回空列表
104 pass
106 return completions
108 def cmd_use(self, args, ctx):
109 task = ctx.tm.get_task_by_id(args.tid)
110 return TaskModeResult(task=task)
112 def _load_task_state(self, path):
113 """加载任务状态"""
114 return TaskState.from_file(path)
116 def cmd_resume(self, args, ctx):
117 """从 task.json 文件加载任务"""
118 task_state = self._load_task_state(args.path)
120 # 将任务添加到任务管理器中
121 task = ctx.tm.load_task(task_state)
122 return TaskModeResult(task=task)
124 def cmd_replay(self, args, ctx):
125 """重放任务"""
126 task_state = self._load_task_state(args.path)
128 # 显示重放信息
129 instruction = task_state.instruction
130 task_id = task_state.task_id
131 events = task_state.get_component_state('events') or []
132 events_count = len(events)
134 panel = Panel(
135 f"🎬 Task Replay\n\n"
136 f"Task ID: {task_id}\n"
137 f"Instruction: {instruction}\n"
138 f"Events: {events_count}\n"
139 f"Speed: {args.speed}x",
140 title="Replay Mode",
141 border_style="cyan"
142 )
143 ctx.console.print(panel)
145 if events:
146 self._replay_events(ctx, events, args.speed)
148 def _replay_events(self, ctx, events, speed):
149 """简化的事件重放 - 直接按时间间隔触发事件"""
150 display = ctx.tm.display_manager.create_display_plugin()
151 event_bus = EventBus()
152 event_bus.add_listener(display)
154 # 反序列化事件中的对象
155 replay_events = EventSerializer.deserialize_events(events)
157 for i, event in enumerate(replay_events):
158 # 检查是否是 round_start 事件,需要用户确认
159 if event['type'] == 'round_start':
160 if not self._confirm_round_start(ctx,event):
161 print("\n🛑 重放已取消")
162 return
164 # 计算等待时间
165 if i > 0:
166 prev_event = replay_events[i - 1]
167 wait_time = (event['relative_time'] - prev_event['relative_time']) / speed
168 if wait_time > 0:
169 time.sleep(wait_time)
171 event_name = event['type']
172 event_data = event['data'].copy() if isinstance(event['data'], dict) else {}
174 event_bus.emit(event_name, **event_data)
176 def _confirm_round_start(self, ctx, event):
177 """在 round_start 事件时提示用户确认是否继续"""
178 console = ctx.console
179 data = event.get('data', {})
181 # 获取 step 信息
182 round_num = data.get('round', 'Unknown')
183 instruction = data.get('instruction', 'Unknown instruction')
185 # 显示提示面板
186 panel = Panel(
187 f"📋 即将重放 Step {round_num}\n\n"
188 f"指令: {instruction}\n\n"
189 f"⚠️ 继续重放此步骤吗?",
190 title="🔄 Step 重放确认",
191 border_style="yellow"
192 )
193 console.print(panel)
195 # 等待用户输入
196 try:
197 while True:
198 choice = console.input("\n请选择 [y/n]: ").lower().strip()
199 if choice in ['y', 'yes', '是']:
200 console.print("✅ 继续重放...")
201 return True
202 elif choice in ['n', 'no', '否']:
203 return False
204 else:
205 console.print("❓ 请输入 'y' 继续或 'n' 取消")
206 except KeyboardInterrupt:
207 console.print("\n\n❌ 用户中断,取消重放")
208 return False
209 except EOFError:
210 # 处理非交互式环境(如自动化测试)
211 console.print("\n⚠️ 检测到非交互式环境,自动继续重放")
212 return True
214 def cmd(self, args, ctx):
215 self.cmd_list(args, ctx)