Coverage for aipyapp/cli/command/manager.py: 0%

308 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-11 12:45 +0200

1import shlex 

2import argparse 

3from collections import OrderedDict 

4from dataclasses import dataclass 

5from typing import Any 

6 

7from ... import __pkgpath__ 

8from .base import BaseCommand, CommandMode 

9from .cmd_info import InfoCommand 

10from .cmd_help import HelpCommand 

11from .cmd_llm import LLMCommand 

12from .cmd_role import RoleCommand 

13from .cmd_task import TaskCommand 

14from .cmd_mcp import MCPCommand 

15from .cmd_display import DisplayCommand 

16from .cmd_context import ContextCommand 

17from .cmd_steps import StepsCommand 

18from .cmd_block import BlockCommand 

19from .cmd_plugin import Command as PluginCommand 

20from .cmd_custom import CustomCommand 

21from .custom_command_manager import CustomCommandManager 

22from .result import CommandResult 

23 

24from loguru import logger 

25from prompt_toolkit.completion import Completer, Completion 

26from prompt_toolkit.key_binding import KeyBindings 

27from pathlib import Path 

28 

29COMMANDS = [ 

30 InfoCommand, LLMCommand, RoleCommand, DisplayCommand, PluginCommand, StepsCommand, 

31 BlockCommand, ContextCommand, TaskCommand, MCPCommand, HelpCommand, CustomCommand, 

32] 

33 

34@dataclass 

35class CommandContext: 

36 """命令执行上下文""" 

37 task: Any = None 

38 tm: Any = None 

39 console: Any = None 

40 

41class CommandError(Exception): 

42 """Command error""" 

43 pass 

44 

45class CommandInputError(CommandError): 

46 """Command input error""" 

47 def __init__(self, message): 

48 self.message = message 

49 super().__init__(self.message) 

50 

51class CommandArgumentError(CommandError): 

52 """Command argument error""" 

53 def __init__(self, message): 

54 self.message = message 

55 super().__init__(self.message) 

56 

57class InvalidCommandError(CommandError): 

58 """Invalid command error""" 

59 def __init__(self, command): 

60 self.command = command 

61 super().__init__(f"Invalid command: {command}") 

62 

63class InvalidSubcommandError(CommandError): 

64 """Invalid subcommand error""" 

65 def __init__(self, message): 

66 self.message = message 

67 super().__init__(self.message) 

68 

69class CommandManager(Completer): 

70 def __init__(self, settings, tm, console): 

71 self.settings = settings 

72 self.tm = tm 

73 self.task = None 

74 self.console = console 

75 self.mode = CommandMode.MAIN 

76 self.commands_main = OrderedDict() 

77 self.commands_task = OrderedDict() 

78 self.commands = self.commands_main 

79 self.log = logger.bind(src="CommandManager") 

80 self.custom_command_manager = CustomCommandManager() 

81 self.custom_command_manager.add_command_dir(Path(__pkgpath__ / "commands" )) 

82 self.custom_command_manager.add_command_dir(Path(self.settings['config_dir']) / "commands" ) 

83 self.init() 

84 

85 @property 

86 def context(self): 

87 return CommandContext(task=self.task, tm=self.tm, console=self.console) 

88 

89 def init(self): 

90 """Initialize all registered commands""" 

91 commands = [] 

92 

93 # Initialize built-in commands 

94 for command_class in COMMANDS: 

95 command = command_class(self) 

96 self.register_command(command) 

97 commands.append(command) 

98 

99 # Initialize custom commands 

100 custom_commands = self.custom_command_manager.scan_commands() 

101 for custom_command in custom_commands: 

102 # Validate command name doesn't conflict 

103 if self.custom_command_manager.validate_command_name( 

104 custom_command.name, 

105 list(self.commands_main.keys()) + list(self.commands_task.keys()) 

106 ): 

107 custom_command.manager = self # Set manager reference 

108 self.register_command(custom_command) 

109 commands.append(custom_command) 

110 

111 # Initialize all commands 

112 for command in commands: 

113 command.init() 

114 

115 built_in_count = len(COMMANDS) 

116 custom_count = len(custom_commands) 

117 self.log.info(f"Initialized {built_in_count} built-in commands and {custom_count} custom commands") 

118 

119 def is_task_mode(self): 

120 return self.mode == CommandMode.TASK 

121 

122 def is_main_mode(self): 

123 return self.mode == CommandMode.MAIN 

124 

125 def set_task_mode(self, task): 

126 self.task = task 

127 self.mode = CommandMode.TASK 

128 self.commands = self.commands_task 

129 

130 def set_main_mode(self): 

131 self.mode = CommandMode.MAIN 

132 self.task = None 

133 self.commands = self.commands_main 

134 

135 def create_key_bindings(self): 

136 """创建键绑定""" 

137 kb = KeyBindings() 

138 

139 @kb.add('c-f') # Ctrl+F: 开启文件补齐模式 

140 def _(event): 

141 """按Ctrl+F进入文件补齐模式""" 

142 buffer = event.app.current_buffer 

143 

144 # 插入@符号 

145 buffer.insert_text('@') 

146 

147 # 创建专门处理@文件引用的补齐器 

148 file_completer = self._create_file_reference_completer() 

149 

150 # 临时切换到文件引用补齐器 

151 buffer.completer = file_completer 

152 

153 # 触发补齐 

154 buffer.start_completion() 

155 

156 @kb.add('escape', eager=True) # ESC: 恢复默认补齐模式 

157 def _(event): 

158 """按ESC恢复默认补齐模式""" 

159 buffer = event.app.current_buffer 

160 

161 # 直接恢复到CommandManager作为补齐器 

162 buffer.completer = self 

163 

164 # 关闭当前的补齐窗口 

165 if buffer.complete_state: 

166 buffer.cancel_completion() 

167 

168 @kb.add('c-t') # Ctrl+T: 插入当前时间戳 

169 def _(event): 

170 """按Ctrl+T插入当前时间戳""" 

171 from datetime import datetime 

172 timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') 

173 event.app.current_buffer.insert_text(timestamp) 

174 

175 return kb 

176 

177 def _create_file_reference_completer(self): 

178 """创建文件引用补齐器""" 

179 import shlex 

180 import os 

181 

182 class FileReferenceCompleter(Completer): 

183 def get_completions(self, document, complete_event): 

184 # 获取光标前的文本 

185 text_before_cursor = document.text_before_cursor 

186 

187 # 找到最后一个@符号的位置 

188 at_pos = text_before_cursor.rfind('@') 

189 if at_pos == -1: 

190 return 

191 

192 # 获取@后面的部分作为搜索前缀 

193 raw_path = text_before_cursor[at_pos + 1:] 

194 

195 # 处理可能包含引号的路径输入 

196 try: 

197 # 尝试解析引号,如果失败则使用原始输入 

198 unquoted_path = shlex.split(raw_path)[0] if raw_path else '' 

199 except ValueError: 

200 # 如果引号不匹配,使用原始输入 

201 unquoted_path = raw_path 

202 

203 # 确定搜索目录和文件前缀 

204 if not unquoted_path or not os.path.isabs(unquoted_path): 

205 search_dir = Path.cwd() 

206 if unquoted_path: 

207 if os.sep in unquoted_path: 

208 search_dir = search_dir / os.path.dirname(unquoted_path) 

209 search_prefix = os.path.basename(unquoted_path) 

210 else: 

211 search_prefix = unquoted_path 

212 else: 

213 search_prefix = '' 

214 else: 

215 search_dir = Path(os.path.dirname(unquoted_path)) 

216 search_prefix = os.path.basename(unquoted_path) 

217 

218 # 搜索匹配的文件 

219 try: 

220 if search_dir.exists() and search_dir.is_dir(): 

221 for item in search_dir.iterdir(): 

222 if item.name.startswith(search_prefix): 

223 # 计算需要补全的部分 

224 remaining = item.name[len(search_prefix):] 

225 if remaining: 

226 # 处理包含空格的文件名 

227 if ' ' in item.name: 

228 # 如果文件名包含空格,用引号包装完整的文件名 

229 quoted_name = shlex.quote(item.name) 

230 # 计算需要替换的部分:从search_prefix开始到文件名结尾 

231 if search_prefix: 

232 # 替换从search_prefix开始的部分 

233 completion_text = quoted_name[len(search_prefix):] 

234 else: 

235 completion_text = quoted_name 

236 else: 

237 completion_text = remaining 

238 

239 display_text = item.name 

240 if item.is_dir(): 

241 display_text += "/" 

242 

243 yield Completion(completion_text, display=display_text) 

244 except (OSError, PermissionError): 

245 pass 

246 

247 return FileReferenceCompleter() 

248 

249 def register_command(self, command): 

250 """Register a command instance""" 

251 if not isinstance(command, BaseCommand): 

252 raise ValueError("Command must be an instance of BaseCommand") 

253 

254 if CommandMode.MAIN in command.modes: 

255 self.commands_main[command.name] = command 

256 if CommandMode.TASK in command.modes: 

257 self.commands_task[command.name] = command 

258 

259 def get_completions(self, document, complete_event): 

260 """获取自动补齐选项""" 

261 text = document.text_before_cursor 

262 if not text.startswith('/'): 

263 return 

264 

265 text = text[1:] 

266 try: 

267 words = shlex.split(text) 

268 except ValueError as e: 

269 self.log.error(f"输入解析错误: {e}") 

270 return 

271 

272 # 处理主命令补齐 

273 if self._should_complete_main_command(words, text): 

274 yield from self._complete_main_commands(words) 

275 return 

276 

277 # 获取命令实例和参数 

278 command_instance, arguments, subcmd = self._get_command_and_arguments(words) 

279 if not command_instance: 

280 return 

281 

282 # 处理子命令补齐 

283 if self._should_complete_subcommand(words, text, command_instance): 

284 yield from self._complete_subcommands(words, command_instance) 

285 return 

286 

287 # 处理参数补齐 

288 if arguments is None: 

289 # 当没有参数时(如只有主命令),不进行参数补齐 

290 return 

291 if text.endswith(' '): 

292 yield from self._complete_after_space(words, arguments, command_instance, subcmd) 

293 else: 

294 yield from self._complete_partial_input(words, arguments, command_instance, subcmd) 

295 

296 def _should_complete_main_command(self, words, text): 

297 """判断是否应该补齐主命令""" 

298 return len(words) == 0 or (len(words) == 1 and not text.endswith(' ')) 

299 

300 def _complete_main_commands(self, words): 

301 """补齐主命令""" 

302 partial_cmd = words[0] if len(words) > 0 else '' 

303 yield from self._complete_items(self.commands.values(), partial_cmd) 

304 

305 def _should_complete_subcommand(self, words, text, command_instance): 

306 """判断是否应该补齐子命令""" 

307 return (command_instance.subcommands and 

308 (len(words) == 1 or (len(words) == 2 and not text.endswith(' ')))) 

309 

310 def _complete_subcommands(self, words, command_instance): 

311 """补齐子命令""" 

312 partial_subcmd = words[1] if len(words) > 1 else '' 

313 yield from self._complete_items(command_instance.get_subcommands().values(), partial_subcmd) 

314 

315 def _get_command_and_arguments(self, words): 

316 """获取命令实例和参数""" 

317 cmd = words[0] 

318 if cmd not in self.commands: 

319 return None, None, None 

320 

321 command_instance = self.commands[cmd] 

322 subcommands = command_instance.subcommands 

323 

324 if subcommands: 

325 if len(words) < 2: 

326 # 当只有主命令时,返回命令实例但不设置子命令和参数 

327 return command_instance, None, None 

328 subcmd = words[1] 

329 if subcmd not in subcommands: 

330 return None, None, None 

331 arguments = subcommands[subcmd]['arguments'] 

332 else: 

333 subcmd = None 

334 arguments = command_instance.arguments 

335 

336 return command_instance, arguments, subcmd 

337 

338 def _complete_after_space(self, words, arguments, command_instance, subcmd): 

339 """处理以空格结尾的补齐""" 

340 # 检查选项参数 

341 for word in [words[-2] if len(words) > 1 else None, words[-1] if len(words) > 0 else None]: 

342 if word and word.startswith('-'): 

343 yield from self._complete_option_argument(word, arguments, command_instance, subcmd, start_position=0) 

344 return 

345 

346 # 检查位置参数 

347 for arg_name, arg in arguments.items(): 

348 if not arg_name.startswith('-') and arg['requires_value']: 

349 choices = command_instance.get_arg_values(arg, subcmd, '') 

350 if choices: 

351 yield from self._complete_items(choices, '') 

352 return 

353 

354 # 显示所有可用参数 

355 yield from self._complete_items(arguments.values(), '') 

356 

357 def _complete_partial_input(self, words, arguments, command_instance, subcmd): 

358 """处理部分输入的补齐""" 

359 partial_arg = words[-1] 

360 last_word = words[-2] if len(words) > 1 else None 

361 

362 # 检查当前输入是否是选项参数 

363 if partial_arg.startswith('-'): 

364 yield from self._complete_option_argument(partial_arg, arguments, command_instance, subcmd, 

365 start_position=-len(partial_arg)) 

366 return 

367 

368 # 检查上一个词是否是选项参数 

369 if last_word and last_word.startswith('-'): 

370 arg = arguments.get(last_word, None) 

371 if arg and arg['requires_value']: 

372 choices = command_instance.get_arg_values(arg, subcmd, partial_arg) 

373 if choices: 

374 yield from self._complete_items(choices, partial_arg) 

375 return 

376 

377 # 检查是否是位置参数的输入 

378 for arg_name, arg in arguments.items(): 

379 if not arg_name.startswith('-') and arg['requires_value']: 

380 choices = command_instance.get_arg_values(arg, subcmd, partial_arg) 

381 if choices: 

382 yield from self._complete_items(choices, partial_arg) 

383 return 

384 

385 # 显示所有可用参数 

386 yield from self._complete_items(arguments.values(), partial_arg) 

387 

388 def _complete_option_argument(self, option_name, arguments, command_instance, subcmd, start_position): 

389 """补齐选项参数""" 

390 arg = arguments.get(option_name, None) 

391 if not arg or not arg.get('requires_value'): 

392 return 

393 

394 # 尝试通过 get_arg_values 获取选项 

395 choices = command_instance.get_arg_values(arg, subcmd, '') 

396 if choices: 

397 yield from self._complete_items(choices, '') 

398 return 

399 

400 # 尝试使用 arg 的 choices 

401 if 'choices' in arg and arg['choices']: 

402 choice_names = list(arg['choices'].keys()) 

403 for choice_name in choice_names: 

404 yield Completion(choice_name, start_position=start_position, 

405 display_meta=f"Option: {choice_name}") 

406 

407 def _complete_items(self, items, partial, start_pos=None): 

408 """通用的补齐函数""" 

409 start_pos = -len(partial) if start_pos is None else start_pos 

410 for item in items: 

411 if item.name.startswith(partial): 

412 yield Completion( 

413 item.name, 

414 start_position=start_pos, 

415 display_meta=item.desc 

416 ) 

417 

418 def execute(self, user_input: str) -> dict[str, Any]: 

419 """Execute a command""" 

420 if not user_input.startswith('/'): 

421 raise CommandInputError(user_input) 

422 

423 user_input = user_input[1:].strip() 

424 if not user_input: 

425 raise CommandInputError(user_input) 

426 

427 args = shlex.split(user_input) 

428 command = args[0] 

429 if command not in self.commands: 

430 raise InvalidCommandError(command) 

431 

432 command_instance = self.commands[command] 

433 parser = command_instance.parser 

434 ret = None 

435 try: 

436 # Parse remaining arguments (excluding the command name) 

437 parsed_args = parser.parse_args(args[1:]) 

438 parsed_args.raw_args = args[1:] 

439 ret = command_instance.execute(parsed_args) 

440 except SystemExit as e: 

441 raise CommandError(f"SystemExit: {e}") 

442 except argparse.ArgumentError as e: 

443 raise CommandArgumentError(f"ArgumentError: {e}") from e 

444 except Exception as e: 

445 raise CommandError(f"Error: {e}") from e 

446 

447 return CommandResult(command=command, subcommand=getattr(parsed_args, 'subcommand', None), args=vars(parsed_args), result=ret) 

448 

449 def reload_custom_commands(self): 

450 """Reload all custom commands""" 

451 # Remove existing custom commands 

452 custom_command_names = [] 

453 for name, command in list(self.commands_main.items()): 

454 if hasattr(command, 'file_path'): # It's a custom command 

455 custom_command_names.append(name) 

456 del self.commands_main[name] 

457 

458 for name, command in list(self.commands_task.items()): 

459 if hasattr(command, 'file_path'): # It's a custom command 

460 custom_command_names.append(name) 

461 del self.commands_task[name] 

462 

463 # Reload custom commands 

464 custom_commands = self.custom_command_manager.reload_commands() 

465 for custom_command in custom_commands: 

466 if self.custom_command_manager.validate_command_name( 

467 custom_command.name, 

468 list(self.commands_main.keys()) + list(self.commands_task.keys()) 

469 ): 

470 custom_command.manager = self 

471 self.register_command(custom_command) 

472 custom_command.init() 

473 

474 self.log.info(f"Reloaded {len(custom_commands)} custom commands") 

475 return len(custom_commands)