Coverage for src/alprina_cli/policy.py: 23%
91 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
1"""
2Policy and guardrails engine for Alprina CLI.
3Enforces scope, safety rules, and compliance checks.
4"""
6import yaml
7from pathlib import Path
8from typing import List, Optional
9from ipaddress import ip_network, ip_address, AddressValueError
10from urllib.parse import urlparse
11from rich.console import Console
12from rich.panel import Panel
14console = Console()
16ALPRINA_DIR = Path.home() / ".alprina"
17POLICY_FILE = ALPRINA_DIR / "policy.yml"
19DEFAULT_POLICY = {
20 "project": "Alprina Security Audit",
21 "scope": {
22 "allow_domains": [],
23 "allow_cidrs": ["10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"],
24 "forbid_ports": [22, 3389], # SSH, RDP
25 },
26 "policies": {
27 "allow_intrusive": False,
28 "require_terms_ack": True,
29 "max_concurrent_scans": 5,
30 },
31 "billing": {
32 "plan": "free",
33 "max_scans_per_day": 10,
34 }
35}
38def load_policy() -> dict:
39 """Load policy configuration from file."""
40 if not POLICY_FILE.exists():
41 return DEFAULT_POLICY.copy()
43 try:
44 with open(POLICY_FILE, "r") as f:
45 policy = yaml.safe_load(f)
46 return policy or DEFAULT_POLICY.copy()
47 except Exception as e:
48 console.print(f"[yellow]Warning: Could not load policy file: {e}[/yellow]")
49 return DEFAULT_POLICY.copy()
52def save_policy(policy: dict):
53 """Save policy configuration to file."""
54 ALPRINA_DIR.mkdir(exist_ok=True)
56 with open(POLICY_FILE, "w") as f:
57 yaml.dump(policy, f, default_flow_style=False, sort_keys=False)
60def validate_target(target: str, policy: Optional[dict] = None) -> bool:
61 """
62 Validate a target against policy rules.
64 Args:
65 target: Target URL, domain, or IP
66 policy: Policy dict (loads from file if not provided)
68 Returns:
69 True if target is allowed
71 Raises:
72 ValueError: If target violates policy
73 """
74 if policy is None:
75 policy = load_policy()
77 scope = policy.get("scope", {})
79 # Try to parse as IP address
80 try:
81 ip = ip_address(target)
82 return _validate_ip(ip, scope)
83 except (AddressValueError, ValueError):
84 pass
86 # Try to parse as URL
87 if "://" in target:
88 parsed = urlparse(target)
89 domain = parsed.hostname
90 port = parsed.port
92 if port and port in scope.get("forbid_ports", []):
93 raise ValueError(f"Port {port} is forbidden by policy")
95 return _validate_domain(domain, scope)
97 # Treat as domain
98 return _validate_domain(target, scope)
101def _validate_ip(ip: ip_address, scope: dict) -> bool:
102 """Validate IP address against allowed CIDR ranges."""
103 allow_cidrs = scope.get("allow_cidrs", [])
105 if not allow_cidrs:
106 # If no CIDRs specified, allow all private IPs
107 return ip.is_private
109 for cidr in allow_cidrs:
110 if ip in ip_network(cidr):
111 return True
113 raise ValueError(f"IP {ip} is not in allowed CIDR ranges: {allow_cidrs}")
116def _validate_domain(domain: str, scope: dict) -> bool:
117 """Validate domain against allowed domains."""
118 allow_domains = scope.get("allow_domains", [])
120 if not allow_domains:
121 # If no domains specified, allow all
122 console.print("[yellow]Warning: No domain restrictions in policy. All domains allowed.[/yellow]")
123 return True
125 # Check if domain matches or is subdomain of allowed domains
126 for allowed in allow_domains:
127 if domain == allowed or domain.endswith(f".{allowed}"):
128 return True
130 raise ValueError(f"Domain {domain} is not in allowed domains: {allow_domains}")
133def validate_file(path: Path) -> bool:
134 """Validate if file type is allowed for scanning."""
135 allowed_extensions = (
136 ".py", ".js", ".ts", ".jsx", ".tsx", ".java", ".go", ".rs",
137 ".env", ".yaml", ".yml", ".json", ".xml", ".ini", ".conf",
138 ".sh", ".bash", ".zsh", ".dockerfile", ".tf", ".hcl"
139 )
141 if not path.suffix.lower() in allowed_extensions:
142 raise ValueError(f"File type {path.suffix} is not supported for scanning")
144 return True
147def policy_init_command():
148 """Initialize a new policy configuration file."""
149 if POLICY_FILE.exists():
150 console.print("[yellow]Policy file already exists. Overwrite? (y/N)[/yellow]")
151 from rich.prompt import Confirm
152 if not Confirm.ask("Overwrite existing policy?", default=False):
153 return
155 save_policy(DEFAULT_POLICY)
157 console.print(Panel(
158 f"[green]✓ Policy file created[/green]\n\n"
159 f"Location: {POLICY_FILE}\n\n"
160 f"Edit this file to customize your security scanning policies.",
161 title="Policy Initialized"
162 ))
165def policy_test_command(target: str):
166 """Test if a target is allowed by current policy."""
167 console.print(f"Testing policy for target: [bold]{target}[/bold]\n")
169 try:
170 policy = load_policy()
171 validate_target(target, policy)
172 console.print(f"[green]✓ Target is allowed by policy[/green]")
173 except ValueError as e:
174 console.print(f"[red]✗ Target blocked by policy:[/red] {e}")
175 except Exception as e:
176 console.print(f"[red]Error: {e}[/red]")
179def check_intrusive_allowed(policy: Optional[dict] = None) -> bool:
180 """Check if intrusive scans are allowed by policy."""
181 if policy is None:
182 policy = load_policy()
184 return policy.get("policies", {}).get("allow_intrusive", False)
187def get_scan_limits(policy: Optional[dict] = None) -> dict:
188 """Get scan limits from policy."""
189 if policy is None:
190 policy = load_policy()
192 billing = policy.get("billing", {})
194 return {
195 "max_scans_per_day": billing.get("max_scans_per_day", 10),
196 "max_concurrent_scans": policy.get("policies", {}).get("max_concurrent_scans", 5),
197 }