from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Optional
import numpy as np
from numpy.typing import NDArray
from scanner3d.zemod.iar.i_data_grid import IDataGrid
from scanner3d.analysis.psf.plotter import plot_shot


if TYPE_CHECKING:
    from scanner3d.zemod.zemod_field import ZeModField
    from scanner3d.zemod.zemod_analysis import ZeModAnalysis

log = logging.getLogger(__name__)
class Shot:
    __slots__ = ("_data_grid", "_field_x", "_field_y")

    def __init__(
        self,
        *,
        data_grid: IDataGrid,
        field_x: float,
        field_y: float,
        _internal: bool = False,
    ) -> None:
        if not _internal:
            raise RuntimeError("Direct construction of Shot is not allowed. "
                               "Use compute(...) to calculate it from Zemax, "
                               "or from_components() to reconstruct it.")

        self._data_grid = data_grid
        self._field_x = field_x
        self._field_y = field_y

    @classmethod
    def from_components(
        cls,
        *,
        data_grid: IDataGrid,
        field_x: float,
        field_y: float,
    ) -> Shot:
        return cls(
            data_grid=data_grid,
            field_x=field_x,
            field_y=field_y,
            _internal=True,
        )

    @classmethod
    def compute(
        cls,
        *,
        field: ZeModField,
        psf: ZeModAnalysis,
        grid_index: int | None = None,
        x: float = 0.0,
        y: float = 0.0,
    ) -> Shot:
        xf = float(x)
        yf = float(y)

        field.set_xy(xf, yf)
        result = psf.run()

        if grid_index is None:
            log.warning("Test did not specify a data index. First grid will be used by default.")
            grid_index = 1

        data_grid = result.get_data_grid(grid_index)

        return cls(
            data_grid=data_grid,
            field_x=xf,
            field_y=yf,
            _internal=True,
        )

    @property
    def field_x(self) -> float:
        return self._field_x

    @property
    def field_y(self) -> float:
        return self._field_y

    @property
    def data_grid(self) -> IDataGrid:
        return self._data_grid

    @property
    def raw_data(self) -> NDArray[np.float64]:
        return self._data_grid.values

    @property
    def shape(self) -> tuple[int, int]:
        return self._data_grid.shape

    @property
    def min(self) -> float:
        return self._data_grid.value_min

    @property
    def max(self) -> float:
        return self._data_grid.value_max

    def plot(self, *args, **kwargs):
        return plot_shot(self, *args, **kwargs)

    def __str__(self) -> str:
        return (
            "📸PSF Shot\n"
            "────────────────────────────\n"
            f"  field: x={self.field_x:.4g}, y={self.field_y:.4g},\n"
            f"  shape={self.shape}, min={self.min:.4g}, max={self.max:.4g}\n"
        )
