Coverage for src/alprina_cli/database/neon_client.py: 37%
98 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
1"""
2Neon Database Client for CLI Tools
4Context Engineering:
5- Lightweight wrapper around NeonService
6- Fast operations (< 50ms target)
7- Connection pooling for performance
8- Minimal token footprint in responses
9"""
11import os
12from typing import Optional, Dict, Any, List, Tuple
13from uuid import UUID
14from datetime import datetime
15from loguru import logger
17# Import existing NeonService
18import sys
19sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'api', 'services'))
20from neon_service import NeonService
23class NeonDatabaseClient:
24 """
25 Database client for CLI tool integration.
27 Context Engineering:
28 - All methods < 50ms (p95)
29 - Connection pooling (reuse connections)
30 - Minimal data transfer (only what's needed)
31 - Async-first for non-blocking operations
32 """
34 def __init__(self):
35 """Initialize database client."""
36 self.service = NeonService()
37 self._cli_version = os.getenv("ALPRINA_CLI_VERSION", "0.1.0")
39 async def is_available(self) -> bool:
40 """Check if database is configured and available."""
41 return self.service.is_enabled()
43 # ==========================================
44 # Authentication Methods
45 # ==========================================
47 async def authenticate_api_key(self, api_key: str) -> Optional[Dict[str, Any]]:
48 """
49 Authenticate user via API key.
51 Args:
52 api_key: Raw API key (e.g., "alprina_...")
54 Returns:
55 User dict if valid, None otherwise
57 Context: Returns only essential user data
58 """
59 return await self.service.verify_api_key(api_key)
61 async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
62 """Get user by ID."""
63 return await self.service.get_user_by_id(user_id)
65 # ==========================================
66 # Scan Lifecycle Methods
67 # ==========================================
69 async def create_scan(
70 self,
71 user_id: str,
72 tool_name: str,
73 target: str,
74 params: Dict[str, Any]
75 ) -> str:
76 """
77 Create new scan record (status: pending).
79 Args:
80 user_id: User UUID
81 tool_name: Name of tool (e.g., "ScanTool", "ReconTool")
82 target: Scan target (domain, IP, file path)
83 params: Tool parameters as dict
85 Returns:
86 Scan ID (UUID)
88 Context: Fast creation (< 20ms)
89 """
90 metadata = {
91 "tool_name": tool_name,
92 "target": target,
93 "params": params,
94 "cli_version": self._cli_version,
95 "guardrails_enabled": True
96 }
98 scan_id = await self.service.create_scan(
99 user_id=user_id,
100 scan_type=tool_name.lower().replace("tool", ""),
101 workflow_mode="cli",
102 metadata=metadata
103 )
105 # Also update new columns
106 pool = await self.service.get_pool()
107 async with pool.acquire() as conn:
108 await conn.execute(
109 """
110 UPDATE scans
111 SET cli_version = $1,
112 tool_name = $2,
113 target = $3
114 WHERE id = $4
115 """,
116 self._cli_version, tool_name, target, scan_id
117 )
119 logger.debug(f"Created scan {scan_id} for user {user_id}")
120 return scan_id
122 async def update_scan_status(
123 self,
124 scan_id: str,
125 status: str
126 ) -> bool:
127 """
128 Update scan status.
130 Args:
131 scan_id: Scan UUID
132 status: New status (pending/running/completed/failed)
134 Returns:
135 True if updated
137 Context: Fast update (< 10ms)
138 """
139 pool = await self.service.get_pool()
140 async with pool.acquire() as conn:
141 result = await conn.execute(
142 """
143 UPDATE scans
144 SET status = $1
145 WHERE id = $2
146 """,
147 status, scan_id
148 )
149 return result == "UPDATE 1"
151 async def save_scan_results(
152 self,
153 scan_id: str,
154 findings: Dict[str, Any],
155 findings_count: int,
156 status: str = "completed"
157 ) -> bool:
158 """
159 Save scan findings and mark as completed.
161 Args:
162 scan_id: Scan UUID
163 findings: Scan results as dict
164 findings_count: Number of findings
165 status: Final status (completed/failed)
167 Returns:
168 True if saved
170 Context: Efficient JSONB storage
171 """
172 return await self.service.save_scan(
173 scan_id=scan_id,
174 status=status,
175 findings=findings,
176 findings_count=findings_count
177 )
179 async def get_scan(
180 self,
181 scan_id: str,
182 user_id: Optional[str] = None
183 ) -> Optional[Dict[str, Any]]:
184 """
185 Retrieve scan by ID.
187 Args:
188 scan_id: Scan UUID
189 user_id: Optional user ID for access control
191 Returns:
192 Scan dict or None
193 """
194 return await self.service.get_scan(scan_id, user_id)
196 async def list_user_scans(
197 self,
198 user_id: str,
199 limit: int = 20,
200 offset: int = 0,
201 tool_name: Optional[str] = None
202 ) -> List[Dict[str, Any]]:
203 """
204 List user's recent scans.
206 Args:
207 user_id: User UUID
208 limit: Max results (default: 20)
209 offset: Pagination offset
210 tool_name: Optional filter by tool
212 Returns:
213 List of scan dicts
215 Context: Paginated for performance
216 """
217 return await self.service.list_scans(
218 user_id=user_id,
219 limit=limit,
220 offset=offset,
221 scan_type=tool_name.lower().replace("tool", "") if tool_name else None
222 )
224 # ==========================================
225 # Usage Tracking Methods
226 # ==========================================
228 async def track_scan_usage(
229 self,
230 user_id: str,
231 scan_id: str,
232 tool_name: str,
233 credits_used: int = 1,
234 duration_ms: Optional[int] = None,
235 vulnerabilities_found: int = 0
236 ) -> bool:
237 """
238 Track scan usage for metering.
240 Args:
241 user_id: User UUID
242 scan_id: Scan UUID
243 tool_name: Tool name
244 credits_used: Credits consumed (default: 1)
245 duration_ms: Execution time in milliseconds
246 vulnerabilities_found: Number of vulnerabilities found
248 Returns:
249 True if tracked
251 Context: Essential for billing/limits
252 """
253 pool = await self.service.get_pool()
255 # Get subscription_id
256 async with pool.acquire() as conn:
257 subscription = await conn.fetchrow(
258 """
259 SELECT id FROM user_subscriptions
260 WHERE user_id = $1
261 AND status = 'active'
262 ORDER BY created_at DESC
263 LIMIT 1
264 """,
265 user_id
266 )
268 subscription_id = str(subscription['id']) if subscription else None
270 # Insert usage record
271 await conn.execute(
272 """
273 INSERT INTO scan_usage (
274 user_id, subscription_id, scan_id,
275 scan_type, workflow_mode,
276 credits_used, duration_ms, vulnerabilities_found
277 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
278 """,
279 user_id, subscription_id, scan_id,
280 tool_name.lower().replace("tool", ""), "cli",
281 credits_used, duration_ms, vulnerabilities_found
282 )
284 logger.debug(f"Tracked usage for scan {scan_id}: {credits_used} credits")
285 return True
287 async def check_scan_limit(
288 self,
289 user_id: str
290 ) -> Tuple[bool, int, int]:
291 """
292 Check if user can perform another scan.
294 Args:
295 user_id: User UUID
297 Returns:
298 Tuple of (can_scan, scans_used, scans_limit)
300 Context: Critical for rate limiting
301 """
302 pool = await self.service.get_pool()
304 async with pool.acquire() as conn:
305 # Get active subscription
306 subscription = await conn.fetchrow(
307 """
308 SELECT
309 scans_used,
310 scans_limit,
311 current_period_start,
312 current_period_end
313 FROM user_subscriptions
314 WHERE user_id = $1
315 AND status = 'active'
316 ORDER BY created_at DESC
317 LIMIT 1
318 """,
319 user_id
320 )
322 if not subscription:
323 # No active subscription - free tier (10 scans/month)
324 # Count scans this month
325 scans_count = await conn.fetchval(
326 """
327 SELECT COUNT(*)
328 FROM scans
329 WHERE user_id = $1
330 AND created_at >= date_trunc('month', CURRENT_DATE)
331 """,
332 user_id
333 )
334 return (scans_count < 10, scans_count, 10)
336 # Count scans in current period
337 scans_used = await conn.fetchval(
338 """
339 SELECT COUNT(*)
340 FROM scan_usage
341 WHERE user_id = $1
342 AND created_at >= $2
343 AND created_at < $3
344 """,
345 user_id,
346 subscription['current_period_start'],
347 subscription['current_period_end']
348 )
350 scans_limit = subscription['scans_limit']
351 can_scan = scans_used < scans_limit
353 return (can_scan, scans_used, scans_limit)
355 async def increment_scan_count(self, user_id: str) -> bool:
356 """
357 Increment scan count for user's active subscription.
359 Args:
360 user_id: User UUID
362 Returns:
363 True if incremented
364 """
365 pool = await self.service.get_pool()
367 async with pool.acquire() as conn:
368 result = await conn.execute(
369 """
370 UPDATE user_subscriptions
371 SET scans_used = scans_used + 1
372 WHERE user_id = $1
373 AND status = 'active'
374 """,
375 user_id
376 )
377 return "UPDATE" in result
379 # ==========================================
380 # CLI Session Tracking
381 # ==========================================
383 async def create_cli_session(
384 self,
385 user_id: str,
386 cli_version: str,
387 os_info: str,
388 python_version: str
389 ) -> str:
390 """
391 Create new CLI session for analytics.
393 Args:
394 user_id: User UUID
395 cli_version: CLI version string
396 os_info: Operating system
397 python_version: Python version
399 Returns:
400 Session ID
401 """
402 pool = await self.service.get_pool()
404 async with pool.acquire() as conn:
405 session = await conn.fetchrow(
406 """
407 INSERT INTO cli_sessions (
408 user_id, cli_version, os, python_version
409 ) VALUES ($1, $2, $3, $4)
410 RETURNING id
411 """,
412 user_id, cli_version, os_info, python_version
413 )
414 return str(session['id'])
416 async def update_session_activity(self, session_id: str):
417 """Update session last_activity timestamp."""
418 pool = await self.service.get_pool()
420 async with pool.acquire() as conn:
421 await conn.execute(
422 """
423 UPDATE cli_sessions
424 SET last_activity = NOW(),
425 commands_run = commands_run + 1
426 WHERE id = $1
427 """,
428 session_id
429 )
431 # ==========================================
432 # API Key Methods
433 # ==========================================
435 async def list_api_keys(self, user_id: str) -> List[Dict[str, Any]]:
436 """List user's API keys."""
437 return await self.service.list_api_keys(user_id)
439 async def create_api_key(
440 self,
441 user_id: str,
442 name: str
443 ) -> Tuple[str, str]:
444 """
445 Create new API key.
447 Args:
448 user_id: User UUID
449 name: Key name/description
451 Returns:
452 Tuple of (raw_key, key_id)
453 """
454 api_key = self.service.generate_api_key()
455 await self.service.create_api_key(user_id, api_key, name)
457 # Get key ID
458 pool = await self.service.get_pool()
459 async with pool.acquire() as conn:
460 key_hash = self.service.hash_api_key(api_key)
461 key = await conn.fetchrow(
462 "SELECT id FROM api_keys WHERE key_hash = $1",
463 key_hash
464 )
465 return (api_key, str(key['id']))
467 async def revoke_api_key(self, key_id: str, user_id: str) -> bool:
468 """Revoke (deactivate) API key."""
469 return await self.service.deactivate_api_key(key_id, user_id)
471 # ==========================================
472 # Analytics Methods
473 # ==========================================
475 async def get_user_stats(self, user_id: str) -> Dict[str, Any]:
476 """
477 Get user statistics for dashboard.
479 Returns:
480 Dict with scan counts, vulnerabilities, usage, etc.
481 """
482 return await self.service.get_user_stats(user_id)
484 async def get_scan_analytics(
485 self,
486 user_id: str,
487 period_days: int = 30
488 ) -> Dict[str, Any]:
489 """
490 Get scan analytics for charts.
492 Args:
493 user_id: User UUID
494 period_days: Number of days to analyze
496 Returns:
497 Dict with time series data, breakdowns, etc.
498 """
499 pool = await self.service.get_pool()
501 async with pool.acquire() as conn:
502 # Scans over time (daily counts)
503 scans_over_time = await conn.fetch(
504 """
505 SELECT
506 DATE(created_at) as date,
507 COUNT(*) as count,
508 scan_type
509 FROM scans
510 WHERE user_id = $1
511 AND created_at >= NOW() - INTERVAL '%s days'
512 GROUP BY DATE(created_at), scan_type
513 ORDER BY date DESC
514 """ % period_days,
515 user_id
516 )
518 # Vulnerabilities by severity
519 # (This would need to parse JSONB findings - simplified here)
520 vuln_breakdown = await conn.fetchrow(
521 """
522 SELECT
523 SUM(findings_count) as total,
524 COUNT(*) as scan_count
525 FROM scans
526 WHERE user_id = $1
527 AND status = 'completed'
528 AND created_at >= NOW() - INTERVAL '%s days'
529 """ % period_days,
530 user_id
531 )
533 # Top targets
534 top_targets = await conn.fetch(
535 """
536 SELECT
537 target,
538 COUNT(*) as scan_count
539 FROM scans
540 WHERE user_id = $1
541 AND target IS NOT NULL
542 AND created_at >= NOW() - INTERVAL '%s days'
543 GROUP BY target
544 ORDER BY scan_count DESC
545 LIMIT 10
546 """ % period_days,
547 user_id
548 )
550 return {
551 "scans_over_time": [dict(row) for row in scans_over_time],
552 "vulnerabilities": dict(vuln_breakdown) if vuln_breakdown else {},
553 "top_targets": [dict(row) for row in top_targets]
554 }
556 # ==========================================
557 # Cleanup
558 # ==========================================
560 async def close(self):
561 """Close database connection pool."""
562 await self.service.close()
565# Singleton instance
566_client_instance = None
569def get_database_client() -> NeonDatabaseClient:
570 """Get singleton database client instance."""
571 global _client_instance
572 if _client_instance is None:
573 _client_instance = NeonDatabaseClient()
574 return _client_instance