from __future__ import annotations
import numpy as np
import logging
from typing import Sequence, TYPE_CHECKING
from scanner3d.tuner.profile import Profile
from scanner3d.analysis.psf.shot import Shot
from scanner3d.analysis.psf.frame_meta import FrameMeta
from scanner3d.analysis.psf.plotter import plot_frame

if TYPE_CHECKING:
    from scanner3d.zemod.zemod_field import ZeModField
    from scanner3d.zemod.zemod_analysis import ZeModAnalysis


log = logging.getLogger(__name__)
class Frame:
    __slots__ = ("_shots", "_x_seq", "_y_seq", "_meta", "_profile", "_values", "_z")

    def __init__(
        self,
        *,
        shots: list[Shot],
        x_seq: np.ndarray,
        y_seq: np.ndarray,
        meta: FrameMeta,
        profile: Profile,
        z: float,
        _internal: bool = False,
    ) -> None:
        if not _internal:
            raise RuntimeError(
                "Direct construction of Frame is not allowed. "
                "Use compute(...) to calculate it from Zemax, or from_components() to reconstruct it.")

        self._shots = shots
        self._x_seq = x_seq
        self._y_seq = y_seq
        self._meta = meta
        self._profile = profile
        self._z = z
        if shots:
            shapes = {shot.raw_data.shape for shot in shots}
            if len(shapes) != 1:
                raise ValueError(f"All shots in a Frame must have the same raw_data.shape, got: {shapes}")
            self._values = np.stack([shot.raw_data for shot in shots])
        else:
            self._values = np.empty((0, 0, 0), dtype=float)  # -> (n_shots, Ny, Nx)

    @classmethod
    def from_components(
        cls,
        *,
        shots: list[Shot],
        x_seq: Sequence[float] | np.ndarray,
        y_seq: Sequence[float] | np.ndarray,
        profile: Profile,
        z: float,
        meta: FrameMeta | None = None,
    ) -> "Frame":
        x_arr = np.asarray(x_seq, dtype=float)
        y_arr = np.asarray(y_seq, dtype=float)
        return cls(shots=shots,x_seq=x_arr,y_seq=y_arr,meta=meta,profile=profile,z=z,_internal=True)

    @classmethod
    def compute(
            cls,
            *,
            field: ZeModField,
            psf: ZeModAnalysis,
            grid_index: int,
            x_seq: Sequence[float] | None,
            y_seq: Sequence[float] | None,
            profile: Profile,
            z: float,
    ) -> Frame:

        if x_seq is None and y_seq is None:
            raise ValueError("Provide x_seq and/or y_seq (x for 1D-X, y for 1D-Y, both for 2D).")

        x_arr = np.asarray(x_seq, dtype=float) if x_seq is not None else None
        y_arr = np.asarray(y_seq, dtype=float) if y_seq is not None else None

        # 1️⃣ 2D CASE — both x_seq and y_seq provided
        if x_arr is not None and y_arr is not None:
            XX, YY = np.meshgrid(x_arr, y_arr)  # shape (len(y), len(x))
            xs, ys = XX.ravel(), YY.ravel()
            x_meta = x_arr
            y_meta = y_arr

        #2️⃣ 1D CASE — scan along X (y = 0)
        elif x_arr is not None:
            xs = x_arr
            ys = np.zeros_like(xs)
            x_meta = x_arr
            y_meta = np.asarray([0.0], dtype=float)

        #3️⃣ 1D CASE — scan along Y (x = 0)
        else:
            ys = y_arr
            xs = np.zeros_like(ys)
            x_meta = np.asarray([0.0], dtype=float)
            y_meta = y_arr  # type: ignore[assignment]

        shots: list[Shot] = []
        for xv, yv in zip(xs, ys):
            shot = Shot.compute(field=field, psf=psf, x=float(xv), y=float(yv), grid_index=grid_index)
            shots.append(shot)
        meta = FrameMeta.from_shots(shots=shots)
        return cls(shots=shots, x_seq=x_meta, y_seq=y_meta, meta=meta, _internal=True, profile=profile, z=z)

    @property
    def shots(self) -> list[Shot]:
        return self._shots

    @property
    def x_seq(self) -> np.ndarray:
        return self._x_seq

    @property
    def y_seq(self) -> np.ndarray:
        return self._y_seq

    @property
    def meta(self) -> FrameMeta:
        return self._meta

    @property
    def profile(self) -> Profile:
        return self._profile

    @property
    def values(self) -> np.ndarray:
        return self._values

    @property
    def z(self) -> float:
        return self._z


    def __call__(self, row: int, col: int) -> Shot:
        """
        Return the Shot at grid position (row, col).
        row = index into y_seq
        col = index into x_seq
        """
        nx = len(self._x_seq)
        idx = row * nx + col
        return self._shots[idx]

    def __getitem__(self, idx):
        """
        Allow both:
            frame[i]        -> 1D index
            frame[row, col] -> grid index
        """
        if isinstance(idx, tuple) and len(idx) == 2:
            row, col = idx
            nx = len(self._x_seq)
            i = row * nx + col
            return self._shots[i]
        return self._shots[idx]

    def __str__(self) -> str:
        x_arr = np.asarray(self._x_seq, dtype=float)
        y_arr = np.asarray(self._y_seq, dtype=float)

        parts = [
            "🖼️ PSF Frame",
            "────────────────────────────",
            f"  z: {self.z:.2f}",
            f"  x: shape={x_arr.shape}, "
            f"min={np.nanmin(x_arr):.4g}, max={np.nanmax(x_arr):.4g},",
            f"  y: shape={y_arr.shape}, "
            f"min={np.nanmin(y_arr):.4g}, max={np.nanmax(y_arr):.4g},",
            f"  total shots: {len(self._shots)},",
            f"  values shape: {self._values.shape},",
            f"  meta: {self._meta}",

        ]
        return "\n".join(parts)

    def plot(self, *args, **kwargs):
        return plot_frame(self, *args, **kwargs)


