"""Functions for computing GW power spectra"""

import functools
import logging
import typing as tp

import matplotlib.pyplot as plt
import numpy as np

import pttools.type_hints as th
from pttools.bubble import Bubble, Phase
from pttools.ssm import const
from pttools.ssm.nucleation import NucType, DEFAULT_NUC_TYPE
from pttools.ssm.spec_den_gw import gen_lookup, spec_den_gw_scaled
from pttools.ssm.spec_den_v import spec_den_v
from pttools.ssm.low_k import power_spectrum_integration_low, power_spectrum_integration_int, pow_gw_junction

if tp.TYPE_CHECKING:
    from pttools.analysis.utils import FigAndAxes

logger = logging.getLogger(__name__)


class SSMSpectrum:
    """Gravitational wave simulation object"""
    def __init__(
            self,
            bubble: Bubble,
            y: tp.Union[np.ndarray[tuple[int], np.float64], None] = None,
            z_st_thresh: float = const.Z_ST_THRESH,
            nuc_type: NucType = DEFAULT_NUC_TYPE,
            nt: int = const.NTDEFAULT,
            n_z_lookup: int = const.N_Z_LOOKUP_DEFAULT,
            r_star: float | None = None,
            # eta_star: float = 1,
            lifetime_multiplier: float = 1,
            compute: bool = True,
            low_k: bool = True,
            label_latex: str | None = None,
            label_unicode: str | None = None):
        r"""
        :param bubble: the Bubble object
        :param y: $z = kR*$ array
        :param z_st_thresh: for $z$ values above z_sh_tresh,
            use approximation rather than doing the sine transform integral.
        :param nuc_type: nucleation type
        :param nt: number of points in the t array
        :param r_star: $r_*$
        :param lifetime_multiplier: used for computing the source lifetime factor
        :param compute: whether to compute the spectrum immediately
        """
        if y is None:
            self.y = const.Y_DEFAULT
        elif np.isnan(y).any():
            raise ValueError("y must not contain nan values.")
        else:
            self.y = y

        # Parameters
        self.bubble = bubble
        # self.de_method = de_method
        # self.method = method
        self.nuc_type = nuc_type
        self.z_st_thresh = z_st_thresh
        self.nt = nt
        self.n_z_lookup = n_z_lookup
        self.r_star = r_star
        # Todo: Make this an adjustable input parameter
        self.eta_star = 1
        self.lifetime_multiplier = lifetime_multiplier
        self.low_k = low_k
        label_suffix_latex = "" if r_star is None else f", r_*={r_star}$"
        label_suffix_unicode = "" if r_star is None else f", r⁎={r_star}"
        self.label_latex = self.bubble.label_latex[:-1] + label_suffix_latex \
            if label_latex is None else label_latex
        self.label_unicode = self.bubble.label_unicode + label_suffix_unicode \
            if label_unicode is None else label_unicode

        if not (self.r_star is None or np.isnan(r_star)):
            if self.r_star <= 0:
                raise ValueError("r_star must be positive. Got r_star={r_star}.")
            if self.r_star >= 1:
                # Todo: Find a better reference for this.
                logger.warning(
                    "r_star < 1 is required for the phase transition to complete. "
                    "Got r_star=%s. See Hindmarsh & Hijazi 2019, p. 6.",
                    self.r_star
                )

        # Values generated by compute()
        #: $|A(z)|^2$
        self.a2: tp.Optional[np.ndarray] = None
        #: $c_s({T}_\text{gw})$
        self.cs: tp.Optional[float] = None
        #: $P_v(y)$
        self.spec_den_v: tp.Optional[np.ndarray] = None
        #: $P_v({z}_\text{lookup})$
        self.spec_den_v_lookup: tp.Optional[np.ndarray] = None
        #: Spectral density of scaled gravitational wave power
        self.spec_den_gw: tp.Optional[np.ndarray] = None
        #: $\mathcal{P}_{\tilde{v}}(q)$
        self.pow_v: tp.Optional[np.ndarray] = None
        #: $\mathcal{P}_{\text{gw}}(k)$
        self.pow_gw: tp.Optional[np.ndarray] = None
        #: $\mathcal{P}_{\text{gw}}(k)$, expanded to low frequencies
        self.pow_gw_expanded: tp.Optional[np.ndarray] = None
        #: $\mathcal{P}_{\text{gw}}(k)$ for intermediate frequencies
        self.pow_gw_int: tp.Optional[np.ndarray] = None
        #: $\mathcal{P}_{\text{gw}}(k)$ for low frequencies
        self.pow_gw_low: tp.Optional[np.ndarray] = None
        #: $\mathcal{P}_{\text{gw}}(k)$ using the Sound Shell Model (SSM) without the low-k approximation
        self.pow_gw_ssm: tp.Optional[np.ndarray] = None
        #: $z_\text{lookup}$
        self.z_lookup: tp.Optional[np.ndarray] = None

        if compute:
            self.compute()

    def beta(self, H_n: th.FloatOrArr) -> th.FloatOrArr:
        r"""Nucleation rate parameter $\beta$
        $$\beta = (8 \pi)^\frac{1}{3} \frac{{v}_\text{wall}}{{R}_*}$$
        :gw_pt_ssm:`\ ` eq. 4.16, A.14
        :notes:`\ ` eq. 7.21

        Simultaneous nucleation only!
        """
        return self.beta_tilde() * H_n

    def beta_tilde(self):
        r"""Nucleation rate parameter $\tilde{\beta}$, also known as "beta over H"
        $$\tilde{\beta} \equiv \frac{\beta}{{H}_n} = (8 \pi)^\frac{1}{3} \frac{{v}_\text{wall}}{{r}_*}$$
        :gowling_2021:`\ ` eq. 2.1

        Simultaneous nucleation only!
        """
        return (8*np.pi)**(1/3) * self.bubble.v_wall / self.r_star

    def compute(self, eps_lookup: float = 1e-8, lifetime_distribution_a: float = 1.):
        if not self.bubble.solved:
            self.bubble.solve()

        self.cs = np.sqrt(self.bubble.model.cs2(self.bubble.va_enthalpy_density, Phase.BROKEN))
        self.spec_den_v, self.a2 = spec_den_v(
            bub=self.bubble, z=self.y, a=lifetime_distribution_a,
            nuc_type=self.nuc_type, nt=self.nt, z_st_thresh=self.z_st_thresh, cs=self.cs, return_a2=True
        )
        self.pow_v = pow_spec(self.y, spec_den=self.spec_den_v)

        self.z_lookup = gen_lookup(y=self.y, cs=self.cs, n_z_lookup=self.n_z_lookup, eps=eps_lookup)
        self.spec_den_v_lookup = spec_den_v(
            bub=self.bubble, z=self.z_lookup, a=lifetime_distribution_a,
            nuc_type=self.nuc_type, nt=self.nt, z_st_thresh=self.z_st_thresh, cs=self.cs
        )
        self.spec_den_gw, y = spec_den_gw_scaled(
            z_lookup=self.z_lookup, P_v_lookup=self.spec_den_v_lookup, y=self.y, cs=self.cs,
            source_lifetime_factor=self.source_lifetime_factor
        )
        self.pow_gw_ssm = pow_spec(self.y, spec_den=self.spec_den_gw)
        self.pow_gw = self.pow_gw_ssm

        if self.r_star is not None:
            # Lorenzo 2024
            eta_end = self.eta_star + self.Htau_nl
            tau_star = self.eta_star / self.Lf
            tau_end = eta_end / self.Lf

            spec_den_low = 4/3 * power_spectrum_integration_low(
                x_data=self.y, Pv_data=self.spec_den_v,
                z=self.y, cs=self.cs, nu=self.bubble.nu_gdh2024,
                tau_star=tau_star, tau_end=tau_end
            )
            spec_den_int = 4/3 * power_spectrum_integration_int(
                x_data=self.z_lookup, Pv_data=self.spec_den_v_lookup,
                z=self.y, cs=self.cs, tau_star=tau_star
            )

            factor = self.r_star * self.Htau_nl
            self.pow_gw_low = pow_spec(self.y, spec_den_low)
            self.pow_gw_int = pow_spec(self.y, spec_den_int)
            self.pow_gw_expanded = pow_gw_junction(
                z=self.y,
                Pgw_low=self.pow_gw_low * factor,
                Pgw_int=self.pow_gw_int * factor,
                Pgw_high=self.pow_gw_ssm * factor,
                cs=self.cs, nu=self.bubble.nu_gdh2024, tau_star=tau_star, tau_end=tau_end,
                HLf=self.r_star
            ) / factor
            if self.low_k:
                self.pow_gw = self.pow_gw_expanded

    @functools.cached_property
    def Htau_nl(self):
        r"""$H \tau_\text{nl}$"""
        return self.r_star / self.bubble.ubarf

    @functools.cached_property
    def k_peak(self):
        """Peak wavenumber $k_\text{peak}$

        Lorenzo
        """
        return 2 * np.pi / self.r_star * 2 / (1 + 3 * self.bubble.omega_barotropic)

    @functools.cached_property
    def Lf(self):
        """Length-scale of the fluid $L_f$

        Lorenzo
        """
        return 2 * np.pi / self.k_peak

    @functools.cached_property
    def source_lifetime_factor(self) -> float:
        r"""
        Source lifetime correction factor
        $$\frac{1}{1 + 2\nu} \left(1 - \left(1 + \frac{\Delta \eta}{\eta_*} \right) \right)^{-1-2\nu}$$
        where
        $$\frac{\Delta \eta}{\eta_*} = \lambda \frac{2 r_*}{\sqrt{K}}$$
        :giombi_2024_cs:`\ `, eq. 3.13
        """
        ret = 1 / (1 + 2*self.bubble.nu_gdh2024)
        if self.r_star is not None and not np.isinf(self.lifetime_multiplier):
            ret *= (
                1 - (1 + self.lifetime_multiplier * 2 * self.r_star / np.sqrt(self.bubble.kinetic_energy_fraction)) **
                (-1 - 2*self.bubble.nu_gdh2024)
            )
        return ret

    # Plotting

    def plot(
            self,
            fig: plt.Figure | None = None,
            ax: plt.Axes | None = None,
            path: str | None = None,
            **kwargs) -> "FigAndAxes":
        r"""Plot GW power spectrum $\mathcal{P}_{\text{gw}}(k)$"""
        return self.plot_gw(fig, ax, path, **kwargs)

    def plot_gw(
            self,
            fig: plt.Figure | None = None,
            ax: plt.Axes | None = None,
            path: str | None = None,
            **kwargs) -> "FigAndAxes":
        r"""Plot GW power spectrum $\mathcal{P}_{\text{gw}}(k)$"""
        from pttools.analysis.plot_spectra import plot_spectra_gw
        return plot_spectra_gw([self], fig, ax, path, **kwargs)

    def plot_v(
            self,
            fig: plt.Figure | None = None,
            ax: plt.Axes | None = None,
            path: str | None = None,
            **kwargs) -> "FigAndAxes":
        r"""Plot velocity power spectrum $\mathcal{P}_{\tilde{v}}(q)$"""
        from pttools.analysis.plot_spectra import plot_spectra_v
        return plot_spectra_v([self], fig, ax, path, **kwargs)

    def plot_spec_den_gw(
            self,
            fig: plt.Figure | None = None,
            ax: plt.Axes | None = None,
            path: str | None = None,
            **kwargs) -> "FigAndAxes":
        """Plot spectral density of scaled GW power"""
        from pttools.analysis.plot_spectra import plot_spectra_spec_den_gw
        return plot_spectra_spec_den_gw([self], fig, ax, path, **kwargs)

    def plot_spec_den_v(
            self,
            fig: plt.Figure | None = None,
            ax: plt.Axes | None = None,
            path: str | None = None,
            **kwargs) -> "FigAndAxes":
        """Plot spectral density of the velocity field $P_v(y)$"""
        from pttools.analysis.plot_spectra import plot_spectra_spec_den_v
        return plot_spectra_spec_den_v([self], fig, ax, path, **kwargs)


def pow_spec(z: th.FloatOrArr, spec_den: th.FloatOrArr) -> th.FloatOrArr:
    r"""
    Power spectrum from spectral density at dimensionless wavenumber z.
    $$\mathcal{P}(z) = \frac{z^3}{2 \pi^2} \tilde{P}(z)$$

    :gw_pt_ssm:`\ ` eq. 4.18, but without the factor of 2.
    :gowling_2021:`\ ` eq. 2.14, but without the factor of $3K^2$

    :param z: dimensionless wavenumber $z$
    :param spec_den: spectral density
    :return: power spectrum
    """
    return z**3 / (2. * np.pi ** 2) * spec_den
