Coverage for src/alprina_cli/api/services/usage_service.py: 0%
93 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
1"""
2Usage Tracking Service
4Manages user usage tracking, limits, and enforcement.
5"""
7from typing import Dict, Any, Optional, Tuple
8from datetime import datetime, timedelta
9from loguru import logger
10from fastapi import HTTPException
12from .polar_service import polar_service
15class UsageService:
16 """Service for tracking and enforcing usage limits."""
18 def __init__(self):
19 self.polar = polar_service
21 def get_current_month(self) -> str:
22 """Get current month in YYYY-MM format."""
23 return datetime.utcnow().strftime("%Y-%m")
25 async def get_or_create_usage_record(
26 self,
27 user_id: str,
28 tier: str,
29 db_service
30 ) -> Dict[str, Any]:
31 """
32 Get or create usage record for current month.
34 Args:
35 user_id: User ID
36 tier: User tier
37 db_service: Database service instance
39 Returns:
40 Usage record
41 """
42 current_month = self.get_current_month()
44 # Try to get existing record
45 usage = await db_service.get_usage_record(user_id, current_month)
47 if usage:
48 return usage
50 # Create new record with tier limits
51 limits = self.polar.get_tier_limits(tier)
53 usage = await db_service.create_usage_record(
54 user_id=user_id,
55 month=current_month,
56 scans_limit=limits["scans_per_month"],
57 api_calls_limit=limits["api_requests_per_hour"]
58 )
60 logger.info(f"Created usage record for user {user_id}, month {current_month}")
61 return usage
63 async def check_scan_limit(
64 self,
65 user_id: str,
66 tier: str,
67 db_service
68 ) -> Tuple[bool, Optional[str], Dict[str, Any]]:
69 """
70 Check if user can perform a scan.
72 Args:
73 user_id: User ID
74 tier: User tier
75 db_service: Database service
77 Returns:
78 Tuple of (can_scan, error_message, usage_info)
79 """
80 usage = await self.get_or_create_usage_record(user_id, tier, db_service)
82 scans_count = usage.get("scans_count", 0)
83 scans_limit = usage.get("scans_limit")
85 # No limit (Pro/Enterprise)
86 if scans_limit is None:
87 # Soft limit check for Pro (warn at 1000)
88 if tier == "pro" and scans_count >= 1000:
89 logger.warning(f"User {user_id} exceeded 1000 scans (soft limit)")
90 # Allow but log
92 return True, None, {
93 "scans_used": scans_count,
94 "scans_limit": "unlimited",
95 "scans_remaining": "unlimited"
96 }
98 # Hard limit check
99 if scans_count >= scans_limit:
100 return False, (
101 f"Monthly scan limit reached ({scans_limit} scans). "
102 f"Upgrade to Pro for unlimited scans."
103 ), {
104 "scans_used": scans_count,
105 "scans_limit": scans_limit,
106 "scans_remaining": 0
107 }
109 # Approaching limit warning (90%)
110 if scans_count >= scans_limit * 0.9:
111 logger.warning(
112 f"User {user_id} approaching scan limit: "
113 f"{scans_count}/{scans_limit}"
114 )
116 return True, None, {
117 "scans_used": scans_count,
118 "scans_limit": scans_limit,
119 "scans_remaining": scans_limit - scans_count
120 }
122 async def check_workflow_access(
123 self,
124 tier: str,
125 workflow_mode: str
126 ) -> Tuple[bool, Optional[str]]:
127 """
128 Check if user tier has access to workflow mode.
130 Args:
131 tier: User tier
132 workflow_mode: Workflow mode (parallel, sequential, coordinated)
134 Returns:
135 Tuple of (has_access, error_message)
136 """
137 limits = self.polar.get_tier_limits(tier)
139 # Free and Developer: only single agent scans
140 if workflow_mode == "parallel":
141 if not limits["parallel_scans"]:
142 return False, "Parallel scans require Pro tier or higher"
144 elif workflow_mode == "sequential":
145 if not limits["sequential_scans"]:
146 return False, "Sequential workflows require Pro tier or higher"
148 elif workflow_mode == "coordinated":
149 if not limits["coordinated_chains"]:
150 return False, "Coordinated agent chains require Pro tier or higher"
152 return True, None
154 async def check_file_limit(
155 self,
156 tier: str,
157 file_count: int
158 ) -> Tuple[bool, Optional[str]]:
159 """
160 Check if file count is within tier limits.
162 Args:
163 tier: User tier
164 file_count: Number of files to scan
166 Returns:
167 Tuple of (within_limit, error_message)
168 """
169 limits = self.polar.get_tier_limits(tier)
170 files_per_scan = limits["files_per_scan"]
172 # No limit
173 if files_per_scan is None:
174 return True, None
176 if file_count > files_per_scan:
177 return False, (
178 f"File count ({file_count}) exceeds tier limit "
179 f"({files_per_scan} files per scan). "
180 f"Upgrade to Pro for higher limits."
181 )
183 return True, None
185 async def increment_scan_count(
186 self,
187 user_id: str,
188 tier: str,
189 workflow_mode: str,
190 file_count: int,
191 db_service
192 ) -> Dict[str, Any]:
193 """
194 Increment scan count after successful scan.
196 Args:
197 user_id: User ID
198 tier: User tier
199 workflow_mode: Workflow mode used
200 file_count: Files scanned
201 db_service: Database service
203 Returns:
204 Updated usage record
205 """
206 usage = await self.get_or_create_usage_record(user_id, tier, db_service)
208 updates = {
209 "scans_count": usage["scans_count"] + 1,
210 "files_scanned_total": usage.get("files_scanned_total", 0) + file_count
211 }
213 # Track workflow mode usage
214 if workflow_mode == "parallel":
215 updates["parallel_scans_count"] = usage.get("parallel_scans_count", 0) + 1
216 elif workflow_mode == "sequential":
217 updates["sequential_scans_count"] = usage.get("sequential_scans_count", 0) + 1
218 elif workflow_mode == "coordinated":
219 updates["coordinated_chains_count"] = usage.get("coordinated_chains_count", 0) + 1
221 usage = await db_service.update_usage_record(
222 user_id,
223 self.get_current_month(),
224 updates
225 )
227 logger.info(
228 f"Incremented scan count for user {user_id}: "
229 f"{updates['scans_count']} scans"
230 )
232 return usage
234 async def record_scan(
235 self,
236 user_id: str,
237 scan_data: Dict[str, Any],
238 db_service
239 ) -> Dict[str, Any]:
240 """
241 Record a scan in history.
243 Args:
244 user_id: User ID
245 scan_data: Scan details
246 db_service: Database service
248 Returns:
249 Scan history record
250 """
251 record = await db_service.create_scan_history(
252 user_id=user_id,
253 scan_type=scan_data.get("scan_type", "code"),
254 agent_used=scan_data.get("agent", "unknown"),
255 target=scan_data.get("target", ""),
256 files_count=scan_data.get("files_count", 0),
257 findings_count=scan_data.get("findings_count", 0),
258 critical_findings=scan_data.get("critical_findings", 0),
259 high_findings=scan_data.get("high_findings", 0),
260 medium_findings=scan_data.get("medium_findings", 0),
261 low_findings=scan_data.get("low_findings", 0),
262 workflow_mode=scan_data.get("workflow_mode", "single"),
263 duration_seconds=scan_data.get("duration", 0),
264 status=scan_data.get("status", "completed")
265 )
267 return record
269 async def get_usage_stats(
270 self,
271 user_id: str,
272 tier: str,
273 db_service
274 ) -> Dict[str, Any]:
275 """
276 Get usage statistics for user.
278 Args:
279 user_id: User ID
280 tier: User tier
281 db_service: Database service
283 Returns:
284 Usage statistics
285 """
286 usage = await self.get_or_create_usage_record(user_id, tier, db_service)
287 limits = self.polar.get_tier_limits(tier)
289 # Get scan history summary
290 current_month = self.get_current_month()
291 scan_history = await db_service.get_user_scan_history(user_id, limit=10)
293 # Calculate percentage used
294 scans_limit = usage.get("scans_limit")
295 if scans_limit:
296 usage_percentage = (usage["scans_count"] / scans_limit) * 100
297 else:
298 usage_percentage = 0
300 return {
301 "current_period": {
302 "month": current_month,
303 "scans_used": usage["scans_count"],
304 "scans_limit": scans_limit or "unlimited",
305 "scans_remaining": (scans_limit - usage["scans_count"]) if scans_limit else "unlimited",
306 "usage_percentage": round(usage_percentage, 1),
307 "files_scanned": usage.get("files_scanned_total", 0),
308 "reports_generated": usage.get("reports_generated", 0)
309 },
310 "workflows": {
311 "parallel_scans": usage.get("parallel_scans_count", 0),
312 "sequential_scans": usage.get("sequential_scans_count", 0),
313 "coordinated_chains": usage.get("coordinated_chains_count", 0)
314 },
315 "tier_limits": limits,
316 "recent_scans": scan_history,
317 "reset_date": self._get_next_reset_date()
318 }
320 def _get_next_reset_date(self) -> str:
321 """Get next monthly reset date."""
322 now = datetime.utcnow()
323 next_month = now.replace(day=1) + timedelta(days=32)
324 next_month = next_month.replace(day=1)
325 return next_month.strftime("%Y-%m-%d")
327 async def enforce_rate_limit(
328 self,
329 user_id: str,
330 tier: str,
331 db_service
332 ) -> Tuple[bool, Optional[str]]:
333 """
334 Check API rate limit.
336 Args:
337 user_id: User ID
338 tier: User tier
339 db_service: Database service
341 Returns:
342 Tuple of (within_limit, error_message)
343 """
344 # Simple hourly rate limit check
345 # In production, use Redis for better rate limiting
346 limits = self.polar.get_tier_limits(tier)
347 api_limit = limits["api_requests_per_hour"]
349 if api_limit is None:
350 return True, None # No limit
352 # Get API calls in last hour
353 one_hour_ago = datetime.utcnow() - timedelta(hours=1)
354 api_calls = await db_service.count_api_calls_since(user_id, one_hour_ago)
356 if api_calls >= api_limit:
357 return False, (
358 f"API rate limit exceeded ({api_limit} requests/hour). "
359 f"Please try again later or upgrade to Pro for higher limits."
360 )
362 return True, None
365# Create singleton instance
366usage_service = UsageService()