Coverage for src/alprina_cli/api/services/database_service.py: 0%

129 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-14 11:27 +0100

1""" 

2Database Service for SQLAlchemy Operations 

3Handles direct database access for usage tracking and Polar integration. 

4""" 

5 

6import os 

7from typing import Dict, Any, Optional, List 

8from sqlalchemy import create_engine 

9from sqlalchemy.orm import sessionmaker, Session 

10from sqlalchemy.pool import NullPool 

11from contextlib import contextmanager 

12from loguru import logger 

13 

14from ..models.database import Base, User, UsageTracking, ScanHistory, APIKey, PolarWebhook 

15 

16 

17class DatabaseService: 

18 """Service for SQLAlchemy database operations.""" 

19 

20 def __init__(self): 

21 """Initialize database connection.""" 

22 # Use DATABASE_URL for direct database access 

23 self.database_url = os.getenv("DATABASE_URL") 

24 

25 if not self.database_url: 

26 logger.warning("DATABASE_URL not set - database operations disabled") 

27 self.engine = None 

28 self.SessionLocal = None 

29 self.enabled = False 

30 return 

31 

32 try: 

33 # Create engine with NullPool to avoid connection issues 

34 self.engine = create_engine( 

35 self.database_url, 

36 poolclass=NullPool, 

37 echo=False 

38 ) 

39 

40 # Create session factory 

41 self.SessionLocal = sessionmaker(bind=self.engine) 

42 self.enabled = True 

43 

44 logger.info("✅ Database service initialized successfully") 

45 

46 except Exception as e: 

47 logger.error(f"Failed to initialize database: {e}") 

48 self.engine = None 

49 self.SessionLocal = None 

50 self.enabled = False 

51 

52 def is_enabled(self) -> bool: 

53 """Check if database is available.""" 

54 return self.enabled and self.engine is not None 

55 

56 @contextmanager 

57 def get_session(self) -> Session: 

58 """Get database session with automatic cleanup.""" 

59 if not self.is_enabled(): 

60 raise Exception("Database not available") 

61 

62 session = self.SessionLocal() 

63 try: 

64 yield session 

65 session.commit() 

66 except Exception as e: 

67 session.rollback() 

68 raise e 

69 finally: 

70 session.close() 

71 

72 # ========================================== 

73 # Usage Tracking Operations 

74 # ========================================== 

75 

76 async def get_usage_record(self, user_id: str, month: str) -> Optional[Dict[str, Any]]: 

77 """Get usage record for specific user and month.""" 

78 if not self.is_enabled(): 

79 return None 

80 

81 with self.get_session() as session: 

82 record = session.query(UsageTracking).filter( 

83 UsageTracking.user_id == user_id, 

84 UsageTracking.month == month 

85 ).first() 

86 

87 if not record: 

88 return None 

89 

90 return { 

91 "id": record.id, 

92 "user_id": record.user_id, 

93 "month": record.month, 

94 "scans_count": record.scans_count, 

95 "scans_limit": record.scans_limit, 

96 "files_scanned_total": record.files_scanned_total, 

97 "api_calls_count": record.api_calls_count, 

98 "api_calls_limit": record.api_calls_limit, 

99 "parallel_scans_count": record.parallel_scans_count, 

100 "sequential_scans_count": record.sequential_scans_count, 

101 "coordinated_chains_count": record.coordinated_chains_count 

102 } 

103 

104 async def create_usage_record(self, data: Dict[str, Any]) -> Dict[str, Any]: 

105 """Create new usage record.""" 

106 if not self.is_enabled(): 

107 return data 

108 

109 with self.get_session() as session: 

110 record = UsageTracking(**data) 

111 session.add(record) 

112 session.flush() 

113 

114 return { 

115 "id": record.id, 

116 "user_id": record.user_id, 

117 "month": record.month, 

118 "scans_count": record.scans_count, 

119 "scans_limit": record.scans_limit 

120 } 

121 

122 async def update_usage_record(self, user_id: str, month: str, updates: Dict[str, Any]) -> bool: 

123 """Update usage record.""" 

124 if not self.is_enabled(): 

125 return False 

126 

127 with self.get_session() as session: 

128 record = session.query(UsageTracking).filter( 

129 UsageTracking.user_id == user_id, 

130 UsageTracking.month == month 

131 ).first() 

132 

133 if not record: 

134 return False 

135 

136 for key, value in updates.items(): 

137 setattr(record, key, value) 

138 

139 return True 

140 

141 async def increment_scan_count(self, user_id: str, month: str) -> bool: 

142 """Increment scan count for user.""" 

143 if not self.is_enabled(): 

144 return False 

145 

146 with self.get_session() as session: 

147 record = session.query(UsageTracking).filter( 

148 UsageTracking.user_id == user_id, 

149 UsageTracking.month == month 

150 ).first() 

151 

152 if not record: 

153 return False 

154 

155 record.scans_count += 1 

156 return True 

157 

158 # ========================================== 

159 # Scan History Operations 

160 # ========================================== 

161 

162 async def create_scan_history(self, data: Dict[str, Any]) -> Dict[str, Any]: 

163 """Create scan history record.""" 

164 if not self.is_enabled(): 

165 return data 

166 

167 with self.get_session() as session: 

168 scan = ScanHistory(**data) 

169 session.add(scan) 

170 session.flush() 

171 

172 return {"id": scan.id, "created_at": scan.created_at} 

173 

174 async def get_scan_history(self, user_id: str, limit: int = 10) -> List[Dict[str, Any]]: 

175 """Get recent scan history for user.""" 

176 if not self.is_enabled(): 

177 return [] 

178 

179 with self.get_session() as session: 

180 scans = session.query(ScanHistory).filter( 

181 ScanHistory.user_id == user_id 

182 ).order_by(ScanHistory.created_at.desc()).limit(limit).all() 

183 

184 return [ 

185 { 

186 "id": scan.id, 

187 "scan_type": scan.scan_type, 

188 "agent_used": scan.agent_used, 

189 "findings_count": scan.findings_count, 

190 "critical_findings": scan.critical_findings, 

191 "created_at": scan.created_at 

192 } 

193 for scan in scans 

194 ] 

195 

196 # ========================================== 

197 # Polar Webhook Operations 

198 # ========================================== 

199 

200 async def log_webhook_event(self, event_type: str, polar_event_id: str, payload: Dict[str, Any]) -> Dict[str, Any]: 

201 """Log Polar webhook event.""" 

202 if not self.is_enabled(): 

203 return {} 

204 

205 with self.get_session() as session: 

206 webhook = PolarWebhook( 

207 event_type=event_type, 

208 polar_event_id=polar_event_id, 

209 payload=payload, 

210 polar_customer_id=payload.get("data", {}).get("customer_id"), 

211 polar_subscription_id=payload.get("data", {}).get("subscription_id") 

212 ) 

213 session.add(webhook) 

214 session.flush() 

215 

216 return {"id": webhook.id, "created_at": webhook.created_at} 

217 

218 async def mark_webhook_processed(self, polar_event_id: str, error_message: Optional[str] = None) -> bool: 

219 """Mark webhook as processed.""" 

220 if not self.is_enabled(): 

221 return False 

222 

223 with self.get_session() as session: 

224 webhook = session.query(PolarWebhook).filter( 

225 PolarWebhook.polar_event_id == polar_event_id 

226 ).first() 

227 

228 if not webhook: 

229 return False 

230 

231 webhook.processed = error_message is None 

232 webhook.processed_at = os.time() 

233 if error_message: 

234 webhook.error_message = error_message 

235 

236 return True 

237 

238 # ========================================== 

239 # User Operations 

240 # ========================================== 

241 

242 async def update_user(self, user_id: str, updates: Dict[str, Any]) -> bool: 

243 """Update user record.""" 

244 if not self.is_enabled(): 

245 return False 

246 

247 with self.get_session() as session: 

248 user = session.query(User).filter(User.id == user_id).first() 

249 

250 if not user: 

251 return False 

252 

253 for key, value in updates.items(): 

254 setattr(user, key, value) 

255 

256 return True 

257 

258 async def get_user_by_polar_customer(self, polar_customer_id: str) -> Optional[Dict[str, Any]]: 

259 """Get user by Polar customer ID.""" 

260 if not self.is_enabled(): 

261 return None 

262 

263 with self.get_session() as session: 

264 user = session.query(User).filter( 

265 User.polar_customer_id == polar_customer_id 

266 ).first() 

267 

268 if not user: 

269 return None 

270 

271 return { 

272 "id": user.id, 

273 "email": user.email, 

274 "tier": user.tier, 

275 "polar_customer_id": user.polar_customer_id, 

276 "polar_subscription_id": user.polar_subscription_id, 

277 "subscription_status": user.subscription_status 

278 } 

279 

280 

281# Global instance 

282database_service = DatabaseService()