Coverage for src/alprina_cli/policy.py: 23%

91 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-14 11:27 +0100

1""" 

2Policy and guardrails engine for Alprina CLI. 

3Enforces scope, safety rules, and compliance checks. 

4""" 

5 

6import yaml 

7from pathlib import Path 

8from typing import List, Optional 

9from ipaddress import ip_network, ip_address, AddressValueError 

10from urllib.parse import urlparse 

11from rich.console import Console 

12from rich.panel import Panel 

13 

14console = Console() 

15 

16ALPRINA_DIR = Path.home() / ".alprina" 

17POLICY_FILE = ALPRINA_DIR / "policy.yml" 

18 

19DEFAULT_POLICY = { 

20 "project": "Alprina Security Audit", 

21 "scope": { 

22 "allow_domains": [], 

23 "allow_cidrs": ["10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"], 

24 "forbid_ports": [22, 3389], # SSH, RDP 

25 }, 

26 "policies": { 

27 "allow_intrusive": False, 

28 "require_terms_ack": True, 

29 "max_concurrent_scans": 5, 

30 }, 

31 "billing": { 

32 "plan": "free", 

33 "max_scans_per_day": 10, 

34 } 

35} 

36 

37 

38def load_policy() -> dict: 

39 """Load policy configuration from file.""" 

40 if not POLICY_FILE.exists(): 

41 return DEFAULT_POLICY.copy() 

42 

43 try: 

44 with open(POLICY_FILE, "r") as f: 

45 policy = yaml.safe_load(f) 

46 return policy or DEFAULT_POLICY.copy() 

47 except Exception as e: 

48 console.print(f"[yellow]Warning: Could not load policy file: {e}[/yellow]") 

49 return DEFAULT_POLICY.copy() 

50 

51 

52def save_policy(policy: dict): 

53 """Save policy configuration to file.""" 

54 ALPRINA_DIR.mkdir(exist_ok=True) 

55 

56 with open(POLICY_FILE, "w") as f: 

57 yaml.dump(policy, f, default_flow_style=False, sort_keys=False) 

58 

59 

60def validate_target(target: str, policy: Optional[dict] = None) -> bool: 

61 """ 

62 Validate a target against policy rules. 

63 

64 Args: 

65 target: Target URL, domain, or IP 

66 policy: Policy dict (loads from file if not provided) 

67 

68 Returns: 

69 True if target is allowed 

70 

71 Raises: 

72 ValueError: If target violates policy 

73 """ 

74 if policy is None: 

75 policy = load_policy() 

76 

77 scope = policy.get("scope", {}) 

78 

79 # Try to parse as IP address 

80 try: 

81 ip = ip_address(target) 

82 return _validate_ip(ip, scope) 

83 except (AddressValueError, ValueError): 

84 pass 

85 

86 # Try to parse as URL 

87 if "://" in target: 

88 parsed = urlparse(target) 

89 domain = parsed.hostname 

90 port = parsed.port 

91 

92 if port and port in scope.get("forbid_ports", []): 

93 raise ValueError(f"Port {port} is forbidden by policy") 

94 

95 return _validate_domain(domain, scope) 

96 

97 # Treat as domain 

98 return _validate_domain(target, scope) 

99 

100 

101def _validate_ip(ip: ip_address, scope: dict) -> bool: 

102 """Validate IP address against allowed CIDR ranges.""" 

103 allow_cidrs = scope.get("allow_cidrs", []) 

104 

105 if not allow_cidrs: 

106 # If no CIDRs specified, allow all private IPs 

107 return ip.is_private 

108 

109 for cidr in allow_cidrs: 

110 if ip in ip_network(cidr): 

111 return True 

112 

113 raise ValueError(f"IP {ip} is not in allowed CIDR ranges: {allow_cidrs}") 

114 

115 

116def _validate_domain(domain: str, scope: dict) -> bool: 

117 """Validate domain against allowed domains.""" 

118 allow_domains = scope.get("allow_domains", []) 

119 

120 if not allow_domains: 

121 # If no domains specified, allow all 

122 console.print("[yellow]Warning: No domain restrictions in policy. All domains allowed.[/yellow]") 

123 return True 

124 

125 # Check if domain matches or is subdomain of allowed domains 

126 for allowed in allow_domains: 

127 if domain == allowed or domain.endswith(f".{allowed}"): 

128 return True 

129 

130 raise ValueError(f"Domain {domain} is not in allowed domains: {allow_domains}") 

131 

132 

133def validate_file(path: Path) -> bool: 

134 """Validate if file type is allowed for scanning.""" 

135 allowed_extensions = ( 

136 ".py", ".js", ".ts", ".jsx", ".tsx", ".java", ".go", ".rs", 

137 ".env", ".yaml", ".yml", ".json", ".xml", ".ini", ".conf", 

138 ".sh", ".bash", ".zsh", ".dockerfile", ".tf", ".hcl" 

139 ) 

140 

141 if not path.suffix.lower() in allowed_extensions: 

142 raise ValueError(f"File type {path.suffix} is not supported for scanning") 

143 

144 return True 

145 

146 

147def policy_init_command(): 

148 """Initialize a new policy configuration file.""" 

149 if POLICY_FILE.exists(): 

150 console.print("[yellow]Policy file already exists. Overwrite? (y/N)[/yellow]") 

151 from rich.prompt import Confirm 

152 if not Confirm.ask("Overwrite existing policy?", default=False): 

153 return 

154 

155 save_policy(DEFAULT_POLICY) 

156 

157 console.print(Panel( 

158 f"[green]✓ Policy file created[/green]\n\n" 

159 f"Location: {POLICY_FILE}\n\n" 

160 f"Edit this file to customize your security scanning policies.", 

161 title="Policy Initialized" 

162 )) 

163 

164 

165def policy_test_command(target: str): 

166 """Test if a target is allowed by current policy.""" 

167 console.print(f"Testing policy for target: [bold]{target}[/bold]\n") 

168 

169 try: 

170 policy = load_policy() 

171 validate_target(target, policy) 

172 console.print(f"[green]✓ Target is allowed by policy[/green]") 

173 except ValueError as e: 

174 console.print(f"[red]✗ Target blocked by policy:[/red] {e}") 

175 except Exception as e: 

176 console.print(f"[red]Error: {e}[/red]") 

177 

178 

179def check_intrusive_allowed(policy: Optional[dict] = None) -> bool: 

180 """Check if intrusive scans are allowed by policy.""" 

181 if policy is None: 

182 policy = load_policy() 

183 

184 return policy.get("policies", {}).get("allow_intrusive", False) 

185 

186 

187def get_scan_limits(policy: Optional[dict] = None) -> dict: 

188 """Get scan limits from policy.""" 

189 if policy is None: 

190 policy = load_policy() 

191 

192 billing = policy.get("billing", {}) 

193 

194 return { 

195 "max_scans_per_day": billing.get("max_scans_per_day", 10), 

196 "max_concurrent_scans": policy.get("policies", {}).get("max_concurrent_scans", 5), 

197 }