"""Base class for equation of state models and thermodynamics models"""

import abc
import datetime
import logging
import typing as tp

import numpy as np

import pttools.type_hints as th

logger = logging.getLogger(__name__)


class BaseModel(abc.ABC):
    """The base for both Model and ThermoModel

    All temperatures must be in units of GeV for the frequency conversion in Spectrum to work.
    """
    DEFAULT_LABEL_LATEX: str | None = None
    DEFAULT_LABEL_UNICODE: str | None = None
    DEFAULT_NAME: str | None = None
    # Zero temperature would break many of the equations
    DEFAULT_T_MIN: float = 1e-3
    DEFAULT_T_MAX: float = np.inf

    #: Whether the temperature is in proper physics units
    TEMPERATURE_IS_PHYSICAL: bool | None = None

    #: String formatting for thermodynamical quantities
    THERMO_FORMAT: str = "6e"

    def __init__(
            self,
            # Basic info
            name: str | None = None,
            label_latex: str | None = None,
            label_unicode: str | None = None,
            # Numerical values
            T_min: float | None = None,
            T_max: float | None = None,
            # Booleans
            restrict_to_valid: bool = True,
            gen_cs2: bool = True,
            gen_cs2_neg: bool = True,
            temperature_is_physical: bool | None = None,
            silence_temp: bool = False):
        self.name: str = self.DEFAULT_NAME if name is None else name
        self.label_latex: str = self.DEFAULT_LABEL_LATEX if label_latex is None else label_latex
        self.label_unicode: str = self.DEFAULT_LABEL_UNICODE if label_unicode is None else label_unicode
        self.T_min: float = self.DEFAULT_T_MIN if T_min is None else T_min
        self.T_max: float = self.DEFAULT_T_MAX if T_max is None else T_max
        self.silence_temp: bool = silence_temp
        self.restrict_to_valid: bool = restrict_to_valid
        self.temperature_is_physical: bool = self.TEMPERATURE_IS_PHYSICAL \
            if temperature_is_physical is None else temperature_is_physical

        if self.name is None:
            raise ValueError("The model must have a name.")
        if " " in self.name:
            logger.warning(
                "Model names should not have spaces to ensure that the file names don't cause problems. "
                "Got: \"%s\".",
                self.name
            )
        if not (self.label_latex and self.label_unicode):
            raise ValueError("The model must have labels.")
        if "$" in self.label_unicode:
            logger.warning(
                "The Unicode label of a model should not contain \"$\". Got: \"%s\"",
                self.label_unicode
            )
        if self.T_min <= 0:
            raise ValueError(f"T_min should be larger than zero. Got: {self.T_min}")
        if self.T_max <= self.T_min:
            raise ValueError(f"T_max ({self.T_max}) should be higher than T_min ({self.T_min}).")

        if gen_cs2:
            self.cs2 = self.gen_cs2()
        if gen_cs2_neg:
            self.cs2_neg = self.gen_cs2_neg()

    # Concrete methods

    def export(self) -> dict[str, tp.Any]:
        """Export the model parameters to a dictionary. User-created model classes should extend this."""
        return {
            # Basic info
            "name": self.name,
            "label_latex": self.label_latex,
            "label_unicode": self.label_unicode,
            "datetime": datetime.datetime.now(),
            # Numerical values
            "T_min": self.T_min,
            "T_max": self.T_max,
            # Booleans
            "restrict_to_valid": self.restrict_to_valid,
            "silence_temp": self.silence_temp,
            "temperature_is_physical": self.temperature_is_physical
        }

    def gen_cs2(self) -> th.CS2Fun:
        r"""This function generates a Numba-jitted $c_s^2$ function for the model."""
        raise NotImplementedError("This class does not have gen_cs2 defined")

    def gen_cs2_neg(self) -> th.CS2Fun:
        r"""This function generates a negative version of
        the Numba-jitted $c_s^2$ function to be used for maximisation.
        """
        raise NotImplementedError("This class does not have gen_cs2_neg defined")

    def info(self) -> str:
        """Get a string with information about the model."""
        data = self.export()
        max_key_length = max(len(key) for key in data) + 1
        return "\n".join(
            f"{key:<{max_key_length}}: {f'{value:{self.THERMO_FORMAT}}' if isinstance(value, float) else value}"
            for key, value in self.export().items()
        )

    def validate_temp(self, temp: th.FloatOrArr) -> th.FloatOrArr:
        """Validate that the given temperatures are in the validity range of the model.

        If invalid values are found, a copy of the array is created where those are set to np.nan.
        """
        if np.isscalar(temp):
            if temp < self.T_min:
                if not self.silence_temp:
                    logger.warning(
                        "The temperature %s is below the minimum temperature %s of the model \"%s\".",
                        temp, self.T_min, self.name
                    )
                if self.restrict_to_valid:
                    return np.nan
            elif temp > self.T_max:
                if not self.silence_temp:
                    logger.warning(
                        "The temperature %s is above the maximum temperature %s of the model \"%s\".",
                        temp, self.T_max, self.name
                    )
                if self.restrict_to_valid:
                    return np.nan
        else:
            below = temp < self.T_min
            above = temp > self.T_max
            has_below = np.any(below)
            has_above = np.any(above)
            if self.restrict_to_valid and (has_below or has_above):
                temp = np.copy(temp)
            if has_below:
                if not self.silence_temp:
                    logger.warning(
                        "Some temperatures (%s and possibly above) "
                        "are below the minimum temperature %s of the model \"%s\".",
                        np.min(temp), self.T_min, self.name
                    )
                if self.restrict_to_valid:
                    temp[below] = np.nan
            if has_above:
                if not self.silence_temp:
                    logger.warning(
                        "Some temperatures (%s and possibly above) "
                        "are above the maximum temperature %s of the model \"%s\".",
                        np.max(temp), self.T_max, self.name
                    )
                if self.restrict_to_valid:
                    temp[above] = np.nan
        return temp

    # Abstract methods

    @abc.abstractmethod
    def cs2(self, *args, **kwargs) -> th.FloatOrArr:  # pylint: disable=method-hidden
        """Speed of sound squared $c_s^2$"""

    @abc.abstractmethod
    def cs2_neg(self, *args, **kwargs) -> th.FloatOrArr:  # pylint: disable=method-hidden
        """Speed of sound squared with a minus sign, $-c_s^2$. This is needed for finding the maximum of $c_s^2$."""
