Coverage for src/alprina_cli/api/services/usage_service.py: 0%

93 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-14 11:27 +0100

1""" 

2Usage Tracking Service 

3 

4Manages user usage tracking, limits, and enforcement. 

5""" 

6 

7from typing import Dict, Any, Optional, Tuple 

8from datetime import datetime, timedelta 

9from loguru import logger 

10from fastapi import HTTPException 

11 

12from .polar_service import polar_service 

13 

14 

15class UsageService: 

16 """Service for tracking and enforcing usage limits.""" 

17 

18 def __init__(self): 

19 self.polar = polar_service 

20 

21 def get_current_month(self) -> str: 

22 """Get current month in YYYY-MM format.""" 

23 return datetime.utcnow().strftime("%Y-%m") 

24 

25 async def get_or_create_usage_record( 

26 self, 

27 user_id: str, 

28 tier: str, 

29 db_service 

30 ) -> Dict[str, Any]: 

31 """ 

32 Get or create usage record for current month. 

33 

34 Args: 

35 user_id: User ID 

36 tier: User tier 

37 db_service: Database service instance 

38 

39 Returns: 

40 Usage record 

41 """ 

42 current_month = self.get_current_month() 

43 

44 # Try to get existing record 

45 usage = await db_service.get_usage_record(user_id, current_month) 

46 

47 if usage: 

48 return usage 

49 

50 # Create new record with tier limits 

51 limits = self.polar.get_tier_limits(tier) 

52 

53 usage = await db_service.create_usage_record( 

54 user_id=user_id, 

55 month=current_month, 

56 scans_limit=limits["scans_per_month"], 

57 api_calls_limit=limits["api_requests_per_hour"] 

58 ) 

59 

60 logger.info(f"Created usage record for user {user_id}, month {current_month}") 

61 return usage 

62 

63 async def check_scan_limit( 

64 self, 

65 user_id: str, 

66 tier: str, 

67 db_service 

68 ) -> Tuple[bool, Optional[str], Dict[str, Any]]: 

69 """ 

70 Check if user can perform a scan. 

71 

72 Args: 

73 user_id: User ID 

74 tier: User tier 

75 db_service: Database service 

76 

77 Returns: 

78 Tuple of (can_scan, error_message, usage_info) 

79 """ 

80 usage = await self.get_or_create_usage_record(user_id, tier, db_service) 

81 

82 scans_count = usage.get("scans_count", 0) 

83 scans_limit = usage.get("scans_limit") 

84 

85 # No limit (Pro/Enterprise) 

86 if scans_limit is None: 

87 # Soft limit check for Pro (warn at 1000) 

88 if tier == "pro" and scans_count >= 1000: 

89 logger.warning(f"User {user_id} exceeded 1000 scans (soft limit)") 

90 # Allow but log 

91 

92 return True, None, { 

93 "scans_used": scans_count, 

94 "scans_limit": "unlimited", 

95 "scans_remaining": "unlimited" 

96 } 

97 

98 # Hard limit check 

99 if scans_count >= scans_limit: 

100 return False, ( 

101 f"Monthly scan limit reached ({scans_limit} scans). " 

102 f"Upgrade to Pro for unlimited scans." 

103 ), { 

104 "scans_used": scans_count, 

105 "scans_limit": scans_limit, 

106 "scans_remaining": 0 

107 } 

108 

109 # Approaching limit warning (90%) 

110 if scans_count >= scans_limit * 0.9: 

111 logger.warning( 

112 f"User {user_id} approaching scan limit: " 

113 f"{scans_count}/{scans_limit}" 

114 ) 

115 

116 return True, None, { 

117 "scans_used": scans_count, 

118 "scans_limit": scans_limit, 

119 "scans_remaining": scans_limit - scans_count 

120 } 

121 

122 async def check_workflow_access( 

123 self, 

124 tier: str, 

125 workflow_mode: str 

126 ) -> Tuple[bool, Optional[str]]: 

127 """ 

128 Check if user tier has access to workflow mode. 

129 

130 Args: 

131 tier: User tier 

132 workflow_mode: Workflow mode (parallel, sequential, coordinated) 

133 

134 Returns: 

135 Tuple of (has_access, error_message) 

136 """ 

137 limits = self.polar.get_tier_limits(tier) 

138 

139 # Free and Developer: only single agent scans 

140 if workflow_mode == "parallel": 

141 if not limits["parallel_scans"]: 

142 return False, "Parallel scans require Pro tier or higher" 

143 

144 elif workflow_mode == "sequential": 

145 if not limits["sequential_scans"]: 

146 return False, "Sequential workflows require Pro tier or higher" 

147 

148 elif workflow_mode == "coordinated": 

149 if not limits["coordinated_chains"]: 

150 return False, "Coordinated agent chains require Pro tier or higher" 

151 

152 return True, None 

153 

154 async def check_file_limit( 

155 self, 

156 tier: str, 

157 file_count: int 

158 ) -> Tuple[bool, Optional[str]]: 

159 """ 

160 Check if file count is within tier limits. 

161 

162 Args: 

163 tier: User tier 

164 file_count: Number of files to scan 

165 

166 Returns: 

167 Tuple of (within_limit, error_message) 

168 """ 

169 limits = self.polar.get_tier_limits(tier) 

170 files_per_scan = limits["files_per_scan"] 

171 

172 # No limit 

173 if files_per_scan is None: 

174 return True, None 

175 

176 if file_count > files_per_scan: 

177 return False, ( 

178 f"File count ({file_count}) exceeds tier limit " 

179 f"({files_per_scan} files per scan). " 

180 f"Upgrade to Pro for higher limits." 

181 ) 

182 

183 return True, None 

184 

185 async def increment_scan_count( 

186 self, 

187 user_id: str, 

188 tier: str, 

189 workflow_mode: str, 

190 file_count: int, 

191 db_service 

192 ) -> Dict[str, Any]: 

193 """ 

194 Increment scan count after successful scan. 

195 

196 Args: 

197 user_id: User ID 

198 tier: User tier 

199 workflow_mode: Workflow mode used 

200 file_count: Files scanned 

201 db_service: Database service 

202 

203 Returns: 

204 Updated usage record 

205 """ 

206 usage = await self.get_or_create_usage_record(user_id, tier, db_service) 

207 

208 updates = { 

209 "scans_count": usage["scans_count"] + 1, 

210 "files_scanned_total": usage.get("files_scanned_total", 0) + file_count 

211 } 

212 

213 # Track workflow mode usage 

214 if workflow_mode == "parallel": 

215 updates["parallel_scans_count"] = usage.get("parallel_scans_count", 0) + 1 

216 elif workflow_mode == "sequential": 

217 updates["sequential_scans_count"] = usage.get("sequential_scans_count", 0) + 1 

218 elif workflow_mode == "coordinated": 

219 updates["coordinated_chains_count"] = usage.get("coordinated_chains_count", 0) + 1 

220 

221 usage = await db_service.update_usage_record( 

222 user_id, 

223 self.get_current_month(), 

224 updates 

225 ) 

226 

227 logger.info( 

228 f"Incremented scan count for user {user_id}: " 

229 f"{updates['scans_count']} scans" 

230 ) 

231 

232 return usage 

233 

234 async def record_scan( 

235 self, 

236 user_id: str, 

237 scan_data: Dict[str, Any], 

238 db_service 

239 ) -> Dict[str, Any]: 

240 """ 

241 Record a scan in history. 

242 

243 Args: 

244 user_id: User ID 

245 scan_data: Scan details 

246 db_service: Database service 

247 

248 Returns: 

249 Scan history record 

250 """ 

251 record = await db_service.create_scan_history( 

252 user_id=user_id, 

253 scan_type=scan_data.get("scan_type", "code"), 

254 agent_used=scan_data.get("agent", "unknown"), 

255 target=scan_data.get("target", ""), 

256 files_count=scan_data.get("files_count", 0), 

257 findings_count=scan_data.get("findings_count", 0), 

258 critical_findings=scan_data.get("critical_findings", 0), 

259 high_findings=scan_data.get("high_findings", 0), 

260 medium_findings=scan_data.get("medium_findings", 0), 

261 low_findings=scan_data.get("low_findings", 0), 

262 workflow_mode=scan_data.get("workflow_mode", "single"), 

263 duration_seconds=scan_data.get("duration", 0), 

264 status=scan_data.get("status", "completed") 

265 ) 

266 

267 return record 

268 

269 async def get_usage_stats( 

270 self, 

271 user_id: str, 

272 tier: str, 

273 db_service 

274 ) -> Dict[str, Any]: 

275 """ 

276 Get usage statistics for user. 

277 

278 Args: 

279 user_id: User ID 

280 tier: User tier 

281 db_service: Database service 

282 

283 Returns: 

284 Usage statistics 

285 """ 

286 usage = await self.get_or_create_usage_record(user_id, tier, db_service) 

287 limits = self.polar.get_tier_limits(tier) 

288 

289 # Get scan history summary 

290 current_month = self.get_current_month() 

291 scan_history = await db_service.get_user_scan_history(user_id, limit=10) 

292 

293 # Calculate percentage used 

294 scans_limit = usage.get("scans_limit") 

295 if scans_limit: 

296 usage_percentage = (usage["scans_count"] / scans_limit) * 100 

297 else: 

298 usage_percentage = 0 

299 

300 return { 

301 "current_period": { 

302 "month": current_month, 

303 "scans_used": usage["scans_count"], 

304 "scans_limit": scans_limit or "unlimited", 

305 "scans_remaining": (scans_limit - usage["scans_count"]) if scans_limit else "unlimited", 

306 "usage_percentage": round(usage_percentage, 1), 

307 "files_scanned": usage.get("files_scanned_total", 0), 

308 "reports_generated": usage.get("reports_generated", 0) 

309 }, 

310 "workflows": { 

311 "parallel_scans": usage.get("parallel_scans_count", 0), 

312 "sequential_scans": usage.get("sequential_scans_count", 0), 

313 "coordinated_chains": usage.get("coordinated_chains_count", 0) 

314 }, 

315 "tier_limits": limits, 

316 "recent_scans": scan_history, 

317 "reset_date": self._get_next_reset_date() 

318 } 

319 

320 def _get_next_reset_date(self) -> str: 

321 """Get next monthly reset date.""" 

322 now = datetime.utcnow() 

323 next_month = now.replace(day=1) + timedelta(days=32) 

324 next_month = next_month.replace(day=1) 

325 return next_month.strftime("%Y-%m-%d") 

326 

327 async def enforce_rate_limit( 

328 self, 

329 user_id: str, 

330 tier: str, 

331 db_service 

332 ) -> Tuple[bool, Optional[str]]: 

333 """ 

334 Check API rate limit. 

335 

336 Args: 

337 user_id: User ID 

338 tier: User tier 

339 db_service: Database service 

340 

341 Returns: 

342 Tuple of (within_limit, error_message) 

343 """ 

344 # Simple hourly rate limit check 

345 # In production, use Redis for better rate limiting 

346 limits = self.polar.get_tier_limits(tier) 

347 api_limit = limits["api_requests_per_hour"] 

348 

349 if api_limit is None: 

350 return True, None # No limit 

351 

352 # Get API calls in last hour 

353 one_hour_ago = datetime.utcnow() - timedelta(hours=1) 

354 api_calls = await db_service.count_api_calls_since(user_id, one_hour_ago) 

355 

356 if api_calls >= api_limit: 

357 return False, ( 

358 f"API rate limit exceeded ({api_limit} requests/hour). " 

359 f"Please try again later or upgrade to Pro for higher limits." 

360 ) 

361 

362 return True, None 

363 

364 

365# Create singleton instance 

366usage_service = UsageService()