#!/usr/bin/env python3
"""
PySploit Vulnerability Assessment Engine
Comprehensive vulnerability analysis system based on Analysis/Vulnerability but enhanced with RouterSploit integration
"""

import sqlite3
import pandas as pd
import os
import re
import json
import logging
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any, Tuple
from collections import defaultdict, Counter

from ..data.database_manager import DatabaseManager


class PySploitVulnerabilityAssessment:
    """
    Enhanced vulnerability assessment engine that combines:
    - CVE database analysis (like Analysis/Vulnerability)  
    - RouterSploit vulnerability data
    - Network traffic analysis
    - Service fingerprinting
    """
    
    def __init__(self, db_path: Optional[str] = None):
        """Initialize vulnerability assessment engine"""
        self.logger = logging.getLogger(__name__)
        
        # Initialize database manager
        self.db_manager = DatabaseManager(db_path)
        
        if db_path:
            self.db_file = Path(db_path)
        else:
            self.db_file = Path.home() / ".pysploit" / "vulnerability_index.db"
        
        # Check if database exists, create if needed
        if not self.db_file.exists():
            self.logger.info("Vulnerability database not found, initializing...")
            self.db_manager.initialize_database()
        
        self.conn = sqlite3.connect(self.db_file)
        self.conn.row_factory = sqlite3.Row  # Enable dict-like access
        
        # RouterSploit vulnerability patterns
        self.routersploit_patterns = self._load_routersploit_patterns()
        
        self.logger.info(f"Vulnerability assessment engine initialized with database: {self.db_file}")
    
    def _load_routersploit_patterns(self) -> Dict[str, List[Dict]]:
        """Load RouterSploit vulnerability patterns for matching"""
        return {
            'router_exploits': [
                {
                    'pattern': r'linksys.*e\d+',
                    'vendor': 'Linksys',
                    'vulnerability': 'TheMoon malware RCE',
                    'cve': 'CVE-2014-2321',
                    'severity': 'CRITICAL',
                    'description': 'Linksys E-Series vulnerable to TheMoon worm'
                },
                {
                    'pattern': r'netgear.*r[67]\d+',
                    'vendor': 'Netgear',
                    'vulnerability': 'Command injection RCE',
                    'cve': 'CVE-2016-6277',
                    'severity': 'CRITICAL',
                    'description': 'Netgear R-Series command injection vulnerability'
                },
                {
                    'pattern': r'cisco.*ios',
                    'vendor': 'Cisco',
                    'vulnerability': 'Authentication bypass',
                    'cve': 'CVE-2001-0537',
                    'severity': 'HIGH',
                    'description': 'Cisco IOS HTTP authentication bypass'
                },
                {
                    'pattern': r'dlink.*dir-\d+',
                    'vendor': 'D-Link',
                    'vulnerability': 'Admin password disclosure',
                    'cve': 'CVE-2019-17621',
                    'severity': 'HIGH',
                    'description': 'D-Link DIR series password disclosure'
                },
                {
                    'pattern': r'asus.*rt-',
                    'vendor': 'ASUS',
                    'vulnerability': 'Command injection',
                    'cve': 'CVE-2018-5999',
                    'severity': 'CRITICAL',
                    'description': 'ASUS RT series router command injection'
                }
            ],
            'iot_exploits': [
                {
                    'pattern': r'mirai|gafgyt|bashlite',
                    'vendor': 'Various',
                    'vulnerability': 'IoT botnet malware',
                    'cve': 'Multiple',
                    'severity': 'CRITICAL',
                    'description': 'IoT devices vulnerable to botnet malware'
                },
                {
                    'pattern': r'hikvision|dahua',
                    'vendor': 'Hikvision/Dahua',
                    'vulnerability': 'Authentication bypass',
                    'cve': 'CVE-2017-7921',
                    'severity': 'CRITICAL',
                    'description': 'IP camera authentication bypass'
                }
            ],
            'web_exploits': [
                {
                    'pattern': r'apache.*2\.[0-4]\.\d+',
                    'vendor': 'Apache',
                    'vulnerability': 'Various Apache vulnerabilities',
                    'cve': 'Multiple',
                    'severity': 'HIGH',
                    'description': 'Apache HTTP server vulnerabilities'
                },
                {
                    'pattern': r'nginx.*1\.[0-9]\.\d+',
                    'vendor': 'Nginx',
                    'vulnerability': 'HTTP request smuggling',
                    'cve': 'CVE-2019-20372',
                    'severity': 'MEDIUM',
                    'description': 'Nginx HTTP request smuggling'
                }
            ]
        }
    
    def assess_network_services(self, nmap_data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Assess network services for vulnerabilities (similar to Analysis/Vulnerability)
        
        Args:
            nmap_data: Parsed Nmap scan data
            
        Returns:
            dict: Vulnerability assessment results
        """
        assessment = {
            'timestamp': datetime.now().isoformat(),
            'total_hosts': len(nmap_data.get('hosts', [])),
            'vulnerable_hosts': [],
            'vulnerability_summary': defaultdict(int),
            'critical_findings': [],
            'routersploit_matches': []
        }
        
        for host in nmap_data.get('hosts', []):
            host_assessment = self._assess_single_host(host)
            
            if host_assessment['vulnerabilities']:
                assessment['vulnerable_hosts'].append(host_assessment)
                
                # Update summary statistics
                for vuln in host_assessment['vulnerabilities']:
                    assessment['vulnerability_summary'][vuln['severity']] += 1
                    
                    if vuln['severity'] == 'CRITICAL':
                        assessment['critical_findings'].append({
                            'host': host.get('address'),
                            'vulnerability': vuln
                        })
        
        return assessment
    
    def _assess_single_host(self, host: Dict[str, Any]) -> Dict[str, Any]:
        """Assess a single host for vulnerabilities"""
        host_assessment = {
            'address': host.get('address'),
            'hostname': host.get('hostname', ''),
            'os': host.get('os', {}),
            'vulnerabilities': [],
            'risk_score': 0,
            'routersploit_applicable': False
        }
        
        # Check services for known vulnerabilities
        for port in host.get('ports', []):
            service = port.get('service', '').lower()
            version = port.get('version', '').lower()
            product = port.get('product', '').lower()
            
            # Search database for service vulnerabilities
            service_vulns = self._search_service_vulnerabilities(service, version, product)
            host_assessment['vulnerabilities'].extend(service_vulns)
            
            # Check RouterSploit patterns
            routersploit_matches = self._check_routersploit_patterns(service, version, product, port.get('port'))
            host_assessment['vulnerabilities'].extend(routersploit_matches)
            
            if routersploit_matches:
                host_assessment['routersploit_applicable'] = True
        
        # Calculate risk score
        host_assessment['risk_score'] = self._calculate_risk_score(host_assessment['vulnerabilities'])
        
        return host_assessment
    
    def _search_service_vulnerabilities(self, service: str, version: str, product: str) -> List[Dict[str, Any]]:
        """Search database for service-specific vulnerabilities"""
        vulnerabilities = []
        
        try:
            # Build search query
            search_terms = f"{service} {version} {product}".strip()
            
            query = '''
                SELECT cve_id, description, base_score, cvss_severity, vendors, products
                FROM vulnerabilities 
                WHERE (description LIKE ? OR vendors LIKE ? OR products LIKE ?)
                AND base_score >= 4.0
                ORDER BY base_score DESC
                LIMIT 10
            '''
            
            search_pattern = f"%{search_terms}%"
            cursor = self.conn.cursor()
            cursor.execute(query, (search_pattern, search_pattern, search_pattern))
            
            for row in cursor.fetchall():
                vulnerabilities.append({
                    'cve_id': row['cve_id'],
                    'description': row['description'][:200] + '...' if len(row['description']) > 200 else row['description'],
                    'base_score': row['base_score'],
                    'severity': row['cvss_severity'],
                    'source': 'CVE Database',
                    'matched_service': service,
                    'matched_version': version
                })
                
        except Exception as e:
            self.logger.error(f"Error searching service vulnerabilities: {e}")
        
        return vulnerabilities
    
    def _check_routersploit_patterns(self, service: str, version: str, product: str, port: int) -> List[Dict[str, Any]]:
        """Check service against RouterSploit vulnerability patterns"""
        vulnerabilities = []
        
        # Combine service information for pattern matching
        service_info = f"{service} {version} {product}".lower()
        
        # Determine category based on port and service
        if port in [80, 443, 8080, 8443] or 'http' in service:
            category = 'web_exploits'
        elif 'router' in service_info or port in [23, 2323, 8080]:
            category = 'router_exploits'
        elif port in [554, 37777, 34567] or 'camera' in service_info:
            category = 'iot_exploits'
        else:
            category = 'router_exploits'  # Default to router patterns
        
        # Check patterns for the determined category
        for pattern_data in self.routersploit_patterns.get(category, []):
            if re.search(pattern_data['pattern'], service_info, re.IGNORECASE):
                vulnerabilities.append({
                    'cve_id': pattern_data['cve'],
                    'description': pattern_data['description'],
                    'base_score': self._severity_to_score(pattern_data['severity']),
                    'severity': pattern_data['severity'],
                    'source': 'RouterSploit Pattern',
                    'vendor': pattern_data['vendor'],
                    'matched_service': service,
                    'matched_version': version,
                    'routersploit_module': pattern_data.get('module', 'generic')
                })
        
        return vulnerabilities
    
    def _severity_to_score(self, severity: str) -> float:
        """Convert severity string to CVSS score"""
        severity_map = {
            'CRITICAL': 9.5,
            'HIGH': 7.5,
            'MEDIUM': 5.0,
            'LOW': 2.5
        }
        return severity_map.get(severity.upper(), 0.0)
    
    def _calculate_risk_score(self, vulnerabilities: List[Dict[str, Any]]) -> float:
        """Calculate overall risk score for a host"""
        if not vulnerabilities:
            return 0.0
        
        # Weight by severity and count
        total_score = 0
        for vuln in vulnerabilities:
            score = vuln.get('base_score', 0)
            if vuln.get('source') == 'RouterSploit Pattern':
                score *= 1.2  # Boost RouterSploit findings as they're more actionable
            total_score += score
        
        # Normalize to 0-10 scale
        return min(total_score / len(vulnerabilities) * 1.5, 10.0)
    
    def assess_pcap_traffic(self, pcap_data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Assess PCAP traffic for vulnerability indicators (like Analysis/Vulnerability)
        
        Args:
            pcap_data: Parsed PCAP traffic data
            
        Returns:
            dict: Traffic-based vulnerability assessment
        """
        assessment = {
            'timestamp': datetime.now().isoformat(),
            'total_packets': pcap_data.get('total_packets', 0),
            'suspicious_indicators': [],
            'vulnerability_matches': [],
            'attack_signatures': [],
            'compromised_hosts': []
        }
        
        # Analyze packets for vulnerability indicators
        for packet in pcap_data.get('packets', [])[:1000]:  # Limit analysis for performance
            indicators = self._analyze_packet_for_vulnerabilities(packet)
            assessment['suspicious_indicators'].extend(indicators)
        
        # Group by source IP to identify potentially compromised hosts
        source_activity = defaultdict(list)
        for indicator in assessment['suspicious_indicators']:
            if 'src_ip' in indicator:
                source_activity[indicator['src_ip']].append(indicator)
        
        # Identify hosts with multiple suspicious activities
        for src_ip, activities in source_activity.items():
            if len(activities) >= 3:  # Threshold for suspicious activity
                assessment['compromised_hosts'].append({
                    'ip': src_ip,
                    'activity_count': len(activities),
                    'indicators': activities[:5]  # Show first 5 indicators
                })
        
        return assessment
    
    def _analyze_packet_for_vulnerabilities(self, packet: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Analyze individual packet for vulnerability indicators"""
        indicators = []
        
        # Extract packet fields
        src_ip = packet.get('ip.src', '')
        dst_ip = packet.get('ip.dst', '')
        protocol = packet.get('_ws.col.Protocol', '')
        info = packet.get('_ws.col.Info', '').lower()
        
        # HTTP-specific analysis
        http_host = packet.get('http.host', '').lower()
        http_uri = packet.get('http.request.uri', '').lower()
        http_user_agent = packet.get('http.user_agent', '').lower()
        
        # Check for RouterSploit-related patterns in HTTP traffic
        if http_host or http_uri:
            # Look for router admin interfaces
            if any(pattern in f"{http_host}{http_uri}" for pattern in ['/cgi-bin/', '/admin/', '/management/', '/setup.cgi']):
                indicators.append({
                    'type': 'router_admin_access',
                    'description': 'Access to router administration interface detected',
                    'src_ip': src_ip,
                    'dst_ip': dst_ip,
                    'severity': 'MEDIUM',
                    'details': {'host': http_host, 'uri': http_uri}
                })
            
            # Check for known vulnerability exploit patterns
            exploit_patterns = [
                ('CVE-2014-2321', r'tmUnblock\.cgi'),  # TheMoon malware
                ('CVE-2016-6277', r'setup\.cgi.*submit_button'),  # Netgear RCE
                ('Directory Traversal', r'\.\./.*\.\./'),
                ('Command Injection', r'[;&|`$()]'),
            ]
            
            for vuln_name, pattern in exploit_patterns:
                if re.search(pattern, http_uri):
                    indicators.append({
                        'type': 'exploit_attempt',
                        'description': f'Possible {vuln_name} exploit attempt detected',
                        'src_ip': src_ip,
                        'dst_ip': dst_ip,
                        'severity': 'HIGH',
                        'cve': vuln_name if vuln_name.startswith('CVE') else None,
                        'details': {'uri': http_uri, 'pattern': pattern}
                    })
        
        # DNS analysis for IoT/Router malware
        dns_query = packet.get('dns.qry.name', '').lower()
        if dns_query:
            # Check for known malware C&C domains
            malware_domains = ['zeus', 'conficker', 'mirai', 'gafgyt', 'bashlite']
            for malware in malware_domains:
                if malware in dns_query:
                    indicators.append({
                        'type': 'malware_dns',
                        'description': f'DNS query to suspected {malware} malware domain',
                        'src_ip': src_ip,
                        'severity': 'CRITICAL',
                        'details': {'domain': dns_query}
                    })
        
        return indicators
    
    def generate_comprehensive_report(self, 
                                    nmap_assessment: Optional[Dict] = None,
                                    pcap_assessment: Optional[Dict] = None) -> Dict[str, Any]:
        """Generate comprehensive vulnerability assessment report"""
        
        report = {
            'assessment_id': f"pysploit_assessment_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
            'timestamp': datetime.now().isoformat(),
            'summary': {},
            'network_assessment': nmap_assessment or {},
            'traffic_assessment': pcap_assessment or {},
            'recommendations': [],
            'routersploit_findings': []
        }
        
        # Generate summary statistics
        total_vulns = 0
        critical_count = 0
        routersploit_count = 0
        
        if nmap_assessment:
            for host in nmap_assessment.get('vulnerable_hosts', []):
                total_vulns += len(host.get('vulnerabilities', []))
                for vuln in host.get('vulnerabilities', []):
                    if vuln.get('severity') == 'CRITICAL':
                        critical_count += 1
                    if vuln.get('source') == 'RouterSploit Pattern':
                        routersploit_count += 1
                        report['routersploit_findings'].append({
                            'host': host.get('address'),
                            'vulnerability': vuln
                        })
        
        if pcap_assessment:
            total_vulns += len(pcap_assessment.get('vulnerability_matches', []))
            critical_count += len([i for i in pcap_assessment.get('suspicious_indicators', []) 
                                 if i.get('severity') == 'CRITICAL'])
        
        report['summary'] = {
            'total_vulnerabilities': total_vulns,
            'critical_vulnerabilities': critical_count,
            'routersploit_applicable': routersploit_count,
            'risk_level': self._calculate_overall_risk(total_vulns, critical_count),
            'assessment_scope': {
                'network_scan': nmap_assessment is not None,
                'traffic_analysis': pcap_assessment is not None
            }
        }
        
        # Generate recommendations
        report['recommendations'] = self._generate_recommendations(report)
        
        return report
    
    def _calculate_overall_risk(self, total_vulns: int, critical_count: int) -> str:
        """Calculate overall risk level"""
        if critical_count >= 5:
            return 'CRITICAL'
        elif critical_count >= 2 or total_vulns >= 10:
            return 'HIGH'
        elif total_vulns >= 3:
            return 'MEDIUM'
        elif total_vulns >= 1:
            return 'LOW'
        else:
            return 'MINIMAL'
    
    def _generate_recommendations(self, report: Dict[str, Any]) -> List[str]:
        """Generate security recommendations based on assessment"""
        recommendations = []
        
        summary = report.get('summary', {})
        
        if summary.get('critical_vulnerabilities', 0) > 0:
            recommendations.append("URGENT: Address critical vulnerabilities immediately")
            
        if summary.get('routersploit_applicable', 0) > 0:
            recommendations.append("Router/IoT devices detected with known RouterSploit vulnerabilities - update firmware")
            
        if report.get('traffic_assessment', {}).get('compromised_hosts'):
            recommendations.append("Investigate potentially compromised hosts identified in traffic analysis")
            
        recommendations.extend([
            "Implement network segmentation to isolate vulnerable devices",
            "Enable logging and monitoring for security events",
            "Regular vulnerability scanning and patch management",
            "Consider implementing intrusion detection systems"
        ])
        
        return recommendations
    
    def close(self):
        """Close database connection"""
        if hasattr(self, 'conn') and self.conn:
            self.conn.close()
    
    def __enter__(self):
        """Context manager entry"""
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit"""
        self.close()