"""
PySploit Core Vulnerability Analyzer
Advanced vulnerability assessment engine with comprehensive database integration
"""

import re
import json
from typing import Dict, List, Any, Optional
from datetime import datetime
from collections import defaultdict, Counter

from ..database.embedded import get_embedded_database


class PySploitAnalyzer:
    """
    Advanced vulnerability assessment engine providing comprehensive security analysis
    Features integrated vulnerability database with 55,000+ vulnerabilities from multiple sources
    """
    
    def __init__(self):
        """Initialize core analyzer"""
        self.db = get_embedded_database()
        self.vulnerability_patterns = self.db.get_vulnerability_patterns()
        
        print("[+] PySploit Analyzer initialized successfully")
    
    def analyze_service_vulnerabilities(self, 
                                      host_ip: str,
                                      services: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Analyze services for vulnerabilities using embedded database
        
        Args:
            host_ip: IP address of the host
            services: List of services [{'port': 80, 'service': 'http', 'version': '1.0', 'banner': 'Apache/2.4'}]
        
        Returns:
            dict: Analysis results with vulnerabilities found
        """
        
        analysis = {
            'host': host_ip,
            'timestamp': datetime.now().isoformat(),
            'services_analyzed': len(services),
            'vulnerabilities_found': [],
            'risk_score': 0,
            'recommendations': []
        }
        
        for service in services:
            service_vulns = self._analyze_single_service(service)
            analysis['vulnerabilities_found'].extend(service_vulns)
        
        # Calculate risk score
        analysis['risk_score'] = self._calculate_risk_score(analysis['vulnerabilities_found'])
        
        # Generate recommendations
        analysis['recommendations'] = self._generate_recommendations(analysis['vulnerabilities_found'])
        
        return analysis
    
    def _analyze_single_service(self, service: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Analyze a single service for vulnerabilities"""
        vulnerabilities = []
        
        port = service.get('port', 0)
        service_name = service.get('service', '').lower()
        version = service.get('version', '').lower() 
        banner = service.get('banner', '').lower()
        
        # Combine service info for matching
        service_info = f"{service_name} {version} {banner}".lower()
        
        # Check against RouterSploit vulnerabilities
        routersploit_vulns = self.db.search_routersploit()
        
        for vuln in routersploit_vulns:
            # Check if port matches
            target_ports_str = vuln.get('target_ports', '[]')
            target_ports = json.loads(target_ports_str) if target_ports_str else []
            if target_ports and port not in target_ports:
                continue
            
            # Check fingerprint patterns
            fingerprint_patterns_str = vuln.get('fingerprint_patterns', '[]')
            fingerprint_patterns = json.loads(fingerprint_patterns_str) if fingerprint_patterns_str else []
            
            for pattern in fingerprint_patterns:
                if pattern.lower() in service_info:
                    vulnerabilities.append({
                        'vuln_id': vuln['id'],
                        'cve_id': vuln['cve_id'],
                        'title': f"{vuln['id']} - {pattern.title()} Device Vulnerability",
                        'description': vuln['description'],
                        'severity': vuln['severity'],
                        'base_score': vuln['base_score'],
                        'source': 'RouterSploit (Embedded)',
                        'matched_pattern': pattern,
                        'affected_service': f"{service_name}:{port}",
                        'category': vuln['category']
                    })
                    break
        
        # Check against CVE database by vendor/product
        if any(term in service_info for term in ['apache', 'nginx', 'microsoft', 'cisco', 'linksys', 'netgear']):
            for term in ['apache', 'nginx', 'microsoft', 'cisco', 'linksys', 'netgear']:
                if term in service_info:
                    vendor_vulns = self.db.search_by_vendor(term)
                    
                    for vuln in vendor_vulns[:3]:  # Limit to top 3 per vendor
                        vulnerabilities.append({
                            'vuln_id': vuln['cve_id'],
                            'cve_id': vuln['cve_id'],
                            'title': f"CVE - {term.title()} Vulnerability",
                            'description': vuln['description'][:200] + '...' if len(vuln['description']) > 200 else vuln['description'],
                            'severity': vuln['severity'],
                            'base_score': vuln['base_score'],
                            'source': 'CVE (Embedded)',
                            'matched_pattern': term,
                            'affected_service': f"{service_name}:{port}",
                            'category': 'cve'
                        })
                    break
        
        return vulnerabilities
    
    def analyze_network_traffic(self, 
                              traffic_data: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Analyze network traffic for vulnerability indicators
        
        Args:
            traffic_data: List of packet dictionaries with fields like src_ip, dst_ip, payload, etc.
        
        Returns:
            dict: Traffic analysis results
        """
        
        analysis = {
            'timestamp': datetime.now().isoformat(),
            'packets_analyzed': len(traffic_data),
            'suspicious_indicators': [],
            'attack_attempts': [],
            'compromised_hosts': [],
            'summary': {}
        }
        
        host_activity = defaultdict(list)
        
        # Analyze each packet
        for packet in traffic_data:
            indicators = self._analyze_packet(packet)
            analysis['suspicious_indicators'].extend(indicators)
            
            # Track activity per host
            src_ip = packet.get('src_ip', packet.get('ip.src', ''))
            if src_ip and indicators:
                host_activity[src_ip].extend(indicators)
        
        # Identify potentially compromised hosts
        for host_ip, activities in host_activity.items():
            if len(activities) >= 3:  # Threshold for suspicious
                analysis['compromised_hosts'].append({
                    'ip': host_ip,
                    'suspicious_activities': len(activities),
                    'activity_types': list(set(activity['type'] for activity in activities)),
                    'risk_level': 'HIGH' if len(activities) >= 5 else 'MEDIUM'
                })
        
        # Generate summary
        analysis['summary'] = {
            'total_indicators': len(analysis['suspicious_indicators']),
            'unique_attack_types': len(set(ind['type'] for ind in analysis['suspicious_indicators'])),
            'potentially_compromised_hosts': len(analysis['compromised_hosts']),
            'critical_alerts': len([ind for ind in analysis['suspicious_indicators'] if ind.get('severity') == 'CRITICAL'])
        }
        
        return analysis
    
    def _analyze_packet(self, packet: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Analyze individual packet for vulnerability indicators"""
        indicators = []
        
        # Extract packet information
        src_ip = packet.get('src_ip', packet.get('ip.src', ''))
        dst_ip = packet.get('dst_ip', packet.get('ip.dst', ''))
        payload = packet.get('payload', packet.get('data', '')).lower()
        uri = packet.get('uri', packet.get('http.request.uri', '')).lower()
        host = packet.get('host', packet.get('http.host', '')).lower()
        
        # Combine all text fields for analysis
        full_content = f"{payload} {uri} {host}".lower()
        
        # Check against vulnerability patterns
        for pattern_data in self.vulnerability_patterns:
            pattern = pattern_data['pattern']
            
            try:
                if re.search(pattern, full_content, re.IGNORECASE):
                    indicators.append({
                        'type': pattern_data['pattern_type'],
                        'severity': pattern_data['severity'],
                        'description': f"{pattern_data['pattern_type']} pattern detected",
                        'pattern_matched': pattern,
                        'src_ip': src_ip,
                        'dst_ip': dst_ip,
                        'category': pattern_data['category'],
                        'evidence': full_content[:100] + '...' if len(full_content) > 100 else full_content
                    })
            except re.error:
                # Skip invalid regex patterns
                continue
        
        return indicators
    
    def quick_assessment(self, 
                        target_hosts: List[Dict[str, Any]] = None,
                        traffic_data: List[Dict[str, Any]] = None) -> Dict[str, Any]:
        """
        Perform quick vulnerability assessment
        
        Args:
            target_hosts: List of host dictionaries with IP and services
            traffic_data: Network traffic data for analysis
        
        Returns:
            dict: Comprehensive assessment results
        """
        
        assessment = {
            'assessment_id': f"pysploit_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
            'timestamp': datetime.now().isoformat(),
            'assessment_type': 'Offline',
            'network_assessment': None,
            'traffic_assessment': None,
            'overall_risk': 'LOW',
            'critical_findings': [],
            'recommendations': []
        }
        
        # Analyze target hosts if provided
        if target_hosts:
            network_results = []
            
            for host in target_hosts:
                host_ip = host.get('ip', 'unknown')
                services = host.get('services', [])
                
                host_analysis = self.analyze_service_vulnerabilities(host_ip, services)
                network_results.append(host_analysis)
                
                # Track critical findings
                for vuln in host_analysis['vulnerabilities_found']:
                    if vuln['severity'] == 'CRITICAL':
                        assessment['critical_findings'].append({
                            'host': host_ip,
                            'vulnerability': vuln['title'],
                            'cve': vuln.get('cve_id', 'N/A'),
                            'score': vuln['base_score']
                        })
            
            assessment['network_assessment'] = {
                'hosts_analyzed': len(target_hosts),
                'vulnerable_hosts': len([h for h in network_results if h['vulnerabilities_found']]),
                'total_vulnerabilities': sum(len(h['vulnerabilities_found']) for h in network_results),
                'host_results': network_results
            }
        
        # Analyze traffic if provided
        if traffic_data:
            assessment['traffic_assessment'] = self.analyze_network_traffic(traffic_data)
        
        # Calculate overall risk
        critical_count = len(assessment['critical_findings'])
        if assessment['traffic_assessment']:
            critical_count += assessment['traffic_assessment']['summary'].get('critical_alerts', 0)
        
        if critical_count >= 3:
            assessment['overall_risk'] = 'CRITICAL'
        elif critical_count >= 1:
            assessment['overall_risk'] = 'HIGH'
        elif assessment.get('network_assessment', {}).get('total_vulnerabilities', 0) > 0:
            assessment['overall_risk'] = 'MEDIUM'
        
        # Generate recommendations
        assessment['recommendations'] = self._generate_assessment_recommendations(assessment)
        
        return assessment
    
    def _calculate_risk_score(self, vulnerabilities: List[Dict[str, Any]]) -> float:
        """Calculate risk score from vulnerabilities"""
        if not vulnerabilities:
            return 0.0
        
        total_score = sum(vuln.get('base_score', 0) for vuln in vulnerabilities)
        return min(total_score / len(vulnerabilities), 10.0)
    
    def _generate_recommendations(self, vulnerabilities: List[Dict[str, Any]]) -> List[str]:
        """Generate recommendations based on vulnerabilities"""
        recommendations = []
        
        if not vulnerabilities:
            return ['No vulnerabilities detected - maintain current security posture']
        
        severity_counts = Counter(vuln['severity'] for vuln in vulnerabilities)
        
        if severity_counts.get('CRITICAL', 0) > 0:
            recommendations.append(f"URGENT: {severity_counts['CRITICAL']} critical vulnerabilities require immediate attention")
        
        if severity_counts.get('HIGH', 0) > 0:
            recommendations.append(f"HIGH PRIORITY: Address {severity_counts['HIGH']} high-severity vulnerabilities")
        
        # RouterSploit specific recommendations
        routersploit_vulns = [v for v in vulnerabilities if 'RouterSploit' in v.get('source', '')]
        if routersploit_vulns:
            recommendations.append("Router/IoT devices detected with known exploits - update firmware immediately")
        
        recommendations.extend([
            "Implement network segmentation to isolate vulnerable devices",
            "Enable comprehensive logging and monitoring",
            "Regular security assessments and patch management",
            "Consider deploying intrusion detection systems"
        ])
        
        return recommendations
    
    def _generate_assessment_recommendations(self, assessment: Dict[str, Any]) -> List[str]:
        """Generate overall assessment recommendations"""
        recommendations = []
        
        overall_risk = assessment.get('overall_risk', 'LOW')
        
        if overall_risk == 'CRITICAL':
            recommendations.append("[CRITICAL] Immediate security intervention required")
        elif overall_risk == 'HIGH':
            recommendations.append("[HIGH RISK] Priority security remediation needed")
        
        if assessment.get('critical_findings'):
            recommendations.append(f"Address {len(assessment['critical_findings'])} critical security vulnerabilities")
        
        if assessment.get('traffic_assessment', {}).get('compromised_hosts'):
            recommendations.append("Investigate potentially compromised hosts identified in traffic analysis")
        
        recommendations.extend([
            "Conduct regular vulnerability assessments",
            "Implement defense-in-depth security strategy", 
            "Monitor network traffic for anomalous behavior",
            "Maintain up-to-date asset inventory and patch management"
        ])
        
        return recommendations
    
    def get_database_info(self) -> Dict[str, Any]:
        """Get information about the embedded database"""
        stats = self.db.get_statistics()
        
        return {
            'database_type': 'Comprehensive',
            'statistics': stats,
            'capabilities': [
                'CVE vulnerability assessment',
                'ExploitDB security analysis', 
                'NVD threat intelligence',
                'RouterSploit exploitation detection',
                'Network traffic analysis',
                'Service vulnerability assessment',
                'Pattern-based threat detection'
            ],
            'data_sources': [
                'CVE database (comprehensive vulnerability intelligence)',
                'ExploitDB (46,000+ exploit signatures)',
                'NVD (9,000+ vulnerability entries)',
                'RouterSploit framework modules',
                'Advanced vulnerability detection patterns'
            ]
        }
    
    def close(self):
        """Clean up resources"""
        if hasattr(self, 'db'):
            self.db.close()
    
    def __enter__(self):
        """Context manager entry"""
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit"""
        self.close()


# Convenience functions for quick use
def analyze_host(ip: str, services: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Quick host vulnerability analysis"""
    analyzer = PySploitAnalyzer()
    return analyzer.analyze_service_vulnerabilities(ip, services)


def analyze_traffic(packets: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Quick traffic analysis"""
    analyzer = PySploitAnalyzer()
    return analyzer.analyze_network_traffic(packets)


def quick_scan(hosts: List[Dict[str, Any]] = None, 
               traffic: List[Dict[str, Any]] = None) -> Dict[str, Any]:
    """Quick comprehensive scan"""
    analyzer = PySploitAnalyzer()
    return analyzer.quick_assessment(hosts, traffic)