from __future__ import annotations
import logging
import numpy as np
from pathlib import Path
from typing import Optional, Sequence, TYPE_CHECKING, Iterator, overload
from numbers import Real
from scanner3d.analysis.psf.frame import Frame
from scanner3d.h5.album_to_h5 import save_album, load_album
from scanner3d.scanner.scanner_ref import ScannerRef

if TYPE_CHECKING:
    from scanner3d.zemod import ZeMod
    from scanner3d.tuner.tuner import Tuner
    from scanner3d.zemod.zemod_analysis import ZeModAnalysis
    from scanner3d.zemod.zemod_row import ZeModRow

log = logging.getLogger(__name__)
class Album:
    __slots__ = ("_frames", "_z_seq", "_path", "_n_frames", "_camera_ref")

    def __init__(
            self,
            *,
            frames: list[Frame],
            z_seq: Sequence[float],
            path: Optional[Path] = None,
            camera_ref: ScannerRef | None = None,
        _internal: bool = False,
    ) -> None:
        if not _internal:
            raise RuntimeError(
                "Direct construction of PsfBank is not allowed. "
                "Use PsfBank.compute(...) or PsfBank.load(...)."
            )

        self._frames = list(frames)
        self._z_seq = [float(z) for z in z_seq]
        self._path = path
        self._n_frames = len(frames)
        self._camera_ref = camera_ref

    @property
    def frames(self) -> list[Frame]:
        return self._frames

    @property
    def z_seq(self) -> list[float]:
        return self._z_seq

    @property
    def path(self) -> Optional[Path]:
        return self._path

    @property
    def n_frames(self) -> int:
        return self._n_frames

    @property
    def camera_ref(self) -> ScannerRef:
        return self._camera_ref


    def __iter__(self) -> Iterator[Frame]:
        return iter(self._frames)

    def __len__(self) -> int:
        return len(self._frames)

    @overload
    def __getitem__(self, item: float) -> Frame: ...
    @overload
    def __getitem__(self, item: slice) -> list[Frame]: ...

    def __getitem__(self, item):
        if isinstance(item, Real) and not isinstance(item, bool):
            return self.get_nearest(float(item))
        if isinstance(item, slice):
            return self._frames[item]
        raise TypeError(f"Unsupported index type for PSF Album: {type(item)!r}")

    def get_nearest(self, wd: float) -> Frame:
        if not self._z_seq:
            raise ValueError("PSF Album has no z sequence; cannot search nearest frame.")
        z = np.asarray(self._z_seq, dtype=float)
        eligible = z[z <= wd]
        if eligible.size > 0:
            chosen_z = eligible.max()
        else:
            chosen_z = z.min()
        idx = int(np.where(z == chosen_z)[0][0])
        psf = self._frames[idx]
        log.info("[PSF Album] Requested Z=%.3f → selected frame %.3f mm (index %d)",wd, chosen_z, idx)
        return psf

    @classmethod
    def load(cls, path: Path | str) -> "Album":
        return load_album(Path(path))

    def save(self, path: Path | str, **kwargs) -> list[str]:
        written = save_album(album=self, path=Path(path), **kwargs)
        self._path = Path(path)
        return written

    @classmethod
    def from_components(
        cls,
        *,
        frames: list[Frame],
        z_seq: Sequence[float],
        path: Optional[Path] = None,
        camera_ref: ScannerRef | None = None,
    ) -> Album:

        return cls(frames=list(frames),z_seq=list(z_seq),path=path,camera_ref=camera_ref,_internal=True)

    @classmethod
    def compute(
            cls,
            *,
            tuner: Tuner,
            psf: ZeModAnalysis,
            grid_index: int,
            z_seq: Sequence[float],
            x_seq=None,
            y_seq=None) -> "Album":
        frames: list[Frame] = []
        z_row = tuner.sm.get_wd_row() #ask tuner for surface (row) in LDE for set working distance
        field = tuner.fm.get_test_field()  #ask tuner which field used for test

        for z in z_seq:
            z_float = float(z)
            z_row.thickness = z_float
            profile = tuner.get_profile()
            psf_slice = Frame.compute(
                field=field,
                psf=psf,
                grid_index= grid_index,
                x_seq=x_seq,
                y_seq=y_seq,
                profile=profile,
                z= z
            )
            log.debug("PSF Frame @ %.2f mm created", z_float)
            frames.append(psf_slice)
        return cls(frames=frames, z_seq=z_seq, camera_ref=tuner.scanner_ref, _internal=True)

    def __str__(self) -> str:
        if not self._z_seq:
            return "📘 PSF Album (empty)>"
        wds = sorted(self._z_seq)
        n = len(wds)
        lines = [
            "📘 PSF Album",
            "────────────────────────────",
            f"  File: {self._path.name if self._path else '(unsaved)'}",
            f"  Number of frames: {n}",
            f"  Range: {wds[0]:.2f} mm → {wds[-1]:.2f} mm",
        ]
        return "\n".join(lines)

    def __repr__(self) -> str:
        if not self._z_seq:
            return "PSF Album"
        wds = sorted(self._z_seq)
        return f"PSF Album {len(wds)} frames {wds[0]:.2f}–{wds[-1]:.2f} mm>"
