from __future__ import annotations
import logging
from pathlib import Path
from allytools.units import Length
from scanner3d.camera3d.camera3d import Camera3D
from scanner3d.zemod.zemod import ZeMod
from scanner3d.zemod.enums.enums import ZeModFieldTypes, ZeModFieldNormalizationType
from scanner3d.tuner.tuner import Tuner
from scanner3d.test.tests.test_settings import TestSettings, WavelengthCriteria, FocusDistanceCriteria
from scanner3d.analysis.psf.album import Album
from scanner3d.test.tests.camera_test import CameraTest
from scanner3d.test.aid import wd_triplet, file_stem
from scanner3d.test.tests.psf_settings import psf_settings
from lensguild.sensor import length_to_mm

log = logging.getLogger(__name__)
class MajorFrames(CameraTest):
    test_name = "Major Frames"
    settings = TestSettings(
        field_type=ZeModFieldTypes.RealImageHeight,
        filed_normalization= ZeModFieldNormalizationType.Rectangular,
        focus_distance_criteria=FocusDistanceCriteria.BestFocus,
        wavelength_criteria=WavelengthCriteria.Primary,
        test_field_number= 1,
        test_grid_index=0)

    def __init__(self, n_samples: int):
        self.n_samples = n_samples

    def perform(self, *, zemod: ZeMod, camera: Camera3D, _output_root: Path, tuner:Tuner) -> bool:
        with tuner.tune(settings=self.settings):
            with zemod.get_fftpsf() as psf:
                settings = psf.settings
                settings.apply(psf_settings)
                #TODO check desired units
                x_seq = length_to_mm(camera.sensor.grid.x1d_n(self.n_samples))
                y_seq = length_to_mm(camera.sensor.grid.y1d_n(self.n_samples))
                album = Album.compute(
                    tuner=tuner,
                    psf=psf,
                    grid_index = self.settings.test_grid_index,
                    x_seq=x_seq,
                    y_seq=y_seq,
                    z_seq=wd_triplet(camera))
                for f in album.frames:
                    image_bundle = f.plot()
                    png_path = self.output_root / (file_stem(camera=camera, wd_mm=f.z) + ".png")
                    image_bundle.save(str(png_path))
                    log.info("Frame saved as image: %s", png_path)
                h5_name = file_stem(camera=camera) +".h5"
                h5_path = self.output_root / h5_name
                album.save(path=h5_path, on_conflict="overwrite")
                log.info("Album saved with %d frames  to HDF5: %s", album.n_frames, h5_path)
            pngs = list(self.output_root.glob("*.png"))
            return bool(pngs and h5_path.exists())