Coverage for src/alprina_cli/api/services/ai_fix_service.py: 24%
109 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
1"""
2AI Fix Service - Security-focused code fix generation using Kimi AI (primary) and OpenAI (fallback)
4IMPORTANT: This service is strictly limited to security fixes only.
5- Does NOT generate new features
6- Does NOT refactor non-security code
7- Does NOT act as a general code assistant
8- Enforces token limits to control costs
9"""
11import os
12import httpx
13from typing import Dict, List, Optional, Any
14from loguru import logger
15from datetime import datetime
18class AIFixService:
19 """
20 AI-powered security fix generator with strict security-only scope.
22 Uses Kimi API (Moonshot AI) as primary provider with OpenAI as fallback.
23 Enforces token limits and validates that fixes are security-related.
24 """
26 # Token limits to control costs
27 MAX_INPUT_TOKENS = 2000 # Max tokens in vulnerability context
28 MAX_OUTPUT_TOKENS = 1000 # Max tokens in fix response
30 # Security-only keywords to validate fixes
31 SECURITY_KEYWORDS = [
32 "sql injection", "xss", "cross-site scripting", "csrf", "authentication",
33 "authorization", "secret", "password", "api key", "token", "encryption",
34 "sanitize", "validate", "escape", "security", "vulnerability", "exploit",
35 "hardcoded", "insecure", "cve", "owasp", "unsafe", "injection"
36 ]
38 def __init__(self):
39 """Initialize AI Fix Service with Kimi and OpenAI clients."""
40 self.kimi_api_key = os.getenv("KIMI_API_KEY", "sk-wEPam0kfyUviFK1hsHuVn4bHlutpOsj6v9YzRiQUSJP9f8hn")
41 self.openai_api_key = os.getenv("OPENAI_API_KEY")
43 # Kimi API configuration
44 self.kimi_base_url = "https://api.moonshot.cn/v1"
45 self.kimi_model = "moonshot-v1-8k" # Kimi's 8K context model
47 # OpenAI configuration (fallback)
48 self.openai_model = "gpt-4o-mini" # Cost-effective model
50 # Usage tracking
51 self.usage_stats = {
52 "kimi_calls": 0,
53 "openai_calls": 0,
54 "tokens_used": 0,
55 "errors": 0
56 }
58 logger.info("AIFixService initialized (Kimi primary, OpenAI fallback)")
60 async def generate_security_fix(
61 self,
62 vulnerability: Dict[str, Any],
63 code_context: str,
64 file_path: str
65 ) -> Dict[str, Any]:
66 """
67 Generate a security fix for a vulnerability.
69 Args:
70 vulnerability: Vulnerability details (type, severity, line, description)
71 code_context: Code snippet around the vulnerability
72 file_path: Path to the file containing the vulnerability
74 Returns:
75 Dict containing:
76 - fixed_code: The corrected code
77 - explanation: Why this fix addresses the vulnerability
78 - diff: Unified diff of changes
79 - confidence: Confidence score (0.0-1.0)
80 - provider: Which AI provider was used (kimi/openai)
81 - is_security_fix: Validation that this is security-related
82 """
83 try:
84 # Validate that this is a security vulnerability
85 if not self._is_security_vulnerability(vulnerability):
86 return {
87 "error": "Not a security vulnerability",
88 "message": "This service only generates fixes for security vulnerabilities",
89 "is_security_fix": False
90 }
92 # Truncate context to respect token limits
93 truncated_context = self._truncate_context(code_context)
95 # Build security-focused prompt
96 prompt = self._build_security_fix_prompt(
97 vulnerability, truncated_context, file_path
98 )
100 # Try Kimi API first
101 logger.info(f"Attempting Kimi AI fix for {vulnerability.get('type')}")
102 try:
103 result = await self._call_kimi_api(prompt)
104 if result and result.get("fixed_code"):
105 self.usage_stats["kimi_calls"] += 1
106 result["provider"] = "kimi"
107 result["is_security_fix"] = True
108 logger.info("✅ Fix generated successfully using Kimi AI")
109 return result
110 except Exception as e:
111 logger.warning(f"Kimi API failed: {e}, falling back to OpenAI")
112 self.usage_stats["errors"] += 1
114 # Fallback to OpenAI
115 if self.openai_api_key:
116 logger.info("Falling back to OpenAI for fix generation")
117 try:
118 result = await self._call_openai_api(prompt)
119 if result and result.get("fixed_code"):
120 self.usage_stats["openai_calls"] += 1
121 result["provider"] = "openai"
122 result["is_security_fix"] = True
123 logger.info("✅ Fix generated successfully using OpenAI")
124 return result
125 except Exception as e:
126 logger.error(f"OpenAI API also failed: {e}")
127 self.usage_stats["errors"] += 1
129 # Both failed
130 return {
131 "error": "All AI providers failed",
132 "message": "Could not generate fix - please try again later",
133 "is_security_fix": False
134 }
136 except Exception as e:
137 logger.error(f"Error in generate_security_fix: {e}")
138 return {
139 "error": str(e),
140 "message": "Internal error generating fix",
141 "is_security_fix": False
142 }
144 def _is_security_vulnerability(self, vulnerability: Dict) -> bool:
145 """
146 Validate that this is a security vulnerability.
148 We only fix security issues, not general code quality problems.
149 """
150 vuln_type = vulnerability.get("type", "").lower()
151 vuln_desc = vulnerability.get("description", "").lower()
152 vuln_title = vulnerability.get("title", "").lower()
154 combined_text = f"{vuln_type} {vuln_desc} {vuln_title}"
156 # Check if any security keyword is present
157 return any(keyword in combined_text for keyword in self.SECURITY_KEYWORDS)
159 def _truncate_context(self, code_context: str) -> str:
160 """
161 Truncate code context to respect token limits.
163 Rough estimate: 1 token ≈ 4 characters for English
164 """
165 max_chars = self.MAX_INPUT_TOKENS * 4
166 if len(code_context) > max_chars:
167 logger.warning(f"Truncating context from {len(code_context)} to {max_chars} chars")
168 return code_context[:max_chars] + "\n... (truncated)"
169 return code_context
171 def _build_security_fix_prompt(
172 self,
173 vulnerability: Dict,
174 code_context: str,
175 file_path: str
176 ) -> str:
177 """
178 Build a security-focused prompt for AI fix generation.
180 The prompt explicitly limits scope to security fixes only.
181 """
182 vuln_type = vulnerability.get("type", "Unknown")
183 vuln_severity = vulnerability.get("severity", "MEDIUM")
184 vuln_desc = vulnerability.get("description", "No description")
185 line = vulnerability.get("line", "N/A")
187 return f"""You are a security expert tasked with fixing a specific security vulnerability.
189**CRITICAL: You must ONLY fix security vulnerabilities. Do NOT:**
190- Refactor unrelated code
191- Add new features
192- Improve code style/formatting
193- Optimize performance
194- Make non-security changes
196**Vulnerability Details:**
197- Type: {vuln_type}
198- Severity: {vuln_severity}
199- Line: {line}
200- Description: {vuln_desc}
201- File: {file_path}
203**Vulnerable Code Context:**
204```
205{code_context}
206```
208**Your Task:**
2091. Identify the EXACT security issue in the code
2102. Provide a MINIMAL fix that addresses ONLY this security vulnerability
2113. Explain WHY your fix is secure
2124. Keep all other code unchanged
214**Response Format (JSON):**
215```json
216{{
217 "fixed_code": "... only the fixed code section ...",
218 "explanation": "Brief explanation of the security fix",
219 "security_principle": "Which security principle this addresses (e.g., 'Input validation', 'Least privilege')",
220 "confidence": 0.95
221}}
222```
224**Remember:** Make the SMALLEST possible change to fix the security issue. Do not refactor."""
226 async def _call_kimi_api(self, prompt: str) -> Dict[str, Any]:
227 """
228 Call Kimi API (Moonshot AI) for fix generation.
230 Kimi API is OpenAI-compatible, making integration straightforward.
231 """
232 async with httpx.AsyncClient(timeout=30.0) as client:
233 response = await client.post(
234 f"{self.kimi_base_url}/chat/completions",
235 headers={
236 "Authorization": f"Bearer {self.kimi_api_key}",
237 "Content-Type": "application/json"
238 },
239 json={
240 "model": self.kimi_model,
241 "messages": [
242 {
243 "role": "system",
244 "content": "You are a security expert who fixes vulnerabilities with minimal code changes. Return responses in JSON format."
245 },
246 {
247 "role": "user",
248 "content": prompt
249 }
250 ],
251 "temperature": 0.3, # Low temperature for consistent security fixes
252 "max_tokens": self.MAX_OUTPUT_TOKENS
253 }
254 )
256 if response.status_code != 200:
257 raise Exception(f"Kimi API error: {response.status_code} - {response.text}")
259 result = response.json()
261 # Track token usage
262 usage = result.get("usage", {})
263 tokens_used = usage.get("total_tokens", 0)
264 self.usage_stats["tokens_used"] += tokens_used
265 logger.info(f"Kimi API tokens used: {tokens_used}")
267 # Parse response
268 content = result["choices"][0]["message"]["content"]
269 return self._parse_ai_response(content)
271 async def _call_openai_api(self, prompt: str) -> Dict[str, Any]:
272 """
273 Call OpenAI API as fallback for fix generation.
274 """
275 try:
276 from openai import AsyncOpenAI
278 client = AsyncOpenAI(api_key=self.openai_api_key)
280 response = await client.chat.completions.create(
281 model=self.openai_model,
282 messages=[
283 {
284 "role": "system",
285 "content": "You are a security expert who fixes vulnerabilities with minimal code changes. Return responses in JSON format."
286 },
287 {
288 "role": "user",
289 "content": prompt
290 }
291 ],
292 temperature=0.3,
293 max_tokens=self.MAX_OUTPUT_TOKENS
294 )
296 # Track token usage
297 tokens_used = response.usage.total_tokens
298 self.usage_stats["tokens_used"] += tokens_used
299 logger.info(f"OpenAI API tokens used: {tokens_used}")
301 content = response.choices[0].message.content
302 return self._parse_ai_response(content)
304 except ImportError:
305 raise Exception("OpenAI package not installed. Install with: pip install openai")
307 def _parse_ai_response(self, content: str) -> Dict[str, Any]:
308 """
309 Parse AI response and extract fix details.
311 Handles both JSON and plain text responses.
312 """
313 import json
315 try:
316 # Try to parse as JSON
317 if "```json" in content:
318 # Extract JSON from markdown code block
319 json_str = content.split("```json")[1].split("```")[0].strip()
320 return json.loads(json_str)
321 elif content.strip().startswith("{"):
322 return json.loads(content)
323 else:
324 # Plain text response - try to extract relevant parts
325 return {
326 "fixed_code": content,
327 "explanation": "AI provided fix in plain text format",
328 "confidence": 0.7
329 }
330 except Exception as e:
331 logger.warning(f"Could not parse AI response as JSON: {e}")
332 return {
333 "fixed_code": content,
334 "explanation": "Raw AI response (parsing failed)",
335 "confidence": 0.5
336 }
338 def get_usage_stats(self) -> Dict[str, int]:
339 """Get usage statistics for monitoring."""
340 return self.usage_stats.copy()
343# Global instance
344ai_fix_service = AIFixService()