Coverage for src/alprina_cli/api/services/neon_service.py: 16%
347 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
1"""
2Neon Database Service
3Replaces SupabaseService with direct PostgreSQL access via asyncpg.
4"""
6import os
7import secrets
8import hashlib
9import bcrypt
10import asyncpg
11from typing import Optional, Dict, Any, List
12from datetime import datetime, timedelta
13from loguru import logger
16class NeonService:
17 """Service for Neon PostgreSQL database operations."""
19 def __init__(self):
20 """Initialize Neon connection pool."""
21 self.database_url = os.getenv("DATABASE_URL")
23 if not self.database_url:
24 logger.warning("DATABASE_URL not found - database features disabled")
25 self.enabled = False
26 self.pool = None
27 else:
28 self.enabled = True
29 self.pool = None # Created on first use
30 logger.info("Neon service initialized")
32 async def get_pool(self) -> asyncpg.Pool:
33 """Get or create connection pool."""
34 if not self.pool:
35 try:
36 self.pool = await asyncpg.create_pool(
37 self.database_url,
38 min_size=2,
39 max_size=10,
40 command_timeout=60
41 )
42 logger.info("✅ Neon connection pool created successfully")
43 except Exception as e:
44 logger.error(f"❌ Failed to create Neon connection pool: {e}")
45 raise
46 return self.pool
48 def is_enabled(self) -> bool:
49 """Check if database is configured."""
50 return self.enabled
52 # ==========================================
53 # User Management
54 # ==========================================
56 async def create_user(
57 self,
58 email: str,
59 password: str,
60 full_name: Optional[str] = None
61 ) -> Dict[str, Any]:
62 """Create a new user with password."""
63 if not self.is_enabled():
64 raise Exception("Database not configured")
66 pool = await self.get_pool()
68 # Hash password
69 if isinstance(password, str):
70 password_bytes = password.encode('utf-8')
71 else:
72 password_bytes = password
73 password_hash = bcrypt.hashpw(password_bytes, bcrypt.gensalt()).decode('utf-8')
75 async with pool.acquire() as conn:
76 user = await conn.fetchrow(
77 """
78 INSERT INTO users (email, password_hash, full_name, tier)
79 VALUES ($1, $2, $3, 'none')
80 RETURNING id, email, full_name, tier, created_at
81 """,
82 email, password_hash, full_name
83 )
85 # Generate API key
86 api_key = self.generate_api_key()
87 await self.create_api_key(str(user['id']), api_key, "Default API Key")
89 logger.info(f"Created user: {email}")
91 return {
92 "user_id": str(user['id']),
93 "email": user['email'],
94 "full_name": user['full_name'],
95 "tier": user['tier'],
96 "api_key": api_key,
97 "created_at": user['created_at'].isoformat()
98 }
100 async def authenticate_user(
101 self,
102 email: str,
103 password: str
104 ) -> Optional[Dict[str, Any]]:
105 """Authenticate user with email/password."""
106 if not self.is_enabled():
107 return None
109 pool = await self.get_pool()
111 async with pool.acquire() as conn:
112 user = await conn.fetchrow(
113 "SELECT * FROM users WHERE email = $1",
114 email
115 )
117 if not user or not user['password_hash']:
118 logger.debug(f"User not found or no password hash for {email}")
119 return None
121 # Verify password
122 if isinstance(password, str):
123 password = password.encode('utf-8')
125 logger.debug(f"Checking password for {email}")
126 try:
127 password_match = bcrypt.checkpw(password, user['password_hash'].encode('utf-8'))
128 logger.debug(f"Password match: {password_match}")
129 if password_match:
130 return dict(user)
131 except Exception as e:
132 logger.error(f"Password check error: {e}")
134 return None
136 async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
137 """Get user by ID."""
138 if not self.is_enabled():
139 return None
141 pool = await self.get_pool()
143 async with pool.acquire() as conn:
144 user = await conn.fetchrow(
145 "SELECT * FROM users WHERE id = $1",
146 user_id
147 )
148 return dict(user) if user else None
150 async def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]:
151 """Get user by email."""
152 if not self.is_enabled():
153 return None
155 pool = await self.get_pool()
157 async with pool.acquire() as conn:
158 user = await conn.fetchrow(
159 "SELECT * FROM users WHERE email = $1",
160 email
161 )
162 return dict(user) if user else None
164 async def get_user_by_subscription(
165 self,
166 subscription_id: str
167 ) -> Optional[Dict[str, Any]]:
168 """Get user by Polar subscription ID."""
169 if not self.is_enabled():
170 return None
172 pool = await self.get_pool()
174 async with pool.acquire() as conn:
175 user = await conn.fetchrow(
176 "SELECT * FROM users WHERE polar_subscription_id = $1",
177 subscription_id
178 )
179 return dict(user) if user else None
181 async def create_user_from_subscription(
182 self,
183 email: str,
184 polar_customer_id: str,
185 polar_subscription_id: str,
186 tier: str,
187 billing_period: str = "monthly",
188 has_metering: bool = True,
189 scans_included: int = 0,
190 period_start: datetime = None,
191 period_end: datetime = None,
192 seats_included: int = 1
193 ) -> Dict[str, Any]:
194 """Create user from Polar subscription (no password needed)."""
195 if not self.is_enabled():
196 raise Exception("Database not configured")
198 pool = await self.get_pool()
200 async with pool.acquire() as conn:
201 user = await conn.fetchrow(
202 """
203 INSERT INTO users (
204 email,
205 tier,
206 polar_customer_id,
207 polar_subscription_id,
208 subscription_status,
209 billing_period,
210 has_metering,
211 scans_included,
212 scans_used_this_period,
213 period_start,
214 period_end,
215 seats_included,
216 seats_used,
217 extra_seats
218 )
219 VALUES ($1, $2, $3, $4, 'active', $5, $6, $7, 0, $8, $9, $10, 1, 0)
220 RETURNING id, email, tier, created_at
221 """,
222 email, tier, polar_customer_id, polar_subscription_id,
223 billing_period, has_metering, scans_included,
224 period_start, period_end, seats_included
225 )
227 logger.info(f"Created user from subscription: {email}")
229 return dict(user)
231 async def update_user(
232 self,
233 user_id: str,
234 updates: Dict[str, Any]
235 ) -> bool:
236 """Update user fields."""
237 if not self.is_enabled():
238 return False
240 pool = await self.get_pool()
242 # Build SET clause dynamically
243 set_parts = []
244 values = []
245 idx = 1
247 for key, value in updates.items():
248 set_parts.append(f"{key} = ${idx}")
249 values.append(value)
250 idx += 1
252 values.append(user_id)
254 query = f"""
255 UPDATE users
256 SET {', '.join(set_parts)}
257 WHERE id = ${idx}
258 """
260 async with pool.acquire() as conn:
261 result = await conn.execute(query, *values)
262 return result == "UPDATE 1"
264 async def initialize_usage_tracking(
265 self,
266 user_id: str,
267 tier: str
268 ):
269 """Initialize usage tracking for new user."""
270 # Get tier limits
271 from ..services.polar_service import polar_service
272 limits = polar_service.get_tier_limits(tier)
274 await self.update_user(
275 user_id,
276 {
277 "requests_per_hour": limits.get("api_requests_per_hour", 0),
278 "scans_per_month": limits.get("scans_per_month", 0)
279 }
280 )
282 async def increment_user_scans(self, user_id: str):
283 """Increment user scan count."""
284 if not self.is_enabled():
285 return
287 pool = await self.get_pool()
289 async with pool.acquire() as conn:
290 await conn.execute(
291 """
292 UPDATE users
293 SET scans_per_month = scans_per_month + 1
294 WHERE id = $1
295 """,
296 user_id
297 )
299 # ==========================================
300 # API Key Management
301 # ==========================================
303 def generate_api_key(self) -> str:
304 """Generate a new API key."""
305 return f"alprina_sk_{secrets.token_urlsafe(32)}"
307 async def create_api_key(
308 self,
309 user_id: str,
310 api_key: str,
311 name: str,
312 expires_at: Optional[datetime] = None
313 ) -> Dict[str, Any]:
314 """Create a new API key."""
315 if not self.is_enabled():
316 raise Exception("Database not configured")
318 pool = await self.get_pool()
320 key_hash = hashlib.sha256(api_key.encode()).hexdigest()
321 key_prefix = api_key[:16] # First 16 chars for display
323 async with pool.acquire() as conn:
324 key = await conn.fetchrow(
325 """
326 INSERT INTO api_keys (user_id, key_hash, key_prefix, name, expires_at)
327 VALUES ($1, $2, $3, $4, $5)
328 RETURNING id, name, key_prefix, created_at
329 """,
330 user_id, key_hash, key_prefix, name, expires_at
331 )
333 return {
334 "id": str(key['id']),
335 "name": key['name'],
336 "key_prefix": key['key_prefix'],
337 "created_at": key['created_at'].isoformat()
338 }
340 async def verify_api_key(self, api_key: str) -> Optional[Dict[str, Any]]:
341 """Verify API key and return user."""
342 if not self.is_enabled():
343 return None
345 pool = await self.get_pool()
347 key_hash = hashlib.sha256(api_key.encode()).hexdigest()
349 async with pool.acquire() as conn:
350 result = await conn.fetchrow(
351 """
352 SELECT u.*, k.id as key_id
353 FROM users u
354 JOIN api_keys k ON k.user_id = u.id
355 WHERE k.key_hash = $1
356 AND k.is_active = true
357 AND (k.expires_at IS NULL OR k.expires_at > NOW())
358 """,
359 key_hash
360 )
362 if result:
363 # Update last_used_at
364 await conn.execute(
365 "UPDATE api_keys SET last_used_at = NOW() WHERE id = $1",
366 result['key_id']
367 )
369 return dict(result)
371 return None
373 async def list_api_keys(self, user_id: str) -> List[Dict[str, Any]]:
374 """List user API keys."""
375 if not self.is_enabled():
376 return []
378 pool = await self.get_pool()
380 async with pool.acquire() as conn:
381 keys = await conn.fetch(
382 """
383 SELECT id, name, key_prefix, is_active, created_at, last_used_at, expires_at
384 FROM api_keys
385 WHERE user_id = $1
386 ORDER BY created_at DESC
387 """,
388 user_id
389 )
391 return [dict(key) for key in keys]
393 async def deactivate_api_key(self, key_id: str, user_id: str) -> bool:
394 """Deactivate an API key."""
395 if not self.is_enabled():
396 return False
398 pool = await self.get_pool()
400 async with pool.acquire() as conn:
401 result = await conn.execute(
402 """
403 UPDATE api_keys
404 SET is_active = false
405 WHERE id = $1 AND user_id = $2
406 """,
407 key_id, user_id
408 )
410 return result == "UPDATE 1"
412 # ==========================================
413 # Scan Management
414 # ==========================================
416 async def create_scan(
417 self,
418 user_id: str,
419 scan_type: str,
420 workflow_mode: str,
421 metadata: Optional[Dict] = None
422 ) -> str:
423 """Create a new scan."""
424 if not self.is_enabled():
425 raise Exception("Database not configured")
427 pool = await self.get_pool()
429 async with pool.acquire() as conn:
430 scan = await conn.fetchrow(
431 """
432 INSERT INTO scans (user_id, scan_type, workflow_mode, metadata)
433 VALUES ($1, $2, $3, $4)
434 RETURNING id
435 """,
436 user_id, scan_type, workflow_mode, metadata or {}
437 )
439 return str(scan['id'])
441 async def save_scan(
442 self,
443 scan_id: str,
444 status: str,
445 findings: Optional[Dict] = None,
446 findings_count: int = 0,
447 metadata: Optional[Dict] = None
448 ) -> bool:
449 """Update scan with results."""
450 if not self.is_enabled():
451 return False
453 pool = await self.get_pool()
455 async with pool.acquire() as conn:
456 result = await conn.execute(
457 """
458 UPDATE scans
459 SET status = $1,
460 findings = $2,
461 findings_count = $3,
462 metadata = $4,
463 completed_at = NOW()
464 WHERE id = $5
465 """,
466 status, findings or {}, findings_count, metadata or {}, scan_id
467 )
469 return result == "UPDATE 1"
471 async def get_scan(
472 self,
473 scan_id: str,
474 user_id: Optional[str] = None
475 ) -> Optional[Dict[str, Any]]:
476 """Get scan by ID."""
477 if not self.is_enabled():
478 return None
480 pool = await self.get_pool()
482 async with pool.acquire() as conn:
483 if user_id:
484 scan = await conn.fetchrow(
485 "SELECT * FROM scans WHERE id = $1 AND user_id = $2",
486 scan_id, user_id
487 )
488 else:
489 scan = await conn.fetchrow(
490 "SELECT * FROM scans WHERE id = $1",
491 scan_id
492 )
494 return dict(scan) if scan else None
496 async def list_scans(
497 self,
498 user_id: str,
499 limit: int = 10,
500 offset: int = 0,
501 scan_type: Optional[str] = None,
502 workflow_mode: Optional[str] = None
503 ) -> List[Dict[str, Any]]:
504 """List user scans with pagination."""
505 if not self.is_enabled():
506 return []
508 pool = await self.get_pool()
510 query = "SELECT * FROM scans WHERE user_id = $1"
511 params = [user_id]
512 idx = 2
514 if scan_type:
515 query += f" AND scan_type = ${idx}"
516 params.append(scan_type)
517 idx += 1
519 if workflow_mode:
520 query += f" AND workflow_mode = ${idx}"
521 params.append(workflow_mode)
522 idx += 1
524 query += f" ORDER BY created_at DESC LIMIT ${idx} OFFSET ${idx+1}"
525 params.extend([limit, offset])
527 async with pool.acquire() as conn:
528 scans = await conn.fetch(query, *params)
529 return [dict(scan) for scan in scans]
531 # ==========================================
532 # Rate Limiting & Usage Tracking
533 # ==========================================
535 async def check_rate_limit(self, user_id: str) -> Dict[str, Any]:
536 """Check if user is within rate limits."""
537 if not self.is_enabled():
538 return {"allowed": True, "remaining": 0}
540 pool = await self.get_pool()
542 async with pool.acquire() as conn:
543 # Get user tier limits
544 user = await conn.fetchrow(
545 "SELECT tier, requests_per_hour FROM users WHERE id = $1",
546 user_id
547 )
549 if not user:
550 return {"allowed": False, "remaining": 0}
552 # Count requests in last hour
553 one_hour_ago = datetime.utcnow() - timedelta(hours=1)
554 count = await conn.fetchval(
555 """
556 SELECT COUNT(*) FROM usage_logs
557 WHERE user_id = $1 AND created_at > $2
558 """,
559 user_id, one_hour_ago
560 )
562 limit = user['requests_per_hour']
563 remaining = max(0, limit - count)
564 allowed = count < limit
566 return {
567 "allowed": allowed,
568 "remaining": remaining,
569 "limit": limit,
570 "used": count
571 }
573 async def log_request(
574 self,
575 user_id: str,
576 endpoint: str,
577 method: str,
578 status_code: int,
579 duration_ms: float
580 ):
581 """Log API request."""
582 if not self.is_enabled():
583 return
585 pool = await self.get_pool()
587 async with pool.acquire() as conn:
588 await conn.execute(
589 """
590 INSERT INTO usage_logs (user_id, endpoint, method, status_code, duration_ms)
591 VALUES ($1, $2, $3, $4, $5)
592 """,
593 user_id, endpoint, method, status_code, duration_ms
594 )
596 async def get_user_stats(self, user_id: str) -> Dict[str, Any]:
597 """Get user statistics."""
598 if not self.is_enabled():
599 return {}
601 pool = await self.get_pool()
603 async with pool.acquire() as conn:
604 # Get user
605 user = await conn.fetchrow(
606 "SELECT * FROM users WHERE id = $1",
607 user_id
608 )
610 if not user:
611 return {}
613 # Count scans this month
614 first_day = datetime.utcnow().replace(day=1, hour=0, minute=0, second=0, microsecond=0)
615 scans_count = await conn.fetchval(
616 "SELECT COUNT(*) FROM scans WHERE user_id = $1 AND created_at >= $2",
617 user_id, first_day
618 )
620 # Count total scans
621 total_scans = await conn.fetchval(
622 "SELECT COUNT(*) FROM scans WHERE user_id = $1",
623 user_id
624 )
626 # Count vulnerabilities found
627 vulns = await conn.fetchval(
628 "SELECT COALESCE(SUM(findings_count), 0) FROM scans WHERE user_id = $1",
629 user_id
630 )
632 return {
633 "tier": user['tier'],
634 "scans_this_month": scans_count,
635 "total_scans": total_scans,
636 "vulnerabilities_found": vulns,
637 "scans_limit": user['scans_per_month'],
638 "requests_limit": user['requests_per_hour']
639 }
641 # ==========================================
642 # Webhook Event Logging
643 # ==========================================
645 async def log_webhook_event(
646 self,
647 event_type: str,
648 event_id: str,
649 payload: Dict[str, Any]
650 ):
651 """Log webhook event."""
652 if not self.is_enabled():
653 return
655 pool = await self.get_pool()
657 import json
658 payload_json = json.dumps(payload)
660 async with pool.acquire() as conn:
661 await conn.execute(
662 """
663 INSERT INTO webhook_events (event_type, event_id, payload)
664 VALUES ($1, $2, $3::jsonb)
665 ON CONFLICT (event_id) DO NOTHING
666 """,
667 event_type, event_id, payload_json
668 )
670 async def mark_webhook_processed(self, event_id: str):
671 """Mark webhook as processed."""
672 if not self.is_enabled():
673 return
675 pool = await self.get_pool()
677 async with pool.acquire() as conn:
678 await conn.execute(
679 """
680 UPDATE webhook_events
681 SET processed = true, processed_at = NOW()
682 WHERE event_id = $1
683 """,
684 event_id
685 )
687 async def mark_webhook_error(self, event_id: str, error: str):
688 """Mark webhook as errored."""
689 if not self.is_enabled():
690 return
692 pool = await self.get_pool()
694 async with pool.acquire() as conn:
695 await conn.execute(
696 """
697 UPDATE webhook_events
698 SET error = $1
699 WHERE event_id = $2
700 """,
701 error, event_id
702 )
704 # ==========================================
705 # Device Authorization (CLI OAuth)
706 # ==========================================
708 def generate_device_codes(self) -> tuple[str, str]:
709 """Generate device and user codes."""
710 device_code = secrets.token_urlsafe(32)
711 # User code: 4 digits (like GitHub CLI) - easy to type
712 user_code = ''.join(secrets.choice('0123456789') for _ in range(4))
713 return device_code, user_code
715 async def create_device_authorization(self) -> Dict[str, Any]:
716 """Create device authorization for CLI."""
717 if not self.is_enabled():
718 raise Exception("Database not configured")
720 pool = await self.get_pool()
722 device_code, user_code = self.generate_device_codes()
723 expires_at = datetime.utcnow() + timedelta(minutes=15)
725 async with pool.acquire() as conn:
726 auth = await conn.fetchrow(
727 """
728 INSERT INTO device_codes (device_code, user_code, expires_at)
729 VALUES ($1, $2, $3)
730 RETURNING device_code, user_code, expires_at
731 """,
732 device_code, user_code, expires_at
733 )
735 return {
736 "device_code": auth['device_code'],
737 "user_code": auth['user_code'],
738 "verification_uri": "https://www.alprina.com/device",
739 "expires_in": 900 # 15 minutes
740 }
742 async def check_device_authorization(
743 self,
744 device_code: str
745 ) -> Optional[Dict[str, Any]]:
746 """Check if device has been authorized."""
747 if not self.is_enabled():
748 return None
750 pool = await self.get_pool()
752 async with pool.acquire() as conn:
753 auth = await conn.fetchrow(
754 """
755 SELECT * FROM device_codes
756 WHERE device_code = $1
757 """,
758 device_code
759 )
761 if not auth:
762 return None
764 # Check if expired
765 if auth['expires_at'] < datetime.utcnow():
766 return {"status": "expired"}
768 # Check if authorized
769 if auth['authorized'] and auth['user_id']:
770 return {
771 "status": "authorized",
772 "user_id": str(auth['user_id']),
773 "user_code": auth['user_code']
774 }
776 # Still pending
777 return {"status": "pending"}
779 async def authorize_device(self, user_code: str, user_id: str) -> bool:
780 """Authorize a device."""
781 if not self.is_enabled():
782 return False
784 pool = await self.get_pool()
786 async with pool.acquire() as conn:
787 result = await conn.execute(
788 """
789 UPDATE device_codes
790 SET user_id = $1, authorized = true
791 WHERE user_code = $2
792 AND expires_at > NOW()
793 AND authorized = false
794 """,
795 user_id, user_code
796 )
798 return result == "UPDATE 1"
800 # ==========================================
801 # Team Management
802 # ==========================================
804 async def get_team_members(self, owner_id: str) -> List[Dict[str, Any]]:
805 """Get all team members for a team owner."""
806 if not self.is_enabled():
807 return []
809 pool = await self.get_pool()
811 async with pool.acquire() as conn:
812 # Get team members from team_members table
813 # We need to join with users to get email and other details
814 members = await conn.fetch(
815 """
816 SELECT
817 u.id,
818 u.email,
819 u.full_name,
820 tm.role,
821 tm.created_at as joined_at
822 FROM team_members tm
823 JOIN users u ON u.id = tm.user_id
824 WHERE tm.subscription_id IN (
825 SELECT id FROM user_subscriptions
826 WHERE user_id = $1
827 )
828 ORDER BY tm.created_at ASC
829 """,
830 owner_id
831 )
833 return [dict(member) for member in members]
835 async def get_team_member_by_email(
836 self,
837 owner_id: str,
838 email: str
839 ) -> Optional[Dict[str, Any]]:
840 """Check if email is already a team member."""
841 if not self.is_enabled():
842 return None
844 pool = await self.get_pool()
846 async with pool.acquire() as conn:
847 member = await conn.fetchrow(
848 """
849 SELECT
850 u.id,
851 u.email,
852 tm.role
853 FROM team_members tm
854 JOIN users u ON u.id = tm.user_id
855 WHERE tm.subscription_id IN (
856 SELECT id FROM user_subscriptions
857 WHERE user_id = $1
858 )
859 AND LOWER(u.email) = LOWER($2)
860 """,
861 owner_id, email
862 )
864 return dict(member) if member else None
866 async def create_team_invitation(
867 self,
868 owner_id: str,
869 invitee_email: str,
870 role: str
871 ) -> Dict[str, Any]:
872 """Create a team invitation."""
873 if not self.is_enabled():
874 raise Exception("Database not configured")
876 pool = await self.get_pool()
878 # Generate invitation token
879 import secrets
880 invitation_token = secrets.token_urlsafe(32)
882 # Get owner email
883 owner = await self.get_user(owner_id)
884 owner_email = owner.get("email") if owner else None
886 async with pool.acquire() as conn:
887 # Create team_invitations table if it doesn't exist
888 await conn.execute(
889 """
890 CREATE TABLE IF NOT EXISTS team_invitations (
891 id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
892 owner_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
893 owner_email TEXT,
894 invitee_email TEXT NOT NULL,
895 role TEXT NOT NULL CHECK (role IN ('admin', 'member')),
896 token TEXT NOT NULL UNIQUE,
897 status TEXT DEFAULT 'pending' CHECK (status IN ('pending', 'accepted', 'expired')),
898 expires_at TIMESTAMPTZ DEFAULT NOW() + INTERVAL '7 days',
899 created_at TIMESTAMPTZ DEFAULT NOW()
900 );
902 CREATE INDEX IF NOT EXISTS idx_team_invitations_token ON team_invitations(token);
903 CREATE INDEX IF NOT EXISTS idx_team_invitations_owner ON team_invitations(owner_id);
904 CREATE INDEX IF NOT EXISTS idx_team_invitations_email ON team_invitations(invitee_email);
905 """
906 )
908 invitation = await conn.fetchrow(
909 """
910 INSERT INTO team_invitations (
911 owner_id, owner_email, invitee_email, role, token
912 )
913 VALUES ($1, $2, $3, $4, $5)
914 RETURNING id, owner_id, owner_email, invitee_email, role, token, created_at
915 """,
916 owner_id, owner_email, invitee_email, role, invitation_token
917 )
919 return dict(invitation)
921 async def get_team_invitation(self, token: str) -> Optional[Dict[str, Any]]:
922 """Get team invitation by token."""
923 if not self.is_enabled():
924 return None
926 pool = await self.get_pool()
928 async with pool.acquire() as conn:
929 invitation = await conn.fetchrow(
930 """
931 SELECT * FROM team_invitations
932 WHERE token = $1
933 AND status = 'pending'
934 AND expires_at > NOW()
935 """,
936 token
937 )
939 return dict(invitation) if invitation else None
941 async def delete_team_invitation(self, token: str) -> bool:
942 """Delete or mark invitation as accepted."""
943 if not self.is_enabled():
944 return False
946 pool = await self.get_pool()
948 async with pool.acquire() as conn:
949 result = await conn.execute(
950 """
951 UPDATE team_invitations
952 SET status = 'accepted'
953 WHERE token = $1
954 """,
955 token
956 )
958 return result == "UPDATE 1"
960 async def add_team_member(
961 self,
962 owner_id: str,
963 member_id: str,
964 role: str
965 ) -> bool:
966 """Add a member to the team."""
967 if not self.is_enabled():
968 return False
970 pool = await self.get_pool()
972 async with pool.acquire() as conn:
973 # Get owner's subscription
974 subscription = await conn.fetchrow(
975 """
976 SELECT id FROM user_subscriptions
977 WHERE user_id = $1
978 ORDER BY created_at DESC
979 LIMIT 1
980 """,
981 owner_id
982 )
984 if not subscription:
985 logger.error(f"No subscription found for owner {owner_id}")
986 return False
988 # Add team member
989 await conn.execute(
990 """
991 INSERT INTO team_members (subscription_id, user_id, role)
992 VALUES ($1, $2, $3)
993 ON CONFLICT (subscription_id, user_id) DO NOTHING
994 """,
995 subscription["id"], member_id, role
996 )
998 logger.info(f"Added team member {member_id} to subscription {subscription['id']}")
999 return True
1001 async def remove_team_member(
1002 self,
1003 owner_id: str,
1004 member_id: str
1005 ) -> bool:
1006 """Remove a team member."""
1007 if not self.is_enabled():
1008 return False
1010 pool = await self.get_pool()
1012 async with pool.acquire() as conn:
1013 result = await conn.execute(
1014 """
1015 DELETE FROM team_members
1016 WHERE subscription_id IN (
1017 SELECT id FROM user_subscriptions
1018 WHERE user_id = $1
1019 )
1020 AND user_id = $2
1021 """,
1022 owner_id, member_id
1023 )
1025 return result == "DELETE 1"
1027 # ==========================================
1028 # Cleanup
1029 # ==========================================
1031 async def close(self):
1032 """Close database connection pool."""
1033 if self.pool:
1034 await self.pool.close()
1035 logger.info("Neon connection pool closed")
1038# Create singleton instance
1039neon_service = NeonService()