Coverage for src/alprina_cli/context_manager.py: 0%

97 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-14 11:27 +0100

1""" 

2Conversation context management for chat interface. 

3Maintains scan results, findings, and conversation history. 

4""" 

5 

6from typing import List, Dict, Any, Optional 

7from pathlib import Path 

8import json 

9from datetime import datetime 

10from loguru import logger 

11 

12 

13class ConversationContext: 

14 """Manages conversation context and scan results.""" 

15 

16 def __init__(self, max_history: int = 50): 

17 """ 

18 Initialize conversation context. 

19 

20 Args: 

21 max_history: Maximum number of messages to keep in history 

22 """ 

23 self.messages: List[Dict[str, str]] = [] 

24 self.scan_results: Dict[str, Any] = {} 

25 self.current_findings: List[Dict] = [] 

26 self.max_history = max_history 

27 self.session_start = datetime.now() 

28 

29 def add_user_message(self, content: str): 

30 """Add user message to context.""" 

31 self.messages.append({ 

32 "role": "user", 

33 "content": content, 

34 "timestamp": datetime.now().isoformat() 

35 }) 

36 self._trim_history() 

37 logger.debug(f"Added user message: {content[:50]}...") 

38 

39 def add_assistant_message(self, content: str): 

40 """Add assistant response to context.""" 

41 self.messages.append({ 

42 "role": "assistant", 

43 "content": content, 

44 "timestamp": datetime.now().isoformat() 

45 }) 

46 self._trim_history() 

47 logger.debug(f"Added assistant message: {content[:50]}...") 

48 

49 def add_system_message(self, content: str): 

50 """Add system message (scan results, etc.).""" 

51 self.messages.append({ 

52 "role": "system", 

53 "content": content, 

54 "timestamp": datetime.now().isoformat() 

55 }) 

56 self._trim_history() 

57 

58 def _trim_history(self): 

59 """Trim message history to max_history.""" 

60 if len(self.messages) > self.max_history: 

61 # Keep system messages and recent messages 

62 system_messages = [m for m in self.messages if m["role"] == "system"] 

63 recent_messages = [m for m in self.messages if m["role"] != "system"][-self.max_history:] 

64 self.messages = system_messages + recent_messages 

65 logger.debug(f"Trimmed history to {len(self.messages)} messages") 

66 

67 def load_scan_results(self, file_path: Path): 

68 """ 

69 Load scan results from file. 

70 

71 Args: 

72 file_path: Path to scan results JSON file 

73 """ 

74 try: 

75 with open(file_path) as f: 

76 self.scan_results = json.load(f) 

77 self.current_findings = self.scan_results.get("findings", []) 

78 

79 # Add system message about loaded results 

80 summary = self.get_context_summary() 

81 self.add_system_message(f"Loaded scan results: {summary}") 

82 

83 logger.info(f"Loaded {len(self.current_findings)} findings from {file_path}") 

84 except Exception as e: 

85 logger.error(f"Failed to load scan results: {e}") 

86 raise 

87 

88 def load_scan_results_dict(self, results: Dict[str, Any]): 

89 """ 

90 Load scan results from dictionary. 

91 

92 Args: 

93 results: Scan results dictionary 

94 """ 

95 self.scan_results = results 

96 self.current_findings = results.get("findings", []) 

97 

98 # Add system message about results 

99 summary = self.get_context_summary() 

100 self.add_system_message(f"Scan completed: {summary}") 

101 

102 logger.info(f"Loaded {len(self.current_findings)} findings from scan") 

103 

104 def get_finding(self, finding_id: str) -> Optional[Dict]: 

105 """ 

106 Get specific finding by ID. 

107 

108 Args: 

109 finding_id: Finding identifier 

110 

111 Returns: 

112 Finding dictionary or None if not found 

113 """ 

114 for finding in self.current_findings: 

115 if finding.get("id") == finding_id: 

116 return finding 

117 return None 

118 

119 def get_findings_by_severity(self, severity: str) -> List[Dict]: 

120 """ 

121 Get findings filtered by severity. 

122 

123 Args: 

124 severity: Severity level (HIGH, MEDIUM, LOW) 

125 

126 Returns: 

127 List of findings matching severity 

128 """ 

129 return [f for f in self.current_findings if f.get("severity") == severity.upper()] 

130 

131 def get_messages(self, include_system: bool = True) -> List[Dict]: 

132 """ 

133 Get conversation messages. 

134 

135 Args: 

136 include_system: Whether to include system messages 

137 

138 Returns: 

139 List of message dictionaries 

140 """ 

141 if include_system: 

142 return self.messages 

143 return [m for m in self.messages if m["role"] != "system"] 

144 

145 def get_messages_for_llm(self) -> List[Dict[str, str]]: 

146 """ 

147 Get messages formatted for LLM API (without timestamps). 

148 

149 Returns: 

150 List of message dictionaries with role and content only 

151 """ 

152 return [ 

153 {"role": m["role"], "content": m["content"]} 

154 for m in self.messages 

155 ] 

156 

157 def get_context_summary(self) -> str: 

158 """ 

159 Get summary of current scan context. 

160 

161 Returns: 

162 Human-readable summary string 

163 """ 

164 if not self.scan_results: 

165 return "No active scan context" 

166 

167 target = self.scan_results.get('target', 'Unknown') 

168 total = len(self.current_findings) 

169 high = sum(1 for f in self.current_findings if f.get('severity') == 'HIGH') 

170 medium = sum(1 for f in self.current_findings if f.get('severity') == 'MEDIUM') 

171 low = sum(1 for f in self.current_findings if f.get('severity') == 'LOW') 

172 

173 summary = f"{total} findings (HIGH: {high}, MEDIUM: {medium}, LOW: {low}) in {target}" 

174 return summary 

175 

176 def get_detailed_context(self) -> str: 

177 """ 

178 Get detailed context for system prompt. 

179 

180 Returns: 

181 Detailed context string with findings 

182 """ 

183 if not self.scan_results: 

184 return "No scan context available." 

185 

186 context = f""" 

187Current Scan Context: 

188===================== 

189Target: {self.scan_results.get('target', 'Unknown')} 

190Scan ID: {self.scan_results.get('scan_id', 'Unknown')} 

191Profile: {self.scan_results.get('profile', 'default')} 

192Timestamp: {self.scan_results.get('timestamp', 'Unknown')} 

193 

194Findings Summary: 

195----------------- 

196Total Findings: {len(self.current_findings)} 

197""" 

198 

199 # Add severity breakdown 

200 for severity in ['HIGH', 'MEDIUM', 'LOW']: 

201 findings = self.get_findings_by_severity(severity) 

202 if findings: 

203 context += f"\n{severity} Severity ({len(findings)}):\n" 

204 for f in findings[:3]: # Show first 3 of each severity 

205 context += f" - {f.get('id')}: {f.get('title')} ({f.get('file', 'N/A')})\n" 

206 if len(findings) > 3: 

207 context += f" ... and {len(findings) - 3} more\n" 

208 

209 return context 

210 

211 def clear(self): 

212 """Clear conversation history (keeps scan results).""" 

213 self.messages = [] 

214 logger.info("Cleared conversation history") 

215 

216 def clear_all(self): 

217 """Clear everything including scan results.""" 

218 self.messages = [] 

219 self.scan_results = {} 

220 self.current_findings = [] 

221 logger.info("Cleared all context") 

222 

223 def save_conversation(self, file_path: Path): 

224 """ 

225 Save conversation to file. 

226 

227 Args: 

228 file_path: Path to save conversation 

229 """ 

230 conversation_data = { 

231 "session_start": self.session_start.isoformat(), 

232 "messages": self.messages, 

233 "scan_summary": self.get_context_summary(), 

234 "total_messages": len(self.messages) 

235 } 

236 

237 with open(file_path, 'w') as f: 

238 json.dump(conversation_data, f, indent=2) 

239 

240 logger.info(f"Saved conversation to {file_path}") 

241 

242 def get_statistics(self) -> Dict[str, Any]: 

243 """ 

244 Get conversation statistics. 

245 

246 Returns: 

247 Dictionary with statistics 

248 """ 

249 return { 

250 "total_messages": len(self.messages), 

251 "user_messages": sum(1 for m in self.messages if m["role"] == "user"), 

252 "assistant_messages": sum(1 for m in self.messages if m["role"] == "assistant"), 

253 "system_messages": sum(1 for m in self.messages if m["role"] == "system"), 

254 "total_findings": len(self.current_findings), 

255 "high_severity": len(self.get_findings_by_severity("HIGH")), 

256 "medium_severity": len(self.get_findings_by_severity("MEDIUM")), 

257 "low_severity": len(self.get_findings_by_severity("LOW")), 

258 "session_duration": (datetime.now() - self.session_start).total_seconds() 

259 }