"""Exploit Database data collector."""
from __future__ import annotations

import csv
from datetime import datetime, timezone
from io import StringIO
from typing import Any, Dict, List

from .base import BaseDataSource, SourceResult


class ExploitDBDataSource(BaseDataSource):
    """Collect exploit data from Exploit Database.
    
    Exploit-DB provides a CSV file with all exploits:
    https://gitlab.com/exploit-database/exploitdb/-/raw/main/files_exploits.csv
    """

    source_name = "exploit_db"
    CSV_URL = "https://gitlab.com/exploit-database/exploitdb/-/raw/main/files_exploits.csv"

    def __init__(self, **kwargs):
        # Disable SSL verification for corporate proxies
        kwargs.setdefault('verify_ssl', False)
        super().__init__(**kwargs)

    def collect_all_exploits(self, *, cutoff: datetime) -> SourceResult:
        """Collect all exploits from Exploit-DB CSV."""
        response = self._request("GET", self.CSV_URL)
        
        # Parse CSV
        csv_content = response.text
        csv_reader = csv.DictReader(StringIO(csv_content))
        
        exploits = []
        for row in csv_reader:
            # Filter by date if available
            date_str = row.get("date", "")
            if date_str:
                try:
                    exploit_date = datetime.strptime(date_str, "%Y-%m-%d")
                    if exploit_date > cutoff:
                        continue
                except:
                    pass
            
            exploits.append(row)
        
        return SourceResult(
            source=self.source_name,
            package="all_exploits",
            collected_at=datetime.now(timezone.utc),
            payload={
                "exploits": exploits,
                "total_results": len(exploits),
            },
        )

    def collect_by_cve(self, cve_id: str, *, cutoff: datetime) -> SourceResult:
        """Collect exploits for a specific CVE by filtering the full dataset."""
        # Get all exploits first
        all_exploits_result = self.collect_all_exploits(cutoff=cutoff)
        all_exploits = all_exploits_result.payload["exploits"]
        
        # Filter by CVE
        cve_exploits = []
        for exploit in all_exploits:
            codes = exploit.get("codes", "")
            if cve_id in codes:
                cve_exploits.append(exploit)
        
        return SourceResult(
            source=self.source_name,
            package=cve_id,
            collected_at=datetime.now(timezone.utc),
            payload={
                "cve_id": cve_id,
                "exploits": cve_exploits,
                "total_results": len(cve_exploits),
            },
        )

    def collect(self, package: str, *, cutoff: datetime) -> SourceResult:
        """Collect exploit data. Package can be 'all' or CVE-ID."""
        if package == "all":
            return self.collect_all_exploits(cutoff=cutoff)
        elif package.startswith("CVE-"):
            return self.collect_by_cve(package, cutoff=cutoff)
        else:
            # Try as CVE ID
            return self.collect_by_cve(package, cutoff=cutoff)


__all__ = ["ExploitDBDataSource"]
