"""Utilities for tracking glyph emission history and related metrics."""

from __future__ import annotations

from typing import Any
from collections import deque, Counter
from itertools import islice
from collections.abc import Iterable, Mapping
from functools import lru_cache

from .constants import get_param
from .collections_utils import ensure_collection
from .logging_utils import get_logger

logger = get_logger(__name__)

__all__ = (
    "HistoryDict",
    "push_glyph",
    "recent_glyph",
    "ensure_history",
    "current_step_idx",
    "append_metric",
    "last_glyph",
    "count_glyphs",
)


@lru_cache(maxsize=1)
def _resolve_validate_window():
    from .validators import validate_window

    return validate_window


def _validate_window(window: int, *, positive: bool = False) -> int:
    return _resolve_validate_window()(window, positive=positive)


def _ensure_history(
    nd: dict[str, Any], window: int, *, create_zero: bool = False
) -> tuple[int, deque | None]:
    """Validate ``window`` and ensure ``nd['glyph_history']`` deque."""

    v_window = _validate_window(window)
    if v_window == 0 and not create_zero:
        return v_window, None
    hist = nd.setdefault("glyph_history", deque(maxlen=v_window))
    if not isinstance(hist, deque) or hist.maxlen != v_window:
        # Rebuild deque from any iterable, ignoring raw strings/bytes and scalars
        if isinstance(hist, (str, bytes, bytearray)):
            items: Iterable[Any] = ()
        else:
            try:
                items = ensure_collection(hist, max_materialize=None)
            except TypeError:
                logger.debug(
                    "Discarding non-iterable glyph history value %r", hist
                )
                items = ()
        hist = deque(items, maxlen=v_window)
        nd["glyph_history"] = hist
    return v_window, hist


def push_glyph(nd: dict[str, Any], glyph: str, window: int) -> None:
    """Add ``glyph`` to node history with maximum size ``window``.

    ``window`` validation and deque creation are handled by
    :func:`_ensure_history`.
    """

    _, hist = _ensure_history(nd, window, create_zero=True)
    hist.append(str(glyph))


def recent_glyph(nd: dict[str, Any], glyph: str, window: int) -> bool:
    """Return ``True`` if ``glyph`` appeared in last ``window`` emissions.

    ``window`` validation and deque creation are handled by
    :func:`_ensure_history`. A ``window`` of zero returns ``False`` and
    leaves ``nd`` unchanged. Negative values raise :class:`ValueError`.
    """

    v_window, hist = _ensure_history(nd, window)
    if v_window == 0:
        return False
    gl = str(glyph)
    return gl in hist


class HistoryDict(dict):
    """Dict specialized for bounded history series and usage counts.

    Usage counts are tracked explicitly via :meth:`get_increment`. Accessing
    keys through ``__getitem__`` or :meth:`get` does not affect the internal
    counters, avoiding surprising evictions on mere reads. Counting is now
    handled with :class:`collections.Counter` alone, relying on
    :meth:`Counter.most_common` to locate least-used entries when required.

    Parameters
    ----------
    data:
        Initial mapping to populate the dictionary.
    maxlen:
        Maximum length for history lists stored as values.
    """

    def __init__(
        self,
        data: dict[str, Any] | None = None,
        *,
        maxlen: int = 0,
    ) -> None:
        super().__init__(data or {})
        self._maxlen = maxlen
        self._counts: Counter[str] = Counter()
        if self._maxlen > 0:
            for k, v in list(self.items()):
                if isinstance(v, list):
                    super().__setitem__(k, deque(v, maxlen=self._maxlen))
                self._counts[k] = 0
        else:
            for k in self:
                self._counts[k] = 0
        # ``_heap`` is no longer required with ``Counter.most_common``.

    def _increment(self, key: str) -> None:
        """Increase usage count for ``key``."""
        self._counts[key] += 1

    def _to_deque(self, val: Any) -> deque:
        """Coerce ``val`` to a deque respecting ``self._maxlen``.

        ``Iterable`` inputs (excluding ``str`` and ``bytes``) are expanded into
        the deque, while single values are wrapped. Existing deques are
        returned unchanged.
        """

        if isinstance(val, deque):
            return val
        if isinstance(val, Iterable) and not isinstance(val, (str, bytes)):
            return deque(val, maxlen=self._maxlen)
        return deque([val], maxlen=self._maxlen)

    def _resolve_value(self, key: str, default: Any, *, insert: bool) -> Any:
        if insert:
            val = super().setdefault(key, default)
        else:
            val = super().__getitem__(key)
        if self._maxlen > 0:
            if not isinstance(val, Mapping):
                val = self._to_deque(val)
            super().__setitem__(key, val)
        return val

    def get_increment(self, key: str, default: Any = None) -> Any:
        insert = key not in self
        val = self._resolve_value(key, default, insert=insert)
        self._increment(key)
        return val

    def __getitem__(self, key):  # type: ignore[override]
        return self._resolve_value(key, None, insert=False)

    def get(self, key, default=None):  # type: ignore[override]
        try:
            return self._resolve_value(key, None, insert=False)
        except KeyError:
            return default

    def __setitem__(self, key, value):  # type: ignore[override]
        super().__setitem__(key, value)
        if key not in self._counts:
            self._counts[key] = 0

    def setdefault(self, key, default=None):  # type: ignore[override]
        insert = key not in self
        val = self._resolve_value(key, default, insert=insert)
        if insert:
            self._counts[key] = 0
        return val

    def pop_least_used(self) -> Any:
        """Remove and return the value with the smallest usage count."""
        while self._counts:
            key = min(self._counts, key=self._counts.get)
            self._counts.pop(key, None)
            if key in self:
                return super().pop(key)
        raise KeyError("HistoryDict is empty; cannot pop least used")

    def pop_least_used_batch(self, k: int) -> None:
        for _ in range(max(0, int(k))):
            try:
                self.pop_least_used()
            except KeyError:
                break


def ensure_history(G) -> dict[str, Any]:
    """Ensure ``G.graph['history']`` exists and return it.

    ``HISTORY_MAXLEN`` must be non-negative; otherwise a
    :class:`ValueError` is raised. When ``HISTORY_MAXLEN`` is zero, a regular
    ``dict`` is used.
    """
    maxlen, _ = _ensure_history({}, int(get_param(G, "HISTORY_MAXLEN")))
    hist = G.graph.get("history")
    sentinel_key = "_metrics_history_id"
    replaced = False
    if maxlen == 0:
        if isinstance(hist, HistoryDict):
            hist = dict(hist)
            G.graph["history"] = hist
            replaced = True
        elif hist is None:
            hist = {}
            G.graph["history"] = hist
            replaced = True
        if replaced:
            G.graph.pop(sentinel_key, None)
        return hist
    if (
        not isinstance(hist, HistoryDict)
        or hist._maxlen != maxlen
    ):
        hist = HistoryDict(hist, maxlen=maxlen)
        G.graph["history"] = hist
        replaced = True
    excess = len(hist) - maxlen
    if excess > 0:
        hist.pop_least_used_batch(excess)
    if replaced:
        G.graph.pop(sentinel_key, None)
    return hist


def current_step_idx(G) -> int:
    """Return the current step index from ``G`` history."""

    graph = getattr(G, "graph", G)
    return len(graph.get("history", {}).get("C_steps", []))

    

def append_metric(hist: dict[str, Any], key: str, value: Any) -> None:
    """Append ``value`` to ``hist[key]`` list, creating it if missing."""
    hist.setdefault(key, []).append(value)


def last_glyph(nd: dict[str, Any]) -> str | None:
    """Return the most recent glyph for node or ``None``."""
    hist = nd.get("glyph_history")
    return hist[-1] if hist else None


def count_glyphs(
    G, window: int | None = None, *, last_only: bool = False
) -> Counter:
    """Count recent glyphs in the network.

    If ``window`` is ``None``, the full history for each node is used. A
    ``window`` of zero yields an empty :class:`Counter`. Negative values raise
    :class:`ValueError`.
    """

    if window is not None:
        window = _validate_window(window)
        if window == 0:
            return Counter()

    counts: Counter[str] = Counter()
    for _, nd in G.nodes(data=True):
        if last_only:
            g = last_glyph(nd)
            if g:
                counts[g] += 1
            continue
        hist = nd.get("glyph_history")
        if not hist:
            continue
        if window is None:
            seq = hist
        else:
            start = max(len(hist) - window, 0)
            seq = islice(hist, start, None)
        counts.update(seq)

    return counts
