Coverage for src/alprina_cli/api/services/ai_fix_service.py: 24%

109 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-14 11:27 +0100

1""" 

2AI Fix Service - Security-focused code fix generation using Kimi AI (primary) and OpenAI (fallback) 

3 

4IMPORTANT: This service is strictly limited to security fixes only. 

5- Does NOT generate new features 

6- Does NOT refactor non-security code 

7- Does NOT act as a general code assistant 

8- Enforces token limits to control costs 

9""" 

10 

11import os 

12import httpx 

13from typing import Dict, List, Optional, Any 

14from loguru import logger 

15from datetime import datetime 

16 

17 

18class AIFixService: 

19 """ 

20 AI-powered security fix generator with strict security-only scope. 

21  

22 Uses Kimi API (Moonshot AI) as primary provider with OpenAI as fallback. 

23 Enforces token limits and validates that fixes are security-related. 

24 """ 

25 

26 # Token limits to control costs 

27 MAX_INPUT_TOKENS = 2000 # Max tokens in vulnerability context 

28 MAX_OUTPUT_TOKENS = 1000 # Max tokens in fix response 

29 

30 # Security-only keywords to validate fixes 

31 SECURITY_KEYWORDS = [ 

32 "sql injection", "xss", "cross-site scripting", "csrf", "authentication", 

33 "authorization", "secret", "password", "api key", "token", "encryption", 

34 "sanitize", "validate", "escape", "security", "vulnerability", "exploit", 

35 "hardcoded", "insecure", "cve", "owasp", "unsafe", "injection" 

36 ] 

37 

38 def __init__(self): 

39 """Initialize AI Fix Service with Kimi and OpenAI clients.""" 

40 self.kimi_api_key = os.getenv("KIMI_API_KEY", "sk-wEPam0kfyUviFK1hsHuVn4bHlutpOsj6v9YzRiQUSJP9f8hn") 

41 self.openai_api_key = os.getenv("OPENAI_API_KEY") 

42 

43 # Kimi API configuration 

44 self.kimi_base_url = "https://api.moonshot.cn/v1" 

45 self.kimi_model = "moonshot-v1-8k" # Kimi's 8K context model 

46 

47 # OpenAI configuration (fallback) 

48 self.openai_model = "gpt-4o-mini" # Cost-effective model 

49 

50 # Usage tracking 

51 self.usage_stats = { 

52 "kimi_calls": 0, 

53 "openai_calls": 0, 

54 "tokens_used": 0, 

55 "errors": 0 

56 } 

57 

58 logger.info("AIFixService initialized (Kimi primary, OpenAI fallback)") 

59 

60 async def generate_security_fix( 

61 self, 

62 vulnerability: Dict[str, Any], 

63 code_context: str, 

64 file_path: str 

65 ) -> Dict[str, Any]: 

66 """ 

67 Generate a security fix for a vulnerability. 

68  

69 Args: 

70 vulnerability: Vulnerability details (type, severity, line, description) 

71 code_context: Code snippet around the vulnerability 

72 file_path: Path to the file containing the vulnerability 

73  

74 Returns: 

75 Dict containing: 

76 - fixed_code: The corrected code 

77 - explanation: Why this fix addresses the vulnerability 

78 - diff: Unified diff of changes 

79 - confidence: Confidence score (0.0-1.0) 

80 - provider: Which AI provider was used (kimi/openai) 

81 - is_security_fix: Validation that this is security-related 

82 """ 

83 try: 

84 # Validate that this is a security vulnerability 

85 if not self._is_security_vulnerability(vulnerability): 

86 return { 

87 "error": "Not a security vulnerability", 

88 "message": "This service only generates fixes for security vulnerabilities", 

89 "is_security_fix": False 

90 } 

91 

92 # Truncate context to respect token limits 

93 truncated_context = self._truncate_context(code_context) 

94 

95 # Build security-focused prompt 

96 prompt = self._build_security_fix_prompt( 

97 vulnerability, truncated_context, file_path 

98 ) 

99 

100 # Try Kimi API first 

101 logger.info(f"Attempting Kimi AI fix for {vulnerability.get('type')}") 

102 try: 

103 result = await self._call_kimi_api(prompt) 

104 if result and result.get("fixed_code"): 

105 self.usage_stats["kimi_calls"] += 1 

106 result["provider"] = "kimi" 

107 result["is_security_fix"] = True 

108 logger.info("✅ Fix generated successfully using Kimi AI") 

109 return result 

110 except Exception as e: 

111 logger.warning(f"Kimi API failed: {e}, falling back to OpenAI") 

112 self.usage_stats["errors"] += 1 

113 

114 # Fallback to OpenAI 

115 if self.openai_api_key: 

116 logger.info("Falling back to OpenAI for fix generation") 

117 try: 

118 result = await self._call_openai_api(prompt) 

119 if result and result.get("fixed_code"): 

120 self.usage_stats["openai_calls"] += 1 

121 result["provider"] = "openai" 

122 result["is_security_fix"] = True 

123 logger.info("✅ Fix generated successfully using OpenAI") 

124 return result 

125 except Exception as e: 

126 logger.error(f"OpenAI API also failed: {e}") 

127 self.usage_stats["errors"] += 1 

128 

129 # Both failed 

130 return { 

131 "error": "All AI providers failed", 

132 "message": "Could not generate fix - please try again later", 

133 "is_security_fix": False 

134 } 

135 

136 except Exception as e: 

137 logger.error(f"Error in generate_security_fix: {e}") 

138 return { 

139 "error": str(e), 

140 "message": "Internal error generating fix", 

141 "is_security_fix": False 

142 } 

143 

144 def _is_security_vulnerability(self, vulnerability: Dict) -> bool: 

145 """ 

146 Validate that this is a security vulnerability. 

147  

148 We only fix security issues, not general code quality problems. 

149 """ 

150 vuln_type = vulnerability.get("type", "").lower() 

151 vuln_desc = vulnerability.get("description", "").lower() 

152 vuln_title = vulnerability.get("title", "").lower() 

153 

154 combined_text = f"{vuln_type} {vuln_desc} {vuln_title}" 

155 

156 # Check if any security keyword is present 

157 return any(keyword in combined_text for keyword in self.SECURITY_KEYWORDS) 

158 

159 def _truncate_context(self, code_context: str) -> str: 

160 """ 

161 Truncate code context to respect token limits. 

162  

163 Rough estimate: 1 token ≈ 4 characters for English 

164 """ 

165 max_chars = self.MAX_INPUT_TOKENS * 4 

166 if len(code_context) > max_chars: 

167 logger.warning(f"Truncating context from {len(code_context)} to {max_chars} chars") 

168 return code_context[:max_chars] + "\n... (truncated)" 

169 return code_context 

170 

171 def _build_security_fix_prompt( 

172 self, 

173 vulnerability: Dict, 

174 code_context: str, 

175 file_path: str 

176 ) -> str: 

177 """ 

178 Build a security-focused prompt for AI fix generation. 

179  

180 The prompt explicitly limits scope to security fixes only. 

181 """ 

182 vuln_type = vulnerability.get("type", "Unknown") 

183 vuln_severity = vulnerability.get("severity", "MEDIUM") 

184 vuln_desc = vulnerability.get("description", "No description") 

185 line = vulnerability.get("line", "N/A") 

186 

187 return f"""You are a security expert tasked with fixing a specific security vulnerability. 

188 

189**CRITICAL: You must ONLY fix security vulnerabilities. Do NOT:** 

190- Refactor unrelated code 

191- Add new features 

192- Improve code style/formatting 

193- Optimize performance 

194- Make non-security changes 

195 

196**Vulnerability Details:** 

197- Type: {vuln_type} 

198- Severity: {vuln_severity} 

199- Line: {line} 

200- Description: {vuln_desc} 

201- File: {file_path} 

202 

203**Vulnerable Code Context:** 

204``` 

205{code_context} 

206``` 

207 

208**Your Task:** 

2091. Identify the EXACT security issue in the code 

2102. Provide a MINIMAL fix that addresses ONLY this security vulnerability 

2113. Explain WHY your fix is secure 

2124. Keep all other code unchanged 

213 

214**Response Format (JSON):** 

215```json 

216{{ 

217 "fixed_code": "... only the fixed code section ...", 

218 "explanation": "Brief explanation of the security fix", 

219 "security_principle": "Which security principle this addresses (e.g., 'Input validation', 'Least privilege')", 

220 "confidence": 0.95 

221}} 

222``` 

223 

224**Remember:** Make the SMALLEST possible change to fix the security issue. Do not refactor.""" 

225 

226 async def _call_kimi_api(self, prompt: str) -> Dict[str, Any]: 

227 """ 

228 Call Kimi API (Moonshot AI) for fix generation. 

229  

230 Kimi API is OpenAI-compatible, making integration straightforward. 

231 """ 

232 async with httpx.AsyncClient(timeout=30.0) as client: 

233 response = await client.post( 

234 f"{self.kimi_base_url}/chat/completions", 

235 headers={ 

236 "Authorization": f"Bearer {self.kimi_api_key}", 

237 "Content-Type": "application/json" 

238 }, 

239 json={ 

240 "model": self.kimi_model, 

241 "messages": [ 

242 { 

243 "role": "system", 

244 "content": "You are a security expert who fixes vulnerabilities with minimal code changes. Return responses in JSON format." 

245 }, 

246 { 

247 "role": "user", 

248 "content": prompt 

249 } 

250 ], 

251 "temperature": 0.3, # Low temperature for consistent security fixes 

252 "max_tokens": self.MAX_OUTPUT_TOKENS 

253 } 

254 ) 

255 

256 if response.status_code != 200: 

257 raise Exception(f"Kimi API error: {response.status_code} - {response.text}") 

258 

259 result = response.json() 

260 

261 # Track token usage 

262 usage = result.get("usage", {}) 

263 tokens_used = usage.get("total_tokens", 0) 

264 self.usage_stats["tokens_used"] += tokens_used 

265 logger.info(f"Kimi API tokens used: {tokens_used}") 

266 

267 # Parse response 

268 content = result["choices"][0]["message"]["content"] 

269 return self._parse_ai_response(content) 

270 

271 async def _call_openai_api(self, prompt: str) -> Dict[str, Any]: 

272 """ 

273 Call OpenAI API as fallback for fix generation. 

274 """ 

275 try: 

276 from openai import AsyncOpenAI 

277 

278 client = AsyncOpenAI(api_key=self.openai_api_key) 

279 

280 response = await client.chat.completions.create( 

281 model=self.openai_model, 

282 messages=[ 

283 { 

284 "role": "system", 

285 "content": "You are a security expert who fixes vulnerabilities with minimal code changes. Return responses in JSON format." 

286 }, 

287 { 

288 "role": "user", 

289 "content": prompt 

290 } 

291 ], 

292 temperature=0.3, 

293 max_tokens=self.MAX_OUTPUT_TOKENS 

294 ) 

295 

296 # Track token usage 

297 tokens_used = response.usage.total_tokens 

298 self.usage_stats["tokens_used"] += tokens_used 

299 logger.info(f"OpenAI API tokens used: {tokens_used}") 

300 

301 content = response.choices[0].message.content 

302 return self._parse_ai_response(content) 

303 

304 except ImportError: 

305 raise Exception("OpenAI package not installed. Install with: pip install openai") 

306 

307 def _parse_ai_response(self, content: str) -> Dict[str, Any]: 

308 """ 

309 Parse AI response and extract fix details. 

310  

311 Handles both JSON and plain text responses. 

312 """ 

313 import json 

314 

315 try: 

316 # Try to parse as JSON 

317 if "```json" in content: 

318 # Extract JSON from markdown code block 

319 json_str = content.split("```json")[1].split("```")[0].strip() 

320 return json.loads(json_str) 

321 elif content.strip().startswith("{"): 

322 return json.loads(content) 

323 else: 

324 # Plain text response - try to extract relevant parts 

325 return { 

326 "fixed_code": content, 

327 "explanation": "AI provided fix in plain text format", 

328 "confidence": 0.7 

329 } 

330 except Exception as e: 

331 logger.warning(f"Could not parse AI response as JSON: {e}") 

332 return { 

333 "fixed_code": content, 

334 "explanation": "Raw AI response (parsing failed)", 

335 "confidence": 0.5 

336 } 

337 

338 def get_usage_stats(self) -> Dict[str, int]: 

339 """Get usage statistics for monitoring.""" 

340 return self.usage_stats.copy() 

341 

342 

343# Global instance 

344ai_fix_service = AIFixService()