#!/usr/bin/env python3
"""
PySploit Vulnerability Database
Core vulnerability database management and querying functionality.
"""

import sqlite3
import pandas as pd
import os
import re
import json
from collections import defaultdict, Counter
from datetime import datetime
from typing import Dict, List, Optional, Any


class VulnerabilityDatabase:
    """
    Main interface for vulnerability database operations.
    Provides access to CVE data, ExploitDB entries, and RouterSploit modules.
    """
    
    def __init__(self, db_path: Optional[str] = None):
        """
        Initialize vulnerability database connection.
        
        Args:
            db_path (str, optional): Path to SQLite database file
        """
        if db_path:
            self.db_file = db_path
        else:
            # Default to looking for database in common locations
            potential_paths = [
                "vulnerability_index.db",
                "data/vulnerability_index.db", 
                os.path.join(os.path.expanduser("~"), ".pysploit", "vulnerability_index.db"),
                # Path from our original project
                os.path.join("..", "Analysis", "Vulnerability", "exploit_db", "index", "vulnerability_index.db")
            ]
            
            self.db_file = None
            for path in potential_paths:
                if os.path.exists(path):
                    self.db_file = path
                    break
        
        if not self.db_file or not os.path.exists(self.db_file):
            raise FileNotFoundError(
                f"Vulnerability database not found. Please ensure the database is available or "
                f"specify the path with db_path parameter."
            )
        
        self.conn = sqlite3.connect(self.db_file)
        self.conn.row_factory = sqlite3.Row  # Enable dict-like access to rows
    
    def search_by_cve(self, cve_id: str) -> List[Dict[str, Any]]:
        """
        Search for vulnerability by CVE ID.
        
        Args:
            cve_id (str): CVE identifier (e.g., 'CVE-2021-44228')
        
        Returns:
            list: List of vulnerability records matching the CVE
        """
        query = "SELECT * FROM vulnerabilities WHERE cve_id = ?"
        cursor = self.conn.cursor()
        cursor.execute(query, (cve_id,))
        
        results = cursor.fetchall()
        return [dict(row) for row in results]
    
    def search(self, query: str, category: Optional[str] = None, 
               severity: Optional[str] = None, limit: int = 100) -> List[Dict[str, Any]]:
        """
        Search vulnerabilities with optional filters.
        
        Args:
            query (str): Search query (searches in description, vendors, products)
            category (str, optional): Vulnerability category filter
            severity (str, optional): Severity filter (LOW, MEDIUM, HIGH, CRITICAL)
            limit (int): Maximum results to return
        
        Returns:
            list: Matching vulnerability records
        """
        sql_query = "SELECT * FROM vulnerabilities WHERE 1=1"
        params = []
        
        # Text search across multiple fields
        if query:
            sql_query += " AND (description LIKE ? OR vendors LIKE ? OR products LIKE ? OR cve_id LIKE ?)"
            search_term = f"%{query}%"
            params.extend([search_term, search_term, search_term, search_term])
        
        # Category filter
        if category:
            sql_query += " AND (exploit_type LIKE ? OR platform LIKE ?)"
            params.extend([f"%{category}%", f"%{category}%"])
        
        # Severity filter
        if severity:
            sql_query += " AND cvss_severity = ?"
            params.append(severity.upper())
        
        sql_query += f" LIMIT {limit}"
        
        cursor = self.conn.cursor()
        cursor.execute(sql_query, params)
        results = cursor.fetchall()
        
        return [dict(row) for row in results]
    
    def search_by_category(self, category: str) -> List[Dict[str, Any]]:
        """
        Search vulnerabilities by category (router, iot, web, etc.).
        
        Args:
            category (str): Category to search for
        
        Returns:
            list: Vulnerability records in the category
        """
        return self.search("", category=category, limit=1000)
    
    def load_cves(self) -> List[Dict[str, Any]]:
        """
        Load all CVE records from the database.
        
        Returns:
            list: All CVE records
        """
        query = "SELECT * FROM vulnerabilities WHERE cve_id IS NOT NULL AND cve_id != ''"
        cursor = self.conn.cursor()
        cursor.execute(query)
        results = cursor.fetchall()
        
        return [dict(row) for row in results]
    
    def load_exploits(self) -> List[Dict[str, Any]]:
        """
        Load ExploitDB entries from the database.
        
        Returns:
            list: ExploitDB exploit records
        """
        query = "SELECT * FROM vulnerabilities WHERE source = 'exploitdb'"
        cursor = self.conn.cursor()
        cursor.execute(query)
        results = cursor.fetchall()
        
        return [dict(row) for row in results]
    
    def load_routersploit(self) -> List[Dict[str, Any]]:
        """
        Load RouterSploit module data from the database.
        
        Returns:
            list: RouterSploit vulnerability records
        """
        query = "SELECT * FROM vulnerabilities WHERE source = 'routersploit'"
        cursor = self.conn.cursor()
        cursor.execute(query)
        results = cursor.fetchall()
        
        return [dict(row) for row in results]
    
    def get_statistics(self) -> Dict[str, Any]:
        """
        Get database statistics.
        
        Returns:
            dict: Database statistics including counts by source, severity, etc.
        """
        stats = {}
        
        # Total vulnerabilities
        cursor = self.conn.cursor()
        cursor.execute("SELECT COUNT(*) FROM vulnerabilities")
        stats['total_vulnerabilities'] = cursor.fetchone()[0]
        
        # By source
        cursor.execute("SELECT source, COUNT(*) FROM vulnerabilities GROUP BY source")
        stats['by_source'] = dict(cursor.fetchall())
        
        # By severity
        cursor.execute("SELECT cvss_severity, COUNT(*) FROM vulnerabilities WHERE cvss_severity IS NOT NULL GROUP BY cvss_severity")
        stats['by_severity'] = dict(cursor.fetchall())
        
        # By year (from CVE IDs)
        cursor.execute("""
            SELECT substr(cve_id, 5, 4) as year, COUNT(*) 
            FROM vulnerabilities 
            WHERE cve_id LIKE 'CVE-%' 
            GROUP BY year 
            ORDER BY year DESC 
            LIMIT 10
        """)
        stats['by_year'] = dict(cursor.fetchall())
        
        return stats
    
    def search_exploitdb(self, query: str) -> List[Dict[str, Any]]:
        """
        Search specifically in ExploitDB entries.
        
        Args:
            query (str): Search query
        
        Returns:
            list: Matching ExploitDB records
        """
        sql_query = """
            SELECT * FROM vulnerabilities 
            WHERE source = 'exploitdb' 
            AND (description LIKE ? OR exploit_type LIKE ? OR platform LIKE ?)
        """
        search_term = f"%{query}%"
        
        cursor = self.conn.cursor()
        cursor.execute(sql_query, (search_term, search_term, search_term))
        results = cursor.fetchall()
        
        return [dict(row) for row in results]
    
    def get_high_severity_cves(self, limit: int = 50) -> List[Dict[str, Any]]:
        """
        Get high severity CVEs (CRITICAL and HIGH).
        
        Args:
            limit (int): Maximum results to return
        
        Returns:
            list: High severity CVE records
        """
        query = """
            SELECT * FROM vulnerabilities 
            WHERE cvss_severity IN ('CRITICAL', 'HIGH') 
            AND cve_id IS NOT NULL 
            ORDER BY cvss_score DESC 
            LIMIT ?
        """
        
        cursor = self.conn.cursor()
        cursor.execute(query, (limit,))
        results = cursor.fetchall()
        
        return [dict(row) for row in results]
    
    def find_recent_cves(self, year: int = None) -> List[Dict[str, Any]]:
        """
        Find recent CVEs by year.
        
        Args:
            year (int, optional): Year to search (defaults to current year)
        
        Returns:
            list: CVE records from the specified year
        """
        if year is None:
            year = datetime.now().year
        
        query = "SELECT * FROM vulnerabilities WHERE cve_id LIKE ?"
        cve_pattern = f"CVE-{year}-%"
        
        cursor = self.conn.cursor()
        cursor.execute(query, (cve_pattern,))
        results = cursor.fetchall()
        
        return [dict(row) for row in results]
    
    def close(self):
        """Close database connection."""
        if self.conn:
            self.conn.close()
    
    def __enter__(self):
        """Context manager entry."""
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit."""
        self.close()