from __future__ import annotations
import logging
from pathlib import Path
from allytools.strings import sanitize
from scanner3d.camera3d.camera3d import Camera3D
from scanner3d.zemod.zemod import ZeMod
from scanner3d.zemod.enums import ZeModFieldTypes, ZeModFieldNormalizationType
from zempy.raytracer import GenericRayTracer, TraceMethod
from zempy.zosapi.tools.raytrace.enums import RaysType
from scanner3d.test.tests.camera_test import CameraTest
from scanner3d.test.tests.test_settings import TestSettings
from scanner3d.tuner.tuner import Tuner

from scanner3d.analysis.raytracing.ray_batch import RayBatch


log = logging.getLogger(__name__)
class RayMatrix(CameraTest):
    test_name = "ray_matrix"
    settings = TestSettings(
        field_type=ZeModFieldTypes.RealImageHeight,
        filed_normalization= ZeModFieldNormalizationType.Rectangular)

    def __init__(self):
        pass

    def run(self, *, zemod: ZeMod, camera: Camera3D, output_root: Path, tuner:Tuner) -> bool:
        self.settings.edge_field_x = camera.sensor.width.value_mm/2
        self.settings.edge_field_y = camera.sensor.height.value_mm/2
        self.settings.primary_wavelength = camera.primary_wavelength
        self.settings.working_distance = camera.z_range.z_focus
        #x = camera.sensor.width_pix /8
        #y = camera.sensor.height_pix /8
        x= 100
        y =100

        with tuner.tune(settings= self.settings):
            cam_dir = sanitize(getattr(camera, "description", "camera"))
            output_dir = output_root / cam_dir / self.test_name
            output_dir.mkdir(parents=True, exist_ok=True)
            native_batch_ray_tracer = zemod.get_batch_ray_tracer()
            ray_tracer = GenericRayTracer(native_batch_ray_tracer)
            method = TraceMethod.Normal_Unpolarized
            rays_type = RaysType.Real
            to_surface = 0
            wavelength = 1
            result = ray_tracer.run(
                method=TraceMethod.Normal_Unpolarized,
                rays_type=rays_type,
                grid=(x,y),
                to_surface=to_surface,
                wave_number=wavelength)
            batch = RayBatch.collect(
                tracer=ray_tracer,
                rays=result,
                method=method,
                rays_type=rays_type,
                grid=(x, y),
                to_surface=to_surface,
                wave_number=wavelength)
            print(batch)

            return True
