from __future__ import annotations
from h5py import Group
import numpy as np
from typing import Optional
from scanner3d.analysis.psf.frame import Frame, FrameMeta
from scanner3d.analysis.psf.shot import Shot
from scanner3d.zemod.iar.grid_meta import GridMeta
from scanner3d.zemod.iar.data_grid import DataGrid
from scanner3d.tuner.profile import Profile
from scanner3d.h5.h5aid import suggest_chunks
from scanner3d.h5.frame_h5 import FrameH5

from allytools.formats import (
    save_dataclass,
    set_attrs_from_dataclass, read_dataclass_from_h5, read_attrs_into_dataclass
)

FRAME_H5_VERSION = "2.0"

def write_frame(
        grp: Group,
        frame: Frame,
        *,
        compression: Optional[str],
        compression_opts: Optional[int]) -> None:

    shots = frame.shots
    n_shots = len(shots)
    if n_shots == 0:
        raise ValueError("Frame has no shots; nothing to save.")
    shapes = {shot.raw_data.shape for shot in shots}
    if len(shapes) != 1:
        raise ValueError(f"All shots in a Frame must have same raw_data.shape, got {shapes}")

    values = np.stack([shot.raw_data for shot in shots])
    ds_chunks = suggest_chunks(values.shape)

    if compression is not None:
        grp.create_dataset(
            FrameH5.VALUES,
            data=values,
            compression=compression,
            compression_opts=compression_opts,
            chunks=ds_chunks)
    else:
        grp.create_dataset(
            FrameH5.VALUES,
            data=values,
            chunks=ds_chunks)

    grp.create_dataset(FrameH5.X_SEQ, data=np.asarray(frame.x_seq, dtype=float))
    grp.create_dataset(FrameH5.Y_SEQ, data=np.asarray(frame.y_seq, dtype=float))
    grp.create_dataset(FrameH5.Z,     data=frame.z)
    sys_grp = grp.create_group(FrameH5.PROFILE)
    set_attrs_from_dataclass(sys_grp, frame.profile)
    grid_meta_grp = grp.create_group(FrameH5.GRID_META)
    save_dataclass(grid_meta_grp, shots[0].data_grid.meta)
    grp.attrs[FrameH5.FORMAT_VERSION] = FRAME_H5_VERSION


def read_frame(grp: Group) -> Frame:
    """Reconstruct a Frame object from an H5 group written by write_frame()."""
    version = grp.attrs.get(FrameH5.FORMAT_VERSION, None)
    if version not in (None, FRAME_H5_VERSION):
        raise ValueError(
            f"Unsupported Frame H5 format version: {version!r}, "
            f"expected {FRAME_H5_VERSION}")
    values = np.asarray(grp[FrameH5.VALUES][...], dtype=float)   # (N, Ny, Nx)
    x_seq = np.asarray(grp[FrameH5.X_SEQ][...], dtype=float)
    y_seq = np.asarray(grp[FrameH5.Y_SEQ][...], dtype=float)
    z = float(grp[FrameH5.Z][()])

    n_shots, _, _ = values.shape
    nx = x_seq.size
    ny = y_seq.size

    if nx * ny != n_shots:
        raise ValueError(f"Inconsistent frame: nx*ny={nx*ny}, but values.shape[0]={n_shots}")

    profile_grp = grp[FrameH5.PROFILE]
    profile = read_attrs_into_dataclass(profile_grp, Profile)
    grid_meta_grp = grp[FrameH5.GRID_META]
    grid_meta = read_dataclass_from_h5(grid_meta_grp, GridMeta)

    shots: list[Shot] = []
    for k in range(n_shots):
        row, col = divmod(k, nx)
        x = float(x_seq[col])
        y = float(y_seq[row])
        raw = values[k, :, :]
        data_grid = DataGrid.from_components(raw, grid_meta)
        shot = Shot.from_components(data_grid=data_grid, field_x=x,field_y=y)
        shots.append(shot)
    meta = FrameMeta.from_shots(shots)
    return Frame.from_components(shots=shots, x_seq=x_seq, y_seq=y_seq, meta=meta, profile=profile, z=z)