Coverage for src/alprina_cli/guardrails/input_guardrails.py: 38%
112 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
1"""
2Input Guardrails
4Prevent malicious inputs from reaching tools.
5Detect and block: SQL injection, command injection, path traversal, XXE, etc.
6"""
8from abc import ABC, abstractmethod
9from typing import Any, Dict, Optional
10from pydantic import BaseModel
11from loguru import logger
12import re
15class GuardrailResult(BaseModel):
16 """Result from guardrail check"""
17 passed: bool
18 tripwire_triggered: bool = False
19 reason: Optional[str] = None
20 severity: str = "INFO" # INFO, LOW, MEDIUM, HIGH, CRITICAL
21 sanitized_value: Optional[Any] = None
24class InputGuardrail(ABC):
25 """
26 Base class for input guardrails.
28 Context Engineering:
29 - Fast checks (< 10ms per validation)
30 - Clear pass/fail results
31 - Provides sanitized alternatives when possible
32 """
34 name: str = "InputGuardrail"
36 @abstractmethod
37 def check(self, value: Any, param_name: str = "") -> GuardrailResult:
38 """
39 Check if input is safe.
41 Args:
42 value: Input value to check
43 param_name: Name of parameter being checked
45 Returns:
46 GuardrailResult with pass/fail and optional sanitized value
47 """
48 raise NotImplementedError
51class SQLInjectionGuardrail(InputGuardrail):
52 """
53 Detect SQL injection attempts.
55 Patterns detected:
56 - SQL keywords in unexpected places
57 - Comment syntax (-- /* */)
58 - Union-based injection
59 - Boolean-based injection
60 - Time-based injection
61 """
63 name: str = "SQLInjection"
65 # Common SQL injection patterns
66 SQL_PATTERNS = [
67 r"(\bOR\b|\bAND\b)\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+['\"]?", # OR 1=1
68 r";\s*DROP\s+TABLE", # DROP TABLE
69 r";\s*DELETE\s+FROM", # DELETE FROM
70 r";\s*UPDATE\s+\w+\s+SET", # UPDATE SET
71 r"UNION\s+SELECT", # UNION SELECT
72 r"--\s*$", # SQL comment
73 r"/\*.*?\*/", # Block comment
74 r"'\s*OR\s+'", # ' OR '
75 r"'\s*;\s*--", # '; --
76 r"EXEC\s*\(", # EXEC(
77 r"EXECUTE\s*\(", # EXECUTE(
78 r"xp_cmdshell", # xp_cmdshell
79 r"SLEEP\s*\(", # SLEEP( (time-based)
80 r"WAITFOR\s+DELAY", # WAITFOR DELAY
81 ]
83 def check(self, value: Any, param_name: str = "") -> GuardrailResult:
84 """Check for SQL injection patterns"""
85 if not isinstance(value, str):
86 return GuardrailResult(passed=True)
88 # Empty input check
89 if not value.strip():
90 return GuardrailResult(passed=True)
92 # Check each pattern
93 for pattern in self.SQL_PATTERNS:
94 if re.search(pattern, value, re.IGNORECASE):
95 logger.warning(f"SQL injection detected in {param_name}: {value[:100]}")
96 return GuardrailResult(
97 passed=False,
98 tripwire_triggered=True,
99 reason=f"SQL injection pattern detected: {pattern}",
100 severity="CRITICAL"
101 )
103 return GuardrailResult(passed=True)
106class CommandInjectionGuardrail(InputGuardrail):
107 """
108 Detect command injection attempts.
110 Patterns detected:
111 - Shell metacharacters (;, |, &, `, $)
112 - Command chaining
113 - Subshell execution
114 - Environment variable injection
115 """
117 name: str = "CommandInjection"
119 DANGEROUS_PATTERNS = [
120 r";\s*\w+", # Command chaining with ;
121 r"\|\s*\w+", # Pipe to command
122 r"&&\s*\w+", # AND command
123 r"\|\|\s*\w+", # OR command
124 r"`[^`]+`", # Backtick command substitution
125 r"\$\([^)]+\)", # $() command substitution
126 r">\s*/", # Redirect to file
127 r"<\s*/", # Read from file
128 r"\beval\b", # eval
129 r"\bexec\b", # exec
130 r"\bsystem\b", # system
131 r"/dev/tcp/", # TCP backdoor
132 r"/dev/udp/", # UDP backdoor
133 r"\bwget\b.*http", # wget download
134 r"\bcurl\b.*http", # curl download
135 r"nc\s+-", # netcat
136 r"bash\s+-i", # Interactive bash
137 r"sh\s+-i", # Interactive sh
138 r"python\s+-c", # Python one-liner
139 r"perl\s+-e", # Perl one-liner
140 r"ruby\s+-e", # Ruby one-liner
141 ]
143 def check(self, value: Any, param_name: str = "") -> GuardrailResult:
144 """Check for command injection patterns"""
145 if not isinstance(value, str):
146 return GuardrailResult(passed=True)
148 # Empty input check
149 if not value.strip():
150 return GuardrailResult(passed=True)
152 # Check dangerous patterns
153 for pattern in self.DANGEROUS_PATTERNS:
154 if re.search(pattern, value, re.IGNORECASE):
155 logger.warning(f"Command injection detected in {param_name}: {value[:100]}")
156 return GuardrailResult(
157 passed=False,
158 tripwire_triggered=True,
159 reason=f"Command injection pattern detected: {pattern}",
160 severity="CRITICAL"
161 )
163 return GuardrailResult(passed=True)
166class PathTraversalGuardrail(InputGuardrail):
167 """
168 Detect path traversal attempts.
170 Patterns detected:
171 - ../ sequences
172 - Absolute paths to sensitive locations
173 - URL encoding tricks
174 - Windows/Unix path tricks
175 """
177 name: str = "PathTraversal"
179 DANGEROUS_PATTERNS = [
180 r"\.\./", # ../
181 r"\.\.", # ..
182 r"%2e%2e", # URL encoded ..
183 r"\.\.\\", # ..\
184 r"\\\.\\", # \.\
185 r"/etc/passwd", # /etc/passwd
186 r"/etc/shadow", # /etc/shadow
187 r"C:\\Windows", # C:\Windows
188 r"C:\\Program Files", # C:\Program Files
189 r"/proc/self", # /proc/self
190 r"/root/", # /root/
191 ]
193 # Absolute path patterns for sensitive system locations
194 ABSOLUTE_PATH_PATTERNS = [
195 r"^/etc/", # Unix system config
196 r"^/root/", # Root home directory
197 r"^/var/", # System var directory
198 r"^/usr/", # System usr directory (except common public paths)
199 r"^/boot/", # Boot directory
200 r"^/sys/", # System directory
201 r"^/proc/", # Process directory
202 r"^C:\\Windows\\", # Windows directory
203 r"^C:\\Program Files\\", # Program Files
204 r"^C:\\Users\\[^\\]+\\AppData\\", # User app data
205 r"^\\\\\.\\", # Windows device path
206 ]
208 def check(self, value: Any, param_name: str = "") -> GuardrailResult:
209 """Check for path traversal patterns"""
210 if not isinstance(value, str):
211 return GuardrailResult(passed=True)
213 # Empty input check
214 if not value.strip():
215 return GuardrailResult(passed=True) # Allow empty, will be handled by validation
217 # Normalize path for checking
218 normalized = value.replace("\\", "/").lower()
220 # Check dangerous patterns
221 for pattern in self.DANGEROUS_PATTERNS:
222 if re.search(pattern, normalized, re.IGNORECASE):
223 logger.warning(f"Path traversal detected in {param_name}: {value[:100]}")
224 return GuardrailResult(
225 passed=False,
226 tripwire_triggered=True,
227 reason=f"Path traversal pattern detected: {pattern}",
228 severity="HIGH"
229 )
231 # Check absolute path patterns (Unix and Windows)
232 original_value = value # Preserve original case for Windows paths
233 for pattern in self.ABSOLUTE_PATH_PATTERNS:
234 if re.search(pattern, normalized, re.IGNORECASE) or re.search(pattern, original_value):
235 logger.warning(f"Unauthorized absolute path access in {param_name}: {value[:100]}")
236 return GuardrailResult(
237 passed=False,
238 tripwire_triggered=True,
239 reason=f"Unauthorized absolute path access to sensitive location",
240 severity="HIGH"
241 )
243 return GuardrailResult(passed=True)
246class XXEGuardrail(InputGuardrail):
247 """
248 Detect XML External Entity (XXE) injection.
250 Patterns detected:
251 - DOCTYPE declarations
252 - ENTITY definitions
253 - External file references
254 - SYSTEM keyword
255 """
257 name: str = "XXE"
259 XXE_PATTERNS = [
260 r"<!DOCTYPE[^>]*\[", # DOCTYPE declaration with [ (entity definition)
261 r"<!ENTITY", # ENTITY definition
262 r"SYSTEM\s+['\"]", # SYSTEM keyword with quote
263 r"PUBLIC\s+['\"]", # PUBLIC keyword with quote
264 r"file://", # File protocol
265 r"php://", # PHP protocol
266 r"expect://", # Expect protocol
267 r"data://", # Data protocol
268 r"<!ENTITY[^>]*SYSTEM", # ENTITY with SYSTEM
269 r"<!ENTITY[^>]*%", # Parameter entity
270 r"&[a-zA-Z]+;.*SYSTEM", # Entity reference with SYSTEM
271 ]
273 def check(self, value: Any, param_name: str = "") -> GuardrailResult:
274 """Check for XXE patterns"""
275 if not isinstance(value, str):
276 return GuardrailResult(passed=True)
278 # Empty input check
279 if not value.strip():
280 return GuardrailResult(passed=True)
282 # Check XXE patterns
283 for pattern in self.XXE_PATTERNS:
284 if re.search(pattern, value, re.IGNORECASE):
285 logger.warning(f"XXE injection detected in {param_name}: {value[:100]}")
286 return GuardrailResult(
287 passed=False,
288 tripwire_triggered=True,
289 reason=f"XXE injection pattern detected: {pattern}",
290 severity="HIGH"
291 )
293 return GuardrailResult(passed=True)
296class LengthGuardrail(InputGuardrail):
297 """
298 Validate input length to prevent DoS.
300 Prevents:
301 - Extremely long inputs
302 - Resource exhaustion
303 """
305 name: str = "Length"
307 def __init__(self, max_length: int = 10000):
308 self.max_length = max_length
310 def check(self, value: Any, param_name: str = "") -> GuardrailResult:
311 """Check input length"""
312 if isinstance(value, str):
313 length = len(value)
314 elif isinstance(value, (list, dict)):
315 length = len(str(value))
316 else:
317 return GuardrailResult(passed=True)
319 if length > self.max_length:
320 logger.warning(f"Input too long in {param_name}: {length} > {self.max_length}")
321 return GuardrailResult(
322 passed=False,
323 tripwire_triggered=False, # Not malicious, just too large
324 reason=f"Input exceeds maximum length: {length} > {self.max_length}",
325 severity="MEDIUM"
326 )
328 return GuardrailResult(passed=True)
331class TypeGuardrail(InputGuardrail):
332 """
333 Validate input type.
335 Ensures inputs match expected types.
336 """
338 name: str = "Type"
340 def __init__(self, expected_type: type):
341 self.expected_type = expected_type
343 def check(self, value: Any, param_name: str = "") -> GuardrailResult:
344 """Check input type"""
345 if not isinstance(value, self.expected_type):
346 logger.warning(f"Type mismatch in {param_name}: expected {self.expected_type}, got {type(value)}")
347 return GuardrailResult(
348 passed=False,
349 tripwire_triggered=False,
350 reason=f"Type mismatch: expected {self.expected_type.__name__}, got {type(value).__name__}",
351 severity="LOW"
352 )
354 return GuardrailResult(passed=True)
357# Default guardrail chain
358DEFAULT_INPUT_GUARDRAILS = [
359 SQLInjectionGuardrail(),
360 CommandInjectionGuardrail(),
361 PathTraversalGuardrail(),
362 XXEGuardrail(),
363 LengthGuardrail(max_length=10000)
364]
367def validate_input(
368 value: Any,
369 param_name: str = "",
370 guardrails: Optional[list[InputGuardrail]] = None
371) -> GuardrailResult:
372 """
373 Validate input against guardrails.
375 Args:
376 value: Input value to validate
377 param_name: Name of parameter
378 guardrails: List of guardrails to check (defaults to DEFAULT_INPUT_GUARDRAILS)
380 Returns:
381 GuardrailResult - passes only if all guardrails pass
382 """
383 if guardrails is None:
384 guardrails = DEFAULT_INPUT_GUARDRAILS
386 for guardrail in guardrails:
387 result = guardrail.check(value, param_name)
388 if not result.passed:
389 logger.error(f"Guardrail {guardrail.name} failed for {param_name}: {result.reason}")
390 return result
392 return GuardrailResult(passed=True)
395def validate_params(params: Dict[str, Any]) -> Dict[str, GuardrailResult]:
396 """
397 Validate all parameters in a dictionary.
399 Args:
400 params: Dictionary of parameters to validate
402 Returns:
403 Dictionary mapping param names to GuardrailResults
404 """
405 results = {}
407 for param_name, value in params.items():
408 results[param_name] = validate_input(value, param_name)
410 return results