Coverage for src/alprina_cli/chat_enhanced.py: 0%

255 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-14 11:27 +0100

1""" 

2Interactive chat interface for Alprina. 

3Provides conversational AI assistant for security scanning. 

4""" 

5 

6from typing import Optional 

7from pathlib import Path 

8import re 

9import sys 

10 

11from rich.console import Console 

12from rich.markdown import Markdown 

13from rich.panel import Panel 

14from rich.table import Table 

15from rich.progress import Progress, SpinnerColumn, TextColumn 

16from prompt_toolkit import PromptSession 

17from prompt_toolkit.history import FileHistory 

18from prompt_toolkit.auto_suggest import AutoSuggestFromHistory 

19from prompt_toolkit.key_binding import KeyBindings 

20from loguru import logger 

21 

22from .context_manager import ConversationContext 

23from .llm_provider import get_llm_client 

24from .scanner import scan_command 

25from .mitigation import mitigate_command 

26 

27console = Console() 

28 

29 

30class AlprinaChatSession: 

31 """Interactive chat session with Alprina AI.""" 

32 

33 def __init__( 

34 self, 

35 model: str = "claude-3-5-sonnet-20241022", 

36 streaming: bool = True, 

37 context_file: Optional[Path] = None 

38 ): 

39 """ 

40 Initialize chat session. 

41 

42 Args: 

43 model: LLM model to use 

44 streaming: Enable streaming responses 

45 context_file: Load previous scan context 

46 """ 

47 self.model = model 

48 self.streaming = streaming 

49 self.context = ConversationContext() 

50 self.llm = get_llm_client(model=model) 

51 

52 # Set up prompt session with history 

53 history_file = Path.home() / '.alprina' / 'chat_history.txt' 

54 history_file.parent.mkdir(parents=True, exist_ok=True) 

55 

56 self.session = PromptSession( 

57 history=FileHistory(str(history_file)), 

58 auto_suggest=AutoSuggestFromHistory(), 

59 key_bindings=self._create_key_bindings() 

60 ) 

61 

62 # Load context if provided 

63 if context_file: 

64 self.context.load_scan_results(context_file) 

65 

66 logger.info(f"Chat session initialized with model: {model}") 

67 

68 def _create_key_bindings(self): 

69 """Create custom key bindings.""" 

70 kb = KeyBindings() 

71 

72 @kb.add('c-c') 

73 def _(event): 

74 """Handle Ctrl+C gracefully.""" 

75 event.app.exit() 

76 

77 return kb 

78 

79 def start(self): 

80 """Start interactive chat loop.""" 

81 self._show_welcome() 

82 

83 while True: 

84 try: 

85 # Get user input 

86 user_input = self.session.prompt("\n[bold green]You:[/bold green] ") 

87 

88 if not user_input.strip(): 

89 continue 

90 

91 # Check for exit 

92 if user_input.lower() in ['exit', 'quit', 'q', 'bye']: 

93 self._handle_exit() 

94 break 

95 

96 # Handle special commands 

97 if user_input.startswith('/'): 

98 self._handle_command(user_input) 

99 continue 

100 

101 # Process as chat message 

102 self._process_message(user_input) 

103 

104 except KeyboardInterrupt: 

105 console.print("\n[yellow]Use 'exit' or Ctrl+D to quit[/yellow]") 

106 continue 

107 except EOFError: 

108 self._handle_exit() 

109 break 

110 except Exception as e: 

111 logger.error(f"Chat error: {e}", exc_info=True) 

112 console.print(f"[red]Error: {e}[/red]") 

113 

114 def _show_welcome(self): 

115 """Show welcome message.""" 

116 console.print(Panel( 

117 "[bold cyan]Alprina AI Security Assistant[/bold cyan]\n\n" 

118 "I can help you with:\n" 

119 "• Running security scans on code, APIs, and infrastructure\n" 

120 "• Explaining vulnerabilities and security findings\n" 

121 "• Providing remediation steps and code fixes\n" 

122 "• Answering security questions and best practices\n\n" 

123 "[dim]Type '/help' for commands or just ask me anything![/dim]\n" 

124 "[dim]Type 'exit' to quit[/dim]", 

125 title="🛡️ Alprina Chat", 

126 border_style="cyan" 

127 )) 

128 

129 # Show context if loaded 

130 if self.context.scan_results: 

131 console.print(f"\n[cyan]📊 Context loaded:[/cyan] {self.context.get_context_summary()}\n") 

132 

133 def _process_message(self, user_input: str): 

134 """ 

135 Process user message and get AI response. 

136 

137 Args: 

138 user_input: User's message 

139 """ 

140 # Add to context 

141 self.context.add_user_message(user_input) 

142 

143 # Check if user wants to scan something 

144 if self._is_scan_request(user_input): 

145 self._handle_scan_request(user_input) 

146 return 

147 

148 # Check if user wants mitigation 

149 if self._is_mitigation_request(user_input): 

150 self._handle_mitigation_request(user_input) 

151 return 

152 

153 # Get AI response 

154 console.print(f"\n[bold cyan]Alprina:[/bold cyan]") 

155 

156 if self.streaming: 

157 response = self._get_streaming_response(user_input) 

158 else: 

159 with Progress( 

160 SpinnerColumn(), 

161 TextColumn("[progress.description]{task.description}"), 

162 console=console, 

163 transient=True 

164 ) as progress: 

165 progress.add_task("Thinking...", total=None) 

166 response = self._get_response(user_input) 

167 

168 console.print(Markdown(response)) 

169 

170 self.context.add_assistant_message(response) 

171 

172 def _get_response(self, user_input: str) -> str: 

173 """ 

174 Get AI response (non-streaming). 

175 

176 Args: 

177 user_input: User's message 

178 

179 Returns: 

180 AI response text 

181 """ 

182 try: 

183 system_prompt = self._build_system_prompt() 

184 messages = self.context.get_messages_for_llm() 

185 

186 response = self.llm.chat( 

187 messages=messages, 

188 system_prompt=system_prompt, 

189 max_tokens=4096, 

190 temperature=0.7 

191 ) 

192 

193 return response 

194 except Exception as e: 

195 logger.error(f"Failed to get AI response: {e}") 

196 return f"I encountered an error: {e}. Please try again." 

197 

198 def _get_streaming_response(self, user_input: str) -> str: 

199 """ 

200 Get AI response with streaming. 

201 

202 Args: 

203 user_input: User's message 

204 

205 Returns: 

206 Full AI response text 

207 """ 

208 try: 

209 system_prompt = self._build_system_prompt() 

210 messages = self.context.get_messages_for_llm() 

211 

212 full_response = "" 

213 for chunk in self.llm.chat_streaming( 

214 messages=messages, 

215 system_prompt=system_prompt, 

216 max_tokens=4096, 

217 temperature=0.7 

218 ): 

219 console.print(chunk, end="") 

220 full_response += chunk 

221 

222 console.print() # Newline after streaming 

223 return full_response 

224 

225 except Exception as e: 

226 logger.error(f"Failed to get streaming response: {e}") 

227 return f"\n\nI encountered an error: {e}. Please try again." 

228 

229 def _build_system_prompt(self) -> str: 

230 """ 

231 Build context-aware system prompt. 

232 

233 Returns: 

234 System prompt string 

235 """ 

236 base_prompt = """You are Alprina, an AI-powered security assistant designed to help developers and security professionals. 

237 

238Your capabilities: 

239- Analyzing security scan results and vulnerabilities 

240- Explaining security concepts and vulnerabilities in clear terms 

241- Providing actionable remediation steps and code fixes 

242- Answering questions about security best practices 

243- Guiding users through security assessments 

244 

245Guidelines: 

246- Be concise but comprehensive in explanations 

247- Always prioritize security and ethical practices 

248- Provide code examples when relevant 

249- Use markdown formatting for better readability 

250- If you don't know something, admit it rather than guessing 

251- Focus on actionable insights, not just theory 

252 

253Communication style: 

254- Professional but friendly 

255- Clear and jargon-free when possible 

256- Use examples to illustrate concepts 

257- Ask clarifying questions if needed""" 

258 

259 # Add scan context if available 

260 if self.context.scan_results: 

261 context_info = self.context.get_detailed_context() 

262 base_prompt += f"\n\nCurrent Scan Context:\n{context_info}" 

263 

264 return base_prompt 

265 

266 def _is_scan_request(self, text: str) -> bool: 

267 """Check if user wants to run a scan.""" 

268 scan_keywords = ['scan', 'analyze', 'check', 'test', 'audit'] 

269 text_lower = text.lower() 

270 return any(keyword in text_lower for keyword in scan_keywords) and \ 

271 ('my' in text_lower or './' in text or 'http' in text_lower or 'file' in text_lower) 

272 

273 def _is_mitigation_request(self, text: str) -> bool: 

274 """Check if user wants mitigation steps.""" 

275 mitigation_keywords = ['fix', 'remediate', 'solve', 'patch', 'how to fix'] 

276 return any(keyword in text.lower() for keyword in mitigation_keywords) 

277 

278 def _handle_scan_request(self, user_input: str): 

279 """Handle scan request from natural language.""" 

280 # Try to extract target from user input 

281 target = self._extract_target(user_input) 

282 

283 if not target: 

284 console.print("[yellow]I'd be happy to run a scan! What would you like me to scan?[/yellow]") 

285 console.print("[dim]Example: ./src, https://api.example.com, or /path/to/file[/dim]") 

286 return 

287 

288 console.print(f"\n[cyan]→ Running security scan on:[/cyan] {target}\n") 

289 

290 # Determine profile based on target 

291 profile = "code-audit" if Path(target).exists() else "web-recon" 

292 

293 # Run scan (this will use the existing scan_command) 

294 try: 

295 scan_command(target, profile=profile, safe_only=True, output=None) 

296 

297 # The scan results would be captured here 

298 # For now, simulate adding to context 

299 console.print(f"\n[green]✓ Scan complete![/green]") 

300 console.print("[dim]Type 'explain' to learn about findings or 'fix' for remediation steps[/dim]\n") 

301 

302 except Exception as e: 

303 console.print(f"[red]Scan failed: {e}[/red]") 

304 logger.error(f"Scan failed: {e}", exc_info=True) 

305 

306 def _handle_mitigation_request(self, user_input: str): 

307 """Handle mitigation request.""" 

308 console.print("\n[cyan]→ Getting remediation suggestions...[/cyan]\n") 

309 

310 try: 

311 # Run mitigation command 

312 mitigate_command(finding_id=None, report_file=None) 

313 

314 except Exception as e: 

315 console.print(f"[red]Failed to get mitigation: {e}[/red]") 

316 logger.error(f"Mitigation failed: {e}", exc_info=True) 

317 

318 def _extract_target(self, text: str) -> Optional[str]: 

319 """Extract scan target from natural language.""" 

320 # Look for file paths 

321 file_pattern = r'\.\/[\w\/\-\.]+' 

322 match = re.search(file_pattern, text) 

323 if match: 

324 return match.group(0) 

325 

326 # Look for URLs 

327 url_pattern = r'https?://[\w\-\.\/]+' 

328 match = re.search(url_pattern, text) 

329 if match: 

330 return match.group(0) 

331 

332 # Look for absolute paths 

333 path_pattern = r'/[\w\/\-\.]+' 

334 match = re.search(path_pattern, text) 

335 if match: 

336 return match.group(0) 

337 

338 return None 

339 

340 def _handle_command(self, command: str): 

341 """ 

342 Handle special chat commands. 

343 

344 Args: 

345 command: Command string starting with / 

346 """ 

347 cmd_parts = command.split() 

348 cmd_name = cmd_parts[0].lower() 

349 

350 if cmd_name == '/help': 

351 self._show_help() 

352 elif cmd_name == '/scan': 

353 if len(cmd_parts) < 2: 

354 console.print("[yellow]Usage: /scan <target>[/yellow]") 

355 else: 

356 target = cmd_parts[1] 

357 self._handle_scan_request(f"scan {target}") 

358 elif cmd_name == '/explain': 

359 if len(cmd_parts) < 2: 

360 self._show_findings() 

361 else: 

362 finding_id = cmd_parts[1] 

363 self._explain_finding(finding_id) 

364 elif cmd_name == '/fix': 

365 if len(cmd_parts) < 2: 

366 console.print("[yellow]Usage: /fix <finding_id>[/yellow]") 

367 else: 

368 finding_id = cmd_parts[1] 

369 self._fix_finding(finding_id) 

370 elif cmd_name == '/report': 

371 self._show_report() 

372 elif cmd_name == '/clear': 

373 self.context.clear() 

374 console.print("[green]✓ Conversation history cleared[/green]") 

375 elif cmd_name == '/stats': 

376 self._show_stats() 

377 elif cmd_name == '/save': 

378 self._save_conversation() 

379 else: 

380 console.print(f"[red]Unknown command: {cmd_name}[/red]") 

381 console.print("[dim]Type /help for available commands[/dim]") 

382 

383 def _show_help(self): 

384 """Show help message.""" 

385 help_table = Table(title="Available Commands", show_header=True, header_style="bold cyan") 

386 help_table.add_column("Command", style="cyan") 

387 help_table.add_column("Description") 

388 

389 commands = [ 

390 ("/help", "Show this help message"), 

391 ("/scan <target>", "Run security scan on target"), 

392 ("/explain [id]", "Explain finding (or list all)"), 

393 ("/fix <id>", "Get fix for specific finding"), 

394 ("/report", "Show current scan summary"), 

395 ("/clear", "Clear conversation history"), 

396 ("/stats", "Show conversation statistics"), 

397 ("/save", "Save conversation to file"), 

398 ("exit", "Exit chat session"), 

399 ] 

400 

401 for cmd, desc in commands: 

402 help_table.add_row(cmd, desc) 

403 

404 console.print(help_table) 

405 

406 def _show_findings(self): 

407 """Show current scan findings.""" 

408 if not self.context.current_findings: 

409 console.print("[yellow]No scan findings available. Run a scan first![/yellow]") 

410 return 

411 

412 findings_table = Table(title="Current Findings", show_header=True, header_style="bold") 

413 findings_table.add_column("ID", style="dim") 

414 findings_table.add_column("Severity") 

415 findings_table.add_column("Title") 

416 findings_table.add_column("File", style="dim") 

417 

418 for finding in self.context.current_findings: 

419 severity_style = { 

420 'HIGH': 'bold red', 

421 'MEDIUM': 'bold yellow', 

422 'LOW': 'bold green' 

423 }.get(finding.get('severity', 'UNKNOWN'), 'white') 

424 

425 findings_table.add_row( 

426 finding.get('id', 'N/A'), 

427 f"[{severity_style}]{finding.get('severity', 'UNKNOWN')}[/{severity_style}]", 

428 finding.get('title', 'Unknown'), 

429 finding.get('file', 'N/A') 

430 ) 

431 

432 console.print(findings_table) 

433 

434 def _explain_finding(self, finding_id: str): 

435 """Explain specific finding.""" 

436 finding = self.context.get_finding(finding_id) 

437 if not finding: 

438 console.print(f"[red]Finding {finding_id} not found[/red]") 

439 return 

440 

441 # Use AI to explain the finding 

442 explanation_request = f"Can you explain this security finding in detail?\n\n{finding}" 

443 self._process_message(explanation_request) 

444 

445 def _fix_finding(self, finding_id: str): 

446 """Get fix for specific finding.""" 

447 finding = self.context.get_finding(finding_id) 

448 if not finding: 

449 console.print(f"[red]Finding {finding_id} not found[/red]") 

450 return 

451 

452 # Use AI to provide fix 

453 fix_request = f"How can I fix this security vulnerability?\n\n{finding}" 

454 self._process_message(fix_request) 

455 

456 def _show_report(self): 

457 """Show scan report summary.""" 

458 if not self.context.scan_results: 

459 console.print("[yellow]No scan results available[/yellow]") 

460 return 

461 

462 summary = self.context.get_context_summary() 

463 console.print(Panel( 

464 f"[bold]Scan Report[/bold]\n\n{summary}", 

465 border_style="cyan" 

466 )) 

467 

468 def _show_stats(self): 

469 """Show conversation statistics.""" 

470 stats = self.context.get_statistics() 

471 

472 stats_table = Table(title="Session Statistics", show_header=False) 

473 stats_table.add_column("Metric", style="cyan") 

474 stats_table.add_column("Value", style="bold") 

475 

476 stats_table.add_row("Messages", str(stats['total_messages'])) 

477 stats_table.add_row(" User", str(stats['user_messages'])) 

478 stats_table.add_row(" Assistant", str(stats['assistant_messages'])) 

479 stats_table.add_row("Findings", str(stats['total_findings'])) 

480 stats_table.add_row(" HIGH", str(stats['high_severity'])) 

481 stats_table.add_row(" MEDIUM", str(stats['medium_severity'])) 

482 stats_table.add_row(" LOW", str(stats['low_severity'])) 

483 stats_table.add_row("Duration", f"{stats['session_duration']:.0f}s") 

484 

485 console.print(stats_table) 

486 

487 def _save_conversation(self): 

488 """Save conversation to file.""" 

489 output_file = Path.home() / '.alprina' / 'conversations' / f"chat_{self.context.session_start.strftime('%Y%m%d_%H%M%S')}.json" 

490 output_file.parent.mkdir(parents=True, exist_ok=True) 

491 

492 self.context.save_conversation(output_file) 

493 console.print(f"[green]✓ Conversation saved to:[/green] {output_file}") 

494 

495 def _handle_exit(self): 

496 """Handle exit gracefully.""" 

497 stats = self.context.get_statistics() 

498 console.print(f"\n[cyan]Thanks for using Alprina![/cyan]") 

499 console.print(f"[dim]Session stats: {stats['total_messages']} messages, {stats['session_duration']:.0f}s duration[/dim]") 

500 console.print("[dim]💾 Use /save before exit to save your conversation[/dim]\n") 

501 

502 

503def chat_command( 

504 model: str = "claude-3-5-sonnet-20241022", 

505 streaming: bool = True, 

506 load_results: Optional[Path] = None 

507): 

508 """ 

509 Start interactive chat session with Alprina AI. 

510 

511 Args: 

512 model: LLM model to use 

513 streaming: Enable streaming responses 

514 load_results: Load previous scan results for context 

515 """ 

516 try: 

517 session = AlprinaChatSession( 

518 model=model, 

519 streaming=streaming, 

520 context_file=load_results 

521 ) 

522 session.start() 

523 except Exception as e: 

524 logger.error(f"Chat session error: {e}", exc_info=True) 

525 console.print(f"[red]Chat error: {e}[/red]") 

526 sys.exit(1)