#!/usr/bin/env python3
"""
PySploit Helper Functions
Utility functions for configuration, logging, and common operations.
"""

import os
import json
import logging
from typing import Dict, Any, Optional
from pathlib import Path


def load_config(config_path: Optional[str] = None) -> Dict[str, Any]:
    """
    Load PySploit configuration from file.
    
    Args:
        config_path (str, optional): Path to config file
        
    Returns:
        dict: Configuration settings
    """
    default_config = {
        'database': {
            'path': 'vulnerability_index.db',
            'auto_update': True,
            'update_interval_days': 7
        },
        'analysis': {
            'default_filter': 'all',
            'confidence_threshold': 'MEDIUM',
            'severity_threshold': 'LOW'
        },
        'reports': {
            'default_format': 'json',
            'output_directory': 'reports',
            'include_raw_data': False
        },
        'logging': {
            'level': 'INFO',
            'file': 'pysploit.log',
            'console': True
        }
    }
    
    if config_path and os.path.exists(config_path):
        try:
            with open(config_path, 'r') as f:
                user_config = json.load(f)
            
            # Merge user config with defaults
            config = default_config.copy()
            config.update(user_config)
            return config
        except Exception as e:
            print(f"Error loading config from {config_path}: {e}")
    
    # Look for config in common locations
    config_locations = [
        'pysploit_config.json',
        os.path.expanduser('~/.pysploit/config.json'),
        '/etc/pysploit/config.json'
    ]
    
    for location in config_locations:
        if os.path.exists(location):
            try:
                with open(location, 'r') as f:
                    user_config = json.load(f)
                
                config = default_config.copy()
                config.update(user_config)
                return config
            except Exception:
                continue
    
    return default_config


def setup_logging(config: Optional[Dict[str, Any]] = None) -> logging.Logger:
    """
    Setup logging configuration for PySploit.
    
    Args:
        config (dict, optional): Logging configuration
        
    Returns:
        Logger: Configured logger instance
    """
    if config is None:
        config = load_config().get('logging', {})
    
    # Create logger
    logger = logging.getLogger('pysploit')
    logger.setLevel(getattr(logging, config.get('level', 'INFO').upper()))
    
    # Clear existing handlers
    logger.handlers.clear()
    
    # Console handler
    if config.get('console', True):
        console_handler = logging.StreamHandler()
        console_format = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )
        console_handler.setFormatter(console_format)
        logger.addHandler(console_handler)
    
    # File handler
    log_file = config.get('file')
    if log_file:
        try:
            # Create log directory if it doesn't exist
            log_dir = os.path.dirname(log_file)
            if log_dir:
                os.makedirs(log_dir, exist_ok=True)
            
            file_handler = logging.FileHandler(log_file)
            file_format = logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s'
            )
            file_handler.setFormatter(file_format)
            logger.addHandler(file_handler)
        except Exception as e:
            logger.warning(f"Could not setup file logging: {e}")
    
    return logger


def ensure_directory(directory: str) -> str:
    """
    Ensure directory exists, create if it doesn't.
    
    Args:
        directory (str): Directory path
        
    Returns:
        str: Absolute path to directory
    """
    abs_path = os.path.abspath(directory)
    os.makedirs(abs_path, exist_ok=True)
    return abs_path


def get_data_directory() -> str:
    """
    Get or create PySploit data directory.
    
    Returns:
        str: Path to data directory
    """
    # Try user-specific directory first
    user_data_dir = os.path.expanduser('~/.pysploit/data')
    
    if os.path.exists(user_data_dir) or os.access(os.path.expanduser('~'), os.W_OK):
        return ensure_directory(user_data_dir)
    
    # Fall back to current directory
    return ensure_directory('./pysploit_data')


def format_cve_id(cve_id: str) -> str:
    """
    Format CVE ID to standard format.
    
    Args:
        cve_id (str): CVE identifier
        
    Returns:
        str: Formatted CVE ID
    """
    if not cve_id:
        return ""
    
    # Remove extra whitespace and convert to uppercase
    cve_id = cve_id.strip().upper()
    
    # Add CVE- prefix if missing
    if not cve_id.startswith('CVE-'):
        if cve_id.startswith(('19', '20')) and len(cve_id) >= 9:
            # Looks like a CVE without prefix (e.g., "2021-1234")
            cve_id = f"CVE-{cve_id}"
    
    return cve_id


def parse_severity(severity: str) -> str:
    """
    Parse and normalize severity string.
    
    Args:
        severity (str): Severity indicator
        
    Returns:
        str: Normalized severity (CRITICAL, HIGH, MEDIUM, LOW, UNKNOWN)
    """
    if not severity:
        return "UNKNOWN"
    
    severity = severity.upper().strip()
    
    # Map various severity formats
    severity_mapping = {
        'CRITICAL': 'CRITICAL',
        'HIGH': 'HIGH', 
        'MEDIUM': 'MEDIUM',
        'MODERATE': 'MEDIUM',
        'LOW': 'LOW',
        'INFO': 'LOW',
        'INFORMATIONAL': 'LOW',
        '9.0': 'CRITICAL', '9.1': 'CRITICAL', '9.2': 'CRITICAL', 
        '9.3': 'CRITICAL', '9.4': 'CRITICAL', '9.5': 'CRITICAL',
        '9.6': 'CRITICAL', '9.7': 'CRITICAL', '9.8': 'CRITICAL', 
        '9.9': 'CRITICAL', '10.0': 'CRITICAL'
    }
    
    # Try direct mapping first
    if severity in severity_mapping:
        return severity_mapping[severity]
    
    # Try to parse CVSS score ranges
    try:
        score = float(severity)
        if score >= 9.0:
            return 'CRITICAL'
        elif score >= 7.0:
            return 'HIGH'
        elif score >= 4.0:
            return 'MEDIUM'
        else:
            return 'LOW'
    except ValueError:
        pass
    
    return "UNKNOWN"


def validate_ip_address(ip: str) -> bool:
    """
    Validate IP address format.
    
    Args:
        ip (str): IP address string
        
    Returns:
        bool: True if valid IP address
    """
    try:
        import ipaddress
        ipaddress.ip_address(ip)
        return True
    except ValueError:
        return False


def extract_domain_from_url(url: str) -> str:
    """
    Extract domain from URL.
    
    Args:
        url (str): URL string
        
    Returns:
        str: Domain name or empty string
    """
    try:
        from urllib.parse import urlparse
        parsed = urlparse(url)
        return parsed.netloc.lower()
    except Exception:
        return ""


def get_file_hash(file_path: str, algorithm: str = 'sha256') -> str:
    """
    Calculate file hash.
    
    Args:
        file_path (str): Path to file
        algorithm (str): Hash algorithm (md5, sha1, sha256)
        
    Returns:
        str: File hash or empty string if error
    """
    try:
        import hashlib
        
        hash_func = getattr(hashlib, algorithm)()
        
        with open(file_path, 'rb') as f:
            for chunk in iter(lambda: f.read(4096), b""):
                hash_func.update(chunk)
        
        return hash_func.hexdigest()
    except Exception:
        return ""


def safe_json_loads(json_str: str, default: Any = None) -> Any:
    """
    Safely load JSON string with fallback.
    
    Args:
        json_str (str): JSON string
        default: Default value if parsing fails
        
    Returns:
        Any: Parsed JSON or default value
    """
    try:
        return json.loads(json_str)
    except (json.JSONDecodeError, TypeError):
        return default


def truncate_string(text: str, max_length: int = 100, suffix: str = "...") -> str:
    """
    Truncate string to specified length.
    
    Args:
        text (str): Input text
        max_length (int): Maximum length
        suffix (str): Suffix to add if truncated
        
    Returns:
        str: Truncated string
    """
    if len(text) <= max_length:
        return text
    
    return text[:max_length - len(suffix)] + suffix