#!/usr/bin/env python3
"""
PySploit Filters
Traffic and vulnerability filtering utilities.
"""

from typing import Dict, List, Any, Optional
import re


class TrafficFilter:
    """Filter network traffic data based on various criteria."""
    
    @staticmethod
    def filter_by_protocol(traffic_data: Dict[str, Any], protocols: List[str]) -> Dict[str, Any]:
        """Filter traffic data by protocol type."""
        if 'raw_data' not in traffic_data:
            return traffic_data
        
        filtered_data = traffic_data.copy()
        filtered_packets = []
        
        protocols_upper = [p.upper() for p in protocols]
        
        for packet in traffic_data['raw_data']:
            packet_protocol = packet.get('_ws.col.Protocol', '').upper()
            if packet_protocol in protocols_upper:
                filtered_packets.append(packet)
        
        filtered_data['raw_data'] = filtered_packets
        filtered_data['packet_count'] = len(filtered_packets)
        
        return filtered_data
    
    @staticmethod
    def filter_by_ip(traffic_data: Dict[str, Any], ip_addresses: List[str]) -> Dict[str, Any]:
        """Filter traffic data by source or destination IP."""
        if 'raw_data' not in traffic_data:
            return traffic_data
        
        filtered_data = traffic_data.copy()
        filtered_packets = []
        
        for packet in traffic_data['raw_data']:
            src_ip = packet.get('ip.src', '')
            dst_ip = packet.get('ip.dst', '')
            
            if any(ip in [src_ip, dst_ip] for ip in ip_addresses):
                filtered_packets.append(packet)
        
        filtered_data['raw_data'] = filtered_packets
        filtered_data['packet_count'] = len(filtered_packets)
        
        return filtered_data
    
    @staticmethod
    def filter_by_port(traffic_data: Dict[str, Any], ports: List[int]) -> Dict[str, Any]:
        """Filter traffic data by port number."""
        if 'raw_data' not in traffic_data:
            return traffic_data
        
        filtered_data = traffic_data.copy()
        filtered_packets = []
        
        port_strings = [str(p) for p in ports]
        
        for packet in traffic_data['raw_data']:
            src_port = packet.get('tcp.srcport', packet.get('udp.srcport', ''))
            dst_port = packet.get('tcp.dstport', packet.get('udp.dstport', ''))
            
            if src_port in port_strings or dst_port in port_strings:
                filtered_packets.append(packet)
        
        filtered_data['raw_data'] = filtered_packets
        filtered_data['packet_count'] = len(filtered_packets)
        
        return filtered_data
    
    @staticmethod
    def filter_suspicious_traffic(traffic_data: Dict[str, Any]) -> Dict[str, Any]:
        """Filter for potentially suspicious traffic patterns."""
        if 'analysis_results' not in traffic_data:
            return traffic_data
        
        analysis = traffic_data['analysis_results']
        vulnerabilities = analysis.get('vulnerabilities', [])
        
        # Filter for high and critical severity indicators
        suspicious_vulns = [
            v for v in vulnerabilities 
            if v.get('severity', '').upper() in ['HIGH', 'CRITICAL']
        ]
        
        filtered_data = traffic_data.copy()
        filtered_data['analysis_results']['vulnerabilities'] = suspicious_vulns
        
        return filtered_data


class VulnerabilityFilter:
    """Filter vulnerability data based on various criteria."""
    
    @staticmethod
    def filter_by_severity(vulnerabilities: List[Dict[str, Any]], 
                          min_severity: str = 'MEDIUM') -> List[Dict[str, Any]]:
        """Filter vulnerabilities by minimum severity level."""
        severity_order = {'LOW': 1, 'MEDIUM': 2, 'HIGH': 3, 'CRITICAL': 4}
        min_level = severity_order.get(min_severity.upper(), 2)
        
        return [
            v for v in vulnerabilities
            if severity_order.get(v.get('cvss_severity', 'LOW').upper(), 1) >= min_level
        ]
    
    @staticmethod
    def filter_by_cvss_score(vulnerabilities: List[Dict[str, Any]], 
                           min_score: float = 5.0) -> List[Dict[str, Any]]:
        """Filter vulnerabilities by minimum CVSS score."""
        filtered = []
        
        for vuln in vulnerabilities:
            try:
                score = float(vuln.get('cvss_score', 0))
                if score >= min_score:
                    filtered.append(vuln)
            except (ValueError, TypeError):
                # Include vulnerabilities without valid CVSS scores
                filtered.append(vuln)
        
        return filtered
    
    @staticmethod
    def filter_by_category(vulnerabilities: List[Dict[str, Any]], 
                         categories: List[str]) -> List[Dict[str, Any]]:
        """Filter vulnerabilities by category/type."""
        categories_lower = [c.lower() for c in categories]
        
        filtered = []
        for vuln in vulnerabilities:
            exploit_type = vuln.get('exploit_type', '').lower()
            platform = vuln.get('platform', '').lower()
            description = vuln.get('description', '').lower()
            
            if any(cat in exploit_type or cat in platform or cat in description 
                   for cat in categories_lower):
                filtered.append(vuln)
        
        return filtered
    
    @staticmethod
    def filter_by_year(vulnerabilities: List[Dict[str, Any]], 
                      years: List[int]) -> List[Dict[str, Any]]:
        """Filter vulnerabilities by publication year."""
        filtered = []
        
        for vuln in vulnerabilities:
            cve_id = vuln.get('cve_id', '')
            
            # Extract year from CVE ID (e.g., CVE-2021-1234)
            year_match = re.search(r'CVE-(\d{4})-', cve_id)
            if year_match:
                try:
                    cve_year = int(year_match.group(1))
                    if cve_year in years:
                        filtered.append(vuln)
                except ValueError:
                    pass
            else:
                # Include non-CVE vulnerabilities
                filtered.append(vuln)
        
        return filtered
    
    @staticmethod
    def filter_recent_cves(vulnerabilities: List[Dict[str, Any]], 
                          years_back: int = 2) -> List[Dict[str, Any]]:
        """Filter for recent CVEs (within specified years back from current year)."""
        from datetime import datetime
        current_year = datetime.now().year
        recent_years = list(range(current_year - years_back, current_year + 1))
        
        return VulnerabilityFilter.filter_by_year(vulnerabilities, recent_years)
    
    @staticmethod
    def filter_exploitable(vulnerabilities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """Filter for vulnerabilities that have known exploits."""
        exploitable_indicators = [
            'exploit', 'metasploit', 'proof of concept', 'poc', 
            'remote code execution', 'rce', 'buffer overflow'
        ]
        
        filtered = []
        for vuln in vulnerabilities:
            description = vuln.get('description', '').lower()
            exploit_type = vuln.get('exploit_type', '').lower()
            
            if any(indicator in description or indicator in exploit_type 
                   for indicator in exploitable_indicators):
                filtered.append(vuln)
        
        return filtered
    
    @staticmethod
    def filter_by_confidence(matches: List[Dict[str, Any]], 
                           min_confidence: str = 'MEDIUM') -> List[Dict[str, Any]]:
        """Filter vulnerability matches by confidence level."""
        confidence_order = {'LOW': 1, 'MEDIUM': 2, 'HIGH': 3}
        min_level = confidence_order.get(min_confidence.upper(), 2)
        
        return [
            m for m in matches
            if confidence_order.get(m.get('confidence', 'LOW').upper(), 1) >= min_level
        ]
    
    @staticmethod
    def deduplicate_cves(vulnerabilities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """Remove duplicate CVE entries, keeping the most detailed one."""
        seen_cves = {}
        
        for vuln in vulnerabilities:
            cve_id = vuln.get('cve_id', '')
            if not cve_id:
                continue
            
            if cve_id not in seen_cves:
                seen_cves[cve_id] = vuln
            else:
                # Keep the entry with more detailed description
                current_desc_len = len(vuln.get('description', ''))
                existing_desc_len = len(seen_cves[cve_id].get('description', ''))
                
                if current_desc_len > existing_desc_len:
                    seen_cves[cve_id] = vuln
        
        # Add non-CVE vulnerabilities
        non_cve_vulns = [v for v in vulnerabilities if not v.get('cve_id')]
        
        return list(seen_cves.values()) + non_cve_vulns