"""MongoDB-Tasks für Prefect Flows."""

from __future__ import annotations

import json
import time
from typing import Any, Dict, Iterator, List, Optional, Tuple

from prefect import task, get_run_logger
from pymongo import MongoClient

from ..utils.validation import process_document


@task(
    name="Check MongoDB Indexes",
    tags=["mongodb", "indexes"],
)
def check_indexes(
    uri: str,
    database: str,
    collection: str,
    query_json: str,
    explain: bool = False,
) -> Dict[str, Any]:
    """Ermittelt optional via ``explain()`` die Index-Nutzung für eine Query.

    Args:
        uri: MongoDB-Verbindungs-URI
        database: Datenbank-Name
        collection: Collection-Name
        query_json: MongoDB-Query als JSON-String
        explain: Führt ``explain('executionStats')`` aus und prüft Indexnutzung

    Returns:
        Dict mit Informationen zum Explain-Lauf bzw. Skip-Grund
    """
    logger = get_run_logger()

    # Query parsen
    try:
        query = json.loads(query_json)
    except json.JSONDecodeError as e:
        raise ValueError(f"Ungültiges JSON in MongoDB-Query: {e}") from e

    if not explain:
        logger.info("Explain deaktiviert, Indexanalyse übersprungen")
        return {"status": "skipped", "reason": "explain_disabled"}

    client = None
    try:
        client = MongoClient(uri)
        db = client[database]
        coll = db[collection]

        def _uses_index(plan: Any) -> bool:
            if isinstance(plan, dict):
                stage = plan.get("stage", "")
                if stage == "IXSCAN" or stage == "DISTINCT_SCAN":
                    return True

                for key in ("inputStage", "inputStages", "shards", "executionStages", "winningPlan"):
                    child = plan.get(key)
                    if isinstance(child, list):
                        if any(_uses_index(item) for item in child):
                            return True
                    elif isinstance(child, dict):
                        if _uses_index(child):
                            return True
                    elif isinstance(child, (set, tuple)):
                        if any(_uses_index(item) for item in child):
                            return True
                if isinstance(plan.get("shards"), dict):
                    return any(_uses_index(shard_plan) for shard_plan in plan["shards"].values())
            elif isinstance(plan, list):
                return any(_uses_index(item) for item in plan)
            return False

        def _collect_index_names(plan: Any, accumulator: set[str] | None = None) -> set[str]:
            accumulator = accumulator or set()
            if isinstance(plan, dict):
                index_name = plan.get("indexName")
                if isinstance(index_name, str):
                    accumulator.add(index_name)
                for key in ("inputStage", "inputStages", "shards", "executionStages", "winningPlan"):
                    child = plan.get(key)
                    if isinstance(child, list):
                        for item in child:
                            _collect_index_names(item, accumulator)
                    elif isinstance(child, dict):
                        _collect_index_names(child, accumulator)
                if isinstance(plan.get("shards"), dict):
                    for shard_plan in plan["shards"].values():
                        _collect_index_names(shard_plan, accumulator)
            elif isinstance(plan, list):
                for item in plan:
                    _collect_index_names(item, accumulator)
            return accumulator

        try:
            explain_doc = coll.find(query).explain(verbosity="executionStats")
        except TypeError:
            logger.warning("Explain unterstützt kein verbosity-Argument, verwende Standard-Explain")
            explain_doc = coll.find(query).explain()
        winning_plan = explain_doc.get("queryPlanner", {}).get("winningPlan", {})
        execution_stats = explain_doc.get("executionStats", {})

        index_used = _uses_index(winning_plan) or _uses_index(execution_stats.get("executionStages"))
        index_names = sorted(_collect_index_names(winning_plan))

        if index_used:
            logger.info(
                "Explain zeigt Index-Nutzung an",
            )
            if index_names:
                logger.info(f"Verwendete Indexe laut Explain: {index_names}")
        else:
            logger.warning("Explain zeigt keine Index-Nutzung (vermutlich COLLSCAN)")

        return {
            "status": "explain_performed",
            "index_used": index_used,
            "index_names": index_names,
            "winning_stage": winning_plan.get("stage"),
            "plan_summary": winning_plan.get("planSummary"),
            "total_docs_examined": execution_stats.get("totalDocsExamined"),
            "total_keys_examined": execution_stats.get("totalKeysExamined"),
        }

    finally:
        if client:
            client.close()


def fetch_documents_cursor(
    uri: str,
    database: str,
    collection: str,
    query_json: str,
    batch_size: int = 100,
    required_fields: Optional[List[str]] = None,
    exclude_fields: Optional[List[str]] = None,
) -> Iterator[List[Dict[str, Any]]]:
    """Führt MongoDB-Aggregation aus und liefert Batches von Dokumenten.

    Verwendet Aggregation Pipeline für effiziente Filterung auf DB-Ebene:
    - $match: Filter-Query + Pflichtfeld-Check
    - $project: Exclude-Felder entfernen

    Args:
        uri: MongoDB-Verbindungs-URI
        database: Datenbank-Name
        collection: Collection-Name
        query_json: MongoDB-Query als JSON-String
        batch_size: Anzahl Dokumente pro Batch
        required_fields: Liste von Pflichtfeldern (für $match)
        exclude_fields: Liste von auszublendenden Feldern (für $project)

    Yields:
        Batches von Dokumenten als Liste

    Raises:
        ValueError: Wenn Query ungültiges JSON
        Exception: Bei MongoDB-Verbindungsfehlern
    """
    logger = get_run_logger()

    # Query parsen
    try:
        query = json.loads(query_json)
    except json.JSONDecodeError as e:
        raise ValueError(f"Ungültiges JSON in MongoDB-Query: {e}") from e

    logger.info(f"MongoDB-Aggregation: db={database}, collection={collection}, query={query}")

    # Aggregation Pipeline bauen
    pipeline = []

    # Stage 1: $match mit Filter-Query
    if query:
        pipeline.append({"$match": query})

    # Stage 2: $match für Pflichtfelder (müssen existieren)
    if required_fields:
        required_match = {field: {"$exists": True, "$ne": None} for field in required_fields}
        pipeline.append({"$match": required_match})
        logger.info(f"Pflichtfeld-Filter: {required_fields}")

    # WICHTIG: Exclude-Felder NICHT in Aggregation!
    # Sie werden später in process_document() nach dem ID-Mapping entfernt.
    # Grund: ID-Mapping braucht möglicherweise Felder wie '_id', die excludet werden sollen.
    if exclude_fields:
        logger.info(f"Exclude-Felder (werden nach ID-Mapping entfernt): {exclude_fields}")

    logger.debug(f"Aggregation Pipeline: {pipeline}")

    # MongoDB-Verbindung
    client = None
    try:
        client = MongoClient(uri)
        db = client[database]
        coll = db[collection]

        # Performance-Messung starten
        start_time = time.time()

        # Aggregation ausführen
        cursor = coll.aggregate(pipeline, batchSize=batch_size)

        # Batches liefern
        batch = []
        total_docs = 0

        for doc in cursor:
            batch.append(doc)
            if len(batch) >= batch_size:
                total_docs += len(batch)
                logger.debug(f"Batch mit {len(batch)} Dokumenten geladen (total: {total_docs})")
                yield batch
                batch = []

        # Letzter Batch (falls vorhanden)
        if batch:
            total_docs += len(batch)
            logger.debug(f"Letzter Batch mit {len(batch)} Dokumenten geladen (total: {total_docs})")
            yield batch

        # Performance-Messung
        elapsed = time.time() - start_time
        logger.info(
            f"MongoDB-Aggregation abgeschlossen: {total_docs} Dokumente in {elapsed:.2f}s "
            f"({total_docs / elapsed:.1f} docs/s)"
        )

    finally:
        if client:
            client.close()


@task(
    name="Validate and Filter Document",
    tags=["mongodb", "validation"],
)
def validate_and_filter_document(
    doc: Dict[str, Any],
    required_fields: List[str],
    exclude_fields: List[str],
    mongodb_id_field: str,
    solr_id_field: str,
) -> Tuple[Optional[Dict[str, Any]], Optional[str], Optional[List[str]]]:
    """Validiert und filtert ein einzelnes Dokument.

    Args:
        doc: Dokument als Dict
        required_fields: Liste von Pflichtfeldern
        exclude_fields: Liste von auszublendenden Feldern
        mongodb_id_field: Feldname der ID in MongoDB
        solr_id_field: Feldname der ID in Solr

    Returns:
        Tuple (processed_doc, skip_reason, missing_fields)
    """
    return process_document(
        doc=doc,
        required_fields=required_fields,
        exclude_fields=exclude_fields,
        mongodb_id_field=mongodb_id_field,
        solr_id_field=solr_id_field,
    )


def extract_ids_from_documents(
    documents: List[Dict[str, Any]],
    mongodb_id_field: str,
    solr_id_field: str,
) -> List[str]:
    """Extrahiert IDs aus Dokumenten (für Delete-Only Modus).

    Args:
        documents: Liste von Dokumenten
        mongodb_id_field: Feldname der ID in MongoDB
        solr_id_field: Feldname der ID in Solr (wird für Solr-Delete verwendet)

    Returns:
        Liste von IDs (als Strings)

    Raises:
        ValueError: Wenn ID-Feld in einem Dokument fehlt
    """
    logger = get_run_logger()
    ids = []

    for doc in documents:
        if mongodb_id_field not in doc:
            logger.warning(f"Dokument ohne ID-Feld '{mongodb_id_field}' übersprungen: {doc}")
            continue

        # ID als String konvertieren
        doc_id = str(doc[mongodb_id_field])
        ids.append(doc_id)

    logger.info(f"{len(ids)} IDs aus {len(documents)} Dokumenten extrahiert")
    return ids


def process_document_batch(
    documents: List[Dict[str, Any]],
    required_fields: List[str],
    exclude_fields: List[str],
    mongodb_id_field: str,
    solr_id_field: str,
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
    """Verarbeitet einen Batch von Dokumenten.

    Args:
        documents: Liste von Dokumenten
        required_fields: Liste von Pflichtfeldern
        exclude_fields: Liste von auszublendenden Feldern
        mongodb_id_field: Feldname der ID in MongoDB
        solr_id_field: Feldname der ID in Solr

    Returns:
        Tuple (valid_docs, skipped_docs)
        - valid_docs: Liste verarbeiteter Dokumente
        - skipped_docs: Liste übersprungener Dokumente mit Metadaten
    """
    logger = get_run_logger()
    valid_docs = []
    skipped_docs = []

    for doc in documents:
        processed, skip_reason, missing_fields = process_document(
            doc=doc,
            required_fields=required_fields,
            exclude_fields=exclude_fields,
            mongodb_id_field=mongodb_id_field,
            solr_id_field=solr_id_field,
        )

        if processed:
            valid_docs.append(processed)
        else:
            # Metadaten für übersprungene Dokumente
            mongodb_id = doc.get(mongodb_id_field, "unknown")
            skipped_info = {
                "mongodb_id": str(mongodb_id),
                "reason": skip_reason,
                "missing_fields": missing_fields,
            }
            skipped_docs.append(skipped_info)
            logger.warning(
                f"Dokument übersprungen: ID={mongodb_id}, Grund={skip_reason}, "
                f"Fehlende Felder={missing_fields}"
            )

    logger.info(f"Batch verarbeitet: {len(valid_docs)} gültig, {len(skipped_docs)} übersprungen")
    return valid_docs, skipped_docs
