from __future__ import annotations

import atexit
import json
import os
import threading
import time
from dataclasses import dataclass
from datetime import datetime, timezone
from queue import Queue, Full, Empty
from typing import Any, Dict, Optional, Union
from urllib import request, error

__all__ = [
    "configure",
    "emit",
    "record",
    "flush",
    "close",
    "stats",
    "set_user",
    "from_openai",
    "__version__",
]

__version__ = "0.2.4"

JsonDict = Dict[str, Any]


def _dt_to_rfc3339_z(dt: datetime) -> str:
    if dt.tzinfo is None:
        dt = dt.replace(tzinfo=timezone.utc)
    return dt.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")


def _mask_key(key: Optional[str]) -> str:
    if not key:
        return ""
    if len(key) <= 8:
        return "***"
    return key[:4] + "…" + key[-4:]


def _env(name: str, fallback: Optional[str] = None) -> Optional[str]:
    return os.getenv(name, fallback) if fallback is not None else os.getenv(name)


def _resolve_config(api_key: Optional[str], base_url: Optional[str]) -> tuple[str, str]:
    # Read NIVARA_* environment variables
    ak = api_key or _env("NIVARA_API_KEY")
    if not ak:
        raise ValueError("API key required: set NIVARA_API_KEY")
    base = (base_url or _env("NIVARA_BASE_URL") or "https://api.getnivara.com").rstrip("/")
    return ak, base


def _user_agent() -> str:
    return f"nivara-python/{__version__}"


def _http_post(url: str, headers: Dict[str, str], payload: Dict[str, Any], timeout: float) -> tuple[int, Dict[str, Any] | str]:
    data = json.dumps(payload, separators=(",", ":")).encode("utf-8")
    req = request.Request(url=url, data=data, method="POST", headers=headers)
    with request.urlopen(req, timeout=timeout) as resp:
        status = getattr(resp, "status", None) or resp.getcode()
        raw = resp.read()
        try:
            body = json.loads(raw.decode("utf-8") or "{}")
        except Exception:
            body = raw.decode("utf-8", "ignore")
        return status, body


def _should_retry(status: Optional[int]) -> bool:
    if status is None:
        return True
    return status in (408, 429) or 500 <= status <= 599


def _backoff(attempt: int) -> float:
    # attempt: 0,1,2,... -> ~0.2, 0.5, 0.9, capped ~2.0 with jitter
    base = min(2.0, 0.2 * (1.6 ** attempt))
    return base + (0.05 * (attempt % 3))


@dataclass
class _Stats:
    queued: int = 0
    sent: int = 0
    failed: int = 0
    last_error: Optional[str] = None


class _Client:
    def __init__(self) -> None:
        self.api_key = None  # type: Optional[str]
        self.base_url = "https://api.getnivara.com"
        self.timeout = 2.0
        self.retries = 0
        self.mode = "background"  # background | sync
        self.queue_size = 10000
        self.debug = False
        self.default_user: Optional[str] = None

        self._q: Queue[tuple[JsonDict, Optional[float], int]] | None = None
        self._worker: threading.Thread | None = None
        self._stop = threading.Event()
        self._stats = _Stats()
        self._lock = threading.Lock()

    # Configuration and lifecycle
    def configure(
        self,
        *,
        base_url: Optional[str] = None,
        api_key: Optional[str] = None,
        timeout: float = 2.0,
        retries: int = 0,
        mode: str = "background",
        queue_size: int = 10000,
        debug: bool = False,
    ) -> None:
        ak, base = _resolve_config(api_key, base_url)
        self.api_key = ak
        self.base_url = base
        self.timeout = float(timeout)
        self.retries = max(0, int(retries))
        self.mode = mode if mode in ("background", "sync") else "background"
        self.queue_size = max(1, int(queue_size))
        self.debug = bool(debug)
        if self.mode == "background":
            self._ensure_worker()
        else:
            self._shutdown_worker()

    def _ensure_worker(self) -> None:
        if self._q is None:
            self._q = Queue(maxsize=self.queue_size)
        if self._worker is None or not self._worker.is_alive():
            self._stop.clear()
            self._worker = threading.Thread(target=self._run, name="nivara-worker", daemon=True)
            self._worker.start()

    def _shutdown_worker(self) -> None:
        if self._worker is not None:
            self._stop.set()
            try:
                if self._q is not None:
                    # unblock get()
                    self._q.put_nowait(({"__shutdown__": True}, None, 0))
            except Full:
                pass
            self._worker.join(timeout=2.0)
            self._worker = None
        self._q = None

    # Public ops
    def emit(self, payload: JsonDict, *, timeout: Optional[float] = None, retries: Optional[int] = None) -> JsonDict:
        # Background mode
        if self.mode == "background":
            if self._q is None:
                self._ensure_worker()
            try:
                self._q.put_nowait((payload, timeout, retries if retries is not None else self.retries))
                with self._lock:
                    self._stats.queued += 1
                return {"status": "queued"}
            except Full:
                with self._lock:
                    self._stats.failed += 1
                    self._stats.last_error = "queue full"
                if self.debug:
                    print("[nivara] queue full; dropping event")
                return {"status": "error", "error": "queue full"}

        # Sync mode
        return self._send(payload, timeout=timeout or self.timeout, retries=retries if retries is not None else self.retries)

    def _send(self, payload: JsonDict, *, timeout: float, retries: int) -> JsonDict:
        if not self.api_key:
            ak, base = _resolve_config(None, None)
            self.api_key, self.base_url = ak, base

        url = f"{self.base_url}/v1/record"
        headers = {
            "Content-Type": "application/json",
            "X-API-Key": self.api_key,
            "User-Agent": _user_agent(),
        }

        attempt = 0
        while True:
            try:
                status, body = _http_post(url, headers, payload, timeout)
                if status == 201:
                    with self._lock:
                        self._stats.sent += 1
                    return {"status": "ok", "http_status": status, **(body if isinstance(body, dict) else {})}
                # Non-201
                err_dict = body if isinstance(body, dict) else {"error": str(body)}
                err_dict.setdefault("status", "error")
                err_dict["http_status"] = status
                if retries > 0 and _should_retry(status):
                    if self.debug:
                        print(f"[nivara] retryable status {status}; retrying...")
                    time.sleep(_backoff(attempt))
                    attempt += 1
                    retries -= 1
                    continue
                with self._lock:
                    self._stats.failed += 1
                    self._stats.last_error = err_dict.get("error") or f"http {status}"
                return err_dict
            except error.HTTPError as e:
                try:
                    txt = e.read().decode("utf-8")
                    err = json.loads(txt)
                    if not isinstance(err, dict):
                        err = {"error": txt}
                except Exception:
                    err = {"error": str(e)}
                err.setdefault("status", "error")
                err["http_status"] = getattr(e, "code", None)
                if retries > 0 and _should_retry(err.get("http_status")):
                    if self.debug:
                        print(f"[nivara] HTTPError {err.get('http_status')}; retrying...")
                    time.sleep(_backoff(attempt))
                    attempt += 1
                    retries -= 1
                    continue
                with self._lock:
                    self._stats.failed += 1
                    self._stats.last_error = err.get("error")
                return err
            except Exception as e:
                if retries > 0:
                    if self.debug:
                        print(f"[nivara] transport error: {e}; retrying...")
                    time.sleep(_backoff(attempt))
                    attempt += 1
                    retries -= 1
                    continue
                with self._lock:
                    self._stats.failed += 1
                    self._stats.last_error = str(e)
                return {"status": "error", "error": str(e)}

    def _run(self) -> None:
        if self.debug:
            ak_mask = _mask_key(self.api_key)
            print(f"[nivara] worker start base={self.base_url} key={ak_mask}")
        while not self._stop.is_set():
            try:
                item = self._q.get(timeout=0.2) if self._q is not None else None
            except Empty:
                continue
            if item is None:
                continue
            payload, timeout_override, retries_override = item
            if "__shutdown__" in payload:
                self._q.task_done()
                break
            _ = self._send(payload, timeout=timeout_override or self.timeout, retries=retries_override)
            self._q.task_done()
        if self.debug:
            print("[nivara] worker stop")

    def flush(self, *, timeout: Optional[float] = None) -> bool:
        q = self._q
        if q is None:
            return True
        start = time.time()
        while q.unfinished_tasks > 0:
            if timeout is not None and (time.time() - start) > timeout:
                return False
            time.sleep(0.05)
        return True

    def close(self) -> None:
        self._shutdown_worker()

    def stats(self) -> Dict[str, Any]:
        with self._lock:
            return {
                "queued": self._stats.queued,
                "sent": self._stats.sent,
                "failed": self._stats.failed,
                "last_error": self._stats.last_error,
            }


_client = _Client()


def configure(
    *,
    base_url: Optional[str] = None,
    api_key: Optional[str] = None,
    timeout: float = 2.0,
    retries: int = 0,
    mode: str = "background",
    queue_size: int = 10000,
    debug: bool = False,
) -> None:
    """Configure the global client. Defaults favor background send with stdlib.

    Uses NIVARA_API_KEY and NIVARA_BASE_URL (or explicit parameters).
    """
    _client.configure(
        base_url=base_url,
        api_key=api_key,
        timeout=timeout,
        retries=retries,
        mode=mode,
        queue_size=queue_size,
        debug=debug,
    )


def _build_payload(
    *,
    metric: str,
    ts: Optional[Union[str, datetime]] = None,
    user_identifier: Optional[str] = None,
    input_tokens: Optional[int] = None,
    output_tokens: Optional[int] = None,
    cached_tokens: Optional[int] = None,
    reasoning_tokens: Optional[int] = None,
) -> JsonDict:
    if not isinstance(metric, str) or not metric.strip():
        raise ValueError("metric is required and must be a non-empty string")
    payload: JsonDict = {"metric": metric}
    if ts is not None:
        if isinstance(ts, datetime):
            payload["ts"] = _dt_to_rfc3339_z(ts)
        elif isinstance(ts, str):
            payload["ts"] = ts
        else:
            raise TypeError("ts must be datetime or RFC3339 string")
    if user_identifier is not None:
        if not isinstance(user_identifier, str):
            raise TypeError("user_identifier must be a string")
        payload["user_identifier"] = user_identifier
    def add_int(name: str, value: Optional[int]):
        if value is None:
            return
        if not isinstance(value, int):
            raise TypeError(f"{name} must be int")
        if value < 0:
            raise ValueError(f"{name} must be >= 0")
        payload[name] = value
    add_int("input_tokens", input_tokens)
    add_int("output_tokens", output_tokens)
    add_int("cached_tokens", cached_tokens)
    add_int("reasoning_tokens", reasoning_tokens)
    return payload


def emit(
    *,
    metric: str,
    ts: Optional[Union[str, datetime]] = None,
    user_identifier: Optional[str] = None,
    input_tokens: Optional[int] = None,
    output_tokens: Optional[int] = None,
    cached_tokens: Optional[int] = None,
    reasoning_tokens: Optional[int] = None,
    timeout: Optional[float] = None,
    sample_rate: float = 1.0,
) -> JsonDict:
    """Emit a metric using the configured client (background by default).

    sample_rate < 1.0 probabilistically drops events client-side.
    """
    # Sampling
    if sample_rate < 1.0:
        # Simple LCG-based fast RNG (deterministic per-process); ok for sampling
        rnd = (int(time.time_ns()) * 1103515245 + 12345) & 0xFFFFFFFF
        if (rnd / 0x100000000) > sample_rate:
            return {"status": "sampled_out"}

    if user_identifier is None and _client.default_user:
        user_identifier = _client.default_user

    payload = _build_payload(
        metric=metric,
        ts=ts,
        user_identifier=user_identifier,
        input_tokens=input_tokens,
        output_tokens=output_tokens,
        cached_tokens=cached_tokens,
        reasoning_tokens=reasoning_tokens,
    )
    return _client.emit(payload, timeout=timeout, retries=None)


def record(
    *,
    metric: str,
    ts: Optional[Union[str, datetime]] = None,
    user_identifier: Optional[str] = None,
    input_tokens: Optional[int] = None,
    output_tokens: Optional[int] = None,
    cached_tokens: Optional[int] = None,
    reasoning_tokens: Optional[int] = None,
    timeout: Optional[float] = None,
) -> JsonDict:
    """Synchronous one-shot send (ignores background mode)."""
    if user_identifier is None and _client.default_user:
        user_identifier = _client.default_user
    payload = _build_payload(
        metric=metric,
        ts=ts,
        user_identifier=user_identifier,
        input_tokens=input_tokens,
        output_tokens=output_tokens,
        cached_tokens=cached_tokens,
        reasoning_tokens=reasoning_tokens,
    )
    # Force sync: temporarily bypass background
    return _client._send(payload, timeout=timeout or _client.timeout, retries=_client.retries)


def stats() -> Dict[str, Any]:
    return _client.stats()


def flush(timeout: Optional[float] = None) -> bool:
    return _client.flush(timeout=timeout)


def close() -> None:
    _client.close()


def set_user(user_identifier: Optional[str]) -> None:
    _client.default_user = user_identifier


def from_openai(response: Any, *, user_identifier: Optional[str] = None) -> Dict[str, Any]:
    """Extract usage fields from an OpenAI response-like object.

    Returns a dict suitable for splatting into emit()/record() (excluding metric).
    """
    usage = getattr(response, "usage", None)
    input_tokens = getattr(usage, "input_tokens", 0) if usage is not None else 0
    output_tokens = getattr(usage, "output_tokens", 0) if usage is not None else 0
    in_det = getattr(usage, "input_tokens_details", None) if usage is not None else None
    out_det = getattr(usage, "output_tokens_details", None) if usage is not None else None
    cached_tokens = getattr(in_det, "cached_tokens", 0) if in_det is not None else 0
    reasoning_tokens = getattr(out_det, "reasoning_tokens", 0) if out_det is not None else 0
    out: Dict[str, Any] = {
        "user_identifier": user_identifier,
        "input_tokens": int(input_tokens or 0),
        "output_tokens": int(output_tokens or 0),
        "cached_tokens": int(cached_tokens or 0),
        "reasoning_tokens": int(reasoning_tokens or 0),
    }
    # Drop None user_identifier to comply with schema additionalProperties: false
    if out["user_identifier"] is None:
        del out["user_identifier"]
    return out


# Configure a reasonable default on import using env if available
try:
    configure(
        base_url=None,
        api_key=None,
        timeout=float(_env("NIVARA_TIMEOUT", "2.0") or 2.0),
        retries=int(_env("NIVARA_RETRIES", "0") or 0),
        mode=str(_env("NIVARA_MODE", "background") or "background"),
        queue_size=int(_env("NIVARA_QUEUE_SIZE", "10000") or 10000),
        debug=(_env("NIVARA_DEBUG", "0") == "1"),
    )
except Exception:
    # Defer errors until first call if env not present
    pass


@atexit.register
def _cleanup() -> None:
    # Attempt a short flush at exit for background mode
    try:
        _client.flush(timeout=1.0)
        _client.close()
    except Exception:
        pass
