Coverage for src/alprina_cli/api/services/neon_service.py: 16%

347 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-14 11:27 +0100

1""" 

2Neon Database Service 

3Replaces SupabaseService with direct PostgreSQL access via asyncpg. 

4""" 

5 

6import os 

7import secrets 

8import hashlib 

9import bcrypt 

10import asyncpg 

11from typing import Optional, Dict, Any, List 

12from datetime import datetime, timedelta 

13from loguru import logger 

14 

15 

16class NeonService: 

17 """Service for Neon PostgreSQL database operations.""" 

18 

19 def __init__(self): 

20 """Initialize Neon connection pool.""" 

21 self.database_url = os.getenv("DATABASE_URL") 

22 

23 if not self.database_url: 

24 logger.warning("DATABASE_URL not found - database features disabled") 

25 self.enabled = False 

26 self.pool = None 

27 else: 

28 self.enabled = True 

29 self.pool = None # Created on first use 

30 logger.info("Neon service initialized") 

31 

32 async def get_pool(self) -> asyncpg.Pool: 

33 """Get or create connection pool.""" 

34 if not self.pool: 

35 try: 

36 self.pool = await asyncpg.create_pool( 

37 self.database_url, 

38 min_size=2, 

39 max_size=10, 

40 command_timeout=60 

41 ) 

42 logger.info("✅ Neon connection pool created successfully") 

43 except Exception as e: 

44 logger.error(f"❌ Failed to create Neon connection pool: {e}") 

45 raise 

46 return self.pool 

47 

48 def is_enabled(self) -> bool: 

49 """Check if database is configured.""" 

50 return self.enabled 

51 

52 # ========================================== 

53 # User Management 

54 # ========================================== 

55 

56 async def create_user( 

57 self, 

58 email: str, 

59 password: str, 

60 full_name: Optional[str] = None 

61 ) -> Dict[str, Any]: 

62 """Create a new user with password.""" 

63 if not self.is_enabled(): 

64 raise Exception("Database not configured") 

65 

66 pool = await self.get_pool() 

67 

68 # Hash password 

69 if isinstance(password, str): 

70 password_bytes = password.encode('utf-8') 

71 else: 

72 password_bytes = password 

73 password_hash = bcrypt.hashpw(password_bytes, bcrypt.gensalt()).decode('utf-8') 

74 

75 async with pool.acquire() as conn: 

76 user = await conn.fetchrow( 

77 """ 

78 INSERT INTO users (email, password_hash, full_name, tier) 

79 VALUES ($1, $2, $3, 'none') 

80 RETURNING id, email, full_name, tier, created_at 

81 """, 

82 email, password_hash, full_name 

83 ) 

84 

85 # Generate API key 

86 api_key = self.generate_api_key() 

87 await self.create_api_key(str(user['id']), api_key, "Default API Key") 

88 

89 logger.info(f"Created user: {email}") 

90 

91 return { 

92 "user_id": str(user['id']), 

93 "email": user['email'], 

94 "full_name": user['full_name'], 

95 "tier": user['tier'], 

96 "api_key": api_key, 

97 "created_at": user['created_at'].isoformat() 

98 } 

99 

100 async def authenticate_user( 

101 self, 

102 email: str, 

103 password: str 

104 ) -> Optional[Dict[str, Any]]: 

105 """Authenticate user with email/password.""" 

106 if not self.is_enabled(): 

107 return None 

108 

109 pool = await self.get_pool() 

110 

111 async with pool.acquire() as conn: 

112 user = await conn.fetchrow( 

113 "SELECT * FROM users WHERE email = $1", 

114 email 

115 ) 

116 

117 if not user or not user['password_hash']: 

118 logger.debug(f"User not found or no password hash for {email}") 

119 return None 

120 

121 # Verify password 

122 if isinstance(password, str): 

123 password = password.encode('utf-8') 

124 

125 logger.debug(f"Checking password for {email}") 

126 try: 

127 password_match = bcrypt.checkpw(password, user['password_hash'].encode('utf-8')) 

128 logger.debug(f"Password match: {password_match}") 

129 if password_match: 

130 return dict(user) 

131 except Exception as e: 

132 logger.error(f"Password check error: {e}") 

133 

134 return None 

135 

136 async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: 

137 """Get user by ID.""" 

138 if not self.is_enabled(): 

139 return None 

140 

141 pool = await self.get_pool() 

142 

143 async with pool.acquire() as conn: 

144 user = await conn.fetchrow( 

145 "SELECT * FROM users WHERE id = $1", 

146 user_id 

147 ) 

148 return dict(user) if user else None 

149 

150 async def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]: 

151 """Get user by email.""" 

152 if not self.is_enabled(): 

153 return None 

154 

155 pool = await self.get_pool() 

156 

157 async with pool.acquire() as conn: 

158 user = await conn.fetchrow( 

159 "SELECT * FROM users WHERE email = $1", 

160 email 

161 ) 

162 return dict(user) if user else None 

163 

164 async def get_user_by_subscription( 

165 self, 

166 subscription_id: str 

167 ) -> Optional[Dict[str, Any]]: 

168 """Get user by Polar subscription ID.""" 

169 if not self.is_enabled(): 

170 return None 

171 

172 pool = await self.get_pool() 

173 

174 async with pool.acquire() as conn: 

175 user = await conn.fetchrow( 

176 "SELECT * FROM users WHERE polar_subscription_id = $1", 

177 subscription_id 

178 ) 

179 return dict(user) if user else None 

180 

181 async def create_user_from_subscription( 

182 self, 

183 email: str, 

184 polar_customer_id: str, 

185 polar_subscription_id: str, 

186 tier: str, 

187 billing_period: str = "monthly", 

188 has_metering: bool = True, 

189 scans_included: int = 0, 

190 period_start: datetime = None, 

191 period_end: datetime = None, 

192 seats_included: int = 1 

193 ) -> Dict[str, Any]: 

194 """Create user from Polar subscription (no password needed).""" 

195 if not self.is_enabled(): 

196 raise Exception("Database not configured") 

197 

198 pool = await self.get_pool() 

199 

200 async with pool.acquire() as conn: 

201 user = await conn.fetchrow( 

202 """ 

203 INSERT INTO users ( 

204 email, 

205 tier, 

206 polar_customer_id, 

207 polar_subscription_id, 

208 subscription_status, 

209 billing_period, 

210 has_metering, 

211 scans_included, 

212 scans_used_this_period, 

213 period_start, 

214 period_end, 

215 seats_included, 

216 seats_used, 

217 extra_seats 

218 ) 

219 VALUES ($1, $2, $3, $4, 'active', $5, $6, $7, 0, $8, $9, $10, 1, 0) 

220 RETURNING id, email, tier, created_at 

221 """, 

222 email, tier, polar_customer_id, polar_subscription_id, 

223 billing_period, has_metering, scans_included, 

224 period_start, period_end, seats_included 

225 ) 

226 

227 logger.info(f"Created user from subscription: {email}") 

228 

229 return dict(user) 

230 

231 async def update_user( 

232 self, 

233 user_id: str, 

234 updates: Dict[str, Any] 

235 ) -> bool: 

236 """Update user fields.""" 

237 if not self.is_enabled(): 

238 return False 

239 

240 pool = await self.get_pool() 

241 

242 # Build SET clause dynamically 

243 set_parts = [] 

244 values = [] 

245 idx = 1 

246 

247 for key, value in updates.items(): 

248 set_parts.append(f"{key} = ${idx}") 

249 values.append(value) 

250 idx += 1 

251 

252 values.append(user_id) 

253 

254 query = f""" 

255 UPDATE users 

256 SET {', '.join(set_parts)} 

257 WHERE id = ${idx} 

258 """ 

259 

260 async with pool.acquire() as conn: 

261 result = await conn.execute(query, *values) 

262 return result == "UPDATE 1" 

263 

264 async def initialize_usage_tracking( 

265 self, 

266 user_id: str, 

267 tier: str 

268 ): 

269 """Initialize usage tracking for new user.""" 

270 # Get tier limits 

271 from ..services.polar_service import polar_service 

272 limits = polar_service.get_tier_limits(tier) 

273 

274 await self.update_user( 

275 user_id, 

276 { 

277 "requests_per_hour": limits.get("api_requests_per_hour", 0), 

278 "scans_per_month": limits.get("scans_per_month", 0) 

279 } 

280 ) 

281 

282 async def increment_user_scans(self, user_id: str): 

283 """Increment user scan count.""" 

284 if not self.is_enabled(): 

285 return 

286 

287 pool = await self.get_pool() 

288 

289 async with pool.acquire() as conn: 

290 await conn.execute( 

291 """ 

292 UPDATE users 

293 SET scans_per_month = scans_per_month + 1 

294 WHERE id = $1 

295 """, 

296 user_id 

297 ) 

298 

299 # ========================================== 

300 # API Key Management 

301 # ========================================== 

302 

303 def generate_api_key(self) -> str: 

304 """Generate a new API key.""" 

305 return f"alprina_sk_{secrets.token_urlsafe(32)}" 

306 

307 async def create_api_key( 

308 self, 

309 user_id: str, 

310 api_key: str, 

311 name: str, 

312 expires_at: Optional[datetime] = None 

313 ) -> Dict[str, Any]: 

314 """Create a new API key.""" 

315 if not self.is_enabled(): 

316 raise Exception("Database not configured") 

317 

318 pool = await self.get_pool() 

319 

320 key_hash = hashlib.sha256(api_key.encode()).hexdigest() 

321 key_prefix = api_key[:16] # First 16 chars for display 

322 

323 async with pool.acquire() as conn: 

324 key = await conn.fetchrow( 

325 """ 

326 INSERT INTO api_keys (user_id, key_hash, key_prefix, name, expires_at) 

327 VALUES ($1, $2, $3, $4, $5) 

328 RETURNING id, name, key_prefix, created_at 

329 """, 

330 user_id, key_hash, key_prefix, name, expires_at 

331 ) 

332 

333 return { 

334 "id": str(key['id']), 

335 "name": key['name'], 

336 "key_prefix": key['key_prefix'], 

337 "created_at": key['created_at'].isoformat() 

338 } 

339 

340 async def verify_api_key(self, api_key: str) -> Optional[Dict[str, Any]]: 

341 """Verify API key and return user.""" 

342 if not self.is_enabled(): 

343 return None 

344 

345 pool = await self.get_pool() 

346 

347 key_hash = hashlib.sha256(api_key.encode()).hexdigest() 

348 

349 async with pool.acquire() as conn: 

350 result = await conn.fetchrow( 

351 """ 

352 SELECT u.*, k.id as key_id 

353 FROM users u 

354 JOIN api_keys k ON k.user_id = u.id 

355 WHERE k.key_hash = $1 

356 AND k.is_active = true 

357 AND (k.expires_at IS NULL OR k.expires_at > NOW()) 

358 """, 

359 key_hash 

360 ) 

361 

362 if result: 

363 # Update last_used_at 

364 await conn.execute( 

365 "UPDATE api_keys SET last_used_at = NOW() WHERE id = $1", 

366 result['key_id'] 

367 ) 

368 

369 return dict(result) 

370 

371 return None 

372 

373 async def list_api_keys(self, user_id: str) -> List[Dict[str, Any]]: 

374 """List user API keys.""" 

375 if not self.is_enabled(): 

376 return [] 

377 

378 pool = await self.get_pool() 

379 

380 async with pool.acquire() as conn: 

381 keys = await conn.fetch( 

382 """ 

383 SELECT id, name, key_prefix, is_active, created_at, last_used_at, expires_at 

384 FROM api_keys 

385 WHERE user_id = $1 

386 ORDER BY created_at DESC 

387 """, 

388 user_id 

389 ) 

390 

391 return [dict(key) for key in keys] 

392 

393 async def deactivate_api_key(self, key_id: str, user_id: str) -> bool: 

394 """Deactivate an API key.""" 

395 if not self.is_enabled(): 

396 return False 

397 

398 pool = await self.get_pool() 

399 

400 async with pool.acquire() as conn: 

401 result = await conn.execute( 

402 """ 

403 UPDATE api_keys 

404 SET is_active = false 

405 WHERE id = $1 AND user_id = $2 

406 """, 

407 key_id, user_id 

408 ) 

409 

410 return result == "UPDATE 1" 

411 

412 # ========================================== 

413 # Scan Management 

414 # ========================================== 

415 

416 async def create_scan( 

417 self, 

418 user_id: str, 

419 scan_type: str, 

420 workflow_mode: str, 

421 metadata: Optional[Dict] = None 

422 ) -> str: 

423 """Create a new scan.""" 

424 if not self.is_enabled(): 

425 raise Exception("Database not configured") 

426 

427 pool = await self.get_pool() 

428 

429 async with pool.acquire() as conn: 

430 scan = await conn.fetchrow( 

431 """ 

432 INSERT INTO scans (user_id, scan_type, workflow_mode, metadata) 

433 VALUES ($1, $2, $3, $4) 

434 RETURNING id 

435 """, 

436 user_id, scan_type, workflow_mode, metadata or {} 

437 ) 

438 

439 return str(scan['id']) 

440 

441 async def save_scan( 

442 self, 

443 scan_id: str, 

444 status: str, 

445 findings: Optional[Dict] = None, 

446 findings_count: int = 0, 

447 metadata: Optional[Dict] = None 

448 ) -> bool: 

449 """Update scan with results.""" 

450 if not self.is_enabled(): 

451 return False 

452 

453 pool = await self.get_pool() 

454 

455 async with pool.acquire() as conn: 

456 result = await conn.execute( 

457 """ 

458 UPDATE scans 

459 SET status = $1, 

460 findings = $2, 

461 findings_count = $3, 

462 metadata = $4, 

463 completed_at = NOW() 

464 WHERE id = $5 

465 """, 

466 status, findings or {}, findings_count, metadata or {}, scan_id 

467 ) 

468 

469 return result == "UPDATE 1" 

470 

471 async def get_scan( 

472 self, 

473 scan_id: str, 

474 user_id: Optional[str] = None 

475 ) -> Optional[Dict[str, Any]]: 

476 """Get scan by ID.""" 

477 if not self.is_enabled(): 

478 return None 

479 

480 pool = await self.get_pool() 

481 

482 async with pool.acquire() as conn: 

483 if user_id: 

484 scan = await conn.fetchrow( 

485 "SELECT * FROM scans WHERE id = $1 AND user_id = $2", 

486 scan_id, user_id 

487 ) 

488 else: 

489 scan = await conn.fetchrow( 

490 "SELECT * FROM scans WHERE id = $1", 

491 scan_id 

492 ) 

493 

494 return dict(scan) if scan else None 

495 

496 async def list_scans( 

497 self, 

498 user_id: str, 

499 limit: int = 10, 

500 offset: int = 0, 

501 scan_type: Optional[str] = None, 

502 workflow_mode: Optional[str] = None 

503 ) -> List[Dict[str, Any]]: 

504 """List user scans with pagination.""" 

505 if not self.is_enabled(): 

506 return [] 

507 

508 pool = await self.get_pool() 

509 

510 query = "SELECT * FROM scans WHERE user_id = $1" 

511 params = [user_id] 

512 idx = 2 

513 

514 if scan_type: 

515 query += f" AND scan_type = ${idx}" 

516 params.append(scan_type) 

517 idx += 1 

518 

519 if workflow_mode: 

520 query += f" AND workflow_mode = ${idx}" 

521 params.append(workflow_mode) 

522 idx += 1 

523 

524 query += f" ORDER BY created_at DESC LIMIT ${idx} OFFSET ${idx+1}" 

525 params.extend([limit, offset]) 

526 

527 async with pool.acquire() as conn: 

528 scans = await conn.fetch(query, *params) 

529 return [dict(scan) for scan in scans] 

530 

531 # ========================================== 

532 # Rate Limiting & Usage Tracking 

533 # ========================================== 

534 

535 async def check_rate_limit(self, user_id: str) -> Dict[str, Any]: 

536 """Check if user is within rate limits.""" 

537 if not self.is_enabled(): 

538 return {"allowed": True, "remaining": 0} 

539 

540 pool = await self.get_pool() 

541 

542 async with pool.acquire() as conn: 

543 # Get user tier limits 

544 user = await conn.fetchrow( 

545 "SELECT tier, requests_per_hour FROM users WHERE id = $1", 

546 user_id 

547 ) 

548 

549 if not user: 

550 return {"allowed": False, "remaining": 0} 

551 

552 # Count requests in last hour 

553 one_hour_ago = datetime.utcnow() - timedelta(hours=1) 

554 count = await conn.fetchval( 

555 """ 

556 SELECT COUNT(*) FROM usage_logs 

557 WHERE user_id = $1 AND created_at > $2 

558 """, 

559 user_id, one_hour_ago 

560 ) 

561 

562 limit = user['requests_per_hour'] 

563 remaining = max(0, limit - count) 

564 allowed = count < limit 

565 

566 return { 

567 "allowed": allowed, 

568 "remaining": remaining, 

569 "limit": limit, 

570 "used": count 

571 } 

572 

573 async def log_request( 

574 self, 

575 user_id: str, 

576 endpoint: str, 

577 method: str, 

578 status_code: int, 

579 duration_ms: float 

580 ): 

581 """Log API request.""" 

582 if not self.is_enabled(): 

583 return 

584 

585 pool = await self.get_pool() 

586 

587 async with pool.acquire() as conn: 

588 await conn.execute( 

589 """ 

590 INSERT INTO usage_logs (user_id, endpoint, method, status_code, duration_ms) 

591 VALUES ($1, $2, $3, $4, $5) 

592 """, 

593 user_id, endpoint, method, status_code, duration_ms 

594 ) 

595 

596 async def get_user_stats(self, user_id: str) -> Dict[str, Any]: 

597 """Get user statistics.""" 

598 if not self.is_enabled(): 

599 return {} 

600 

601 pool = await self.get_pool() 

602 

603 async with pool.acquire() as conn: 

604 # Get user 

605 user = await conn.fetchrow( 

606 "SELECT * FROM users WHERE id = $1", 

607 user_id 

608 ) 

609 

610 if not user: 

611 return {} 

612 

613 # Count scans this month 

614 first_day = datetime.utcnow().replace(day=1, hour=0, minute=0, second=0, microsecond=0) 

615 scans_count = await conn.fetchval( 

616 "SELECT COUNT(*) FROM scans WHERE user_id = $1 AND created_at >= $2", 

617 user_id, first_day 

618 ) 

619 

620 # Count total scans 

621 total_scans = await conn.fetchval( 

622 "SELECT COUNT(*) FROM scans WHERE user_id = $1", 

623 user_id 

624 ) 

625 

626 # Count vulnerabilities found 

627 vulns = await conn.fetchval( 

628 "SELECT COALESCE(SUM(findings_count), 0) FROM scans WHERE user_id = $1", 

629 user_id 

630 ) 

631 

632 return { 

633 "tier": user['tier'], 

634 "scans_this_month": scans_count, 

635 "total_scans": total_scans, 

636 "vulnerabilities_found": vulns, 

637 "scans_limit": user['scans_per_month'], 

638 "requests_limit": user['requests_per_hour'] 

639 } 

640 

641 # ========================================== 

642 # Webhook Event Logging 

643 # ========================================== 

644 

645 async def log_webhook_event( 

646 self, 

647 event_type: str, 

648 event_id: str, 

649 payload: Dict[str, Any] 

650 ): 

651 """Log webhook event.""" 

652 if not self.is_enabled(): 

653 return 

654 

655 pool = await self.get_pool() 

656 

657 import json 

658 payload_json = json.dumps(payload) 

659 

660 async with pool.acquire() as conn: 

661 await conn.execute( 

662 """ 

663 INSERT INTO webhook_events (event_type, event_id, payload) 

664 VALUES ($1, $2, $3::jsonb) 

665 ON CONFLICT (event_id) DO NOTHING 

666 """, 

667 event_type, event_id, payload_json 

668 ) 

669 

670 async def mark_webhook_processed(self, event_id: str): 

671 """Mark webhook as processed.""" 

672 if not self.is_enabled(): 

673 return 

674 

675 pool = await self.get_pool() 

676 

677 async with pool.acquire() as conn: 

678 await conn.execute( 

679 """ 

680 UPDATE webhook_events 

681 SET processed = true, processed_at = NOW() 

682 WHERE event_id = $1 

683 """, 

684 event_id 

685 ) 

686 

687 async def mark_webhook_error(self, event_id: str, error: str): 

688 """Mark webhook as errored.""" 

689 if not self.is_enabled(): 

690 return 

691 

692 pool = await self.get_pool() 

693 

694 async with pool.acquire() as conn: 

695 await conn.execute( 

696 """ 

697 UPDATE webhook_events 

698 SET error = $1 

699 WHERE event_id = $2 

700 """, 

701 error, event_id 

702 ) 

703 

704 # ========================================== 

705 # Device Authorization (CLI OAuth) 

706 # ========================================== 

707 

708 def generate_device_codes(self) -> tuple[str, str]: 

709 """Generate device and user codes.""" 

710 device_code = secrets.token_urlsafe(32) 

711 # User code: 4 digits (like GitHub CLI) - easy to type 

712 user_code = ''.join(secrets.choice('0123456789') for _ in range(4)) 

713 return device_code, user_code 

714 

715 async def create_device_authorization(self) -> Dict[str, Any]: 

716 """Create device authorization for CLI.""" 

717 if not self.is_enabled(): 

718 raise Exception("Database not configured") 

719 

720 pool = await self.get_pool() 

721 

722 device_code, user_code = self.generate_device_codes() 

723 expires_at = datetime.utcnow() + timedelta(minutes=15) 

724 

725 async with pool.acquire() as conn: 

726 auth = await conn.fetchrow( 

727 """ 

728 INSERT INTO device_codes (device_code, user_code, expires_at) 

729 VALUES ($1, $2, $3) 

730 RETURNING device_code, user_code, expires_at 

731 """, 

732 device_code, user_code, expires_at 

733 ) 

734 

735 return { 

736 "device_code": auth['device_code'], 

737 "user_code": auth['user_code'], 

738 "verification_uri": "https://www.alprina.com/device", 

739 "expires_in": 900 # 15 minutes 

740 } 

741 

742 async def check_device_authorization( 

743 self, 

744 device_code: str 

745 ) -> Optional[Dict[str, Any]]: 

746 """Check if device has been authorized.""" 

747 if not self.is_enabled(): 

748 return None 

749 

750 pool = await self.get_pool() 

751 

752 async with pool.acquire() as conn: 

753 auth = await conn.fetchrow( 

754 """ 

755 SELECT * FROM device_codes 

756 WHERE device_code = $1 

757 """, 

758 device_code 

759 ) 

760 

761 if not auth: 

762 return None 

763 

764 # Check if expired 

765 if auth['expires_at'] < datetime.utcnow(): 

766 return {"status": "expired"} 

767 

768 # Check if authorized 

769 if auth['authorized'] and auth['user_id']: 

770 return { 

771 "status": "authorized", 

772 "user_id": str(auth['user_id']), 

773 "user_code": auth['user_code'] 

774 } 

775 

776 # Still pending 

777 return {"status": "pending"} 

778 

779 async def authorize_device(self, user_code: str, user_id: str) -> bool: 

780 """Authorize a device.""" 

781 if not self.is_enabled(): 

782 return False 

783 

784 pool = await self.get_pool() 

785 

786 async with pool.acquire() as conn: 

787 result = await conn.execute( 

788 """ 

789 UPDATE device_codes 

790 SET user_id = $1, authorized = true 

791 WHERE user_code = $2 

792 AND expires_at > NOW() 

793 AND authorized = false 

794 """, 

795 user_id, user_code 

796 ) 

797 

798 return result == "UPDATE 1" 

799 

800 # ========================================== 

801 # Team Management 

802 # ========================================== 

803 

804 async def get_team_members(self, owner_id: str) -> List[Dict[str, Any]]: 

805 """Get all team members for a team owner.""" 

806 if not self.is_enabled(): 

807 return [] 

808 

809 pool = await self.get_pool() 

810 

811 async with pool.acquire() as conn: 

812 # Get team members from team_members table 

813 # We need to join with users to get email and other details 

814 members = await conn.fetch( 

815 """ 

816 SELECT  

817 u.id, 

818 u.email, 

819 u.full_name, 

820 tm.role, 

821 tm.created_at as joined_at 

822 FROM team_members tm 

823 JOIN users u ON u.id = tm.user_id 

824 WHERE tm.subscription_id IN ( 

825 SELECT id FROM user_subscriptions  

826 WHERE user_id = $1 

827 ) 

828 ORDER BY tm.created_at ASC 

829 """, 

830 owner_id 

831 ) 

832 

833 return [dict(member) for member in members] 

834 

835 async def get_team_member_by_email( 

836 self, 

837 owner_id: str, 

838 email: str 

839 ) -> Optional[Dict[str, Any]]: 

840 """Check if email is already a team member.""" 

841 if not self.is_enabled(): 

842 return None 

843 

844 pool = await self.get_pool() 

845 

846 async with pool.acquire() as conn: 

847 member = await conn.fetchrow( 

848 """ 

849 SELECT  

850 u.id, 

851 u.email, 

852 tm.role 

853 FROM team_members tm 

854 JOIN users u ON u.id = tm.user_id 

855 WHERE tm.subscription_id IN ( 

856 SELECT id FROM user_subscriptions  

857 WHERE user_id = $1 

858 ) 

859 AND LOWER(u.email) = LOWER($2) 

860 """, 

861 owner_id, email 

862 ) 

863 

864 return dict(member) if member else None 

865 

866 async def create_team_invitation( 

867 self, 

868 owner_id: str, 

869 invitee_email: str, 

870 role: str 

871 ) -> Dict[str, Any]: 

872 """Create a team invitation.""" 

873 if not self.is_enabled(): 

874 raise Exception("Database not configured") 

875 

876 pool = await self.get_pool() 

877 

878 # Generate invitation token 

879 import secrets 

880 invitation_token = secrets.token_urlsafe(32) 

881 

882 # Get owner email 

883 owner = await self.get_user(owner_id) 

884 owner_email = owner.get("email") if owner else None 

885 

886 async with pool.acquire() as conn: 

887 # Create team_invitations table if it doesn't exist 

888 await conn.execute( 

889 """ 

890 CREATE TABLE IF NOT EXISTS team_invitations ( 

891 id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), 

892 owner_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, 

893 owner_email TEXT, 

894 invitee_email TEXT NOT NULL, 

895 role TEXT NOT NULL CHECK (role IN ('admin', 'member')), 

896 token TEXT NOT NULL UNIQUE, 

897 status TEXT DEFAULT 'pending' CHECK (status IN ('pending', 'accepted', 'expired')), 

898 expires_at TIMESTAMPTZ DEFAULT NOW() + INTERVAL '7 days', 

899 created_at TIMESTAMPTZ DEFAULT NOW() 

900 ); 

901  

902 CREATE INDEX IF NOT EXISTS idx_team_invitations_token ON team_invitations(token); 

903 CREATE INDEX IF NOT EXISTS idx_team_invitations_owner ON team_invitations(owner_id); 

904 CREATE INDEX IF NOT EXISTS idx_team_invitations_email ON team_invitations(invitee_email); 

905 """ 

906 ) 

907 

908 invitation = await conn.fetchrow( 

909 """ 

910 INSERT INTO team_invitations ( 

911 owner_id, owner_email, invitee_email, role, token 

912 ) 

913 VALUES ($1, $2, $3, $4, $5) 

914 RETURNING id, owner_id, owner_email, invitee_email, role, token, created_at 

915 """, 

916 owner_id, owner_email, invitee_email, role, invitation_token 

917 ) 

918 

919 return dict(invitation) 

920 

921 async def get_team_invitation(self, token: str) -> Optional[Dict[str, Any]]: 

922 """Get team invitation by token.""" 

923 if not self.is_enabled(): 

924 return None 

925 

926 pool = await self.get_pool() 

927 

928 async with pool.acquire() as conn: 

929 invitation = await conn.fetchrow( 

930 """ 

931 SELECT * FROM team_invitations 

932 WHERE token = $1 

933 AND status = 'pending' 

934 AND expires_at > NOW() 

935 """, 

936 token 

937 ) 

938 

939 return dict(invitation) if invitation else None 

940 

941 async def delete_team_invitation(self, token: str) -> bool: 

942 """Delete or mark invitation as accepted.""" 

943 if not self.is_enabled(): 

944 return False 

945 

946 pool = await self.get_pool() 

947 

948 async with pool.acquire() as conn: 

949 result = await conn.execute( 

950 """ 

951 UPDATE team_invitations 

952 SET status = 'accepted' 

953 WHERE token = $1 

954 """, 

955 token 

956 ) 

957 

958 return result == "UPDATE 1" 

959 

960 async def add_team_member( 

961 self, 

962 owner_id: str, 

963 member_id: str, 

964 role: str 

965 ) -> bool: 

966 """Add a member to the team.""" 

967 if not self.is_enabled(): 

968 return False 

969 

970 pool = await self.get_pool() 

971 

972 async with pool.acquire() as conn: 

973 # Get owner's subscription 

974 subscription = await conn.fetchrow( 

975 """ 

976 SELECT id FROM user_subscriptions 

977 WHERE user_id = $1 

978 ORDER BY created_at DESC 

979 LIMIT 1 

980 """, 

981 owner_id 

982 ) 

983 

984 if not subscription: 

985 logger.error(f"No subscription found for owner {owner_id}") 

986 return False 

987 

988 # Add team member 

989 await conn.execute( 

990 """ 

991 INSERT INTO team_members (subscription_id, user_id, role) 

992 VALUES ($1, $2, $3) 

993 ON CONFLICT (subscription_id, user_id) DO NOTHING 

994 """, 

995 subscription["id"], member_id, role 

996 ) 

997 

998 logger.info(f"Added team member {member_id} to subscription {subscription['id']}") 

999 return True 

1000 

1001 async def remove_team_member( 

1002 self, 

1003 owner_id: str, 

1004 member_id: str 

1005 ) -> bool: 

1006 """Remove a team member.""" 

1007 if not self.is_enabled(): 

1008 return False 

1009 

1010 pool = await self.get_pool() 

1011 

1012 async with pool.acquire() as conn: 

1013 result = await conn.execute( 

1014 """ 

1015 DELETE FROM team_members 

1016 WHERE subscription_id IN ( 

1017 SELECT id FROM user_subscriptions  

1018 WHERE user_id = $1 

1019 ) 

1020 AND user_id = $2 

1021 """, 

1022 owner_id, member_id 

1023 ) 

1024 

1025 return result == "DELETE 1" 

1026 

1027 # ========================================== 

1028 # Cleanup 

1029 # ========================================== 

1030 

1031 async def close(self): 

1032 """Close database connection pool.""" 

1033 if self.pool: 

1034 await self.pool.close() 

1035 logger.info("Neon connection pool closed") 

1036 

1037 

1038# Create singleton instance 

1039neon_service = NeonService()