Coverage for aipyapp/cli/command/cmd_task.py: 0%

134 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-11 12:02 +0200

1import time 

2from pathlib import Path 

3import json 

4import os 

5import shlex 

6 

7from rich.panel import Panel 

8 

9from ... import T, EventBus 

10from ...aipy.event_serializer import EventSerializer 

11from ...aipy.task_state import TaskState 

12from .base import Completable 

13from .base_parser import ParserCommand 

14from .result import TaskModeResult 

15from .utils import print_records 

16 

17 

18class TaskCommand(ParserCommand): 

19 name = 'task' 

20 description = T('Task operations') 

21 

22 def add_subcommands(self, subparsers): 

23 subparsers.add_parser('list', help=T('List recent tasks')) 

24 parser = subparsers.add_parser('use', help=T('Load a recent task by task id')) 

25 parser.add_argument('tid', type=str, help=T('Task ID')) 

26 parser = subparsers.add_parser('resume', help=T('Load task from task.json file')) 

27 parser.add_argument('path', type=str, help=T('Path to task.json file')) 

28 parser = subparsers.add_parser('replay', help=T('Replay task from task.json file')) 

29 parser.add_argument('path', type=str, help=T('Path to task.json file')) 

30 parser.add_argument('--speed', type=float, default=1.0, help=T('Replay speed multiplier (default: 1.0)')) 

31 

32 def cmd_list(self, args, ctx): 

33 rows = ctx.tm.list_tasks() 

34 print_records(rows) 

35 

36 def get_arg_values(self, arg, subcommand=None, partial_value=''): 

37 if subcommand == 'use' and arg.name == 'tid': 

38 tasks = self.manager.tm.get_tasks() 

39 return [Completable(task.task_id, task.instruction[:32]) for task in tasks] 

40 elif subcommand in ('resume', 'replay') and arg.name == 'path': 

41 return self._get_path_completions(partial_value) 

42 return super().get_arg_values(arg, subcommand) 

43 

44 def _get_path_completions(self, partial_path=''): 

45 """获取文件路径补齐选项,优先显示 .json 文件""" 

46 completions = [] 

47 

48 # 处理可能包含引号的路径输入 

49 try: 

50 # 尝试解析引号,如果失败则使用原始输入 

51 unquoted_path = shlex.split(partial_path)[0] if partial_path else '' 

52 except ValueError: 

53 # 如果引号不匹配,使用原始输入 

54 unquoted_path = partial_path 

55 

56 # 如果是空输入或相对路径,从当前目录开始 

57 if not unquoted_path or not os.path.isabs(unquoted_path): 

58 search_dir = os.getcwd() 

59 if unquoted_path: 

60 # 处理相对路径 

61 if os.sep in unquoted_path: 

62 search_dir = os.path.join(search_dir, os.path.dirname(unquoted_path)) 

63 prefix = os.path.basename(unquoted_path) 

64 else: 

65 prefix = unquoted_path 

66 else: 

67 prefix = '' 

68 else: 

69 # 绝对路径 

70 search_dir = os.path.dirname(unquoted_path) 

71 prefix = os.path.basename(unquoted_path) 

72 

73 try: 

74 if os.path.isdir(search_dir): 

75 items = os.listdir(search_dir) 

76 

77 # 分别收集文件和目录 

78 json_files = [] 

79 other_files = [] 

80 directories = [] 

81 

82 for item in items: 

83 if not item.startswith('.') and item.startswith(prefix): 

84 full_path = os.path.join(search_dir, item) 

85 # 处理包含空格的文件名,使用引号包装 

86 display_name = shlex.quote(item) if ' ' in item else item 

87 

88 if os.path.isdir(full_path): 

89 # 对于目录,在引号内添加 / 后缀 

90 if ' ' in item: 

91 display_name = shlex.quote(item + '/') 

92 else: 

93 display_name = item + '/' 

94 directories.append(Completable(display_name, f"Directory")) 

95 elif item.endswith('.json'): 

96 json_files.append(Completable(display_name, f"JSON file")) 

97 else: 

98 other_files.append(Completable(display_name, f"File")) 

99 

100 # 优先显示 JSON 文件,然后目录,最后其他文件 

101 completions = json_files + directories + other_files 

102 except (OSError, PermissionError): 

103 # 如果无法访问目录,返回空列表 

104 pass 

105 

106 return completions 

107 

108 def cmd_use(self, args, ctx): 

109 task = ctx.tm.get_task_by_id(args.tid) 

110 return TaskModeResult(task=task) 

111 

112 def _load_task_state(self, path): 

113 """加载任务状态""" 

114 return TaskState.from_file(path) 

115 

116 def cmd_resume(self, args, ctx): 

117 """从 task.json 文件加载任务""" 

118 task_state = self._load_task_state(args.path) 

119 

120 # 将任务添加到任务管理器中 

121 task = ctx.tm.load_task(task_state) 

122 return TaskModeResult(task=task) 

123 

124 def cmd_replay(self, args, ctx): 

125 """重放任务""" 

126 task_state = self._load_task_state(args.path) 

127 

128 # 显示重放信息 

129 instruction = task_state.instruction 

130 task_id = task_state.task_id 

131 events = task_state.get_component_state('events') or [] 

132 events_count = len(events) 

133 

134 panel = Panel( 

135 f"🎬 Task Replay\n\n" 

136 f"Task ID: {task_id}\n" 

137 f"Instruction: {instruction}\n" 

138 f"Events: {events_count}\n" 

139 f"Speed: {args.speed}x", 

140 title="Replay Mode", 

141 border_style="cyan" 

142 ) 

143 ctx.console.print(panel) 

144 

145 if events: 

146 self._replay_events(ctx, events, args.speed) 

147 

148 def _replay_events(self, ctx, events, speed): 

149 """简化的事件重放 - 直接按时间间隔触发事件""" 

150 display = ctx.tm.display_manager.create_display_plugin() 

151 event_bus = EventBus() 

152 event_bus.add_listener(display) 

153 

154 # 反序列化事件中的对象 

155 replay_events = EventSerializer.deserialize_events(events) 

156 

157 for i, event in enumerate(replay_events): 

158 # 检查是否是 round_start 事件,需要用户确认 

159 if event['type'] == 'round_start': 

160 if not self._confirm_round_start(ctx,event): 

161 print("\n🛑 重放已取消") 

162 return 

163 

164 # 计算等待时间 

165 if i > 0: 

166 prev_event = replay_events[i - 1] 

167 wait_time = (event['relative_time'] - prev_event['relative_time']) / speed 

168 if wait_time > 0: 

169 time.sleep(wait_time) 

170 

171 event_name = event['type'] 

172 event_data = event['data'].copy() if isinstance(event['data'], dict) else {} 

173 

174 event_bus.emit(event_name, **event_data) 

175 

176 def _confirm_round_start(self, ctx, event): 

177 """在 round_start 事件时提示用户确认是否继续""" 

178 console = ctx.console 

179 data = event.get('data', {}) 

180 

181 # 获取 step 信息 

182 round_num = data.get('round', 'Unknown') 

183 instruction = data.get('instruction', 'Unknown instruction') 

184 

185 # 显示提示面板 

186 panel = Panel( 

187 f"📋 即将重放 Step {round_num}\n\n" 

188 f"指令: {instruction}\n\n" 

189 f"⚠️ 继续重放此步骤吗?", 

190 title="🔄 Step 重放确认", 

191 border_style="yellow" 

192 ) 

193 console.print(panel) 

194 

195 # 等待用户输入 

196 try: 

197 while True: 

198 choice = console.input("\n请选择 [y/n]: ").lower().strip() 

199 if choice in ['y', 'yes', '是']: 

200 console.print("✅ 继续重放...") 

201 return True 

202 elif choice in ['n', 'no', '否']: 

203 return False 

204 else: 

205 console.print("❓ 请输入 'y' 继续或 'n' 取消") 

206 except KeyboardInterrupt: 

207 console.print("\n\n❌ 用户中断,取消重放") 

208 return False 

209 except EOFError: 

210 # 处理非交互式环境(如自动化测试) 

211 console.print("\n⚠️ 检测到非交互式环境,自动继续重放") 

212 return True 

213 

214 def cmd(self, args, ctx): 

215 self.cmd_list(args, ctx)