"""Model defined by data arrays loaded from an HDF5 file"""

import os.path
import typing as tp

from h5py import File
import numba
from numba.extending import overload
import numpy as np
from scipy.interpolate import splev, splrep

from pttools.bubble.boundary import Phase
from pttools.models.model import Model
import pttools.type_hints as th
from pttools.speedup.overload import np_all_fix

Float1DArray = np.ndarray[tuple[int], np.float64]


class DataModel(Model):
    def __init__(
            self,
            T_s: Float1DArray,
            T_b: Float1DArray,
            p_s: Float1DArray,
            p_b: Float1DArray,
            e_s: Float1DArray,
            e_b: Float1DArray,
            cs2_s: tp.Union[Float1DArray, None] = None,
            cs2_b: tp.Union[Float1DArray, None] = None,
            T_crit: float | None = None,
            T_nucl: float | None = None,
            T_min: float | None = None,
            T_max: float | None = None,
            T_is_physical: bool = False,
            name: str | None = None
            ):
        r"""Create a model from data arrays
        :param T_s: Temperatures $T_s$ in the symmetric phase
        :param T_b: Temperatures $T_b$ in the broken phase
        :param p_s: Pressures $p_s$ in the symmetric phase
        :param p_b: Pressures $p_b$ in the broken phase
        :param e_s: Energy densities $e_s$ in the symmetric phase
        :param e_b: Energy densities $e_b$ in the broken phase
        :param cs2_s: Sound speed squared $c_s^2$ in the symmetric phase
        :param cs2_b: Sound speed squared $c_b^2$ in the broken phase
        :param T_crit: Critical temperature $T_\text{c}$
        :param T_nucl: Nucleation temperature $T_\text{n}$
        :param T_min: Minimum temperature $T_\text{min}$ for which the model is valid
        :param T_max: Maximum temperature $T_\text{max}$ for which the model is valid
        :param T_is_physical: Whether the temperature is in physical units
        :param name: Name of the model
        """
        self.data_T_s = T_s
        self.data_T_b = T_b
        self.data_p_s = p_s
        self.data_p_b = p_b
        self.data_e_s = e_s
        self.data_e_b = e_b
        self.data_cs2_s = cs2_s
        self.data_cs2_b = cs2_b

        self.data_w_s = self.data_p_s + self.data_e_s
        self.data_w_b = self.data_p_b + self.data_e_b
        # self.data_s_s = self.data_w_s / self.data_temp
        # self.data_s_b = self.data_w_b / self.data_temp

        self.data_T_s_log = np.log10(self.data_T_s)
        self.data_T_b_log = np.log10(self.data_T_b)

        self.spline_p_s = splrep(self.data_T_s_log, self.data_p_s, k=1)
        self.spline_p_b = splrep(self.data_T_b_log, self.data_p_b, k=1)
        self.spline_e_s = splrep(self.data_T_s_log, self.data_e_s, k=1)
        self.spline_e_b = splrep(self.data_T_b_log, self.data_e_b, k=1)
        self.spline_temp_s = splrep(self.data_w_s, self.data_T_s_log, k=1)
        self.spline_temp_b = splrep(self.data_w_b, self.data_T_b_log, k=1)

        super().__init__(
            T_min=np.min(self.data_T_b) if T_min is None else T_min,
            T_max=np.max(self.data_T_s) if T_max is None else T_max,
            T_crit=T_crit,
            implicit_V=True,
            name=name,
            temperature_is_physical=T_is_physical,
            label_latex=name,
            label_unicode=name,
            gen_cs2=False,
            gen_cs2_neg=False,
            gen_critical=False
        )
        self.cs2 = self.gen_cs2()

    @classmethod
    def from_hdf5(cls, path: str, name: str | None = None, T_is_physical: bool = False) -> "DataModel":
        """Create a model from an HDF5 file generated by e.g. WallGo"""
        with File(path, "r") as file:
            if name is None:
                name = (
                    file.attrs["model_label"]
                    if "model_label" in file.attrs
                    else os.path.splitext(os.path.basename(path))[0]
                )
            T_crit = file.attrs["critical_temperature"] if "critical_temperature" in file.attrs else None
            T_nucl = file.attrs["nucleation_temperature"] if "nucleation_temperature" in file.attrs else None
            if "high_temperature_phase" in file and "low_temperature_phase" in file:
                phase_s = file["high_temperature_phase"]
                phase_b = file["low_temperature_phase"]
                T_min = (
                    phase_b.attrs["min_possible_temperature"]
                    if "min_possible_temperature" in phase_b.attrs else None
                )
                T_max = (
                    phase_s.attrs["max_possible_temperature"]
                    if "max_possible_temperature" in phase_s.attrs else None
                )
                return cls(
                    T_s=phase_s["temperature"][:],
                    T_b=phase_b["temperature"][:],
                    p_s=phase_s["pressure"][:],
                    p_b=phase_b["pressure"][:],
                    e_s=phase_s["energy_density"][:],
                    e_b=phase_b["energy_density"][:],
                    cs2_s=phase_s["sound_speed_squared"][:],
                    cs2_b=phase_b["sound_speed_squared"][:],
                    T_crit=T_crit,
                    T_nucl=T_nucl,
                    T_min=T_min,
                    T_max=T_max,
                    T_is_physical=T_is_physical,
                    name=name
                )
            raise KeyError(
                "The model data was not found in the HDF5 file. "
                f"These are the file contents: {file.keys()}"
            )

    # @classmethod
    # def get_attr(cls, file: h5py.File, name: str, value, default):
    #     if value is not None:
    #         return value
    #     if name in file.attrs:
    #         return file.attrs[name]
    #     return default

    @staticmethod
    def interpolate(spline_s, spline_b, x: th.FloatOrArr, phase: th.FloatOrArr) -> th.FloatOrArr:
        """Interpolate between the splines of the two phases"""
        return \
            splev(x, spline_s) * phase + \
            splev(x, spline_b) * (1 - phase)

    @classmethod
    def interpolate_temp(cls, spline_s, spline_b, temp: th.FloatOrArr, phase: th.FloatOrArr) -> th.FloatOrArr:
        """Interpolate between the splines of the two phases in the given temperatures"""
        return \
            splev(np.log10(temp), spline_s) * phase + \
            splev(np.log10(temp), spline_b) * (1 - phase)

    def gen_cs2(self):
        # T_min = self.T_min
        # T_max = self.T_max
        w_min = self.w_min
        w_max = self.w_max
        data_w_s = self.data_p_s + self.data_e_s
        data_w_b = self.data_p_b + self.data_e_b
        # spline_cs2_s = splrep(self.data_T_s_log, self.data_cs2_s, k=1)
        # spline_cs2_b = splrep(self.data_T_b_log, self.data_cs2_b, k=1)
        spline_cs2_w_s = splrep(data_w_s, self.data_cs2_s, k=1)
        spline_cs2_w_b = splrep(data_w_b, self.data_cs2_b, k=1)

        @numba.njit
        def cs2_compute(w: th.FloatOrArr, phase: th.FloatOrArr) -> th.FloatOrArr:
            if np_all_fix(phase == Phase.SYMMETRIC.value):
                return splev(w, spline_cs2_w_s)
            if np_all_fix(phase == Phase.BROKEN.value):
                return splev(w, spline_cs2_w_b)
            return splev(w, spline_cs2_w_s) * phase \
                + splev(w, spline_cs2_w_b) * (1 - phase)

        @numba.njit
        def cs2_scalar(w: float, phase: th.FloatOrArr) -> th.FloatOrArr:
            if w < w_min or w > w_max:
                return np.nan
            return cs2_compute(w, phase)

        @numba.njit
        def cs2_arr(w: np.ndarray, phase: th.FloatOrArr) -> np.ndarray:
            # This check somehow fixes a compilation bug in Numba 0.60.0
            if np.isscalar(w):
                raise TypeError
            invalid = np.logical_or(w < w_min, w > w_max)
            if np.any(invalid):
                temp2 = w.copy()
                temp2[invalid] = np.nan
            return cs2_compute(w, phase)

        def cs2(w: th.FloatOrArr, phase: th.FloatOrArr) -> th.FloatOrArr:
            if isinstance(w, float):
                return cs2_scalar(w, phase)
            if isinstance(w, np.ndarray):
                if not w.ndim:
                    return cs2_scalar(w.item(), phase)
                return cs2_arr(w, phase)
            raise TypeError(f"Unknown type for w: {type(w)}")

        @overload(cs2, jit_options={"nopython": True})
        def cs2_numba(w: th.FloatOrArr, phase: th.FloatOrArr) -> th.FloatOrArrNumba:
            if isinstance(w, numba.types.Float):
                return cs2_scalar
            if isinstance(w, numba.types.Array):
                return cs2_arr
            raise TypeError(f"Unknown type for w: {type(w)}")

        return cs2

    def params_str(self) -> str:
        return self.name

    def p_temp(self, temp: th.FloatOrArr, phase: th.FloatOrArr) -> th.FloatOrArr:
        return self.interpolate_temp(self.spline_p_s, self.spline_p_b, temp, phase)

    def e_temp(self, temp: th.FloatOrArr, phase: th.FloatOrArr) -> th.FloatOrArr:
        return self.interpolate_temp(self.spline_e_s, self.spline_e_b, temp, phase)

    def temp(self, w: th.FloatOrArr, phase: th.FloatOrArr) -> th.FloatOrArr:
        return \
            10**splev(w, self.spline_temp_s) * phase + \
            10**splev(w, self.spline_temp_b) * (1 - phase)

    def w(self, temp: th.FloatOrArr, phase: th.FloatOrArr) -> th.FloatOrArr:
        return self.p_temp(temp, phase) + self.e_temp(temp, phase)
