Coverage for src/alprina_cli/tools/base.py: 30%
174 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
1"""
2Base classes for Alprina CLI tools.
4Context Engineering Principles:
5- Tools are callable utilities (not full LLM agents)
6- Self-contained with clear input/output schemas
7- Minimal token footprint in context
8- MCP-compatible design
9- Progressive disclosure pattern
11Based on: Kimi-CLI CallableTool2 pattern
12"""
14from abc import ABC, abstractmethod
15from typing import TypeVar, Generic, Any, Optional
16from pydantic import BaseModel
17from loguru import logger
20# Tool return types (lightweight wrappers)
21class ToolResult:
22 """Base result from tool execution"""
24 def __init__(self, content: Any, metadata: dict = None):
25 self.content = content
26 self.metadata = metadata or {}
28 def __str__(self):
29 return str(self.content)
32class ToolOk(ToolResult):
33 """Successful tool execution"""
35 def __init__(self, content: Any = None, output: str = "", metadata: dict = None):
36 super().__init__(content or output, metadata)
37 self.output = output if output else str(content)
40class ToolError(ToolResult):
41 """Failed tool execution"""
43 def __init__(self, message: str, brief: str = None, output: str = "", metadata: dict = None):
44 super().__init__({"error": message, "brief": brief or message}, metadata)
45 self.message = message
46 self.brief = brief or message
47 self.output = output
50# Type variable for tool parameters
51TParams = TypeVar('TParams', bound=BaseModel)
54class AlprinaToolBase(ABC, Generic[TParams]):
55 """
56 Base class for all Alprina CLI tools.
58 Context Engineering Benefits:
59 - Lightweight: No embedded LLM, just callable functions
60 - Clear contracts: Pydantic schemas for inputs/outputs
61 - Composable: Can be used by multiple agents/commands
62 - Testable: Pure functions with minimal dependencies
63 - MCP-ready: Compatible with Model Context Protocol
65 Usage:
66 ```python
67 class ScanParams(BaseModel):
68 target: str = Field(description="Target to scan")
70 class ScanTool(AlprinaToolBase[ScanParams]):
71 name: str = "Scan"
72 description: str = "Perform security scan"
73 params: type[ScanParams] = ScanParams
75 async def execute(self, params: ScanParams):
76 result = await perform_scan(params.target)
77 return ToolOk(content=result)
78 ```
79 """
81 # Tool metadata (subclasses must define)
82 name: str = "Tool"
83 description: str = "Base tool"
84 params: type[TParams] = BaseModel # type: ignore
86 # Optional: Guardrails for security validation
87 input_guardrails: list = []
88 output_guardrails: list = []
89 enable_guardrails: bool = True
91 # Optional: Memory service for context persistence
92 memory_service: Optional[Any] = None
94 # Optional: Database client for persistence
95 database_client: Optional[Any] = None
96 enable_database: bool = True
98 # Optional: API key for authentication
99 api_key: Optional[str] = None
101 def __init__(
102 self,
103 memory_service: Optional[Any] = None,
104 enable_guardrails: bool = True,
105 database_client: Optional[Any] = None,
106 enable_database: bool = True,
107 api_key: Optional[str] = None,
108 **kwargs
109 ):
110 """
111 Initialize tool with optional configuration.
113 Args:
114 memory_service: Optional MemoryService instance for context persistence
115 enable_guardrails: Enable/disable guardrails (default: True)
116 database_client: Optional NeonDatabaseClient for scan persistence
117 enable_database: Enable/disable database persistence (default: True)
118 api_key: Optional API key for authentication
119 **kwargs: Additional configuration
120 """
121 self.memory_service = memory_service
122 self.enable_guardrails = enable_guardrails
123 self.database_client = database_client
124 self.enable_database = enable_database
125 self.api_key = api_key
127 # Initialize default guardrails if not set
128 if not self.input_guardrails and enable_guardrails:
129 from alprina_cli.guardrails import DEFAULT_INPUT_GUARDRAILS
130 self.input_guardrails = DEFAULT_INPUT_GUARDRAILS
132 if not self.output_guardrails and enable_guardrails:
133 from alprina_cli.guardrails import DEFAULT_OUTPUT_GUARDRAILS
134 self.output_guardrails = DEFAULT_OUTPUT_GUARDRAILS
136 for key, value in kwargs.items():
137 setattr(self, key, value)
139 @abstractmethod
140 async def execute(self, params: TParams) -> ToolResult:
141 """
142 Execute the tool with given parameters.
144 Context: This is where the actual work happens.
145 Should return high-signal results (not verbose logs).
147 Args:
148 params: Validated parameters (Pydantic model)
150 Returns:
151 ToolOk: Success with result
152 ToolError: Failure with error message
153 """
154 raise NotImplementedError
156 async def __call__(self, params: TParams) -> ToolResult:
157 """
158 Call interface for the tool.
160 Context: Applies guardrails, authentication, and database persistence.
161 """
162 import time
163 start_time = time.time()
164 user_id = None
165 scan_id = None
167 # Step 1: Authenticate user (if API key provided)
168 if self.enable_database and self.database_client and self.api_key:
169 try:
170 auth_result = await self.database_client.authenticate_api_key(self.api_key)
171 if not auth_result:
172 return ToolError(
173 message="Invalid or expired API key",
174 brief="Authentication failed",
175 metadata={"code": "AUTH_FAILED"}
176 )
177 user_id = auth_result.get('user_id')
178 logger.debug(f"Authenticated user: {user_id}")
179 except Exception as e:
180 logger.error(f"Authentication error: {e}")
181 # Continue without authentication (local usage)
182 pass
184 # Step 2: Check scan limits (if authenticated)
185 if user_id and self.enable_database and self.database_client:
186 try:
187 can_scan, scans_used, scans_limit = await self.database_client.check_scan_limit(user_id)
188 if not can_scan:
189 return ToolError(
190 message=f"Monthly scan limit exceeded ({scans_used}/{scans_limit})",
191 brief="Scan limit reached",
192 metadata={"scans_used": scans_used, "scans_limit": scans_limit}
193 )
194 logger.debug(f"Scan limit check: {scans_used}/{scans_limit}")
195 except Exception as e:
196 logger.error(f"Scan limit check error: {e}")
197 # Continue anyway (don't block on limit check errors)
198 pass
200 # Step 3: Create scan record (status: pending)
201 if user_id and self.enable_database and self.database_client:
202 try:
203 # Extract target from params
204 params_dict = params.model_dump() if hasattr(params, 'model_dump') else params.dict()
205 target = params_dict.get('target', 'unknown')
207 scan_id = await self.database_client.create_scan(
208 user_id=user_id,
209 tool_name=self.name,
210 target=target,
211 params=params_dict
212 )
213 logger.info(f"Created scan record: {scan_id}")
214 except Exception as e:
215 logger.error(f"Failed to create scan record: {e}")
216 # Continue anyway (don't block on DB errors)
217 pass
219 # Step 4: Update scan status (running)
220 if scan_id and self.enable_database and self.database_client:
221 try:
222 await self.database_client.update_scan_status(scan_id, "running")
223 except Exception as e:
224 logger.error(f"Failed to update scan status: {e}")
226 # Step 5: Apply input guardrails
227 if self.enable_guardrails and self.input_guardrails:
228 try:
229 from alprina_cli.guardrails import validate_input
231 # Convert params to dict for validation
232 params_dict = params.model_dump() if hasattr(params, 'model_dump') else params.dict()
234 # Validate each parameter
235 for param_name, value in params_dict.items():
236 validation_result = validate_input(value, param_name, guardrails=self.input_guardrails)
238 if not validation_result.passed:
239 logger.warning(
240 f"Input guardrail triggered in {self.name}.{param_name}: {validation_result.reason}"
241 )
242 if validation_result.tripwire_triggered:
243 # Critical violation - block execution and mark scan as failed
244 if scan_id and self.enable_database and self.database_client:
245 try:
246 await self.database_client.save_scan_results(
247 scan_id=scan_id,
248 findings={"error": validation_result.reason},
249 findings_count=0,
250 status="failed"
251 )
252 except Exception as e:
253 logger.error(f"Failed to save scan failure: {e}")
255 return ToolError(
256 message=f"Security violation: {validation_result.reason}",
257 brief=f"Input guardrail blocked execution",
258 metadata={"severity": validation_result.severity, "param": param_name}
259 )
260 except Exception as e:
261 logger.error(f"Error applying input guardrails in {self.name}: {e}")
262 # Don't block on guardrail errors, just log
263 pass
265 # Step 6: Execute tool
266 result = await self.execute(params)
267 duration_ms = int((time.time() - start_time) * 1000)
269 # Step 7: Apply output guardrails (sanitize sensitive data)
270 if self.enable_guardrails and self.output_guardrails and isinstance(result, ToolOk):
271 try:
272 from alprina_cli.guardrails import sanitize_output, sanitize_dict
274 # Sanitize output content
275 if isinstance(result.content, str):
276 sanitization_result = sanitize_output(result.content, guardrails=self.output_guardrails)
277 if sanitization_result.redactions_made > 0:
278 logger.info(
279 f"Output sanitized in {self.name}: {sanitization_result.redactions_made} redactions "
280 f"({', '.join(sanitization_result.redaction_types)})"
281 )
282 result.content = sanitization_result.sanitized_value
283 result.output = sanitization_result.sanitized_value
285 elif isinstance(result.content, dict):
286 sanitized_content, redactions = sanitize_dict(result.content, guardrails=self.output_guardrails)
287 if redactions > 0:
288 logger.info(f"Output sanitized in {self.name}: {redactions} redactions")
289 result.content = sanitized_content
291 except Exception as e:
292 logger.error(f"Error applying output guardrails in {self.name}: {e}")
293 # Don't block on sanitization errors, return original result
294 pass
296 # Step 8: Save scan results to database
297 if scan_id and user_id and self.enable_database and self.database_client:
298 try:
299 # Extract findings and count
300 findings = result.content if isinstance(result.content, dict) else {"output": str(result.content)}
301 findings_count = 0
303 if isinstance(result, ToolOk):
304 # Try to count findings from result
305 if isinstance(result.content, dict):
306 findings_count = len(result.content.get('findings', []))
307 if findings_count == 0 and 'vulnerabilities' in result.content:
308 findings_count = len(result.content.get('vulnerabilities', []))
310 # Save successful scan
311 await self.database_client.save_scan_results(
312 scan_id=scan_id,
313 findings=findings,
314 findings_count=findings_count,
315 status="completed"
316 )
317 logger.info(f"Saved scan results: {scan_id} ({findings_count} findings)")
318 else:
319 # Save failed scan
320 await self.database_client.save_scan_results(
321 scan_id=scan_id,
322 findings={"error": getattr(result, 'message', str(result.content))},
323 findings_count=0,
324 status="failed"
325 )
326 logger.warning(f"Saved failed scan: {scan_id}")
328 except Exception as e:
329 logger.error(f"Failed to save scan results: {e}")
331 # Step 9: Track usage for billing
332 if scan_id and user_id and self.enable_database and self.database_client:
333 try:
334 # Determine credit cost (could be configurable per tool)
335 credits_used = 1 # Default
337 await self.database_client.track_scan_usage(
338 user_id=user_id,
339 scan_id=scan_id,
340 tool_name=self.name,
341 credits_used=credits_used,
342 duration_ms=duration_ms,
343 vulnerabilities_found=findings_count if isinstance(result, ToolOk) else 0
344 )
345 logger.debug(f"Tracked usage: {credits_used} credits, {duration_ms}ms")
347 # Increment scan count
348 await self.database_client.increment_scan_count(user_id)
350 except Exception as e:
351 logger.error(f"Failed to track usage: {e}")
353 # Add scan_id to result metadata if available
354 if scan_id and isinstance(result, ToolResult):
355 if not result.metadata:
356 result.metadata = {}
357 result.metadata['scan_id'] = scan_id
358 result.metadata['duration_ms'] = duration_ms
360 return result
362 def to_dict(self) -> dict:
363 """
364 Convert tool to dictionary representation.
366 Context: Used for serialization and MCP integration.
367 """
368 return {
369 "name": self.name,
370 "description": self.description,
371 "parameters": self.params.model_json_schema() if self.params else {}
372 }
374 def to_mcp_schema(self) -> dict:
375 """
376 Convert tool to MCP-compatible schema.
378 Context: Enables integration with Model Context Protocol.
379 """
380 return {
381 "name": self.name,
382 "description": self.description,
383 "inputSchema": self.params.model_json_schema() if self.params else {}
384 }
386 def __repr__(self):
387 return f"<{self.__class__.__name__} name='{self.name}'>"
390class SyncToolBase(Generic[TParams]):
391 """
392 Synchronous version of AlprinaToolBase.
394 Context: Use for tools that don't need async (rare).
395 Most tools should use AlprinaToolBase (async).
396 """
398 name: str = "SyncTool"
399 description: str = "Synchronous base tool"
400 params: type[TParams] = BaseModel # type: ignore
402 def __init__(self, **kwargs):
403 for key, value in kwargs.items():
404 setattr(self, key, value)
406 @abstractmethod
407 def execute(self, params: TParams) -> ToolResult:
408 """Execute synchronously"""
409 raise NotImplementedError
411 def __call__(self, params: TParams) -> ToolResult:
412 return self.execute(params)
414 def to_dict(self) -> dict:
415 return {
416 "name": self.name,
417 "description": self.description,
418 "parameters": self.params.model_json_schema() if self.params else {}
419 }
422# Convenience exports
423__all__ = [
424 "AlprinaToolBase",
425 "SyncToolBase",
426 "ToolResult",
427 "ToolOk",
428 "ToolError",
429 "TParams"
430]