from __future__ import annotations

import json
import os
import time
import tracemalloc
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Dict, Any, Optional

try:
    import psutil  # optional
except Exception:
    psutil = None  # type: ignore

try:
    # Optional NVML (NVIDIA): provided by 'pynvml'
    import pynvml  # type: ignore
except Exception:
    pynvml = None  # type: ignore


@dataclass
class TelemetrySample:
    # Absolute times (seconds)
    t_start: float
    t_end: float
    # CPU & mem
    cpu_user: float
    cpu_system: float
    rss_bytes: int
    # Workload metrics (optional)
    frames: int = 0
    gps_points: int = 0
    notes: str = ""

    @property
    def wall_sec(self) -> float:
        return max(0.0, self.t_end - self.t_start)

    def to_dict(self) -> Dict[str, Any]:
        d = asdict(self)
        d["wall_sec"] = self.wall_sec
        return d


def _proc_resource() -> Dict[str, Any]:
    # Best effort cross-platform metrics:
    cpu_user = 0.0
    cpu_sys = 0.0
    rss = 0
    try:
        if psutil:
            p = psutil.Process(os.getpid())
            cpu = p.cpu_times()
            mem = p.memory_info()
            cpu_user = float(getattr(cpu, "user", 0.0))
            cpu_sys = float(getattr(cpu, "system", 0.0))
            rss = int(getattr(mem, "rss", 0))
        else:
            # Fallbacks: tracemalloc for heap estimate; CPU from process_time
            rss = 0
    except Exception:
        pass
    return {"cpu_user": cpu_user, "cpu_system": cpu_sys, "rss_bytes": rss}


class TelemetryRecorder:
    """
    Context manager to measure wall/CPU/memory around a pipeline step.
    Usage:
        with TelemetryRecorder(frames=n, notes="coarse retrieval") as rec:
            ... work ...
        sample = rec.sample
    """
    def __init__(self, frames: int = 0, gps_points: int = 0, notes: str = "", enable_gpu: bool = True):
        self.frames = frames
        self.gps_points = gps_points
        self.notes = notes
        self.enable_gpu = enable_gpu
        self.sample: Optional[TelemetrySample] = None
        self._t0 = 0.0
        self._cpu0_user = 0.0
        self._cpu0_sys = 0.0
        self._rss0 = 0
        self._ps_proc = None
        self._cpu_percent0 = None

    def __enter__(self):
        tracemalloc.start()
        self._t0 = time.perf_counter()
        r = _proc_resource()
        self._cpu0_user = r["cpu_user"]
        self._cpu0_sys = r["cpu_system"]
        self._rss0 = r["rss_bytes"]
        # Init psutil CPU% baseline if available
        if psutil:
            try:
                self._ps_proc = psutil.Process(os.getpid())
                # Prime cpu_percent measurement
                _ = self._ps_proc.cpu_percent(interval=None)
            except Exception:
                self._ps_proc = None
        return self

    def __exit__(self, exc_type, exc, tb):
        t1 = time.perf_counter()
        r1 = _proc_resource()
        # Use absolute values; consumers can compute deltas if needed
        self.sample = TelemetrySample(
            t_start=self._t0,
            t_end=t1,
            cpu_user=r1["cpu_user"],
            cpu_system=r1["cpu_system"],
            rss_bytes=r1["rss_bytes"] or self._rss0,
            frames=self.frames,
            gps_points=self.gps_points,
            notes=self.notes,
        )
        try:
            tracemalloc.stop()
        except Exception:
            pass

    @staticmethod
    def collect_cpu_pct_rss() -> Dict[str, Any]:
        """Collect instant CPU% and RSS in MB (requires psutil)."""
        cpu_pct = None
        rss_mb = None
        if psutil:
            try:
                p = psutil.Process(os.getpid())
                cpu_pct = float(p.cpu_percent(interval=None))
                rss = int(getattr(p.memory_info(), "rss", 0))
                rss_mb = rss / 1e6
            except Exception:
                pass
        return {"cpu_pct": cpu_pct, "rss_mb": rss_mb}

    @staticmethod
    def collect_gpu_metrics(enable_gpu: bool = True) -> Dict[str, Any]:
        """Collect GPU utilization% and memory MB (requires pynvml)."""
        if not enable_gpu:
            return {"gpu_util": None, "gpu_mem_mb": None}
        gpu_util = None
        gpu_mem_mb = None
        if pynvml:
            try:
                pynvml.nvmlInit()
                count = pynvml.nvmlDeviceGetCount()
                if count > 0:
                    # Aggregate across GPUs
                    util_sum = 0.0
                    mem_used = 0
                    mem_total = 0
                    for i in range(count):
                        h = pynvml.nvmlDeviceGetHandleByIndex(i)
                        u = pynvml.nvmlDeviceGetUtilizationRates(h)
                        util_sum += float(getattr(u, "gpu", 0.0))
                        m = pynvml.nvmlDeviceGetMemoryInfo(h)
                        mem_used += int(getattr(m, "used", 0))
                        mem_total += int(getattr(m, "total", 0))
                    gpu_util = util_sum / max(1, count)
                    gpu_mem_mb = mem_used / 1e6
            except Exception:
                pass
            finally:
                try:
                    pynvml.nvmlShutdown()
                except Exception:
                    pass
        return {"gpu_util": gpu_util, "gpu_mem_mb": gpu_mem_mb}

    @staticmethod
    def write_perf_json(out_dir: Path, payload: Dict[str, Any]) -> Path:
        out_dir.mkdir(parents=True, exist_ok=True)
        path = out_dir / "perf.json"
        try:
            path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
        except Exception:
            # Best effort: don't crash pipeline for telemetry failure
            pass
        return path
