"""
Database Manager
Handles vulnerability database initialization, updates, and management
"""

import sqlite3
import os
import json
import logging
from pathlib import Path
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta

from .cve_sources import CVEDataManager


class DatabaseManager:
    """
    Manages vulnerability database creation, updates, and maintenance
    Provides safe update mechanisms without downloading exploit code
    """
    
    def __init__(self, db_path: Optional[str] = None):
        """Initialize database manager"""
        self.logger = logging.getLogger(__name__)
        
        if db_path:
            self.db_path = Path(db_path)
        else:
            # Default database location
            self.db_path = Path.home() / ".pysploit" / "vulnerability_index.db"
        
        # Ensure directory exists
        self.db_path.parent.mkdir(parents=True, exist_ok=True)
        
    def initialize_database(self, force: bool = False) -> bool:
        """
        Initialize empty vulnerability database with proper schema
        
        Args:
            force: Whether to overwrite existing database
            
        Returns:
            bool: True if successful
        """
        if self.db_path.exists() and not force:
            self.logger.info(f"Database already exists at {self.db_path}")
            return True
            
        try:
            with sqlite3.connect(self.db_path) as conn:
                cursor = conn.cursor()
                
                # Create main vulnerabilities table
                cursor.execute('''
                    CREATE TABLE IF NOT EXISTS vulnerabilities (
                        id INTEGER PRIMARY KEY AUTOINCREMENT,
                        cve_id TEXT,
                        description TEXT,
                        published_date TEXT,
                        modified_date TEXT,
                        base_score REAL,
                        cvss_severity TEXT,
                        cvss_vector TEXT,
                        vendors TEXT,
                        products TEXT,
                        versions TEXT,
                        exploit_type TEXT,
                        platform TEXT,
                        source TEXT,
                        references TEXT,
                        created_at TEXT DEFAULT CURRENT_TIMESTAMP,
                        updated_at TEXT DEFAULT CURRENT_TIMESTAMP
                    )
                ''')
                
                # Create indexes for better search performance
                cursor.execute('CREATE INDEX IF NOT EXISTS idx_cve_id ON vulnerabilities(cve_id)')
                cursor.execute('CREATE INDEX IF NOT EXISTS idx_severity ON vulnerabilities(cvss_severity)')
                cursor.execute('CREATE INDEX IF NOT EXISTS idx_score ON vulnerabilities(base_score)')
                cursor.execute('CREATE INDEX IF NOT EXISTS idx_source ON vulnerabilities(source)')
                cursor.execute('CREATE INDEX IF NOT EXISTS idx_platform ON vulnerabilities(platform)')
                
                # Create metadata table for tracking updates
                cursor.execute('''
                    CREATE TABLE IF NOT EXISTS database_metadata (
                        key TEXT PRIMARY KEY,
                        value TEXT,
                        updated_at TEXT DEFAULT CURRENT_TIMESTAMP
                    )
                ''')
                
                # Initialize metadata
                cursor.execute('''
                    INSERT OR REPLACE INTO database_metadata (key, value) 
                    VALUES ('created_at', ?)
                ''', (datetime.now().isoformat(),))
                
                cursor.execute('''
                    INSERT OR REPLACE INTO database_metadata (key, value) 
                    VALUES ('version', '1.0')
                ''')
                
                conn.commit()
                
            self.logger.info(f"Database initialized successfully at {self.db_path}")
            return True
            
        except Exception as e:
            self.logger.error(f"Failed to initialize database: {e}")
            return False
    
    def update_from_cve_sources(self, 
                               days_back: int = 30,
                               severity_filter: Optional[str] = None,
                               max_records: int = 1000) -> Dict[str, Any]:
        """
        Safely update database with CVE data from official sources
        Does NOT download exploit code - only CVE metadata
        
        Args:
            days_back: How many days back to fetch CVEs
            severity_filter: Only fetch CVEs of this severity (HIGH, CRITICAL)
            max_records: Maximum number of records to fetch
            
        Returns:
            dict: Update statistics
        """
        if not self.db_path.exists():
            self.logger.info("Database doesn't exist, initializing...")
            if not self.initialize_database():
                return {'error': 'Failed to initialize database'}
        
        stats = {
            'fetched': 0,
            'new_records': 0,
            'updated_records': 0,
            'skipped': 0,
            'errors': 0
        }
        
        try:
            # Use CVE data manager for safe API access
            data_manager = CVEDataManager()
            
            self.logger.info(f"Fetching CVEs from last {days_back} days...")
            
            # Fetch recent CVEs
            if severity_filter:
                cves = data_manager.get_recent_cves(days=days_back, severity_filter=severity_filter)
            else:
                cves = data_manager.get_recent_cves(days=days_back)
            
            # Limit results to prevent overwhelming the database
            cves = cves[:max_records]
            stats['fetched'] = len(cves)
            
            self.logger.info(f"Processing {len(cves)} CVEs...")
            
            with sqlite3.connect(self.db_path) as conn:
                cursor = conn.cursor()
                
                for cve in cves:
                    try:
                        # Check if CVE already exists
                        cursor.execute('SELECT id FROM vulnerabilities WHERE cve_id = ?', (cve['id'],))
                        existing = cursor.fetchone()
                        
                        # Prepare CVE data for insertion
                        cve_data = self._prepare_cve_data(cve)
                        
                        if existing:
                            # Update existing record
                            cursor.execute('''
                                UPDATE vulnerabilities 
                                SET description = ?, published_date = ?, modified_date = ?,
                                    base_score = ?, cvss_severity = ?, vendors = ?, 
                                    products = ?, references = ?, updated_at = ?
                                WHERE cve_id = ?
                            ''', (*cve_data[1:8], datetime.now().isoformat(), cve['id']))
                            stats['updated_records'] += 1
                        else:
                            # Insert new record
                            cursor.execute('''
                                INSERT INTO vulnerabilities 
                                (cve_id, description, published_date, modified_date,
                                 base_score, cvss_severity, vendors, products, 
                                 source, references)
                                VALUES (?, ?, ?, ?, ?, ?, ?, ?, 'NVD', ?)
                            ''', cve_data)
                            stats['new_records'] += 1
                            
                    except Exception as e:
                        self.logger.error(f"Error processing CVE {cve.get('id', 'unknown')}: {e}")
                        stats['errors'] += 1
                        continue
                
                # Update metadata
                cursor.execute('''
                    INSERT OR REPLACE INTO database_metadata (key, value, updated_at) 
                    VALUES ('last_update', ?, ?)
                ''', (datetime.now().isoformat(), datetime.now().isoformat()))
                
                conn.commit()
            
            self.logger.info(f"Database update completed: {stats}")
            return stats
            
        except Exception as e:
            self.logger.error(f"Failed to update database: {e}")
            return {'error': str(e)}
    
    def _prepare_cve_data(self, cve: Dict[str, Any]) -> tuple:
        """Prepare CVE data for database insertion"""
        return (
            cve.get('id', ''),
            cve.get('description', '')[:2000],  # Limit description length
            cve.get('published', ''),
            cve.get('modified', ''),
            cve.get('base_score'),
            cve.get('severity', ''),
            json.dumps(cve.get('vendors', [])) if cve.get('vendors') else '',
            json.dumps(cve.get('products', [])) if cve.get('products') else '',
            json.dumps(cve.get('references', [])) if cve.get('references') else ''
        )
    
    def get_database_stats(self) -> Dict[str, Any]:
        """Get database statistics and metadata"""
        if not self.db_path.exists():
            return {'error': 'Database not found'}
        
        try:
            with sqlite3.connect(self.db_path) as conn:
                cursor = conn.cursor()
                
                # Get record counts
                cursor.execute('SELECT COUNT(*) FROM vulnerabilities')
                total_cves = cursor.fetchone()[0]
                
                cursor.execute('SELECT COUNT(*) FROM vulnerabilities WHERE cvss_severity = "CRITICAL"')
                critical_cves = cursor.fetchone()[0]
                
                cursor.execute('SELECT COUNT(*) FROM vulnerabilities WHERE cvss_severity = "HIGH"')
                high_cves = cursor.fetchone()[0]
                
                # Get metadata
                cursor.execute('SELECT key, value, updated_at FROM database_metadata')
                metadata = {row[0]: {'value': row[1], 'updated_at': row[2]} 
                           for row in cursor.fetchall()}
                
                # Get last update info
                cursor.execute('''
                    SELECT MAX(updated_at) FROM vulnerabilities 
                    WHERE updated_at IS NOT NULL
                ''')
                last_cve_update = cursor.fetchone()[0]
                
                return {
                    'database_path': str(self.db_path),
                    'total_cves': total_cves,
                    'critical_cves': critical_cves,
                    'high_cves': high_cves,
                    'last_cve_update': last_cve_update,
                    'metadata': metadata
                }
                
        except Exception as e:
            return {'error': str(e)}
    
    def cleanup_old_records(self, days_old: int = 365) -> int:
        """Remove old CVE records to keep database size manageable"""
        if not self.db_path.exists():
            return 0
            
        cutoff_date = (datetime.now() - timedelta(days=days_old)).isoformat()
        
        try:
            with sqlite3.connect(self.db_path) as conn:
                cursor = conn.cursor()
                
                cursor.execute('''
                    DELETE FROM vulnerabilities 
                    WHERE published_date < ? AND cvss_severity NOT IN ('HIGH', 'CRITICAL')
                ''', (cutoff_date,))
                
                deleted_count = cursor.rowcount
                conn.commit()
                
                self.logger.info(f"Cleaned up {deleted_count} old CVE records")
                return deleted_count
                
        except Exception as e:
            self.logger.error(f"Failed to cleanup old records: {e}")
            return 0
    
    def backup_database(self, backup_path: Optional[str] = None) -> bool:
        """Create a backup of the database"""
        if not self.db_path.exists():
            return False
            
        if not backup_path:
            backup_path = self.db_path.parent / f"vulnerability_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}.db"
        else:
            backup_path = Path(backup_path)
            
        try:
            import shutil
            shutil.copy2(self.db_path, backup_path)
            self.logger.info(f"Database backed up to {backup_path}")
            return True
            
        except Exception as e:
            self.logger.error(f"Failed to backup database: {e}")
            return False
    
    def verify_database_integrity(self) -> Dict[str, Any]:
        """Verify database integrity and return diagnostic info"""
        if not self.db_path.exists():
            return {'error': 'Database not found'}
            
        try:
            with sqlite3.connect(self.db_path) as conn:
                cursor = conn.cursor()
                
                # Check database integrity
                cursor.execute('PRAGMA integrity_check')
                integrity_result = cursor.fetchone()[0]
                
                # Check for required tables
                cursor.execute('''
                    SELECT name FROM sqlite_master 
                    WHERE type='table' AND name IN ('vulnerabilities', 'database_metadata')
                ''')
                tables = [row[0] for row in cursor.fetchall()]
                
                # Check for required indexes
                cursor.execute('''
                    SELECT name FROM sqlite_master 
                    WHERE type='index' AND name LIKE 'idx_%'
                ''')
                indexes = [row[0] for row in cursor.fetchall()]
                
                return {
                    'integrity_check': integrity_result,
                    'required_tables': ['vulnerabilities', 'database_metadata'],
                    'existing_tables': tables,
                    'indexes': indexes,
                    'database_size_mb': os.path.getsize(self.db_path) / (1024 * 1024),
                    'is_healthy': integrity_result == 'ok' and len(tables) == 2
                }
                
        except Exception as e:
            return {'error': str(e)}


# Convenience functions for easy database management
def initialize_pysploit_database(db_path: Optional[str] = None, force: bool = False) -> bool:
    """Initialize PySploit vulnerability database"""
    manager = DatabaseManager(db_path)
    return manager.initialize_database(force=force)


def update_vulnerability_database(db_path: Optional[str] = None, 
                                days_back: int = 7,
                                severity_filter: str = "HIGH") -> Dict[str, Any]:
    """Update vulnerability database with recent CVEs"""
    manager = DatabaseManager(db_path)
    return manager.update_from_cve_sources(
        days_back=days_back,
        severity_filter=severity_filter,
        max_records=500  # Conservative limit
    )


def get_database_status(db_path: Optional[str] = None) -> Dict[str, Any]:
    """Get comprehensive database status"""
    manager = DatabaseManager(db_path)
    return manager.get_database_stats()