Coverage for src/alprina_cli/services/fix_generator.py: 15%
142 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
1"""
2AI-Powered Fix Generator - Generates secure code fixes for vulnerabilities.
3Uses GPT-4/Claude to create context-aware security fixes with explanations.
4"""
6import os
7import difflib
8from typing import Dict, List, Optional
9from pathlib import Path
10from loguru import logger
12from ..llm_provider import get_llm_client
15class FixGenerator:
16 """
17 Generate AI-powered fixes for security vulnerabilities.
19 Uses existing LLM integration (GPT-4/Claude) to generate secure
20 code alternatives with explanations and confidence scoring.
21 """
23 def __init__(self):
24 """Initialize fix generator with LLM client."""
25 self.llm = get_llm_client()
26 logger.info(f"FixGenerator initialized with LLM provider: {self.llm.provider}")
28 def generate_fix(
29 self,
30 code: str,
31 vulnerability: Dict,
32 filename: str,
33 context_lines: int = 10
34 ) -> Dict:
35 """
36 Generate AI-powered fix for a vulnerability.
38 Args:
39 code: Full source code of the file
40 vulnerability: Vulnerability details (type, severity, line, etc.)
41 filename: Name of the file being fixed
42 context_lines: Lines of context around vulnerability (default: 10)
44 Returns:
45 Dict with:
46 - fixed_code: Complete fixed code
47 - explanation: Why this fix works
48 - changes: List of specific changes made
49 - diff: Unified diff showing changes
50 - confidence: Confidence score (0.0-1.0)
51 - security_notes: Important security considerations
52 """
53 try:
54 logger.info(f"Generating fix for {vulnerability.get('type')} in {filename}")
56 # Extract relevant code context
57 vuln_line = vulnerability.get("line", 0)
58 code_lines = code.split("\n")
60 # Get context around vulnerability
61 start_line = max(0, vuln_line - context_lines)
62 end_line = min(len(code_lines), vuln_line + context_lines)
63 context = "\n".join(code_lines[start_line:end_line])
65 # Build fix generation prompt
66 prompt = self._build_fix_prompt(
67 context=context,
68 vulnerability=vulnerability,
69 filename=filename,
70 full_code_length=len(code_lines)
71 )
73 # Generate fix using LLM
74 logger.debug("Calling LLM for fix generation...")
75 response = self.llm.chat(
76 messages=[{"role": "user", "content": prompt}],
77 system_prompt=self._get_system_prompt(),
78 temperature=0.3, # Lower temperature for more deterministic fixes
79 max_tokens=2000
80 )
82 # Parse LLM response
83 fix_data = self._parse_fix_response(response)
85 # Generate diff
86 fix_data["diff"] = self._generate_diff(context, fix_data.get("fixed_code", context))
88 # Calculate confidence score
89 fix_data["confidence"] = self._calculate_confidence(vulnerability, fix_data)
91 # Add metadata
92 fix_data["vulnerability_type"] = vulnerability.get("type")
93 fix_data["severity"] = vulnerability.get("severity")
94 fix_data["filename"] = filename
95 fix_data["line"] = vuln_line
97 logger.info(f"Fix generated with confidence: {fix_data['confidence']:.2f}")
98 return fix_data
100 except Exception as e:
101 logger.error(f"Error generating fix: {e}")
102 return {
103 "error": str(e),
104 "fixed_code": code,
105 "explanation": "Could not generate fix due to error",
106 "confidence": 0.0
107 }
109 def generate_multiple_fixes(
110 self,
111 findings: List[Dict],
112 file_contents: Dict[str, str]
113 ) -> Dict[str, List[Dict]]:
114 """
115 Generate fixes for multiple vulnerabilities across files.
117 Args:
118 findings: List of vulnerability findings
119 file_contents: Dict mapping file paths to their content
121 Returns:
122 Dict mapping file paths to list of fix suggestions
123 """
124 fixes_by_file = {}
126 for finding in findings:
127 file_path = finding.get("location", "").split(":")[0]
129 if file_path not in file_contents:
130 logger.warning(f"File not found: {file_path}")
131 continue
133 code = file_contents[file_path]
134 fix = self.generate_fix(code, finding, file_path)
136 if file_path not in fixes_by_file:
137 fixes_by_file[file_path] = []
139 fixes_by_file[file_path].append(fix)
141 return fixes_by_file
143 def _build_fix_prompt(
144 self,
145 context: str,
146 vulnerability: Dict,
147 filename: str,
148 full_code_length: int
149 ) -> str:
150 """Build detailed prompt for fix generation."""
152 vuln_type = vulnerability.get("type", "Security Issue")
153 severity = vulnerability.get("severity", "MEDIUM")
154 description = vulnerability.get("description", "")
155 cwe = vulnerability.get("cwe", "")
156 cvss = vulnerability.get("cvss_score", "")
158 prompt = f"""You are a security expert fixing code vulnerabilities.
160**VULNERABILITY DETAILS:**
161- Type: {vuln_type}
162- Severity: {severity}
163- File: {filename}
164- Line: {vulnerability.get('line', 'unknown')}
165{f"- CWE: {cwe}" if cwe else ""}
166{f"- CVSS Score: {cvss}/10.0" if cvss else ""}
168**ISSUE DESCRIPTION:**
169{description}
171**VULNERABLE CODE:**
172```{self._get_language_from_filename(filename)}
173{context}
174```
176**YOUR TASK:**
177Generate a secure fix that:
1781. ✅ COMPLETELY resolves the {vuln_type} vulnerability
1792. ✅ Maintains ALL original functionality
1803. ✅ Follows security best practices
1814. ✅ Preserves exact indentation and code style
1825. ✅ Includes inline comments explaining the fix
1836. ✅ Works with the existing codebase structure
185**IMPORTANT:**
186- Return ONLY the fixed code section (same lines as shown above)
187- Preserve exact indentation (spaces/tabs)
188- Don't remove unrelated code
189- Add security-focused comments
190- Make minimal changes (only fix the vulnerability)
192**RETURN FORMAT:**
193Return a JSON object with these fields:
195{{
196 "fixed_code": "complete fixed code with exact indentation",
197 "explanation": "detailed explanation of why this fix is secure and how it works",
198 "changes": ["specific change 1", "specific change 2", ...],
199 "security_notes": ["important security consideration 1", "consideration 2", ...]
200}}
202**EXAMPLES OF GOOD FIXES:**
204SQL Injection:
205❌ query = f"SELECT * FROM users WHERE id = {{user_id}}"
206✅ query = "SELECT * FROM users WHERE id = ?"
207✅ cursor.execute(query, (user_id,))
209XSS:
210❌ return f"<div>{{user_input}}</div>"
211✅ from html import escape
212✅ return f"<div>{{escape(user_input)}}</div>"
214Hardcoded Secret:
215❌ API_KEY = "sk_live_abc123"
216✅ import os
217✅ API_KEY = os.getenv("API_KEY")
218✅ if not API_KEY: raise ValueError("API_KEY not set")
220Now generate the secure fix:
221"""
222 return prompt
224 def _get_system_prompt(self) -> str:
225 """Get system prompt for LLM."""
226 return """You are an expert security engineer specializing in secure code remediation.
228Your expertise includes:
229- OWASP Top 10 vulnerabilities
230- CWE (Common Weakness Enumeration)
231- Secure coding practices for all major languages
232- Defense-in-depth security principles
233- Zero-trust architecture
235When generating fixes:
236✅ Prioritize security without breaking functionality
237✅ Use well-established security libraries and patterns
238✅ Add clear comments explaining security rationale
239✅ Preserve code style and indentation exactly
240✅ Return valid JSON format
242Return ONLY valid JSON. No markdown formatting, no explanatory text outside JSON."""
244 def _parse_fix_response(self, response: str) -> Dict:
245 """Parse LLM response into structured fix data."""
246 import json
247 import re
249 try:
250 # Remove markdown code blocks if present
251 json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', response, re.DOTALL)
252 if json_match:
253 response = json_match.group(1)
255 # Parse JSON
256 fix_data = json.loads(response)
258 # Validate required fields
259 if "fixed_code" not in fix_data:
260 logger.warning("LLM response missing 'fixed_code' field")
261 fix_data["fixed_code"] = ""
263 # Ensure all expected fields exist
264 fix_data.setdefault("explanation", "Fix generated by AI")
265 fix_data.setdefault("changes", [])
266 fix_data.setdefault("security_notes", [])
268 return fix_data
270 except json.JSONDecodeError as e:
271 logger.error(f"Failed to parse LLM response as JSON: {e}")
272 logger.debug(f"Response content: {response[:500]}")
274 # Fallback: try to extract code from response
275 return {
276 "fixed_code": self._extract_code_from_text(response),
277 "explanation": "Fix generated, but response format was unexpected",
278 "changes": ["Security improvement applied"],
279 "security_notes": ["Please verify this fix manually"]
280 }
282 def _extract_code_from_text(self, text: str) -> str:
283 """Extract code from unstructured text response."""
284 import re
286 # Try to find code blocks
287 code_match = re.search(r'```[\w]*\n(.*?)\n```', text, re.DOTALL)
288 if code_match:
289 return code_match.group(1)
291 # If no code blocks, look for fixed_code field
292 fixed_match = re.search(r'"fixed_code"\s*:\s*"(.*?)"', text, re.DOTALL)
293 if fixed_match:
294 return fixed_match.group(1).replace('\\n', '\n')
296 # Return original text as fallback
297 return text
299 def _generate_diff(self, original: str, fixed: str) -> str:
300 """Generate unified diff between original and fixed code."""
301 original_lines = original.splitlines(keepends=True)
302 fixed_lines = fixed.splitlines(keepends=True)
304 diff = difflib.unified_diff(
305 original_lines,
306 fixed_lines,
307 fromfile="original",
308 tofile="fixed",
309 lineterm=""
310 )
312 return "".join(diff)
314 def _calculate_confidence(self, vulnerability: Dict, fix_data: Dict) -> float:
315 """
316 Calculate confidence score for the generated fix (0.0-1.0).
318 Based on:
319 - Vulnerability type (some are easier to fix)
320 - Fix completeness
321 - Explanation quality
322 - Security considerations included
323 """
324 confidence = 0.7 # Base confidence
326 vuln_type = vulnerability.get("type", "")
328 # High confidence for well-known patterns
329 high_confidence_types = [
330 "SQL Injection",
331 "Hardcoded Secret",
332 "XSS",
333 "Command Injection",
334 "Path Traversal"
335 ]
336 if any(t.lower() in vuln_type.lower() for t in high_confidence_types):
337 confidence += 0.15
339 # Lower confidence for complex vulnerabilities
340 low_confidence_types = [
341 "Race Condition",
342 "Business Logic",
343 "Authentication Flow",
344 "Complex Authorization"
345 ]
346 if any(t.lower() in vuln_type.lower() for t in low_confidence_types):
347 confidence -= 0.2
349 # Increase confidence if fix has good explanation
350 explanation = fix_data.get("explanation", "")
351 if len(explanation) > 100 and "secure" in explanation.lower():
352 confidence += 0.05
354 # Increase if security notes provided
355 if fix_data.get("security_notes") and len(fix_data["security_notes"]) > 0:
356 confidence += 0.05
358 # Increase if specific changes listed
359 if fix_data.get("changes") and len(fix_data["changes"]) > 0:
360 confidence += 0.05
362 # Cap at 0.95 (never 100% certain)
363 return min(0.95, max(0.0, confidence))
365 def _get_language_from_filename(self, filename: str) -> str:
366 """Determine programming language from filename."""
367 ext_map = {
368 ".py": "python",
369 ".js": "javascript",
370 ".ts": "typescript",
371 ".java": "java",
372 ".go": "go",
373 ".rb": "ruby",
374 ".php": "php",
375 ".cs": "csharp",
376 ".cpp": "cpp",
377 ".c": "c",
378 ".rs": "rust",
379 ".swift": "swift",
380 ".kt": "kotlin",
381 }
383 ext = Path(filename).suffix.lower()
384 return ext_map.get(ext, "")
386 def apply_fix(
387 self,
388 filepath: str,
389 fix_data: Dict,
390 backup: bool = True
391 ) -> bool:
392 """
393 Apply a generated fix to a file.
395 Args:
396 filepath: Path to file to fix
397 fix_data: Fix data from generate_fix()
398 backup: Create .backup file before modifying
400 Returns:
401 True if successful, False otherwise
402 """
403 try:
404 filepath = Path(filepath)
406 # Read current file
407 with open(filepath, 'r', encoding='utf-8') as f:
408 current_code = f.read()
410 # Create backup if requested
411 if backup:
412 backup_path = filepath.with_suffix(filepath.suffix + ".backup")
413 with open(backup_path, 'w', encoding='utf-8') as f:
414 f.write(current_code)
415 logger.info(f"Backup created: {backup_path}")
417 # Apply fix (write fixed code)
418 fixed_code = fix_data.get("fixed_code", "")
419 if not fixed_code:
420 logger.error("No fixed code in fix_data")
421 return False
423 with open(filepath, 'w', encoding='utf-8') as f:
424 f.write(fixed_code)
426 logger.info(f"Fix applied to: {filepath}")
427 return True
429 except Exception as e:
430 logger.error(f"Failed to apply fix to {filepath}: {e}")
431 return False
434# Global fix generator instance
435_fix_generator = None
438def get_fix_generator() -> FixGenerator:
439 """Get or create global fix generator instance."""
440 global _fix_generator
441 if _fix_generator is None:
442 _fix_generator = FixGenerator()
443 return _fix_generator
446# Convenience functions
447def generate_fix(code: str, vulnerability: Dict, filename: str) -> Dict:
448 """
449 Convenience function to generate a fix.
451 Args:
452 code: Source code
453 vulnerability: Vulnerability details
454 filename: File name
456 Returns:
457 Fix data dict
458 """
459 generator = get_fix_generator()
460 return generator.generate_fix(code, vulnerability, filename)
463def apply_fix_to_file(filepath: str, fix_data: Dict, backup: bool = True) -> bool:
464 """
465 Convenience function to apply a fix to a file.
467 Args:
468 filepath: Path to file
469 fix_data: Fix data from generate_fix()
470 backup: Create backup file
472 Returns:
473 True if successful
474 """
475 generator = get_fix_generator()
476 return generator.apply_fix(filepath, fix_data, backup)