Coverage for src/alprina_cli/services/fix_generator.py: 15%

142 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-14 11:27 +0100

1""" 

2AI-Powered Fix Generator - Generates secure code fixes for vulnerabilities. 

3Uses GPT-4/Claude to create context-aware security fixes with explanations. 

4""" 

5 

6import os 

7import difflib 

8from typing import Dict, List, Optional 

9from pathlib import Path 

10from loguru import logger 

11 

12from ..llm_provider import get_llm_client 

13 

14 

15class FixGenerator: 

16 """ 

17 Generate AI-powered fixes for security vulnerabilities. 

18  

19 Uses existing LLM integration (GPT-4/Claude) to generate secure 

20 code alternatives with explanations and confidence scoring. 

21 """ 

22 

23 def __init__(self): 

24 """Initialize fix generator with LLM client.""" 

25 self.llm = get_llm_client() 

26 logger.info(f"FixGenerator initialized with LLM provider: {self.llm.provider}") 

27 

28 def generate_fix( 

29 self, 

30 code: str, 

31 vulnerability: Dict, 

32 filename: str, 

33 context_lines: int = 10 

34 ) -> Dict: 

35 """ 

36 Generate AI-powered fix for a vulnerability. 

37 

38 Args: 

39 code: Full source code of the file 

40 vulnerability: Vulnerability details (type, severity, line, etc.) 

41 filename: Name of the file being fixed 

42 context_lines: Lines of context around vulnerability (default: 10) 

43 

44 Returns: 

45 Dict with: 

46 - fixed_code: Complete fixed code 

47 - explanation: Why this fix works 

48 - changes: List of specific changes made 

49 - diff: Unified diff showing changes 

50 - confidence: Confidence score (0.0-1.0) 

51 - security_notes: Important security considerations 

52 """ 

53 try: 

54 logger.info(f"Generating fix for {vulnerability.get('type')} in {filename}") 

55 

56 # Extract relevant code context 

57 vuln_line = vulnerability.get("line", 0) 

58 code_lines = code.split("\n") 

59 

60 # Get context around vulnerability 

61 start_line = max(0, vuln_line - context_lines) 

62 end_line = min(len(code_lines), vuln_line + context_lines) 

63 context = "\n".join(code_lines[start_line:end_line]) 

64 

65 # Build fix generation prompt 

66 prompt = self._build_fix_prompt( 

67 context=context, 

68 vulnerability=vulnerability, 

69 filename=filename, 

70 full_code_length=len(code_lines) 

71 ) 

72 

73 # Generate fix using LLM 

74 logger.debug("Calling LLM for fix generation...") 

75 response = self.llm.chat( 

76 messages=[{"role": "user", "content": prompt}], 

77 system_prompt=self._get_system_prompt(), 

78 temperature=0.3, # Lower temperature for more deterministic fixes 

79 max_tokens=2000 

80 ) 

81 

82 # Parse LLM response 

83 fix_data = self._parse_fix_response(response) 

84 

85 # Generate diff 

86 fix_data["diff"] = self._generate_diff(context, fix_data.get("fixed_code", context)) 

87 

88 # Calculate confidence score 

89 fix_data["confidence"] = self._calculate_confidence(vulnerability, fix_data) 

90 

91 # Add metadata 

92 fix_data["vulnerability_type"] = vulnerability.get("type") 

93 fix_data["severity"] = vulnerability.get("severity") 

94 fix_data["filename"] = filename 

95 fix_data["line"] = vuln_line 

96 

97 logger.info(f"Fix generated with confidence: {fix_data['confidence']:.2f}") 

98 return fix_data 

99 

100 except Exception as e: 

101 logger.error(f"Error generating fix: {e}") 

102 return { 

103 "error": str(e), 

104 "fixed_code": code, 

105 "explanation": "Could not generate fix due to error", 

106 "confidence": 0.0 

107 } 

108 

109 def generate_multiple_fixes( 

110 self, 

111 findings: List[Dict], 

112 file_contents: Dict[str, str] 

113 ) -> Dict[str, List[Dict]]: 

114 """ 

115 Generate fixes for multiple vulnerabilities across files. 

116 

117 Args: 

118 findings: List of vulnerability findings 

119 file_contents: Dict mapping file paths to their content 

120 

121 Returns: 

122 Dict mapping file paths to list of fix suggestions 

123 """ 

124 fixes_by_file = {} 

125 

126 for finding in findings: 

127 file_path = finding.get("location", "").split(":")[0] 

128 

129 if file_path not in file_contents: 

130 logger.warning(f"File not found: {file_path}") 

131 continue 

132 

133 code = file_contents[file_path] 

134 fix = self.generate_fix(code, finding, file_path) 

135 

136 if file_path not in fixes_by_file: 

137 fixes_by_file[file_path] = [] 

138 

139 fixes_by_file[file_path].append(fix) 

140 

141 return fixes_by_file 

142 

143 def _build_fix_prompt( 

144 self, 

145 context: str, 

146 vulnerability: Dict, 

147 filename: str, 

148 full_code_length: int 

149 ) -> str: 

150 """Build detailed prompt for fix generation.""" 

151 

152 vuln_type = vulnerability.get("type", "Security Issue") 

153 severity = vulnerability.get("severity", "MEDIUM") 

154 description = vulnerability.get("description", "") 

155 cwe = vulnerability.get("cwe", "") 

156 cvss = vulnerability.get("cvss_score", "") 

157 

158 prompt = f"""You are a security expert fixing code vulnerabilities. 

159 

160**VULNERABILITY DETAILS:** 

161- Type: {vuln_type} 

162- Severity: {severity} 

163- File: {filename} 

164- Line: {vulnerability.get('line', 'unknown')} 

165{f"- CWE: {cwe}" if cwe else ""} 

166{f"- CVSS Score: {cvss}/10.0" if cvss else ""} 

167 

168**ISSUE DESCRIPTION:** 

169{description} 

170 

171**VULNERABLE CODE:** 

172```{self._get_language_from_filename(filename)} 

173{context} 

174``` 

175 

176**YOUR TASK:** 

177Generate a secure fix that: 

1781. ✅ COMPLETELY resolves the {vuln_type} vulnerability 

1792. ✅ Maintains ALL original functionality 

1803. ✅ Follows security best practices 

1814. ✅ Preserves exact indentation and code style 

1825. ✅ Includes inline comments explaining the fix 

1836. ✅ Works with the existing codebase structure 

184 

185**IMPORTANT:** 

186- Return ONLY the fixed code section (same lines as shown above) 

187- Preserve exact indentation (spaces/tabs) 

188- Don't remove unrelated code 

189- Add security-focused comments 

190- Make minimal changes (only fix the vulnerability) 

191 

192**RETURN FORMAT:** 

193Return a JSON object with these fields: 

194 

195{{ 

196 "fixed_code": "complete fixed code with exact indentation", 

197 "explanation": "detailed explanation of why this fix is secure and how it works", 

198 "changes": ["specific change 1", "specific change 2", ...], 

199 "security_notes": ["important security consideration 1", "consideration 2", ...] 

200}} 

201 

202**EXAMPLES OF GOOD FIXES:** 

203 

204SQL Injection: 

205❌ query = f"SELECT * FROM users WHERE id = {{user_id}}" 

206✅ query = "SELECT * FROM users WHERE id = ?" 

207✅ cursor.execute(query, (user_id,)) 

208 

209XSS: 

210❌ return f"<div>{{user_input}}</div>" 

211✅ from html import escape 

212✅ return f"<div>{{escape(user_input)}}</div>" 

213 

214Hardcoded Secret: 

215❌ API_KEY = "sk_live_abc123" 

216✅ import os 

217✅ API_KEY = os.getenv("API_KEY") 

218✅ if not API_KEY: raise ValueError("API_KEY not set") 

219 

220Now generate the secure fix: 

221""" 

222 return prompt 

223 

224 def _get_system_prompt(self) -> str: 

225 """Get system prompt for LLM.""" 

226 return """You are an expert security engineer specializing in secure code remediation. 

227 

228Your expertise includes: 

229- OWASP Top 10 vulnerabilities 

230- CWE (Common Weakness Enumeration) 

231- Secure coding practices for all major languages 

232- Defense-in-depth security principles 

233- Zero-trust architecture 

234 

235When generating fixes: 

236✅ Prioritize security without breaking functionality 

237✅ Use well-established security libraries and patterns 

238✅ Add clear comments explaining security rationale 

239✅ Preserve code style and indentation exactly 

240✅ Return valid JSON format 

241 

242Return ONLY valid JSON. No markdown formatting, no explanatory text outside JSON.""" 

243 

244 def _parse_fix_response(self, response: str) -> Dict: 

245 """Parse LLM response into structured fix data.""" 

246 import json 

247 import re 

248 

249 try: 

250 # Remove markdown code blocks if present 

251 json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', response, re.DOTALL) 

252 if json_match: 

253 response = json_match.group(1) 

254 

255 # Parse JSON 

256 fix_data = json.loads(response) 

257 

258 # Validate required fields 

259 if "fixed_code" not in fix_data: 

260 logger.warning("LLM response missing 'fixed_code' field") 

261 fix_data["fixed_code"] = "" 

262 

263 # Ensure all expected fields exist 

264 fix_data.setdefault("explanation", "Fix generated by AI") 

265 fix_data.setdefault("changes", []) 

266 fix_data.setdefault("security_notes", []) 

267 

268 return fix_data 

269 

270 except json.JSONDecodeError as e: 

271 logger.error(f"Failed to parse LLM response as JSON: {e}") 

272 logger.debug(f"Response content: {response[:500]}") 

273 

274 # Fallback: try to extract code from response 

275 return { 

276 "fixed_code": self._extract_code_from_text(response), 

277 "explanation": "Fix generated, but response format was unexpected", 

278 "changes": ["Security improvement applied"], 

279 "security_notes": ["Please verify this fix manually"] 

280 } 

281 

282 def _extract_code_from_text(self, text: str) -> str: 

283 """Extract code from unstructured text response.""" 

284 import re 

285 

286 # Try to find code blocks 

287 code_match = re.search(r'```[\w]*\n(.*?)\n```', text, re.DOTALL) 

288 if code_match: 

289 return code_match.group(1) 

290 

291 # If no code blocks, look for fixed_code field 

292 fixed_match = re.search(r'"fixed_code"\s*:\s*"(.*?)"', text, re.DOTALL) 

293 if fixed_match: 

294 return fixed_match.group(1).replace('\\n', '\n') 

295 

296 # Return original text as fallback 

297 return text 

298 

299 def _generate_diff(self, original: str, fixed: str) -> str: 

300 """Generate unified diff between original and fixed code.""" 

301 original_lines = original.splitlines(keepends=True) 

302 fixed_lines = fixed.splitlines(keepends=True) 

303 

304 diff = difflib.unified_diff( 

305 original_lines, 

306 fixed_lines, 

307 fromfile="original", 

308 tofile="fixed", 

309 lineterm="" 

310 ) 

311 

312 return "".join(diff) 

313 

314 def _calculate_confidence(self, vulnerability: Dict, fix_data: Dict) -> float: 

315 """ 

316 Calculate confidence score for the generated fix (0.0-1.0). 

317 

318 Based on: 

319 - Vulnerability type (some are easier to fix) 

320 - Fix completeness 

321 - Explanation quality 

322 - Security considerations included 

323 """ 

324 confidence = 0.7 # Base confidence 

325 

326 vuln_type = vulnerability.get("type", "") 

327 

328 # High confidence for well-known patterns 

329 high_confidence_types = [ 

330 "SQL Injection", 

331 "Hardcoded Secret", 

332 "XSS", 

333 "Command Injection", 

334 "Path Traversal" 

335 ] 

336 if any(t.lower() in vuln_type.lower() for t in high_confidence_types): 

337 confidence += 0.15 

338 

339 # Lower confidence for complex vulnerabilities 

340 low_confidence_types = [ 

341 "Race Condition", 

342 "Business Logic", 

343 "Authentication Flow", 

344 "Complex Authorization" 

345 ] 

346 if any(t.lower() in vuln_type.lower() for t in low_confidence_types): 

347 confidence -= 0.2 

348 

349 # Increase confidence if fix has good explanation 

350 explanation = fix_data.get("explanation", "") 

351 if len(explanation) > 100 and "secure" in explanation.lower(): 

352 confidence += 0.05 

353 

354 # Increase if security notes provided 

355 if fix_data.get("security_notes") and len(fix_data["security_notes"]) > 0: 

356 confidence += 0.05 

357 

358 # Increase if specific changes listed 

359 if fix_data.get("changes") and len(fix_data["changes"]) > 0: 

360 confidence += 0.05 

361 

362 # Cap at 0.95 (never 100% certain) 

363 return min(0.95, max(0.0, confidence)) 

364 

365 def _get_language_from_filename(self, filename: str) -> str: 

366 """Determine programming language from filename.""" 

367 ext_map = { 

368 ".py": "python", 

369 ".js": "javascript", 

370 ".ts": "typescript", 

371 ".java": "java", 

372 ".go": "go", 

373 ".rb": "ruby", 

374 ".php": "php", 

375 ".cs": "csharp", 

376 ".cpp": "cpp", 

377 ".c": "c", 

378 ".rs": "rust", 

379 ".swift": "swift", 

380 ".kt": "kotlin", 

381 } 

382 

383 ext = Path(filename).suffix.lower() 

384 return ext_map.get(ext, "") 

385 

386 def apply_fix( 

387 self, 

388 filepath: str, 

389 fix_data: Dict, 

390 backup: bool = True 

391 ) -> bool: 

392 """ 

393 Apply a generated fix to a file. 

394 

395 Args: 

396 filepath: Path to file to fix 

397 fix_data: Fix data from generate_fix() 

398 backup: Create .backup file before modifying 

399 

400 Returns: 

401 True if successful, False otherwise 

402 """ 

403 try: 

404 filepath = Path(filepath) 

405 

406 # Read current file 

407 with open(filepath, 'r', encoding='utf-8') as f: 

408 current_code = f.read() 

409 

410 # Create backup if requested 

411 if backup: 

412 backup_path = filepath.with_suffix(filepath.suffix + ".backup") 

413 with open(backup_path, 'w', encoding='utf-8') as f: 

414 f.write(current_code) 

415 logger.info(f"Backup created: {backup_path}") 

416 

417 # Apply fix (write fixed code) 

418 fixed_code = fix_data.get("fixed_code", "") 

419 if not fixed_code: 

420 logger.error("No fixed code in fix_data") 

421 return False 

422 

423 with open(filepath, 'w', encoding='utf-8') as f: 

424 f.write(fixed_code) 

425 

426 logger.info(f"Fix applied to: {filepath}") 

427 return True 

428 

429 except Exception as e: 

430 logger.error(f"Failed to apply fix to {filepath}: {e}") 

431 return False 

432 

433 

434# Global fix generator instance 

435_fix_generator = None 

436 

437 

438def get_fix_generator() -> FixGenerator: 

439 """Get or create global fix generator instance.""" 

440 global _fix_generator 

441 if _fix_generator is None: 

442 _fix_generator = FixGenerator() 

443 return _fix_generator 

444 

445 

446# Convenience functions 

447def generate_fix(code: str, vulnerability: Dict, filename: str) -> Dict: 

448 """ 

449 Convenience function to generate a fix. 

450 

451 Args: 

452 code: Source code 

453 vulnerability: Vulnerability details 

454 filename: File name 

455 

456 Returns: 

457 Fix data dict 

458 """ 

459 generator = get_fix_generator() 

460 return generator.generate_fix(code, vulnerability, filename) 

461 

462 

463def apply_fix_to_file(filepath: str, fix_data: Dict, backup: bool = True) -> bool: 

464 """ 

465 Convenience function to apply a fix to a file. 

466 

467 Args: 

468 filepath: Path to file 

469 fix_data: Fix data from generate_fix() 

470 backup: Create backup file 

471 

472 Returns: 

473 True if successful 

474 """ 

475 generator = get_fix_generator() 

476 return generator.apply_fix(filepath, fix_data, backup)