from __future__ import annotations
import logging
from pathlib import Path
from allytools.units import Length
from scanner3d.camera3d.camera3d import Camera3D
from scanner3d.zemod.zemod import ZeMod
from scanner3d.zemod.enums import ZeModFieldTypes, ZeModFieldNormalizationType
from scanner3d.analysis.psf.album import Album
from scanner3d.test.tests.camera_test import CameraTest
from scanner3d.test.tests.test_settings import TestSettings, WavelengthCriteria, FocusDistanceCriteria
from scanner3d.test.aid import wd_sequence, file_stem
from scanner3d.test.tests.psf_settings import psf_settings
from scanner3d.tuner.tuner import Tuner
from lensguild.sensor import length_to_mm

log = logging.getLogger(__name__)
class AlbumRadial(CameraTest):
    test_name = "Album - Radial"
    settings = TestSettings(
        field_type=ZeModFieldTypes.RealImageHeight,
        filed_normalization=ZeModFieldNormalizationType.Rectangular,
        focus_distance_criteria=FocusDistanceCriteria.BestFocus,
        wavelength_criteria=WavelengthCriteria.Primary,
        test_field_number= 1,
        test_grid_index=0)

    def __init__(self, *,  dr_mm: Length, dz_mm: Length):
        self.dr_mm = dr_mm
        self.dz_mm = dz_mm

    def perform(self, *, zemod: ZeMod, camera: Camera3D, _output_root: Path, tuner:Tuner) -> bool:
        with tuner.tune(settings= self.settings):
            r_seq =length_to_mm(camera.sensor.grid.get_radial(self.dr_mm))
            with zemod.get_fftpsf() as psf:
                settings = psf.settings
                settings.apply(psf_settings)
                album = Album.compute(
                    tuner=tuner,
                    psf=psf,
                    grid_index=self.settings.test_grid_index,
                    x_seq=r_seq,
                    z_seq=wd_sequence(camera, self.dz_mm))
                h5_name = file_stem(camera=camera) +".h5"
                h5_path = self.output_root / h5_name
                album.save(path=h5_path, on_conflict="overwrite")
                n = album.n_frames
                log.info("Album saved with %d frames  to HDF5: %s", n, h5_path)
            ok = h5_path.exists() and h5_path.stat().st_size > 0 and n > 0
            if not ok:
                log.warning("Album output verification failed: exists=%s size=%s slices=%d",
                            h5_path.exists(),
                            (h5_path.stat().st_size if h5_path.exists() else 0),n)
            else:
                log.info("test - %s successfully done", self.test_name)
            return ok