from __future__ import annotations
import numpy as np

from dataclasses import dataclass
from typing import Generic, List, Tuple, TypeVar, Iterable, Any


T_Ray = TypeVar("T_Ray")

@dataclass(slots=True)
class RayBatch(Generic[T_Ray]):
    method: Any
    rays_type: Any
    grid: Tuple[int, int]
    to_surface: int
    wave_number: int
    rays: List[T_Ray]
    x_lin: np.ndarray | None
    y_lin: np.ndarray | None
    process_time: float | None = None

    @property
    def gx(self) -> int: return self.grid[0]
    @property
    def gy(self) -> int: return self.grid[1]
    @property
    def total(self) -> int: return self.gx * self.gy
    @property
    def shape(self) -> tuple[int, int]: return self.gy, self.gx


    def array(self, attr: str) -> np.ndarray:
        """Return a (gy, gx) array of a per-ray attribute (e.g. 'x', 'y', 'z', 'intensity', 'opd', ...)."""
        vals = [getattr(r, attr) for r in self.rays]
        if len(vals) != self.total:
            # If you ever change acceptance criteria and the count differs, pad with NaN
            pad = self.total - len(vals)
            vals += [np.nan] * max(0, pad)
        return np.asarray(vals, dtype=float).reshape(self.shape)

    def vectors(self, *attrs: str) -> dict[str, np.ndarray]:
        """Batch-pull multiple attributes at once, all reshaped (gy, gx)."""
        return {a: self.array(a) for a in attrs}

    def as_struct(self, attrs: Iterable[str]) -> dict[str, np.ndarray]:
        """Alias of vectors; pass any iterable of attribute names."""
        return self.vectors(*tuple(attrs))

    # ---- common convenience fields ----
    def xy(self) -> tuple[np.ndarray, np.ndarray]:
        return self.array("x"), self.array("y")

    def xyz(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        return self.array("x"), self.array("y"), self.array("z")

    def __str__(self) -> str:
        """Pretty one-line summary."""
        n = len(self.rays)
        gx, gy = self.grid
        first = self.rays[0] if n else None
        attrs = ", ".join(vars(first).keys()) if first else "–"
        return (
            f"RayBatch[{gx}×{gy}]\n "
            f"Processed time {self.process_time:.3f} sec"
            f"({n} rays) \n"
            f"method={getattr(self.method, 'name', self.method)}\n"
            f"wave={self.wave_number}\n"
            f"to_surface={self.to_surface}\n"
            f"rays_type={getattr(self.rays_type, 'name', self.rays_type)}\n"
            f"fields: {attrs}"
        )



    @classmethod
    def collect(cls,
                tracer: Any,
                rays: List[T_Ray],
                *,
                method: Any,
                rays_type: Any,
                grid: Tuple[int, int],
                to_surface: int,
                wave_number: int) -> "RayBatch[T_Ray]":
        return cls(
            method=method,
            rays_type=rays_type,
            grid=grid,
            to_surface=to_surface,
            wave_number=wave_number,
            rays=rays,
            x_lin=getattr(tracer, "x_lin", None),
            y_lin=getattr(tracer, "y_lin", None),
            process_time= getattr(tracer, "process_time", None))