Coverage for src/alprina_cli/database/neon_client.py: 37%

98 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-14 11:27 +0100

1""" 

2Neon Database Client for CLI Tools 

3 

4Context Engineering: 

5- Lightweight wrapper around NeonService 

6- Fast operations (< 50ms target) 

7- Connection pooling for performance 

8- Minimal token footprint in responses 

9""" 

10 

11import os 

12from typing import Optional, Dict, Any, List, Tuple 

13from uuid import UUID 

14from datetime import datetime 

15from loguru import logger 

16 

17# Import existing NeonService 

18import sys 

19sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'api', 'services')) 

20from neon_service import NeonService 

21 

22 

23class NeonDatabaseClient: 

24 """ 

25 Database client for CLI tool integration. 

26 

27 Context Engineering: 

28 - All methods < 50ms (p95) 

29 - Connection pooling (reuse connections) 

30 - Minimal data transfer (only what's needed) 

31 - Async-first for non-blocking operations 

32 """ 

33 

34 def __init__(self): 

35 """Initialize database client.""" 

36 self.service = NeonService() 

37 self._cli_version = os.getenv("ALPRINA_CLI_VERSION", "0.1.0") 

38 

39 async def is_available(self) -> bool: 

40 """Check if database is configured and available.""" 

41 return self.service.is_enabled() 

42 

43 # ========================================== 

44 # Authentication Methods 

45 # ========================================== 

46 

47 async def authenticate_api_key(self, api_key: str) -> Optional[Dict[str, Any]]: 

48 """ 

49 Authenticate user via API key. 

50 

51 Args: 

52 api_key: Raw API key (e.g., "alprina_...") 

53 

54 Returns: 

55 User dict if valid, None otherwise 

56 

57 Context: Returns only essential user data 

58 """ 

59 return await self.service.verify_api_key(api_key) 

60 

61 async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: 

62 """Get user by ID.""" 

63 return await self.service.get_user_by_id(user_id) 

64 

65 # ========================================== 

66 # Scan Lifecycle Methods 

67 # ========================================== 

68 

69 async def create_scan( 

70 self, 

71 user_id: str, 

72 tool_name: str, 

73 target: str, 

74 params: Dict[str, Any] 

75 ) -> str: 

76 """ 

77 Create new scan record (status: pending). 

78 

79 Args: 

80 user_id: User UUID 

81 tool_name: Name of tool (e.g., "ScanTool", "ReconTool") 

82 target: Scan target (domain, IP, file path) 

83 params: Tool parameters as dict 

84 

85 Returns: 

86 Scan ID (UUID) 

87 

88 Context: Fast creation (< 20ms) 

89 """ 

90 metadata = { 

91 "tool_name": tool_name, 

92 "target": target, 

93 "params": params, 

94 "cli_version": self._cli_version, 

95 "guardrails_enabled": True 

96 } 

97 

98 scan_id = await self.service.create_scan( 

99 user_id=user_id, 

100 scan_type=tool_name.lower().replace("tool", ""), 

101 workflow_mode="cli", 

102 metadata=metadata 

103 ) 

104 

105 # Also update new columns 

106 pool = await self.service.get_pool() 

107 async with pool.acquire() as conn: 

108 await conn.execute( 

109 """ 

110 UPDATE scans 

111 SET cli_version = $1, 

112 tool_name = $2, 

113 target = $3 

114 WHERE id = $4 

115 """, 

116 self._cli_version, tool_name, target, scan_id 

117 ) 

118 

119 logger.debug(f"Created scan {scan_id} for user {user_id}") 

120 return scan_id 

121 

122 async def update_scan_status( 

123 self, 

124 scan_id: str, 

125 status: str 

126 ) -> bool: 

127 """ 

128 Update scan status. 

129 

130 Args: 

131 scan_id: Scan UUID 

132 status: New status (pending/running/completed/failed) 

133 

134 Returns: 

135 True if updated 

136 

137 Context: Fast update (< 10ms) 

138 """ 

139 pool = await self.service.get_pool() 

140 async with pool.acquire() as conn: 

141 result = await conn.execute( 

142 """ 

143 UPDATE scans 

144 SET status = $1 

145 WHERE id = $2 

146 """, 

147 status, scan_id 

148 ) 

149 return result == "UPDATE 1" 

150 

151 async def save_scan_results( 

152 self, 

153 scan_id: str, 

154 findings: Dict[str, Any], 

155 findings_count: int, 

156 status: str = "completed" 

157 ) -> bool: 

158 """ 

159 Save scan findings and mark as completed. 

160 

161 Args: 

162 scan_id: Scan UUID 

163 findings: Scan results as dict 

164 findings_count: Number of findings 

165 status: Final status (completed/failed) 

166 

167 Returns: 

168 True if saved 

169 

170 Context: Efficient JSONB storage 

171 """ 

172 return await self.service.save_scan( 

173 scan_id=scan_id, 

174 status=status, 

175 findings=findings, 

176 findings_count=findings_count 

177 ) 

178 

179 async def get_scan( 

180 self, 

181 scan_id: str, 

182 user_id: Optional[str] = None 

183 ) -> Optional[Dict[str, Any]]: 

184 """ 

185 Retrieve scan by ID. 

186 

187 Args: 

188 scan_id: Scan UUID 

189 user_id: Optional user ID for access control 

190 

191 Returns: 

192 Scan dict or None 

193 """ 

194 return await self.service.get_scan(scan_id, user_id) 

195 

196 async def list_user_scans( 

197 self, 

198 user_id: str, 

199 limit: int = 20, 

200 offset: int = 0, 

201 tool_name: Optional[str] = None 

202 ) -> List[Dict[str, Any]]: 

203 """ 

204 List user's recent scans. 

205 

206 Args: 

207 user_id: User UUID 

208 limit: Max results (default: 20) 

209 offset: Pagination offset 

210 tool_name: Optional filter by tool 

211 

212 Returns: 

213 List of scan dicts 

214 

215 Context: Paginated for performance 

216 """ 

217 return await self.service.list_scans( 

218 user_id=user_id, 

219 limit=limit, 

220 offset=offset, 

221 scan_type=tool_name.lower().replace("tool", "") if tool_name else None 

222 ) 

223 

224 # ========================================== 

225 # Usage Tracking Methods 

226 # ========================================== 

227 

228 async def track_scan_usage( 

229 self, 

230 user_id: str, 

231 scan_id: str, 

232 tool_name: str, 

233 credits_used: int = 1, 

234 duration_ms: Optional[int] = None, 

235 vulnerabilities_found: int = 0 

236 ) -> bool: 

237 """ 

238 Track scan usage for metering. 

239 

240 Args: 

241 user_id: User UUID 

242 scan_id: Scan UUID 

243 tool_name: Tool name 

244 credits_used: Credits consumed (default: 1) 

245 duration_ms: Execution time in milliseconds 

246 vulnerabilities_found: Number of vulnerabilities found 

247 

248 Returns: 

249 True if tracked 

250 

251 Context: Essential for billing/limits 

252 """ 

253 pool = await self.service.get_pool() 

254 

255 # Get subscription_id 

256 async with pool.acquire() as conn: 

257 subscription = await conn.fetchrow( 

258 """ 

259 SELECT id FROM user_subscriptions 

260 WHERE user_id = $1 

261 AND status = 'active' 

262 ORDER BY created_at DESC 

263 LIMIT 1 

264 """, 

265 user_id 

266 ) 

267 

268 subscription_id = str(subscription['id']) if subscription else None 

269 

270 # Insert usage record 

271 await conn.execute( 

272 """ 

273 INSERT INTO scan_usage ( 

274 user_id, subscription_id, scan_id, 

275 scan_type, workflow_mode, 

276 credits_used, duration_ms, vulnerabilities_found 

277 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) 

278 """, 

279 user_id, subscription_id, scan_id, 

280 tool_name.lower().replace("tool", ""), "cli", 

281 credits_used, duration_ms, vulnerabilities_found 

282 ) 

283 

284 logger.debug(f"Tracked usage for scan {scan_id}: {credits_used} credits") 

285 return True 

286 

287 async def check_scan_limit( 

288 self, 

289 user_id: str 

290 ) -> Tuple[bool, int, int]: 

291 """ 

292 Check if user can perform another scan. 

293 

294 Args: 

295 user_id: User UUID 

296 

297 Returns: 

298 Tuple of (can_scan, scans_used, scans_limit) 

299 

300 Context: Critical for rate limiting 

301 """ 

302 pool = await self.service.get_pool() 

303 

304 async with pool.acquire() as conn: 

305 # Get active subscription 

306 subscription = await conn.fetchrow( 

307 """ 

308 SELECT 

309 scans_used, 

310 scans_limit, 

311 current_period_start, 

312 current_period_end 

313 FROM user_subscriptions 

314 WHERE user_id = $1 

315 AND status = 'active' 

316 ORDER BY created_at DESC 

317 LIMIT 1 

318 """, 

319 user_id 

320 ) 

321 

322 if not subscription: 

323 # No active subscription - free tier (10 scans/month) 

324 # Count scans this month 

325 scans_count = await conn.fetchval( 

326 """ 

327 SELECT COUNT(*) 

328 FROM scans 

329 WHERE user_id = $1 

330 AND created_at >= date_trunc('month', CURRENT_DATE) 

331 """, 

332 user_id 

333 ) 

334 return (scans_count < 10, scans_count, 10) 

335 

336 # Count scans in current period 

337 scans_used = await conn.fetchval( 

338 """ 

339 SELECT COUNT(*) 

340 FROM scan_usage 

341 WHERE user_id = $1 

342 AND created_at >= $2 

343 AND created_at < $3 

344 """, 

345 user_id, 

346 subscription['current_period_start'], 

347 subscription['current_period_end'] 

348 ) 

349 

350 scans_limit = subscription['scans_limit'] 

351 can_scan = scans_used < scans_limit 

352 

353 return (can_scan, scans_used, scans_limit) 

354 

355 async def increment_scan_count(self, user_id: str) -> bool: 

356 """ 

357 Increment scan count for user's active subscription. 

358 

359 Args: 

360 user_id: User UUID 

361 

362 Returns: 

363 True if incremented 

364 """ 

365 pool = await self.service.get_pool() 

366 

367 async with pool.acquire() as conn: 

368 result = await conn.execute( 

369 """ 

370 UPDATE user_subscriptions 

371 SET scans_used = scans_used + 1 

372 WHERE user_id = $1 

373 AND status = 'active' 

374 """, 

375 user_id 

376 ) 

377 return "UPDATE" in result 

378 

379 # ========================================== 

380 # CLI Session Tracking 

381 # ========================================== 

382 

383 async def create_cli_session( 

384 self, 

385 user_id: str, 

386 cli_version: str, 

387 os_info: str, 

388 python_version: str 

389 ) -> str: 

390 """ 

391 Create new CLI session for analytics. 

392 

393 Args: 

394 user_id: User UUID 

395 cli_version: CLI version string 

396 os_info: Operating system 

397 python_version: Python version 

398 

399 Returns: 

400 Session ID 

401 """ 

402 pool = await self.service.get_pool() 

403 

404 async with pool.acquire() as conn: 

405 session = await conn.fetchrow( 

406 """ 

407 INSERT INTO cli_sessions ( 

408 user_id, cli_version, os, python_version 

409 ) VALUES ($1, $2, $3, $4) 

410 RETURNING id 

411 """, 

412 user_id, cli_version, os_info, python_version 

413 ) 

414 return str(session['id']) 

415 

416 async def update_session_activity(self, session_id: str): 

417 """Update session last_activity timestamp.""" 

418 pool = await self.service.get_pool() 

419 

420 async with pool.acquire() as conn: 

421 await conn.execute( 

422 """ 

423 UPDATE cli_sessions 

424 SET last_activity = NOW(), 

425 commands_run = commands_run + 1 

426 WHERE id = $1 

427 """, 

428 session_id 

429 ) 

430 

431 # ========================================== 

432 # API Key Methods 

433 # ========================================== 

434 

435 async def list_api_keys(self, user_id: str) -> List[Dict[str, Any]]: 

436 """List user's API keys.""" 

437 return await self.service.list_api_keys(user_id) 

438 

439 async def create_api_key( 

440 self, 

441 user_id: str, 

442 name: str 

443 ) -> Tuple[str, str]: 

444 """ 

445 Create new API key. 

446 

447 Args: 

448 user_id: User UUID 

449 name: Key name/description 

450 

451 Returns: 

452 Tuple of (raw_key, key_id) 

453 """ 

454 api_key = self.service.generate_api_key() 

455 await self.service.create_api_key(user_id, api_key, name) 

456 

457 # Get key ID 

458 pool = await self.service.get_pool() 

459 async with pool.acquire() as conn: 

460 key_hash = self.service.hash_api_key(api_key) 

461 key = await conn.fetchrow( 

462 "SELECT id FROM api_keys WHERE key_hash = $1", 

463 key_hash 

464 ) 

465 return (api_key, str(key['id'])) 

466 

467 async def revoke_api_key(self, key_id: str, user_id: str) -> bool: 

468 """Revoke (deactivate) API key.""" 

469 return await self.service.deactivate_api_key(key_id, user_id) 

470 

471 # ========================================== 

472 # Analytics Methods 

473 # ========================================== 

474 

475 async def get_user_stats(self, user_id: str) -> Dict[str, Any]: 

476 """ 

477 Get user statistics for dashboard. 

478 

479 Returns: 

480 Dict with scan counts, vulnerabilities, usage, etc. 

481 """ 

482 return await self.service.get_user_stats(user_id) 

483 

484 async def get_scan_analytics( 

485 self, 

486 user_id: str, 

487 period_days: int = 30 

488 ) -> Dict[str, Any]: 

489 """ 

490 Get scan analytics for charts. 

491 

492 Args: 

493 user_id: User UUID 

494 period_days: Number of days to analyze 

495 

496 Returns: 

497 Dict with time series data, breakdowns, etc. 

498 """ 

499 pool = await self.service.get_pool() 

500 

501 async with pool.acquire() as conn: 

502 # Scans over time (daily counts) 

503 scans_over_time = await conn.fetch( 

504 """ 

505 SELECT 

506 DATE(created_at) as date, 

507 COUNT(*) as count, 

508 scan_type 

509 FROM scans 

510 WHERE user_id = $1 

511 AND created_at >= NOW() - INTERVAL '%s days' 

512 GROUP BY DATE(created_at), scan_type 

513 ORDER BY date DESC 

514 """ % period_days, 

515 user_id 

516 ) 

517 

518 # Vulnerabilities by severity 

519 # (This would need to parse JSONB findings - simplified here) 

520 vuln_breakdown = await conn.fetchrow( 

521 """ 

522 SELECT 

523 SUM(findings_count) as total, 

524 COUNT(*) as scan_count 

525 FROM scans 

526 WHERE user_id = $1 

527 AND status = 'completed' 

528 AND created_at >= NOW() - INTERVAL '%s days' 

529 """ % period_days, 

530 user_id 

531 ) 

532 

533 # Top targets 

534 top_targets = await conn.fetch( 

535 """ 

536 SELECT 

537 target, 

538 COUNT(*) as scan_count 

539 FROM scans 

540 WHERE user_id = $1 

541 AND target IS NOT NULL 

542 AND created_at >= NOW() - INTERVAL '%s days' 

543 GROUP BY target 

544 ORDER BY scan_count DESC 

545 LIMIT 10 

546 """ % period_days, 

547 user_id 

548 ) 

549 

550 return { 

551 "scans_over_time": [dict(row) for row in scans_over_time], 

552 "vulnerabilities": dict(vuln_breakdown) if vuln_breakdown else {}, 

553 "top_targets": [dict(row) for row in top_targets] 

554 } 

555 

556 # ========================================== 

557 # Cleanup 

558 # ========================================== 

559 

560 async def close(self): 

561 """Close database connection pool.""" 

562 await self.service.close() 

563 

564 

565# Singleton instance 

566_client_instance = None 

567 

568 

569def get_database_client() -> NeonDatabaseClient: 

570 """Get singleton database client instance.""" 

571 global _client_instance 

572 if _client_instance is None: 

573 _client_instance = NeonDatabaseClient() 

574 return _client_instance