Coverage for src/alprina_cli/agents/llm_config.py: 46%
57 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
1"""
2LLM Configuration and Cost Controls
3Manages LLM usage limits, model selection, and cost optimization
4"""
6from typing import Dict, Any
7from enum import Enum
10class ModelTier(Enum):
11 """Model tiers for different use cases"""
12 COMPLEX = "claude-3-5-sonnet-20241022" # Best quality, higher cost
13 SIMPLE = "claude-3-haiku-20240307" # Fast and cheap
14 BALANCED = "claude-3-5-sonnet-20241022" # Default choice
17class LLMConfig:
18 """LLM cost controls and configuration"""
20 # Limits
21 MAX_VULNS_TO_ENHANCE = 5 # Only enhance top 5 vulnerabilities
22 MAX_TOKENS = 2000 # Token limit per request
23 TIMEOUT_SECONDS = 30 # API timeout
25 # Model selection
26 DEFAULT_MODEL = ModelTier.COMPLEX.value
28 # Cost optimization flags
29 USE_HAIKU_FOR_LOW_SEVERITY = True # Use cheaper model for low/medium
30 ENABLE_CACHING = True # Cache responses (future feature)
32 # Severity-based model selection
33 SEVERITY_MODEL_MAP = {
34 'critical': ModelTier.COMPLEX.value,
35 'high': ModelTier.COMPLEX.value,
36 'medium': ModelTier.SIMPLE.value if USE_HAIKU_FOR_LOW_SEVERITY else ModelTier.COMPLEX.value,
37 'low': ModelTier.SIMPLE.value if USE_HAIKU_FOR_LOW_SEVERITY else ModelTier.COMPLEX.value,
38 }
40 # Cost estimates (as of Jan 2025)
41 COST_PER_1K_TOKENS = {
42 ModelTier.COMPLEX.value: {
43 'input': 0.003, # $3 per million tokens
44 'output': 0.015, # $15 per million tokens
45 },
46 ModelTier.SIMPLE.value: {
47 'input': 0.00025, # $0.25 per million tokens
48 'output': 0.00125, # $1.25 per million tokens
49 }
50 }
52 @staticmethod
53 def select_model(severity: str) -> str:
54 """
55 Select appropriate model based on vulnerability severity
57 Args:
58 severity: Vulnerability severity (critical, high, medium, low)
60 Returns:
61 Model identifier string
62 """
63 return LLMConfig.SEVERITY_MODEL_MAP.get(
64 severity.lower(),
65 LLMConfig.DEFAULT_MODEL
66 )
68 @staticmethod
69 def estimate_cost(
70 num_vulnerabilities: int,
71 avg_input_tokens: int = 800,
72 avg_output_tokens: int = 1500
73 ) -> Dict[str, Any]:
74 """
75 Estimate LLM enhancement cost
77 Args:
78 num_vulnerabilities: Number of vulnerabilities to enhance
79 avg_input_tokens: Average input tokens per request
80 avg_output_tokens: Average output tokens per request
82 Returns:
83 Dictionary with cost breakdown
84 """
85 # Assume 60% critical/high (Sonnet), 40% medium/low (Haiku if enabled)
86 complex_ratio = 0.6
87 simple_ratio = 0.4 if LLMConfig.USE_HAIKU_FOR_LOW_SEVERITY else 0.0
89 complex_count = int(num_vulnerabilities * complex_ratio)
90 simple_count = num_vulnerabilities - complex_count
92 # Calculate costs
93 complex_cost = (
94 complex_count * (
95 (avg_input_tokens / 1000) * LLMConfig.COST_PER_1K_TOKENS[ModelTier.COMPLEX.value]['input'] +
96 (avg_output_tokens / 1000) * LLMConfig.COST_PER_1K_TOKENS[ModelTier.COMPLEX.value]['output']
97 )
98 )
100 simple_cost = (
101 simple_count * (
102 (avg_input_tokens / 1000) * LLMConfig.COST_PER_1K_TOKENS[ModelTier.SIMPLE.value]['input'] +
103 (avg_output_tokens / 1000) * LLMConfig.COST_PER_1K_TOKENS[ModelTier.SIMPLE.value]['output']
104 )
105 )
107 total_cost = complex_cost + simple_cost
109 return {
110 'total_cost': round(total_cost, 4),
111 'complex_vulns': complex_count,
112 'complex_cost': round(complex_cost, 4),
113 'simple_vulns': simple_count,
114 'simple_cost': round(simple_cost, 4),
115 'cost_per_vuln': round(total_cost / num_vulnerabilities, 4) if num_vulnerabilities > 0 else 0,
116 }
118 @staticmethod
119 def should_enhance(
120 vulnerability: Dict[str, Any],
121 current_count: int
122 ) -> bool:
123 """
124 Determine if vulnerability should be enhanced with LLM
126 Args:
127 vulnerability: Vulnerability data
128 current_count: Number of vulnerabilities already enhanced
130 Returns:
131 True if should enhance, False otherwise
132 """
133 # Check limit
134 if current_count >= LLMConfig.MAX_VULNS_TO_ENHANCE:
135 return False
137 # Always enhance critical/high
138 severity = vulnerability.get('severity', '').lower()
139 if severity in ['critical', 'high']:
140 return True
142 # Enhance medium/low only if under limit
143 return current_count < LLMConfig.MAX_VULNS_TO_ENHANCE
146class UsageTracker:
147 """Track LLM API usage for cost monitoring"""
149 def __init__(self):
150 self.requests = 0
151 self.total_input_tokens = 0
152 self.total_output_tokens = 0
153 self.total_cost = 0.0
155 def record_request(
156 self,
157 model: str,
158 input_tokens: int,
159 output_tokens: int
160 ):
161 """Record an API request for cost tracking"""
162 self.requests += 1
163 self.total_input_tokens += input_tokens
164 self.total_output_tokens += output_tokens
166 # Calculate cost
167 cost_config = LLMConfig.COST_PER_1K_TOKENS.get(model, {})
168 input_cost = (input_tokens / 1000) * cost_config.get('input', 0)
169 output_cost = (output_tokens / 1000) * cost_config.get('output', 0)
171 self.total_cost += (input_cost + output_cost)
173 def get_summary(self) -> Dict[str, Any]:
174 """Get usage summary"""
175 return {
176 'requests': self.requests,
177 'total_input_tokens': self.total_input_tokens,
178 'total_output_tokens': self.total_output_tokens,
179 'total_cost': round(self.total_cost, 4),
180 'avg_cost_per_request': round(
181 self.total_cost / self.requests, 4
182 ) if self.requests > 0 else 0,
183 }
185 def reset(self):
186 """Reset usage tracking"""
187 self.requests = 0
188 self.total_input_tokens = 0
189 self.total_output_tokens = 0
190 self.total_cost = 0.0