import asyncio
import json
import logging
import os
import tempfile
import time
from contextlib import suppress, asynccontextmanager, contextmanager
from typing import Any, Optional, List, Dict

import dask
import dask.dataframe as dd
import numpy as np
import pandas as pd
from dask.distributed import Client, LocalCluster, get_client, Future
from filelock import FileLock

from sibi_dst.utils import Logger

# --------------------------------------------------------------------------------------
# Logger
# --------------------------------------------------------------------------------------
logger = Logger.default_logger(logger_name=__name__)


# --------------------------------------------------------------------------------------
# Safe numeric helpers
# --------------------------------------------------------------------------------------
def _to_int_safe(x: Any, default: int = 0) -> int:
    """
    Convert scalar-like input to int safely.

    - Returns `default` for None or on conversion failure.
    - Handles Python ints/floats, numpy scalars/arrays, pandas Series/Index, list/tuple.
    """
    if x is None:
        return default

    # Fast path for plain ints and floats
    if isinstance(x, (int, np.integer)) and not isinstance(x, bool):
        return int(x)
    if isinstance(x, (float, np.floating)):
        try:
            return int(x)
        except Exception:
            return default

    # numpy scalar
    if isinstance(x, np.generic):
        try:
            return int(x.item())
        except Exception:
            return default

    # pandas Series/Index or general sequences
    if isinstance(x, (pd.Series, pd.Index, list, tuple, np.ndarray)):
        try:
            arr = np.asarray(x)
            if arr.size == 0:
                return default
            return _to_int_safe(arr.ravel()[0], default=default)
        except Exception:
            return default

    # objects with item()/iloc
    if hasattr(x, "item"):
        try:
            return _to_int_safe(x.item(), default=default)
        except Exception:
            return default
    if hasattr(x, "iloc"):
        try:
            return _to_int_safe(x.iloc[0], default=default)
        except Exception:
            return default

    try:
        return int(x)
    except Exception:
        return default


# --------------------------------------------------------------------------------------
# Core safe compute / persist / gather helpers
# --------------------------------------------------------------------------------------
def _safe_compute(obj: Any, dask_client: Optional[Client] = None) -> Any:
    """
    Compute a Dask object safely for both synchronous and asynchronous clients.
    Always returns a resolved value (never a Future).
    """
    if dask_client:
        if getattr(dask_client, "asynchronous", False):
            fut = asyncio.run_coroutine_threadsafe(
                dask_client.compute(obj), dask_client.loop
            )
            res = fut.result()
            return res.result() if isinstance(res, Future) else res
        res = dask_client.compute(obj)
        return res.result() if isinstance(res, Future) else res
    return obj.compute()


def _safe_persist(obj: Any, dask_client: Optional[Client] = None) -> Any:
    """
    Persist a Dask collection safely for both synchronous and asynchronous clients.
    Returns the resolved persisted collection (never a Future).
    """
    if dask_client:
        if getattr(dask_client, "asynchronous", False):
            fut = asyncio.run_coroutine_threadsafe(
                dask_client.persist(obj), dask_client.loop
            )
            res = fut.result()
            return res.result() if isinstance(res, Future) else res
        res = dask_client.persist(obj)
        return res.result() if isinstance(res, Future) else res
    return obj.persist()


def _ensure_graph_integrity(obj: Any) -> Any:
    """
    Attempt to rebuild a broken graph if prior cluster was lost.
    Heuristic: if a Dask collection has an empty .dask attribute, repartition to rewire.
    """
    try:
        if hasattr(obj, "dask") and not obj.dask:
            logger.debug("_ensure_graph_integrity: rebuilding graph from meta")
            return obj.repartition(npartitions=1)
        return obj
    except Exception:
        return obj


def _safe_gather(objs: List[Any], dask_client: Optional[Client] = None) -> List[Any]:
    """
    Gather results safely across local or distributed clients.
    Handles missing dependency errors gracefully by re-materializing.
    """
    if not objs:
        return []

    try:
        if dask_client:
            computed = [dask_client.compute(o) for o in objs]
            results = dask_client.gather(computed)
        else:
            results = dask.compute(*objs, scheduler="threads")
        return list(results)
    except ValueError as e:
        if "Missing dependency" in str(e):
            logger.warning(
                "_safe_gather: detected orphaned Dask graph. Rebuilding locally."
            )
            try:
                new_objs = [
                    o.reset_index(drop=True) if hasattr(o, "reset_index") else o
                    for o in objs
                ]
                results = [o.compute(scheduler="threads") for o in new_objs]
                return results
            except Exception as inner:
                logger.error(f"_safe_gather recovery failed: {inner}")
                raise
        raise
    except Exception as e:
        logger.warning(
            f"_safe_gather: fallback to local compute due to {type(e).__name__}: {e}"
        )
        try:
            return list(dask.compute(*objs, scheduler="threads"))
        except Exception:
            return []


def _safe_wait(
    obj: Any, dask_client: Optional[Client] = None, timeout: Optional[float] = None
) -> Any:
    """
    Wait for a Dask computation or persisted object to complete safely.
    Works in both distributed and local threaded contexts.
    """
    if obj is None:
        return None

    try:
        if dask_client:
            # Ensure at least one worker; wait() then returns immediately if done
            with suppress(Exception):
                dask_client.wait_for_workers(1, timeout=10)
            dask_client.wait(obj, timeout=timeout)
            return obj

        # Try implicit default client
        try:
            client = get_client()
            client.wait(obj, timeout=timeout)
            return obj
        except ValueError:
            logger.debug(
                "_safe_wait: no active distributed client; falling back to local compute."
            )
            if hasattr(obj, "compute"):
                obj.compute(scheduler="threads")
            return obj

    except Exception as e:
        logger.warning(f"_safe_wait: {type(e).__name__} - {e}")
        return obj


# --------------------------------------------------------------------------------------
# Dask emptiness helpers
# --------------------------------------------------------------------------------------
def dask_is_probably_empty(ddf: dd.DataFrame) -> bool:
    """Quick structural check before computing."""
    return getattr(ddf, "npartitions", 0) == 0 or len(ddf._meta.columns) == 0


def dask_is_empty_truthful(
    ddf: dd.DataFrame, dask_client: Optional[Client] = None
) -> bool:
    """Full compute of row count across all partitions."""
    total = _safe_compute(ddf.map_partitions(len).sum(), dask_client)
    return int(_to_int_safe(total)) == 0


def dask_is_empty(
    ddf: dd.DataFrame,
    *,
    sample: int = 4,
    dask_client: Optional[Client] = None,
) -> bool:
    """
    Heuristic emptiness check. Samples the first few partitions, then falls back to full compute.
    """
    if dask_is_probably_empty(ddf):
        return True

    k = min(max(sample, 1), ddf.npartitions)
    tasks = [ddf.get_partition(i).map_partitions(len) for i in range(k)]
    probes = _safe_gather(tasks, dask_client)

    if any(_to_int_safe(n) > 0 for n in probes):
        return False
    if k == ddf.npartitions and all(_to_int_safe(n) == 0 for n in probes):
        return True

    return dask_is_empty_truthful(ddf, dask_client=dask_client)


# --------------------------------------------------------------------------------------
# Unique value extractor (client-safe)
# --------------------------------------------------------------------------------------
class UniqueValuesExtractor:
    """
    Extract unique non-null values from Dask or pandas columns.
    Uses optional Dask client and thread offload to avoid blocking.
    """

    def __init__(self, dask_client: Optional[Client] = None):
        self.dask_client = dask_client

    def _compute_to_list_sync(self, series) -> List[Any]:
        """Compute unique list synchronously with or without client."""
        if hasattr(series, "compute"):
            if self.dask_client:
                result = self.dask_client.compute(series).result()
            else:
                result = series.compute()
        else:
            result = series

        if isinstance(result, (np.ndarray, pd.Series, list)):
            return pd.Series(result).dropna().unique().tolist()
        return [result]

    async def compute_to_list(self, series) -> List[Any]:
        """Offload compute to a background thread."""
        return await asyncio.to_thread(self._compute_to_list_sync, series)

    async def extract_unique_values(self, df, *columns: str) -> Dict[str, List[Any]]:
        """Concurrently extract unique values for requested columns."""

        async def one(col: str):
            ser = df[col].dropna().unique()
            return col, await self.compute_to_list(ser)

        results = await asyncio.gather(*(one(c) for c in columns))
        return dict(results)


# --------------------------------------------------------------------------------------
# Dask client lifecycle mixin
# --------------------------------------------------------------------------------------
class DaskClientMixin:
    """
    Resilient Dask client lifecycle with:
    - Shared JSON registry + file lock
    - Reference counting across processes
    - Watchdog health checks with auto-reattach
    - Optional external scheduler
    - Graceful shutdown with worker retirement
    """

    REGISTRY_PATH = os.path.join(tempfile.gettempdir(), "shared_dask_cluster.json")
    REGISTRY_LOCK = FileLock(REGISTRY_PATH + ".lock")
    WATCHDOG_INTERVAL = 120  # seconds

    def __init__(self, **kwargs):
        self.dask_client: Optional[Client] = None
        self.own_dask_client: bool = False
        self.logger = kwargs.get("logger") or Logger.default_logger(
            logger_name=__name__
        )
        self._watchdog_task: Optional[asyncio.Task] = None
        self._watchdog_stop = asyncio.Event()

    # ---------- registry ----------
    @classmethod
    def _read_registry(cls) -> Optional[dict]:
        if not os.path.exists(cls.REGISTRY_PATH):
            return None
        try:
            with open(cls.REGISTRY_PATH, "r") as f:
                data = json.load(f)
            if not isinstance(data, dict) or "address" not in data:
                return None
            return data
        except (json.JSONDecodeError, OSError):
            return None

    @classmethod
    def _write_registry(cls, data: dict) -> None:
        tmp = cls.REGISTRY_PATH + ".tmp"
        with open(tmp, "w") as f:
            json.dump(data, f)
        os.replace(tmp, cls.REGISTRY_PATH)

    @classmethod
    def _remove_registry(cls) -> None:
        with suppress(FileNotFoundError):
            os.remove(cls.REGISTRY_PATH)

    @classmethod
    def _cleanup_stale_registry(cls, logger_obj=None) -> None:
        reg = cls._read_registry()
        if not reg:
            return
        try:
            c = Client(address=reg["address"], timeout=5)
            c.close()
        except Exception:
            if logger_obj:
                logger_obj.warning(
                    f"Stale Dask registry at {reg.get('address')}. Removing."
                )
            cls._remove_registry()

    # ---------- helpers ----------
    @staticmethod
    def _has_inflight(client: Client) -> bool:
        try:
            info = client.scheduler_info()
            n_tasks = info.get("tasks", 0) or len(info.get("task_counts", {}))
            processing = sum(len(v) for v in info.get("processing", {}).values())
            return bool(n_tasks or processing)
        except Exception:
            return False

    @staticmethod
    def _retire_all_workers(client: Client, timeout: float = 10.0) -> None:
        with suppress(Exception):
            client.retire_workers(workers=None, close_workers=True, remove=True)

    # ---------- init ----------
    def _init_dask_client(
        self,
        dask_client: Optional[Client] = None,
        *,
        logger: Optional[Logger] = None,
        scheduler_address: Optional[str] = None,
        use_remote_cluster: bool = False,
        n_workers: int = 4,
        threads_per_worker: int = 2,
        processes: bool = False,
        asynchronous: bool = False,
        memory_limit: str = "auto",
        local_directory: Optional[str] = None,
        silence_logs: int = logger.WARNING,
        resources: Optional[dict] = None,
        timeout: int = 30,
        watchdog: bool = True,
        worker_memory_env: Optional[Dict[str, str]] = None,
    ) -> None:
        """
        Initialize or attach a Dask client.
        """
        self.logger = logger or self.logger
        self.dask_client = dask_client
        self.own_dask_client = False

        # Reduce noisy logs
        logging.getLogger("distributed.scheduler").setLevel(logger.WARNING)
        logging.getLogger("distributed.worker").setLevel(logger.WARNING)
        logging.getLogger("distributed.comm").setLevel(logger.ERROR)
        logging.getLogger("distributed.batched").setLevel(logger.ERROR)
        logging.getLogger("distributed.shuffle._scheduler_plugin").setLevel(
            logger.ERROR
        )

        # 1) reuse existing client in-context
        if self.dask_client is None:
            with suppress(ValueError, RuntimeError):
                self.dask_client = get_client()

        # 2) external scheduler
        if self.dask_client is None and use_remote_cluster and scheduler_address:
            try:
                self.dask_client = Client(address=scheduler_address, timeout=timeout)
                self.own_dask_client = True
                self.logger.info(
                    f"Connected to external scheduler {scheduler_address}. "
                    f"Dashboard: {self.dask_client.dashboard_link}"
                )
                if watchdog:
                    self._start_watchdog()
                return
            except Exception as e:
                self.logger.warning(
                    f"Remote connect failed: {e}. Falling back to local."
                )

        # Default worker memory env
        if worker_memory_env is None:
            worker_memory_env = {
                "DASK_WORKER_MEMORY_TARGET": "0.6",
                "DASK_WORKER_MEMORY_SPILL": "0.7",
                "DASK_WORKER_MEMORY_PAUSE": "0.8",
            }

        # 3) shared local cluster (registry)
        with self.REGISTRY_LOCK:
            self._cleanup_stale_registry(self.logger)
            reg = self._read_registry()

            # Refuse to reuse a cluster marked as closing
            if reg and reg.get("closing"):
                self._remove_registry()
                reg = None

            if reg:
                try:
                    self.dask_client = Client(address=reg["address"], timeout=timeout)
                    reg["refcount"] = int(reg.get("refcount", 0)) + 1
                    self._write_registry(reg)
                    self.logger.info(
                        f"Reusing LocalCluster at {reg['address']} (refcount={reg['refcount']})."
                    )
                    if watchdog:
                        self._start_watchdog()
                    return
                except Exception:
                    self.logger.warning("Registry address unreachable. Recreating.")
                    self._remove_registry()

            os.environ.setdefault("DISTRIBUTED_WORKER_MEMORY_TARGET", "0.6")
            os.environ.setdefault("DISTRIBUTED_WORKER_MEMORY_SPILL", "0.7")
            os.environ.setdefault("DISTRIBUTED_WORKER_MEMORY_PAUSE", "0.8")

            cluster = LocalCluster(
                n_workers=n_workers,
                threads_per_worker=threads_per_worker,
                processes=processes,
                asynchronous=asynchronous,
                memory_limit=memory_limit,
                local_directory=local_directory,
                silence_logs=silence_logs,
                resources=resources,
                timeout=timeout,
            )
            self.dask_client = Client(cluster)
            self.own_dask_client = True

            reg = {"address": cluster.scheduler_address, "refcount": 1}
            self._write_registry(reg)
            self.logger.info(
                f"Started LocalCluster {reg['address']} "
                f"({n_workers} workers x {threads_per_worker} threads). "
                f"Dashboard: {self.dask_client.dashboard_link}"
            )

        if watchdog:
            self._start_watchdog()

    # ---------- watchdog ----------
    def _start_watchdog(self) -> None:
        async def watchdog_loop():
            while not self._watchdog_stop.is_set():
                await asyncio.sleep(self.WATCHDOG_INTERVAL)
                try:
                    if not self.dask_client:
                        raise RuntimeError("No client bound.")
                    # quick liveness check
                    self.dask_client.scheduler_info()
                    # Force GC on workers to mitigate unmanaged memory
                    with suppress(Exception):
                        self.dask_client.run(lambda: __import__("gc").collect())
                except Exception:
                    self.logger.warning("Dask watchdog: client unhealthy. Reattaching.")
                    try:
                        with self.REGISTRY_LOCK:
                            self._cleanup_stale_registry(self.logger)
                            reg = self._read_registry()
                            if reg and not reg.get("closing"):
                                self.dask_client = Client(
                                    address=reg["address"], timeout=10
                                )
                                self.logger.info("Reattached to existing LocalCluster.")
                            else:
                                # recreate minimal in-proc cluster
                                cluster = LocalCluster(
                                    n_workers=2,
                                    threads_per_worker=1,
                                    processes=False,
                                    silence_logs=logging.WARNING,
                                )
                                self.dask_client = Client(cluster)
                                self.own_dask_client = True
                                self._write_registry(
                                    {
                                        "address": cluster.scheduler_address,
                                        "refcount": 1,
                                    }
                                )
                                self.logger.info("Recreated LocalCluster.")
                    except Exception as e:
                        self.logger.error(f"Watchdog reattach failed: {e}")

        try:
            loop = asyncio.get_event_loop()
            if loop.is_running():
                self._watchdog_task = loop.create_task(watchdog_loop())
                self.logger.debug("Started Dask watchdog.")
        except RuntimeError:
            self.logger.debug("Watchdog not started. No running loop.")

    async def _stop_watchdog(self) -> None:
        self._watchdog_stop.set()
        if self._watchdog_task:
            with suppress(Exception):
                await asyncio.wait([self._watchdog_task], timeout=5)
            self._watchdog_task = None

    # ---------- close ----------
    def _close_dask_client(self) -> None:
        if not self.dask_client:
            return

        with self.REGISTRY_LOCK:
            reg = self._read_registry()
            if reg and "refcount" in reg:
                reg["refcount"] = max(0, int(reg["refcount"]) - 1)
                if reg["refcount"] == 0:
                    self.logger.info("Refcount reached 0. Closing LocalCluster.")
                    # mark closing to block new attachers during drain
                    reg["closing"] = True
                    self._write_registry(reg)

                    client = self.dask_client
                    cluster = getattr(client, "cluster", None)

                    # Graceful drain with bounded wait
                    try:
                        deadline = time.time() + 15
                        while self._has_inflight(client) and time.time() < deadline:
                            time.sleep(0.5)
                        self._retire_all_workers(client)
                    except Exception:
                        pass

                    # Close client and cluster
                    with suppress(Exception):
                        client.close()
                    with suppress(Exception):
                        if cluster:
                            cluster.close()

                    # Remove registry
                    self._remove_registry()
                else:
                    self._write_registry(reg)
                    self.logger.debug(f"Decremented refcount to {reg['refcount']}.")
            else:
                # No registry bookkeeping; only close clusters we created
                if self.own_dask_client:
                    client = self.dask_client
                    cluster = getattr(client, "cluster", None)
                    with suppress(Exception):
                        self._retire_all_workers(client, timeout=8.0)
                    with suppress(Exception):
                        client.close()
                    with suppress(Exception):
                        if cluster:
                            cluster.close()
                self.logger.debug("Closed client without registry tracking.")

        # stop watchdog
        if self._watchdog_task:
            asyncio.create_task(self._stop_watchdog())


# --------------------------------------------------------------------------------------
# Persistent singleton
# --------------------------------------------------------------------------------------
_persistent_mixin: Optional[DaskClientMixin] = None


def get_persistent_client(
    *,
    logger: Optional[Logger] = None,
    use_remote_cluster: bool = False,
    scheduler_address: Optional[str] = None,
) -> Client:
    global _persistent_mixin
    if _persistent_mixin is None or _persistent_mixin.dask_client is None:
        _persistent_mixin = DaskClientMixin(logger=logger)
        _persistent_mixin._init_dask_client(
            use_remote_cluster=use_remote_cluster,
            scheduler_address=scheduler_address,
            n_workers=4,
            threads_per_worker=2,
            processes=False,
            watchdog=True,
        )
    return _persistent_mixin.dask_client  # type: ignore[return-value]


# --------------------------------------------------------------------------------------
# Shared session contexts
# --------------------------------------------------------------------------------------
def shared_dask_session(*, async_mode: bool = True, **kwargs):
    """
    Context manager for shared Dask client.
    Keeps cluster alive across contexts via registry refcounting.
    """
    mixin = DaskClientMixin()
    mixin._init_dask_client(**kwargs)

    if async_mode:

        @asynccontextmanager
        async def _async_manager():
            try:
                client = mixin.dask_client
                yield client
            finally:
                mixin._close_dask_client()

        return _async_manager()
    else:

        @contextmanager
        def _sync_manager():
            try:
                client = mixin.dask_client
                yield client
            finally:
                mixin._close_dask_client()

        return _sync_manager()

