#!/usr/bin/env python3
"""
PySploit Nmap Analyzer
Parse and analyze Nmap XML scan results for vulnerability assessment.
"""

import xml.etree.ElementTree as ET
import pandas as pd
import re
from typing import Dict, List, Any, Optional
import os


class NmapAnalyzer:
    """
    Analyze Nmap XML scan results and extract vulnerability-relevant information.
    """
    
    def __init__(self):
        pass
    
    def parse_xml(self, xml_path: str) -> Dict[str, Any]:
        """
        Parse Nmap XML file and extract scan results.
        
        Args:
            xml_path (str): Path to Nmap XML file
            
        Returns:
            dict: Parsed scan results
        """
        if not os.path.exists(xml_path):
            raise FileNotFoundError(f"Nmap XML file not found: {xml_path}")
        
        try:
            tree = ET.parse(xml_path)
            root = tree.getroot()
            
            scan_results = {
                'scan_info': self._extract_scan_info(root),
                'hosts': self._extract_hosts(root),
                'summary': self._generate_summary(root)
            }
            
            return scan_results
            
        except ET.ParseError as e:
            raise ValueError(f"Invalid XML file: {e}")
        except Exception as e:
            raise RuntimeError(f"Error parsing XML: {e}")
    
    def _extract_scan_info(self, root: ET.Element) -> Dict[str, Any]:
        """Extract scan metadata from XML root."""
        scan_info = {}
        
        # Nmaprun attributes
        scan_info['scanner'] = root.get('scanner', 'nmap')
        scan_info['version'] = root.get('version', 'unknown')
        scan_info['start_time'] = root.get('startstr', 'unknown')
        
        # Scaninfo element
        scaninfo = root.find('scaninfo')
        if scaninfo is not None:
            scan_info['type'] = scaninfo.get('type', 'unknown')
            scan_info['protocol'] = scaninfo.get('protocol', 'unknown')
            scan_info['num_services'] = scaninfo.get('numservices', '0')
        
        return scan_info
    
    def _extract_hosts(self, root: ET.Element) -> List[Dict[str, Any]]:
        """Extract host information from scan results."""
        hosts = []
        
        for host in root.findall('host'):
            host_data = self._extract_single_host(host)
            if host_data:
                hosts.append(host_data)
        
        return hosts
    
    def _extract_single_host(self, host: ET.Element) -> Optional[Dict[str, Any]]:
        """Extract information for a single host."""
        host_data = {}
        
        # Host status
        status = host.find('status')
        if status is not None:
            host_data['status'] = status.get('state', 'unknown')
            if host_data['status'] != 'up':
                return None  # Skip hosts that are not up
        
        # IP address
        address = host.find("address[@addrtype='ipv4']")
        if address is not None:
            host_data['ip'] = address.get('addr', 'unknown')
        
        # Hostname
        hostnames = host.find('hostnames')
        if hostnames is not None:
            hostname = hostnames.find('hostname')
            if hostname is not None:
                host_data['hostname'] = hostname.get('name', '')
        
        # OS detection
        os_info = self._extract_os_info(host)
        if os_info:
            host_data['os'] = os_info
        
        # Ports and services
        ports = self._extract_ports(host)
        if ports:
            host_data['ports'] = ports
            host_data['open_ports'] = [p for p in ports if p.get('state') == 'open']
        
        # NSE scripts
        host_scripts = self._extract_host_scripts(host)
        if host_scripts:
            host_data['host_scripts'] = host_scripts
        
        return host_data
    
    def _extract_os_info(self, host: ET.Element) -> Optional[Dict[str, Any]]:
        """Extract OS detection information."""
        os_elem = host.find('os')
        if os_elem is None:
            return None
        
        os_info = {}
        
        # OS matches
        osmatch = os_elem.find('osmatch')
        if osmatch is not None:
            os_info['name'] = osmatch.get('name', 'unknown')
            os_info['accuracy'] = osmatch.get('accuracy', '0')
        
        # OS classes
        osclass = os_elem.find('osclass')
        if osclass is not None:
            os_info['type'] = osclass.get('type', 'unknown')
            os_info['vendor'] = osclass.get('vendor', 'unknown')
            os_info['osfamily'] = osclass.get('osfamily', 'unknown')
        
        return os_info
    
    def _extract_ports(self, host: ET.Element) -> List[Dict[str, Any]]:
        """Extract port and service information."""
        ports = []
        
        ports_elem = host.find('ports')
        if ports_elem is None:
            return ports
        
        for port in ports_elem.findall('port'):
            port_data = {}
            
            # Port number and protocol
            port_data['port'] = port.get('portid', '0')
            port_data['protocol'] = port.get('protocol', 'tcp')
            
            # Port state
            state = port.find('state')
            if state is not None:
                port_data['state'] = state.get('state', 'unknown')
                port_data['reason'] = state.get('reason', 'unknown')
            
            # Service information
            service = port.find('service')
            if service is not None:
                port_data['service'] = service.get('name', 'unknown')
                port_data['product'] = service.get('product', '')
                port_data['version'] = service.get('version', '')
                port_data['extrainfo'] = service.get('extrainfo', '')
                port_data['method'] = service.get('method', 'unknown')
                port_data['conf'] = service.get('conf', '0')
            
            # NSE scripts for this port
            port_scripts = self._extract_port_scripts(port)
            if port_scripts:
                port_data['scripts'] = port_scripts
            
            ports.append(port_data)
        
        return ports
    
    def _extract_port_scripts(self, port: ET.Element) -> List[Dict[str, Any]]:
        """Extract NSE script results for a port."""
        scripts = []
        
        for script in port.findall('script'):
            script_data = {
                'id': script.get('id', 'unknown'),
                'output': script.get('output', ''),
            }
            
            # Extract script elements and tables if present
            elements = self._extract_script_elements(script)
            if elements:
                script_data['elements'] = elements
            
            scripts.append(script_data)
        
        return scripts
    
    def _extract_host_scripts(self, host: ET.Element) -> List[Dict[str, Any]]:
        """Extract host-level NSE scripts."""
        scripts = []
        
        hostscript = host.find('hostscript')
        if hostscript is not None:
            for script in hostscript.findall('script'):
                script_data = {
                    'id': script.get('id', 'unknown'),
                    'output': script.get('output', ''),
                }
                
                elements = self._extract_script_elements(script)
                if elements:
                    script_data['elements'] = elements
                
                scripts.append(script_data)
        
        return scripts
    
    def _extract_script_elements(self, script: ET.Element) -> Dict[str, Any]:
        """Extract structured data from NSE script results."""
        elements = {}
        
        # Extract elem and table elements
        for elem in script.findall('.//elem'):
            key = elem.get('key')
            value = elem.text
            if key and value:
                elements[key] = value
        
        for table in script.findall('.//table'):
            table_key = table.get('key', 'table')
            table_data = {}
            
            for elem in table.findall('elem'):
                key = elem.get('key')
                value = elem.text
                if key and value:
                    table_data[key] = value
            
            if table_data:
                elements[table_key] = table_data
        
        return elements
    
    def _generate_summary(self, root: ET.Element) -> Dict[str, Any]:
        """Generate scan summary statistics."""
        summary = {
            'total_hosts': 0,
            'hosts_up': 0,
            'hosts_down': 0,
            'open_ports': 0,
            'services': {},
            'os_types': {}
        }
        
        for host in root.findall('host'):
            summary['total_hosts'] += 1
            
            status = host.find('status')
            if status is not None and status.get('state') == 'up':
                summary['hosts_up'] += 1
            else:
                summary['hosts_down'] += 1
            
            # Count services and open ports
            ports_elem = host.find('ports')
            if ports_elem is not None:
                for port in ports_elem.findall('port'):
                    state = port.find('state')
                    if state is not None and state.get('state') == 'open':
                        summary['open_ports'] += 1
                        
                        service = port.find('service')
                        if service is not None:
                            service_name = service.get('name', 'unknown')
                            summary['services'][service_name] = summary['services'].get(service_name, 0) + 1
            
            # Count OS types
            os_elem = host.find('os')
            if os_elem is not None:
                osclass = os_elem.find('osclass')
                if osclass is not None:
                    os_type = osclass.get('type', 'unknown')
                    summary['os_types'][os_type] = summary['os_types'].get(os_type, 0) + 1
        
        return summary
    
    def extract_services(self, scan_results: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Extract service information from scan results.
        
        Args:
            scan_results (dict): Parsed scan results from parse_xml()
            
        Returns:
            list: Service information with vulnerability context
        """
        services = []
        
        for host in scan_results.get('hosts', []):
            host_ip = host.get('ip', 'unknown')
            
            for port in host.get('open_ports', []):
                service_info = {
                    'host': host_ip,
                    'port': port.get('port', '0'),
                    'protocol': port.get('protocol', 'tcp'),
                    'service': port.get('service', 'unknown'),
                    'product': port.get('product', ''),
                    'version': port.get('version', ''),
                    'state': port.get('state', 'unknown')
                }
                
                # Add script results that might indicate vulnerabilities
                scripts = port.get('scripts', [])
                vuln_scripts = [s for s in scripts if 'vuln' in s.get('id', '')]
                if vuln_scripts:
                    service_info['vulnerability_scripts'] = vuln_scripts
                
                services.append(service_info)
        
        return services
    
    def identify_vulnerabilities(self, services: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Identify potential vulnerabilities based on service information.
        
        Args:
            services (list): Service information from extract_services()
            
        Returns:
            list: Potential vulnerability indicators
        """
        vulnerabilities = []
        
        for service in services:
            vulns = self._check_service_vulnerabilities(service)
            vulnerabilities.extend(vulns)
        
        return vulnerabilities
    
    def _check_service_vulnerabilities(self, service: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Check a single service for vulnerability indicators."""
        vulnerabilities = []
        
        # Check for known vulnerable services
        vulnerable_services = {
            'ftp': {'severity': 'MEDIUM', 'description': 'FTP service (potential cleartext authentication)'},
            'telnet': {'severity': 'HIGH', 'description': 'Telnet service (cleartext authentication)'},
            'rlogin': {'severity': 'HIGH', 'description': 'Rlogin service (insecure authentication)'},
            'smtp': {'severity': 'LOW', 'description': 'SMTP service (potential information disclosure)'},
            'snmp': {'severity': 'MEDIUM', 'description': 'SNMP service (potential community string exposure)'}
        }
        
        service_name = service.get('service', '').lower()
        if service_name in vulnerable_services:
            vuln_info = vulnerable_services[service_name]
            vulnerabilities.append({
                'type': 'insecure_service',
                'severity': vuln_info['severity'],
                'description': vuln_info['description'],
                'host': service.get('host', 'unknown'),
                'port': service.get('port', 'unknown'),
                'service': service_name
            })
        
        # Check for outdated software versions
        product = service.get('product', '')
        version = service.get('version', '')
        
        if product and version:
            # Simple check for obviously old versions
            if self._is_old_version(product, version):
                vulnerabilities.append({
                    'type': 'outdated_software',
                    'severity': 'MEDIUM',
                    'description': f'Potentially outdated {product} version {version}',
                    'host': service.get('host', 'unknown'),
                    'port': service.get('port', 'unknown'),
                    'product': product,
                    'version': version
                })
        
        # Check NSE vulnerability scripts
        vuln_scripts = service.get('vulnerability_scripts', [])
        for script in vuln_scripts:
            script_id = script.get('id', '')
            output = script.get('output', '')
            
            if 'VULNERABLE' in output.upper():
                vulnerabilities.append({
                    'type': 'nse_vulnerability',
                    'severity': 'HIGH',
                    'description': f'NSE script {script_id} detected vulnerability',
                    'host': service.get('host', 'unknown'),
                    'port': service.get('port', 'unknown'),
                    'script_id': script_id,
                    'script_output': output[:200]  # Truncate long output
                })
        
        return vulnerabilities
    
    def _is_old_version(self, product: str, version: str) -> bool:
        """Simple heuristic to detect potentially old software versions."""
        # This is a simplified check - in practice you'd want a more sophisticated approach
        try:
            # Extract year from version string if possible
            year_match = re.search(r'20(\d{2})', version)
            if year_match:
                year = int('20' + year_match.group(1))
                current_year = 2025  # Update this as needed
                return (current_year - year) > 3  # Consider 3+ years old as potentially outdated
        except:
            pass
        
        return False