Coverage for src/alprina_cli/tools/base.py: 30%

174 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-14 11:27 +0100

1""" 

2Base classes for Alprina CLI tools. 

3 

4Context Engineering Principles: 

5- Tools are callable utilities (not full LLM agents) 

6- Self-contained with clear input/output schemas 

7- Minimal token footprint in context 

8- MCP-compatible design 

9- Progressive disclosure pattern 

10 

11Based on: Kimi-CLI CallableTool2 pattern 

12""" 

13 

14from abc import ABC, abstractmethod 

15from typing import TypeVar, Generic, Any, Optional 

16from pydantic import BaseModel 

17from loguru import logger 

18 

19 

20# Tool return types (lightweight wrappers) 

21class ToolResult: 

22 """Base result from tool execution""" 

23 

24 def __init__(self, content: Any, metadata: dict = None): 

25 self.content = content 

26 self.metadata = metadata or {} 

27 

28 def __str__(self): 

29 return str(self.content) 

30 

31 

32class ToolOk(ToolResult): 

33 """Successful tool execution""" 

34 

35 def __init__(self, content: Any = None, output: str = "", metadata: dict = None): 

36 super().__init__(content or output, metadata) 

37 self.output = output if output else str(content) 

38 

39 

40class ToolError(ToolResult): 

41 """Failed tool execution""" 

42 

43 def __init__(self, message: str, brief: str = None, output: str = "", metadata: dict = None): 

44 super().__init__({"error": message, "brief": brief or message}, metadata) 

45 self.message = message 

46 self.brief = brief or message 

47 self.output = output 

48 

49 

50# Type variable for tool parameters 

51TParams = TypeVar('TParams', bound=BaseModel) 

52 

53 

54class AlprinaToolBase(ABC, Generic[TParams]): 

55 """ 

56 Base class for all Alprina CLI tools. 

57 

58 Context Engineering Benefits: 

59 - Lightweight: No embedded LLM, just callable functions 

60 - Clear contracts: Pydantic schemas for inputs/outputs 

61 - Composable: Can be used by multiple agents/commands 

62 - Testable: Pure functions with minimal dependencies 

63 - MCP-ready: Compatible with Model Context Protocol 

64 

65 Usage: 

66 ```python 

67 class ScanParams(BaseModel): 

68 target: str = Field(description="Target to scan") 

69 

70 class ScanTool(AlprinaToolBase[ScanParams]): 

71 name: str = "Scan" 

72 description: str = "Perform security scan" 

73 params: type[ScanParams] = ScanParams 

74 

75 async def execute(self, params: ScanParams): 

76 result = await perform_scan(params.target) 

77 return ToolOk(content=result) 

78 ``` 

79 """ 

80 

81 # Tool metadata (subclasses must define) 

82 name: str = "Tool" 

83 description: str = "Base tool" 

84 params: type[TParams] = BaseModel # type: ignore 

85 

86 # Optional: Guardrails for security validation 

87 input_guardrails: list = [] 

88 output_guardrails: list = [] 

89 enable_guardrails: bool = True 

90 

91 # Optional: Memory service for context persistence 

92 memory_service: Optional[Any] = None 

93 

94 # Optional: Database client for persistence 

95 database_client: Optional[Any] = None 

96 enable_database: bool = True 

97 

98 # Optional: API key for authentication 

99 api_key: Optional[str] = None 

100 

101 def __init__( 

102 self, 

103 memory_service: Optional[Any] = None, 

104 enable_guardrails: bool = True, 

105 database_client: Optional[Any] = None, 

106 enable_database: bool = True, 

107 api_key: Optional[str] = None, 

108 **kwargs 

109 ): 

110 """ 

111 Initialize tool with optional configuration. 

112 

113 Args: 

114 memory_service: Optional MemoryService instance for context persistence 

115 enable_guardrails: Enable/disable guardrails (default: True) 

116 database_client: Optional NeonDatabaseClient for scan persistence 

117 enable_database: Enable/disable database persistence (default: True) 

118 api_key: Optional API key for authentication 

119 **kwargs: Additional configuration 

120 """ 

121 self.memory_service = memory_service 

122 self.enable_guardrails = enable_guardrails 

123 self.database_client = database_client 

124 self.enable_database = enable_database 

125 self.api_key = api_key 

126 

127 # Initialize default guardrails if not set 

128 if not self.input_guardrails and enable_guardrails: 

129 from alprina_cli.guardrails import DEFAULT_INPUT_GUARDRAILS 

130 self.input_guardrails = DEFAULT_INPUT_GUARDRAILS 

131 

132 if not self.output_guardrails and enable_guardrails: 

133 from alprina_cli.guardrails import DEFAULT_OUTPUT_GUARDRAILS 

134 self.output_guardrails = DEFAULT_OUTPUT_GUARDRAILS 

135 

136 for key, value in kwargs.items(): 

137 setattr(self, key, value) 

138 

139 @abstractmethod 

140 async def execute(self, params: TParams) -> ToolResult: 

141 """ 

142 Execute the tool with given parameters. 

143 

144 Context: This is where the actual work happens. 

145 Should return high-signal results (not verbose logs). 

146 

147 Args: 

148 params: Validated parameters (Pydantic model) 

149 

150 Returns: 

151 ToolOk: Success with result 

152 ToolError: Failure with error message 

153 """ 

154 raise NotImplementedError 

155 

156 async def __call__(self, params: TParams) -> ToolResult: 

157 """ 

158 Call interface for the tool. 

159 

160 Context: Applies guardrails, authentication, and database persistence. 

161 """ 

162 import time 

163 start_time = time.time() 

164 user_id = None 

165 scan_id = None 

166 

167 # Step 1: Authenticate user (if API key provided) 

168 if self.enable_database and self.database_client and self.api_key: 

169 try: 

170 auth_result = await self.database_client.authenticate_api_key(self.api_key) 

171 if not auth_result: 

172 return ToolError( 

173 message="Invalid or expired API key", 

174 brief="Authentication failed", 

175 metadata={"code": "AUTH_FAILED"} 

176 ) 

177 user_id = auth_result.get('user_id') 

178 logger.debug(f"Authenticated user: {user_id}") 

179 except Exception as e: 

180 logger.error(f"Authentication error: {e}") 

181 # Continue without authentication (local usage) 

182 pass 

183 

184 # Step 2: Check scan limits (if authenticated) 

185 if user_id and self.enable_database and self.database_client: 

186 try: 

187 can_scan, scans_used, scans_limit = await self.database_client.check_scan_limit(user_id) 

188 if not can_scan: 

189 return ToolError( 

190 message=f"Monthly scan limit exceeded ({scans_used}/{scans_limit})", 

191 brief="Scan limit reached", 

192 metadata={"scans_used": scans_used, "scans_limit": scans_limit} 

193 ) 

194 logger.debug(f"Scan limit check: {scans_used}/{scans_limit}") 

195 except Exception as e: 

196 logger.error(f"Scan limit check error: {e}") 

197 # Continue anyway (don't block on limit check errors) 

198 pass 

199 

200 # Step 3: Create scan record (status: pending) 

201 if user_id and self.enable_database and self.database_client: 

202 try: 

203 # Extract target from params 

204 params_dict = params.model_dump() if hasattr(params, 'model_dump') else params.dict() 

205 target = params_dict.get('target', 'unknown') 

206 

207 scan_id = await self.database_client.create_scan( 

208 user_id=user_id, 

209 tool_name=self.name, 

210 target=target, 

211 params=params_dict 

212 ) 

213 logger.info(f"Created scan record: {scan_id}") 

214 except Exception as e: 

215 logger.error(f"Failed to create scan record: {e}") 

216 # Continue anyway (don't block on DB errors) 

217 pass 

218 

219 # Step 4: Update scan status (running) 

220 if scan_id and self.enable_database and self.database_client: 

221 try: 

222 await self.database_client.update_scan_status(scan_id, "running") 

223 except Exception as e: 

224 logger.error(f"Failed to update scan status: {e}") 

225 

226 # Step 5: Apply input guardrails 

227 if self.enable_guardrails and self.input_guardrails: 

228 try: 

229 from alprina_cli.guardrails import validate_input 

230 

231 # Convert params to dict for validation 

232 params_dict = params.model_dump() if hasattr(params, 'model_dump') else params.dict() 

233 

234 # Validate each parameter 

235 for param_name, value in params_dict.items(): 

236 validation_result = validate_input(value, param_name, guardrails=self.input_guardrails) 

237 

238 if not validation_result.passed: 

239 logger.warning( 

240 f"Input guardrail triggered in {self.name}.{param_name}: {validation_result.reason}" 

241 ) 

242 if validation_result.tripwire_triggered: 

243 # Critical violation - block execution and mark scan as failed 

244 if scan_id and self.enable_database and self.database_client: 

245 try: 

246 await self.database_client.save_scan_results( 

247 scan_id=scan_id, 

248 findings={"error": validation_result.reason}, 

249 findings_count=0, 

250 status="failed" 

251 ) 

252 except Exception as e: 

253 logger.error(f"Failed to save scan failure: {e}") 

254 

255 return ToolError( 

256 message=f"Security violation: {validation_result.reason}", 

257 brief=f"Input guardrail blocked execution", 

258 metadata={"severity": validation_result.severity, "param": param_name} 

259 ) 

260 except Exception as e: 

261 logger.error(f"Error applying input guardrails in {self.name}: {e}") 

262 # Don't block on guardrail errors, just log 

263 pass 

264 

265 # Step 6: Execute tool 

266 result = await self.execute(params) 

267 duration_ms = int((time.time() - start_time) * 1000) 

268 

269 # Step 7: Apply output guardrails (sanitize sensitive data) 

270 if self.enable_guardrails and self.output_guardrails and isinstance(result, ToolOk): 

271 try: 

272 from alprina_cli.guardrails import sanitize_output, sanitize_dict 

273 

274 # Sanitize output content 

275 if isinstance(result.content, str): 

276 sanitization_result = sanitize_output(result.content, guardrails=self.output_guardrails) 

277 if sanitization_result.redactions_made > 0: 

278 logger.info( 

279 f"Output sanitized in {self.name}: {sanitization_result.redactions_made} redactions " 

280 f"({', '.join(sanitization_result.redaction_types)})" 

281 ) 

282 result.content = sanitization_result.sanitized_value 

283 result.output = sanitization_result.sanitized_value 

284 

285 elif isinstance(result.content, dict): 

286 sanitized_content, redactions = sanitize_dict(result.content, guardrails=self.output_guardrails) 

287 if redactions > 0: 

288 logger.info(f"Output sanitized in {self.name}: {redactions} redactions") 

289 result.content = sanitized_content 

290 

291 except Exception as e: 

292 logger.error(f"Error applying output guardrails in {self.name}: {e}") 

293 # Don't block on sanitization errors, return original result 

294 pass 

295 

296 # Step 8: Save scan results to database 

297 if scan_id and user_id and self.enable_database and self.database_client: 

298 try: 

299 # Extract findings and count 

300 findings = result.content if isinstance(result.content, dict) else {"output": str(result.content)} 

301 findings_count = 0 

302 

303 if isinstance(result, ToolOk): 

304 # Try to count findings from result 

305 if isinstance(result.content, dict): 

306 findings_count = len(result.content.get('findings', [])) 

307 if findings_count == 0 and 'vulnerabilities' in result.content: 

308 findings_count = len(result.content.get('vulnerabilities', [])) 

309 

310 # Save successful scan 

311 await self.database_client.save_scan_results( 

312 scan_id=scan_id, 

313 findings=findings, 

314 findings_count=findings_count, 

315 status="completed" 

316 ) 

317 logger.info(f"Saved scan results: {scan_id} ({findings_count} findings)") 

318 else: 

319 # Save failed scan 

320 await self.database_client.save_scan_results( 

321 scan_id=scan_id, 

322 findings={"error": getattr(result, 'message', str(result.content))}, 

323 findings_count=0, 

324 status="failed" 

325 ) 

326 logger.warning(f"Saved failed scan: {scan_id}") 

327 

328 except Exception as e: 

329 logger.error(f"Failed to save scan results: {e}") 

330 

331 # Step 9: Track usage for billing 

332 if scan_id and user_id and self.enable_database and self.database_client: 

333 try: 

334 # Determine credit cost (could be configurable per tool) 

335 credits_used = 1 # Default 

336 

337 await self.database_client.track_scan_usage( 

338 user_id=user_id, 

339 scan_id=scan_id, 

340 tool_name=self.name, 

341 credits_used=credits_used, 

342 duration_ms=duration_ms, 

343 vulnerabilities_found=findings_count if isinstance(result, ToolOk) else 0 

344 ) 

345 logger.debug(f"Tracked usage: {credits_used} credits, {duration_ms}ms") 

346 

347 # Increment scan count 

348 await self.database_client.increment_scan_count(user_id) 

349 

350 except Exception as e: 

351 logger.error(f"Failed to track usage: {e}") 

352 

353 # Add scan_id to result metadata if available 

354 if scan_id and isinstance(result, ToolResult): 

355 if not result.metadata: 

356 result.metadata = {} 

357 result.metadata['scan_id'] = scan_id 

358 result.metadata['duration_ms'] = duration_ms 

359 

360 return result 

361 

362 def to_dict(self) -> dict: 

363 """ 

364 Convert tool to dictionary representation. 

365 

366 Context: Used for serialization and MCP integration. 

367 """ 

368 return { 

369 "name": self.name, 

370 "description": self.description, 

371 "parameters": self.params.model_json_schema() if self.params else {} 

372 } 

373 

374 def to_mcp_schema(self) -> dict: 

375 """ 

376 Convert tool to MCP-compatible schema. 

377 

378 Context: Enables integration with Model Context Protocol. 

379 """ 

380 return { 

381 "name": self.name, 

382 "description": self.description, 

383 "inputSchema": self.params.model_json_schema() if self.params else {} 

384 } 

385 

386 def __repr__(self): 

387 return f"<{self.__class__.__name__} name='{self.name}'>" 

388 

389 

390class SyncToolBase(Generic[TParams]): 

391 """ 

392 Synchronous version of AlprinaToolBase. 

393 

394 Context: Use for tools that don't need async (rare). 

395 Most tools should use AlprinaToolBase (async). 

396 """ 

397 

398 name: str = "SyncTool" 

399 description: str = "Synchronous base tool" 

400 params: type[TParams] = BaseModel # type: ignore 

401 

402 def __init__(self, **kwargs): 

403 for key, value in kwargs.items(): 

404 setattr(self, key, value) 

405 

406 @abstractmethod 

407 def execute(self, params: TParams) -> ToolResult: 

408 """Execute synchronously""" 

409 raise NotImplementedError 

410 

411 def __call__(self, params: TParams) -> ToolResult: 

412 return self.execute(params) 

413 

414 def to_dict(self) -> dict: 

415 return { 

416 "name": self.name, 

417 "description": self.description, 

418 "parameters": self.params.model_json_schema() if self.params else {} 

419 } 

420 

421 

422# Convenience exports 

423__all__ = [ 

424 "AlprinaToolBase", 

425 "SyncToolBase", 

426 "ToolResult", 

427 "ToolOk", 

428 "ToolError", 

429 "TParams" 

430]