Coverage for src/alprina_cli/api/services/github_scanner.py: 23%
86 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
1"""
2GitHub Scanner Service
3Scans PR changes and push commits for security vulnerabilities.
4"""
6import sys
7from pathlib import Path
8from typing import List, Dict, Optional
9from loguru import logger
10import uuid
12# Add parent directory to path
13sys.path.insert(0, str(Path(__file__).parent.parent.parent))
15from ...quick_scanner import quick_scan
16from .github_service import GitHubService
19class GitHubScanner:
20 """Scanner for GitHub repository changes."""
22 def __init__(self):
23 self.github_service = GitHubService()
25 async def scan_pr_changes(
26 self,
27 repo_full_name: str,
28 pr_number: int,
29 changed_files: List[Dict],
30 base_sha: str,
31 head_sha: str,
32 access_token: str
33 ) -> Dict:
34 """
35 Scan files changed in a pull request.
37 Args:
38 repo_full_name: Full repository name (owner/repo)
39 pr_number: Pull request number
40 changed_files: List of changed files from GitHub API
41 base_sha: Base commit SHA
42 head_sha: Head commit SHA
43 access_token: GitHub access token
45 Returns:
46 Dict containing scan results and findings
47 """
48 scan_id = str(uuid.uuid4())
49 findings = []
50 files_scanned = 0
52 logger.info(f"🔍 Scanning {len(changed_files)} files in PR #{pr_number}")
54 # Filter for scannable files (code files only)
55 scannable_files = self._filter_scannable_files(changed_files)
57 for file_info in scannable_files:
58 filename = file_info.get("filename")
59 status = file_info.get("status")
61 # Skip deleted files
62 if status == "removed":
63 continue
65 # Get file content from head SHA
66 file_content = await self.github_service.get_file_content(
67 repo_full_name=repo_full_name,
68 file_path=filename,
69 ref=head_sha,
70 access_token=access_token
71 )
73 if not file_content:
74 continue
76 # Create temporary file for scanning
77 import tempfile
78 with tempfile.NamedTemporaryFile(mode='w', suffix=self._get_file_extension(filename), delete=False) as tmp:
79 tmp.write(file_content)
80 tmp_path = tmp.name
82 try:
83 # Run quick scan on file
84 scan_result = quick_scan(tmp_path)
86 files_scanned += 1
88 # Add findings with file context
89 for finding in scan_result.get("findings", []):
90 findings.append({
91 **finding,
92 "id": str(uuid.uuid4()),
93 "file": filename,
94 "pr_number": pr_number,
95 "is_new": True, # All findings in PR are considered new
96 "language": self._detect_language(filename),
97 "fix_recommendation": self._get_fix_recommendation(finding["pattern"]),
98 "learn_more_url": self._get_learn_more_url(finding["pattern"]),
99 "risk": self._assess_risk(finding["severity"])
100 })
101 finally:
102 # Clean up temp file
103 Path(tmp_path).unlink(missing_ok=True)
105 # Calculate summary
106 summary = {
107 "critical": sum(1 for f in findings if f["severity"] == "critical"),
108 "high": sum(1 for f in findings if f["severity"] == "high"),
109 "medium": sum(1 for f in findings if f["severity"] == "medium"),
110 "low": sum(1 for f in findings if f["severity"] == "low"),
111 }
113 logger.info(f"✅ Scan complete: {len(findings)} findings in {files_scanned} files")
115 return {
116 "scan_id": scan_id,
117 "pr_number": pr_number,
118 "repo_full_name": repo_full_name,
119 "files_scanned": files_scanned,
120 "findings": findings,
121 "summary": summary,
122 "scanned_at": logger._core.clock.now().isoformat(),
123 }
125 async def scan_push_changes(
126 self,
127 repo_full_name: str,
128 changed_files: List[str],
129 access_token: str
130 ) -> Dict:
131 """
132 Scan files changed in a push to main branch.
133 Similar to PR scan but without PR context.
134 """
135 scan_id = str(uuid.uuid4())
136 findings = []
137 files_scanned = 0
139 logger.info(f"🔍 Scanning {len(changed_files)} files in push to {repo_full_name}")
141 # Filter for code files
142 scannable_files = [f for f in changed_files if self._is_code_file(f)]
144 for filename in scannable_files:
145 # Get file content
146 file_content = await self.github_service.get_file_content(
147 repo_full_name=repo_full_name,
148 file_path=filename,
149 ref="main", # or master
150 access_token=access_token
151 )
153 if not file_content:
154 continue
156 # Create temporary file for scanning
157 import tempfile
158 with tempfile.NamedTemporaryFile(mode='w', suffix=self._get_file_extension(filename), delete=False) as tmp:
159 tmp.write(file_content)
160 tmp_path = tmp.name
162 try:
163 # Run quick scan
164 scan_result = quick_scan(tmp_path)
166 files_scanned += 1
168 # Add findings
169 for finding in scan_result.get("findings", []):
170 findings.append({
171 **finding,
172 "id": str(uuid.uuid4()),
173 "file": filename,
174 "language": self._detect_language(filename),
175 "fix_recommendation": self._get_fix_recommendation(finding["pattern"]),
176 "risk": self._assess_risk(finding["severity"])
177 })
178 finally:
179 Path(tmp_path).unlink(missing_ok=True)
181 # Calculate summary
182 summary = {
183 "critical": sum(1 for f in findings if f["severity"] == "critical"),
184 "high": sum(1 for f in findings if f["severity"] == "high"),
185 "medium": sum(1 for f in findings if f["severity"] == "medium"),
186 "low": sum(1 for f in findings if f["severity"] == "low"),
187 }
189 return {
190 "scan_id": scan_id,
191 "repo_full_name": repo_full_name,
192 "files_scanned": files_scanned,
193 "findings": findings,
194 "summary": summary
195 }
197 def _filter_scannable_files(self, changed_files: List[Dict]) -> List[Dict]:
198 """Filter for scannable code files."""
199 scannable = []
200 for file_info in changed_files:
201 filename = file_info.get("filename", "")
202 if self._is_code_file(filename):
203 scannable.append(file_info)
204 return scannable
206 def _is_code_file(self, filename: str) -> bool:
207 """Check if file is a code file we can scan."""
208 code_extensions = {
209 '.py', '.js', '.ts', '.tsx', '.jsx',
210 '.java', '.php', '.rb', '.go', '.rs',
211 '.c', '.cpp', '.cs', '.swift', '.kt'
212 }
213 ext = Path(filename).suffix.lower()
214 return ext in code_extensions
216 def _get_file_extension(self, filename: str) -> str:
217 """Get file extension for temporary file."""
218 return Path(filename).suffix or '.txt'
220 def _detect_language(self, filename: str) -> str:
221 """Detect programming language from filename."""
222 ext_to_lang = {
223 '.py': 'python',
224 '.js': 'javascript',
225 '.ts': 'typescript',
226 '.tsx': 'typescript',
227 '.jsx': 'javascript',
228 '.java': 'java',
229 '.php': 'php',
230 '.rb': 'ruby',
231 '.go': 'go',
232 '.rs': 'rust',
233 '.c': 'c',
234 '.cpp': 'cpp',
235 '.cs': 'csharp',
236 '.swift': 'swift',
237 '.kt': 'kotlin'
238 }
239 ext = Path(filename).suffix.lower()
240 return ext_to_lang.get(ext, 'text')
242 def _get_fix_recommendation(self, pattern: str) -> str:
243 """Get fix recommendation for vulnerability pattern."""
244 recommendations = {
245 "sql_injection": "Use parameterized queries or an ORM to prevent SQL injection",
246 "hardcoded_secrets": "Store secrets in environment variables, never in code",
247 "xss_vulnerability": "Sanitize all user input before rendering in HTML",
248 "command_injection": "Avoid shell=True, use subprocess with argument lists",
249 "path_traversal": "Validate and sanitize all file paths before use",
250 "weak_crypto": "Use modern algorithms like SHA-256 or bcrypt for hashing",
251 "insecure_random": "Use secrets module for cryptographic random numbers",
252 "missing_auth": "Add authentication middleware to protect sensitive endpoints",
253 "debug_enabled": "Disable debug mode in production environments",
254 "exposed_endpoints": "Add authorization checks to admin/internal endpoints"
255 }
256 return recommendations.get(pattern, "Review and remediate this vulnerability")
258 def _get_learn_more_url(self, pattern: str) -> str:
259 """Get learning resource URL for vulnerability."""
260 urls = {
261 "sql_injection": "https://owasp.org/www-community/attacks/SQL_Injection",
262 "xss_vulnerability": "https://owasp.org/www-community/attacks/xss/",
263 "command_injection": "https://owasp.org/www-community/attacks/Command_Injection",
264 "hardcoded_secrets": "https://owasp.org/www-community/vulnerabilities/Use_of_hard-coded_password",
265 "path_traversal": "https://owasp.org/www-community/attacks/Path_Traversal",
266 }
267 return urls.get(pattern, "https://owasp.org/www-project-top-ten/")
269 def _assess_risk(self, severity: str) -> str:
270 """Assess exploitability risk."""
271 risk_mapping = {
272 "critical": "High - Easily exploitable",
273 "high": "Moderate - Exploitable with effort",
274 "medium": "Low - Requires specific conditions",
275 "low": "Minimal - Hard to exploit"
276 }
277 return risk_mapping.get(severity, "Unknown")