#!/usr/bin/env python3
"""
PySploit Vulnerability Matcher
Match network traffic and scan results against vulnerability databases.
"""

import re
from typing import Dict, List, Any, Optional
from ..database.embedded import EmbeddedVulnerabilityDatabase


class VulnerabilityMatcher:
    """
    Match network data against vulnerability databases to identify potential security issues.
    """
    
    def __init__(self, database: EmbeddedVulnerabilityDatabase):
        """
        Initialize vulnerability matcher with database.
        
        Args:
            database (EmbeddedVulnerabilityDatabase): Initialized vulnerability database
        """
        self.database = database
    
    def match_traffic(self, traffic_analysis: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Match network traffic analysis results against vulnerability database.
        
        Args:
            traffic_analysis (dict): Results from PcapAnalyzer.analyze()
            
        Returns:
            list: Vulnerability matches with context
        """
        matches = []
        
        # Extract vulnerability indicators from traffic analysis
        if 'analysis_results' in traffic_analysis:
            analysis = traffic_analysis['analysis_results']
            
            # Match detected vulnerabilities against database
            if 'vulnerabilities' in analysis:
                for vuln_indicator in analysis['vulnerabilities']:
                    db_matches = self._match_indicator_to_database(vuln_indicator)
                    matches.extend(db_matches)
            
            # Match protocols against known vulnerable services
            if 'protocols' in analysis:
                protocol_matches = self._match_protocols(analysis['protocols'])
                matches.extend(protocol_matches)
            
            # Match hosts for known vulnerability patterns
            if 'hosts' in analysis:
                host_matches = self._match_host_patterns(analysis['hosts'])
                matches.extend(host_matches)
        
        return matches
    
    def match_nmap_data(self, nmap_results: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Match Nmap scan results against vulnerability database.
        
        Args:
            nmap_results (dict): Results from NmapAnalyzer.parse_xml()
            
        Returns:
            list: Vulnerability matches
        """
        matches = []
        
        for host in nmap_results.get('hosts', []):
            host_ip = host.get('ip', 'unknown')
            
            # Match services against vulnerability database
            for port in host.get('open_ports', []):
                service_matches = self._match_service_to_database(host_ip, port)
                matches.extend(service_matches)
            
            # Match OS information against vulnerabilities
            if 'os' in host:
                os_matches = self._match_os_to_database(host_ip, host['os'])
                matches.extend(os_matches)
        
        return matches
    
    def match_router_signatures(self, traffic_data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Match traffic data specifically against router vulnerabilities.
        
        Args:
            traffic_data (dict): Network traffic analysis results
            
        Returns:
            list: Router-specific vulnerability matches
        """
        # Search for router-specific vulnerabilities in database
        router_vulns = self.database.search_by_category('router')
        
        matches = []
        
        # Match against RouterSploit data
        routersploit_data = self.database.load_routersploit()
        
        for vuln in routersploit_data:
            match = self._check_router_vulnerability_match(traffic_data, vuln)
            if match:
                matches.append(match)
        
        return matches
    
    def _match_indicator_to_database(self, indicator: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Match a vulnerability indicator against the database."""
        matches = []
        
        indicator_type = indicator.get('type', '')
        
        # Search database for similar vulnerability types
        search_terms = {
            'sql_injection': 'sql injection',
            'xss': 'cross-site scripting',
            'path_traversal': 'directory traversal',
            'cleartext_protocol': 'cleartext',
            'weak_authentication': 'authentication'
        }
        
        search_term = search_terms.get(indicator_type, indicator_type)
        
        if search_term:
            db_results = self.database.search(search_term, limit=10)
            
            for db_vuln in db_results:
                match = {
                    'match_type': 'traffic_indicator',
                    'confidence': self._calculate_match_confidence(indicator, db_vuln),
                    'indicator': indicator,
                    'vulnerability': db_vuln,
                    'source_ip': indicator.get('source_ip', 'unknown'),
                    'destination_ip': indicator.get('destination_ip', 'unknown'),
                    'severity': indicator.get('severity', 'UNKNOWN')
                }
                matches.append(match)
        
        return matches
    
    def _match_protocols(self, protocols: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Match detected protocols against vulnerable service patterns."""
        matches = []
        
        risky_protocols = protocols.get('risky_protocols', [])
        
        for protocol in risky_protocols:
            # Search for vulnerabilities related to this protocol
            db_results = self.database.search(protocol.lower(), limit=5)
            
            for db_vuln in db_results:
                match = {
                    'match_type': 'protocol_vulnerability',
                    'confidence': 'MEDIUM',
                    'protocol': protocol,
                    'vulnerability': db_vuln,
                    'severity': db_vuln.get('cvss_severity', 'UNKNOWN')
                }
                matches.append(match)
        
        return matches
    
    def _match_host_patterns(self, hosts: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Match host communication patterns against vulnerability indicators."""
        matches = []
        
        # Look for suspicious communication patterns that might indicate exploitation
        top_communications = hosts.get('top_communications', {})
        
        for (src_ip, dst_ip), count in top_communications.items():
            if count > 100:  # High volume communication
                # Search for vulnerabilities that might cause high traffic
                db_results = self.database.search('denial of service', limit=3)
                
                for db_vuln in db_results:
                    match = {
                        'match_type': 'communication_pattern',
                        'confidence': 'LOW',
                        'pattern': f'High volume communication: {src_ip} -> {dst_ip} ({count} packets)',
                        'vulnerability': db_vuln,
                        'source_ip': src_ip,
                        'destination_ip': dst_ip,
                        'packet_count': count
                    }
                    matches.append(match)
        
        return matches
    
    def _match_service_to_database(self, host_ip: str, port_info: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Match a service against vulnerability database."""
        matches = []
        
        service_name = port_info.get('service', '')
        product = port_info.get('product', '')
        version = port_info.get('version', '')
        
        # Search for service-specific vulnerabilities
        if service_name:
            db_results = self.database.search(service_name, limit=10)
            
            for db_vuln in db_results:
                confidence = self._calculate_service_match_confidence(port_info, db_vuln)
                
                if confidence != 'NONE':
                    match = {
                        'match_type': 'service_vulnerability',
                        'confidence': confidence,
                        'host': host_ip,
                        'port': port_info.get('port', 'unknown'),
                        'service': service_name,
                        'product': product,
                        'version': version,
                        'vulnerability': db_vuln,
                        'severity': db_vuln.get('cvss_severity', 'UNKNOWN')
                    }
                    matches.append(match)
        
        # Check for product/version specific vulnerabilities
        if product and version:
            search_query = f"{product} {version}"
            db_results = self.database.search(search_query, limit=5)
            
            for db_vuln in db_results:
                match = {
                    'match_type': 'product_vulnerability',
                    'confidence': 'HIGH',
                    'host': host_ip,
                    'port': port_info.get('port', 'unknown'),
                    'product': product,
                    'version': version,
                    'vulnerability': db_vuln,
                    'severity': db_vuln.get('cvss_severity', 'UNKNOWN')
                }
                matches.append(match)
        
        return matches
    
    def _match_os_to_database(self, host_ip: str, os_info: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Match OS information against vulnerability database."""
        matches = []
        
        os_name = os_info.get('name', '')
        os_vendor = os_info.get('vendor', '')
        
        if os_name:
            # Search for OS-specific vulnerabilities
            db_results = self.database.search(os_name, limit=5)
            
            for db_vuln in db_results:
                match = {
                    'match_type': 'os_vulnerability',
                    'confidence': 'MEDIUM',
                    'host': host_ip,
                    'os_name': os_name,
                    'os_vendor': os_vendor,
                    'vulnerability': db_vuln,
                    'severity': db_vuln.get('cvss_severity', 'UNKNOWN')
                }
                matches.append(match)
        
        return matches
    
    def _check_router_vulnerability_match(self, traffic_data: Dict[str, Any], 
                                         router_vuln: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        """Check if traffic data matches a router vulnerability pattern."""
        
        # Simple pattern matching - in practice this would be more sophisticated
        raw_data = traffic_data.get('raw_data', [])
        
        # Look for HTTP requests to common router interfaces
        router_indicators = [
            '192.168.',  # Common router IP ranges
            '10.0.0.',
            '172.16.',
            '/admin',
            '/cgi-bin',
            'router'
        ]
        
        for packet in raw_data:
            packet_str = str(packet)
            
            if any(indicator in packet_str for indicator in router_indicators):
                return {
                    'match_type': 'router_pattern',
                    'confidence': 'MEDIUM',
                    'vulnerability': router_vuln,
                    'traffic_pattern': packet_str[:100],  # Truncate for brevity
                    'severity': router_vuln.get('cvss_severity', 'UNKNOWN')
                }
        
        return None
    
    def _calculate_match_confidence(self, indicator: Dict[str, Any], 
                                   db_vuln: Dict[str, Any]) -> str:
        """Calculate confidence level for a vulnerability match."""
        
        # Simple confidence calculation based on text similarity
        indicator_desc = indicator.get('description', '').lower()
        vuln_desc = db_vuln.get('description', '').lower()
        
        # Count common words
        indicator_words = set(indicator_desc.split())
        vuln_words = set(vuln_desc.split())
        
        if not indicator_words or not vuln_words:
            return 'LOW'
        
        common_words = indicator_words.intersection(vuln_words)
        similarity = len(common_words) / max(len(indicator_words), len(vuln_words))
        
        if similarity > 0.5:
            return 'HIGH'
        elif similarity > 0.2:
            return 'MEDIUM'
        else:
            return 'LOW'
    
    def _calculate_service_match_confidence(self, service: Dict[str, Any], 
                                          db_vuln: Dict[str, Any]) -> str:
        """Calculate confidence for service vulnerability matches."""
        
        service_name = service.get('service', '').lower()
        vuln_desc = db_vuln.get('description', '').lower()
        
        # Direct service name match
        if service_name in vuln_desc:
            return 'HIGH'
        
        # Product name match
        product = service.get('product', '').lower()
        if product and product in vuln_desc:
            return 'HIGH'
        
        # Partial matches
        if any(word in vuln_desc for word in service_name.split() if len(word) > 3):
            return 'MEDIUM'
        
        return 'LOW'