Coverage for src/alprina_cli/guardrails/input_guardrails.py: 38%

112 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-14 11:27 +0100

1""" 

2Input Guardrails 

3 

4Prevent malicious inputs from reaching tools. 

5Detect and block: SQL injection, command injection, path traversal, XXE, etc. 

6""" 

7 

8from abc import ABC, abstractmethod 

9from typing import Any, Dict, Optional 

10from pydantic import BaseModel 

11from loguru import logger 

12import re 

13 

14 

15class GuardrailResult(BaseModel): 

16 """Result from guardrail check""" 

17 passed: bool 

18 tripwire_triggered: bool = False 

19 reason: Optional[str] = None 

20 severity: str = "INFO" # INFO, LOW, MEDIUM, HIGH, CRITICAL 

21 sanitized_value: Optional[Any] = None 

22 

23 

24class InputGuardrail(ABC): 

25 """ 

26 Base class for input guardrails. 

27 

28 Context Engineering: 

29 - Fast checks (< 10ms per validation) 

30 - Clear pass/fail results 

31 - Provides sanitized alternatives when possible 

32 """ 

33 

34 name: str = "InputGuardrail" 

35 

36 @abstractmethod 

37 def check(self, value: Any, param_name: str = "") -> GuardrailResult: 

38 """ 

39 Check if input is safe. 

40 

41 Args: 

42 value: Input value to check 

43 param_name: Name of parameter being checked 

44 

45 Returns: 

46 GuardrailResult with pass/fail and optional sanitized value 

47 """ 

48 raise NotImplementedError 

49 

50 

51class SQLInjectionGuardrail(InputGuardrail): 

52 """ 

53 Detect SQL injection attempts. 

54 

55 Patterns detected: 

56 - SQL keywords in unexpected places 

57 - Comment syntax (-- /* */) 

58 - Union-based injection 

59 - Boolean-based injection 

60 - Time-based injection 

61 """ 

62 

63 name: str = "SQLInjection" 

64 

65 # Common SQL injection patterns 

66 SQL_PATTERNS = [ 

67 r"(\bOR\b|\bAND\b)\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+['\"]?", # OR 1=1 

68 r";\s*DROP\s+TABLE", # DROP TABLE 

69 r";\s*DELETE\s+FROM", # DELETE FROM 

70 r";\s*UPDATE\s+\w+\s+SET", # UPDATE SET 

71 r"UNION\s+SELECT", # UNION SELECT 

72 r"--\s*$", # SQL comment 

73 r"/\*.*?\*/", # Block comment 

74 r"'\s*OR\s+'", # ' OR ' 

75 r"'\s*;\s*--", # '; -- 

76 r"EXEC\s*\(", # EXEC( 

77 r"EXECUTE\s*\(", # EXECUTE( 

78 r"xp_cmdshell", # xp_cmdshell 

79 r"SLEEP\s*\(", # SLEEP( (time-based) 

80 r"WAITFOR\s+DELAY", # WAITFOR DELAY 

81 ] 

82 

83 def check(self, value: Any, param_name: str = "") -> GuardrailResult: 

84 """Check for SQL injection patterns""" 

85 if not isinstance(value, str): 

86 return GuardrailResult(passed=True) 

87 

88 # Empty input check 

89 if not value.strip(): 

90 return GuardrailResult(passed=True) 

91 

92 # Check each pattern 

93 for pattern in self.SQL_PATTERNS: 

94 if re.search(pattern, value, re.IGNORECASE): 

95 logger.warning(f"SQL injection detected in {param_name}: {value[:100]}") 

96 return GuardrailResult( 

97 passed=False, 

98 tripwire_triggered=True, 

99 reason=f"SQL injection pattern detected: {pattern}", 

100 severity="CRITICAL" 

101 ) 

102 

103 return GuardrailResult(passed=True) 

104 

105 

106class CommandInjectionGuardrail(InputGuardrail): 

107 """ 

108 Detect command injection attempts. 

109 

110 Patterns detected: 

111 - Shell metacharacters (;, |, &, `, $) 

112 - Command chaining 

113 - Subshell execution 

114 - Environment variable injection 

115 """ 

116 

117 name: str = "CommandInjection" 

118 

119 DANGEROUS_PATTERNS = [ 

120 r";\s*\w+", # Command chaining with ; 

121 r"\|\s*\w+", # Pipe to command 

122 r"&&\s*\w+", # AND command 

123 r"\|\|\s*\w+", # OR command 

124 r"`[^`]+`", # Backtick command substitution 

125 r"\$\([^)]+\)", # $() command substitution 

126 r">\s*/", # Redirect to file 

127 r"<\s*/", # Read from file 

128 r"\beval\b", # eval 

129 r"\bexec\b", # exec 

130 r"\bsystem\b", # system 

131 r"/dev/tcp/", # TCP backdoor 

132 r"/dev/udp/", # UDP backdoor 

133 r"\bwget\b.*http", # wget download 

134 r"\bcurl\b.*http", # curl download 

135 r"nc\s+-", # netcat 

136 r"bash\s+-i", # Interactive bash 

137 r"sh\s+-i", # Interactive sh 

138 r"python\s+-c", # Python one-liner 

139 r"perl\s+-e", # Perl one-liner 

140 r"ruby\s+-e", # Ruby one-liner 

141 ] 

142 

143 def check(self, value: Any, param_name: str = "") -> GuardrailResult: 

144 """Check for command injection patterns""" 

145 if not isinstance(value, str): 

146 return GuardrailResult(passed=True) 

147 

148 # Empty input check 

149 if not value.strip(): 

150 return GuardrailResult(passed=True) 

151 

152 # Check dangerous patterns 

153 for pattern in self.DANGEROUS_PATTERNS: 

154 if re.search(pattern, value, re.IGNORECASE): 

155 logger.warning(f"Command injection detected in {param_name}: {value[:100]}") 

156 return GuardrailResult( 

157 passed=False, 

158 tripwire_triggered=True, 

159 reason=f"Command injection pattern detected: {pattern}", 

160 severity="CRITICAL" 

161 ) 

162 

163 return GuardrailResult(passed=True) 

164 

165 

166class PathTraversalGuardrail(InputGuardrail): 

167 """ 

168 Detect path traversal attempts. 

169 

170 Patterns detected: 

171 - ../ sequences 

172 - Absolute paths to sensitive locations 

173 - URL encoding tricks 

174 - Windows/Unix path tricks 

175 """ 

176 

177 name: str = "PathTraversal" 

178 

179 DANGEROUS_PATTERNS = [ 

180 r"\.\./", # ../ 

181 r"\.\.", # .. 

182 r"%2e%2e", # URL encoded .. 

183 r"\.\.\\", # ..\ 

184 r"\\\.\\", # \.\ 

185 r"/etc/passwd", # /etc/passwd 

186 r"/etc/shadow", # /etc/shadow 

187 r"C:\\Windows", # C:\Windows 

188 r"C:\\Program Files", # C:\Program Files 

189 r"/proc/self", # /proc/self 

190 r"/root/", # /root/ 

191 ] 

192 

193 # Absolute path patterns for sensitive system locations 

194 ABSOLUTE_PATH_PATTERNS = [ 

195 r"^/etc/", # Unix system config 

196 r"^/root/", # Root home directory 

197 r"^/var/", # System var directory 

198 r"^/usr/", # System usr directory (except common public paths) 

199 r"^/boot/", # Boot directory 

200 r"^/sys/", # System directory 

201 r"^/proc/", # Process directory 

202 r"^C:\\Windows\\", # Windows directory 

203 r"^C:\\Program Files\\", # Program Files 

204 r"^C:\\Users\\[^\\]+\\AppData\\", # User app data 

205 r"^\\\\\.\\", # Windows device path 

206 ] 

207 

208 def check(self, value: Any, param_name: str = "") -> GuardrailResult: 

209 """Check for path traversal patterns""" 

210 if not isinstance(value, str): 

211 return GuardrailResult(passed=True) 

212 

213 # Empty input check 

214 if not value.strip(): 

215 return GuardrailResult(passed=True) # Allow empty, will be handled by validation 

216 

217 # Normalize path for checking 

218 normalized = value.replace("\\", "/").lower() 

219 

220 # Check dangerous patterns 

221 for pattern in self.DANGEROUS_PATTERNS: 

222 if re.search(pattern, normalized, re.IGNORECASE): 

223 logger.warning(f"Path traversal detected in {param_name}: {value[:100]}") 

224 return GuardrailResult( 

225 passed=False, 

226 tripwire_triggered=True, 

227 reason=f"Path traversal pattern detected: {pattern}", 

228 severity="HIGH" 

229 ) 

230 

231 # Check absolute path patterns (Unix and Windows) 

232 original_value = value # Preserve original case for Windows paths 

233 for pattern in self.ABSOLUTE_PATH_PATTERNS: 

234 if re.search(pattern, normalized, re.IGNORECASE) or re.search(pattern, original_value): 

235 logger.warning(f"Unauthorized absolute path access in {param_name}: {value[:100]}") 

236 return GuardrailResult( 

237 passed=False, 

238 tripwire_triggered=True, 

239 reason=f"Unauthorized absolute path access to sensitive location", 

240 severity="HIGH" 

241 ) 

242 

243 return GuardrailResult(passed=True) 

244 

245 

246class XXEGuardrail(InputGuardrail): 

247 """ 

248 Detect XML External Entity (XXE) injection. 

249 

250 Patterns detected: 

251 - DOCTYPE declarations 

252 - ENTITY definitions 

253 - External file references 

254 - SYSTEM keyword 

255 """ 

256 

257 name: str = "XXE" 

258 

259 XXE_PATTERNS = [ 

260 r"<!DOCTYPE[^>]*\[", # DOCTYPE declaration with [ (entity definition) 

261 r"<!ENTITY", # ENTITY definition 

262 r"SYSTEM\s+['\"]", # SYSTEM keyword with quote 

263 r"PUBLIC\s+['\"]", # PUBLIC keyword with quote 

264 r"file://", # File protocol 

265 r"php://", # PHP protocol 

266 r"expect://", # Expect protocol 

267 r"data://", # Data protocol 

268 r"<!ENTITY[^>]*SYSTEM", # ENTITY with SYSTEM 

269 r"<!ENTITY[^>]*%", # Parameter entity 

270 r"&[a-zA-Z]+;.*SYSTEM", # Entity reference with SYSTEM 

271 ] 

272 

273 def check(self, value: Any, param_name: str = "") -> GuardrailResult: 

274 """Check for XXE patterns""" 

275 if not isinstance(value, str): 

276 return GuardrailResult(passed=True) 

277 

278 # Empty input check 

279 if not value.strip(): 

280 return GuardrailResult(passed=True) 

281 

282 # Check XXE patterns 

283 for pattern in self.XXE_PATTERNS: 

284 if re.search(pattern, value, re.IGNORECASE): 

285 logger.warning(f"XXE injection detected in {param_name}: {value[:100]}") 

286 return GuardrailResult( 

287 passed=False, 

288 tripwire_triggered=True, 

289 reason=f"XXE injection pattern detected: {pattern}", 

290 severity="HIGH" 

291 ) 

292 

293 return GuardrailResult(passed=True) 

294 

295 

296class LengthGuardrail(InputGuardrail): 

297 """ 

298 Validate input length to prevent DoS. 

299 

300 Prevents: 

301 - Extremely long inputs 

302 - Resource exhaustion 

303 """ 

304 

305 name: str = "Length" 

306 

307 def __init__(self, max_length: int = 10000): 

308 self.max_length = max_length 

309 

310 def check(self, value: Any, param_name: str = "") -> GuardrailResult: 

311 """Check input length""" 

312 if isinstance(value, str): 

313 length = len(value) 

314 elif isinstance(value, (list, dict)): 

315 length = len(str(value)) 

316 else: 

317 return GuardrailResult(passed=True) 

318 

319 if length > self.max_length: 

320 logger.warning(f"Input too long in {param_name}: {length} > {self.max_length}") 

321 return GuardrailResult( 

322 passed=False, 

323 tripwire_triggered=False, # Not malicious, just too large 

324 reason=f"Input exceeds maximum length: {length} > {self.max_length}", 

325 severity="MEDIUM" 

326 ) 

327 

328 return GuardrailResult(passed=True) 

329 

330 

331class TypeGuardrail(InputGuardrail): 

332 """ 

333 Validate input type. 

334 

335 Ensures inputs match expected types. 

336 """ 

337 

338 name: str = "Type" 

339 

340 def __init__(self, expected_type: type): 

341 self.expected_type = expected_type 

342 

343 def check(self, value: Any, param_name: str = "") -> GuardrailResult: 

344 """Check input type""" 

345 if not isinstance(value, self.expected_type): 

346 logger.warning(f"Type mismatch in {param_name}: expected {self.expected_type}, got {type(value)}") 

347 return GuardrailResult( 

348 passed=False, 

349 tripwire_triggered=False, 

350 reason=f"Type mismatch: expected {self.expected_type.__name__}, got {type(value).__name__}", 

351 severity="LOW" 

352 ) 

353 

354 return GuardrailResult(passed=True) 

355 

356 

357# Default guardrail chain 

358DEFAULT_INPUT_GUARDRAILS = [ 

359 SQLInjectionGuardrail(), 

360 CommandInjectionGuardrail(), 

361 PathTraversalGuardrail(), 

362 XXEGuardrail(), 

363 LengthGuardrail(max_length=10000) 

364] 

365 

366 

367def validate_input( 

368 value: Any, 

369 param_name: str = "", 

370 guardrails: Optional[list[InputGuardrail]] = None 

371) -> GuardrailResult: 

372 """ 

373 Validate input against guardrails. 

374 

375 Args: 

376 value: Input value to validate 

377 param_name: Name of parameter 

378 guardrails: List of guardrails to check (defaults to DEFAULT_INPUT_GUARDRAILS) 

379 

380 Returns: 

381 GuardrailResult - passes only if all guardrails pass 

382 """ 

383 if guardrails is None: 

384 guardrails = DEFAULT_INPUT_GUARDRAILS 

385 

386 for guardrail in guardrails: 

387 result = guardrail.check(value, param_name) 

388 if not result.passed: 

389 logger.error(f"Guardrail {guardrail.name} failed for {param_name}: {result.reason}") 

390 return result 

391 

392 return GuardrailResult(passed=True) 

393 

394 

395def validate_params(params: Dict[str, Any]) -> Dict[str, GuardrailResult]: 

396 """ 

397 Validate all parameters in a dictionary. 

398 

399 Args: 

400 params: Dictionary of parameters to validate 

401 

402 Returns: 

403 Dictionary mapping param names to GuardrailResults 

404 """ 

405 results = {} 

406 

407 for param_name, value in params.items(): 

408 results[param_name] = validate_input(value, param_name) 

409 

410 return results