Coverage for src/alprina_cli/agents/llm_config.py: 46%

57 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-14 11:27 +0100

1""" 

2LLM Configuration and Cost Controls 

3Manages LLM usage limits, model selection, and cost optimization 

4""" 

5 

6from typing import Dict, Any 

7from enum import Enum 

8 

9 

10class ModelTier(Enum): 

11 """Model tiers for different use cases""" 

12 COMPLEX = "claude-3-5-sonnet-20241022" # Best quality, higher cost 

13 SIMPLE = "claude-3-haiku-20240307" # Fast and cheap 

14 BALANCED = "claude-3-5-sonnet-20241022" # Default choice 

15 

16 

17class LLMConfig: 

18 """LLM cost controls and configuration""" 

19 

20 # Limits 

21 MAX_VULNS_TO_ENHANCE = 5 # Only enhance top 5 vulnerabilities 

22 MAX_TOKENS = 2000 # Token limit per request 

23 TIMEOUT_SECONDS = 30 # API timeout 

24 

25 # Model selection 

26 DEFAULT_MODEL = ModelTier.COMPLEX.value 

27 

28 # Cost optimization flags 

29 USE_HAIKU_FOR_LOW_SEVERITY = True # Use cheaper model for low/medium 

30 ENABLE_CACHING = True # Cache responses (future feature) 

31 

32 # Severity-based model selection 

33 SEVERITY_MODEL_MAP = { 

34 'critical': ModelTier.COMPLEX.value, 

35 'high': ModelTier.COMPLEX.value, 

36 'medium': ModelTier.SIMPLE.value if USE_HAIKU_FOR_LOW_SEVERITY else ModelTier.COMPLEX.value, 

37 'low': ModelTier.SIMPLE.value if USE_HAIKU_FOR_LOW_SEVERITY else ModelTier.COMPLEX.value, 

38 } 

39 

40 # Cost estimates (as of Jan 2025) 

41 COST_PER_1K_TOKENS = { 

42 ModelTier.COMPLEX.value: { 

43 'input': 0.003, # $3 per million tokens 

44 'output': 0.015, # $15 per million tokens 

45 }, 

46 ModelTier.SIMPLE.value: { 

47 'input': 0.00025, # $0.25 per million tokens 

48 'output': 0.00125, # $1.25 per million tokens 

49 } 

50 } 

51 

52 @staticmethod 

53 def select_model(severity: str) -> str: 

54 """ 

55 Select appropriate model based on vulnerability severity 

56 

57 Args: 

58 severity: Vulnerability severity (critical, high, medium, low) 

59 

60 Returns: 

61 Model identifier string 

62 """ 

63 return LLMConfig.SEVERITY_MODEL_MAP.get( 

64 severity.lower(), 

65 LLMConfig.DEFAULT_MODEL 

66 ) 

67 

68 @staticmethod 

69 def estimate_cost( 

70 num_vulnerabilities: int, 

71 avg_input_tokens: int = 800, 

72 avg_output_tokens: int = 1500 

73 ) -> Dict[str, Any]: 

74 """ 

75 Estimate LLM enhancement cost 

76 

77 Args: 

78 num_vulnerabilities: Number of vulnerabilities to enhance 

79 avg_input_tokens: Average input tokens per request 

80 avg_output_tokens: Average output tokens per request 

81 

82 Returns: 

83 Dictionary with cost breakdown 

84 """ 

85 # Assume 60% critical/high (Sonnet), 40% medium/low (Haiku if enabled) 

86 complex_ratio = 0.6 

87 simple_ratio = 0.4 if LLMConfig.USE_HAIKU_FOR_LOW_SEVERITY else 0.0 

88 

89 complex_count = int(num_vulnerabilities * complex_ratio) 

90 simple_count = num_vulnerabilities - complex_count 

91 

92 # Calculate costs 

93 complex_cost = ( 

94 complex_count * ( 

95 (avg_input_tokens / 1000) * LLMConfig.COST_PER_1K_TOKENS[ModelTier.COMPLEX.value]['input'] + 

96 (avg_output_tokens / 1000) * LLMConfig.COST_PER_1K_TOKENS[ModelTier.COMPLEX.value]['output'] 

97 ) 

98 ) 

99 

100 simple_cost = ( 

101 simple_count * ( 

102 (avg_input_tokens / 1000) * LLMConfig.COST_PER_1K_TOKENS[ModelTier.SIMPLE.value]['input'] + 

103 (avg_output_tokens / 1000) * LLMConfig.COST_PER_1K_TOKENS[ModelTier.SIMPLE.value]['output'] 

104 ) 

105 ) 

106 

107 total_cost = complex_cost + simple_cost 

108 

109 return { 

110 'total_cost': round(total_cost, 4), 

111 'complex_vulns': complex_count, 

112 'complex_cost': round(complex_cost, 4), 

113 'simple_vulns': simple_count, 

114 'simple_cost': round(simple_cost, 4), 

115 'cost_per_vuln': round(total_cost / num_vulnerabilities, 4) if num_vulnerabilities > 0 else 0, 

116 } 

117 

118 @staticmethod 

119 def should_enhance( 

120 vulnerability: Dict[str, Any], 

121 current_count: int 

122 ) -> bool: 

123 """ 

124 Determine if vulnerability should be enhanced with LLM 

125 

126 Args: 

127 vulnerability: Vulnerability data 

128 current_count: Number of vulnerabilities already enhanced 

129 

130 Returns: 

131 True if should enhance, False otherwise 

132 """ 

133 # Check limit 

134 if current_count >= LLMConfig.MAX_VULNS_TO_ENHANCE: 

135 return False 

136 

137 # Always enhance critical/high 

138 severity = vulnerability.get('severity', '').lower() 

139 if severity in ['critical', 'high']: 

140 return True 

141 

142 # Enhance medium/low only if under limit 

143 return current_count < LLMConfig.MAX_VULNS_TO_ENHANCE 

144 

145 

146class UsageTracker: 

147 """Track LLM API usage for cost monitoring""" 

148 

149 def __init__(self): 

150 self.requests = 0 

151 self.total_input_tokens = 0 

152 self.total_output_tokens = 0 

153 self.total_cost = 0.0 

154 

155 def record_request( 

156 self, 

157 model: str, 

158 input_tokens: int, 

159 output_tokens: int 

160 ): 

161 """Record an API request for cost tracking""" 

162 self.requests += 1 

163 self.total_input_tokens += input_tokens 

164 self.total_output_tokens += output_tokens 

165 

166 # Calculate cost 

167 cost_config = LLMConfig.COST_PER_1K_TOKENS.get(model, {}) 

168 input_cost = (input_tokens / 1000) * cost_config.get('input', 0) 

169 output_cost = (output_tokens / 1000) * cost_config.get('output', 0) 

170 

171 self.total_cost += (input_cost + output_cost) 

172 

173 def get_summary(self) -> Dict[str, Any]: 

174 """Get usage summary""" 

175 return { 

176 'requests': self.requests, 

177 'total_input_tokens': self.total_input_tokens, 

178 'total_output_tokens': self.total_output_tokens, 

179 'total_cost': round(self.total_cost, 4), 

180 'avg_cost_per_request': round( 

181 self.total_cost / self.requests, 4 

182 ) if self.requests > 0 else 0, 

183 } 

184 

185 def reset(self): 

186 """Reset usage tracking""" 

187 self.requests = 0 

188 self.total_input_tokens = 0 

189 self.total_output_tokens = 0 

190 self.total_cost = 0.0