"""
FFTLog Tutorial Script

This tutorial reproduces the example in appendix B of [astro-ph/9905191].
See also the webpage: https://jila.colorado.edu/~ajsh/FFTLog/

This is a standalone version extracted from the marimo notebook tutorial.py
"""

from typing import Sequence

import camb
import camb.model
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import scipy.special
from camb.sources import (
    GaussianSourceWindow,
    SourceWindow,
)

from fftloggin import Grid
from fftloggin.kernels import BesselJKernel, SphericalBesselJKernel

# Configure matplotlib
plt.rcParams["figure.dpi"] = 300
plt.rcParams["savefig.dpi"] = 300


# ============================================================================
# CAMB Setup and Matter Power Spectrum
# ============================================================================

h = 0.675
minkh = 1e-4
maxkh = 1
npoints = 512
lmax = 2000
dlog = np.log(maxkh / minkh) / (npoints - 1)
k_per_logint = int(1 / dlog)

redshift_slices = [0, 1, 2]

cosmology_params = {
    "H0": 100 * h,
    "ombh2": 0.022,
    "omch2": 0.122,
    "mnu": 0,
    "omk": 0,
    "tau": 0.06,
    "As": 2e-9,
    "ns": 0.965,
}


def update_params(
    params: camb.CAMBparams,
    lmax: int | None = None,
    kmax: float | None = None,
    dlog: float | None = None,
    redshifts: Sequence[float] | None = None,
    non_linear: bool = False,
    windows: Sequence[SourceWindow] | None = None,
) -> camb.CAMBparams:
    if lmax:
        params.set_for_lmax(lmax=lmax, nonlinear=non_linear)

    if redshifts is not None:
        params.set_matter_power(
            redshifts=redshifts,
            kmax=kmax,
            k_per_logint=int(1 / dlog),
            nonlinear=non_linear,
        )

    if windows is not None:
        params.SourceWindows = windows
    return params


params = camb.set_params(**cosmology_params)
params = update_params(
    params,
    kmax=maxkh / h,
    dlog=dlog,
    redshifts=redshift_slices,
    non_linear=False,
    windows=None,
    lmax=lmax,
)

results = camb.get_results(params)

kh, z, pk = results.get_matter_power_spectrum(minkh=minkh, maxkh=maxkh, npoints=npoints)


def plot_matter_power_spectrum():
    """Plot the matter power spectrum for different redshifts."""
    fig, ax = plt.subplots()
    for i, redshift in enumerate(redshift_slices):
        ax.loglog(kh, pk[i, :], label=f"$z={redshift}$")
    ax.set_xlabel("$k \\, \\rm{[h/Mpc]}$")
    ax.set_ylabel("$P(k) \\, \\rm{[Mpc/h]^3}$")
    ax.set_title("Matter Power Spectrum")
    ax.legend()
    ax.grid()
    fig.tight_layout()
    return fig


# ============================================================================
# The Hankel Transform
# ============================================================================


def get_xi_direct_integration(kh: np.ndarray, r: np.ndarray, pk: np.ndarray):
    """
    Compute correlation function via direct integration.

    Converting the power spectrum into the real-space correlation function ξ(r)
    amounts to taking the Fourier transform.
    """
    log_kh = np.log(kh)
    integrand = pk * kh**3 * scipy.special.spherical_jn(0, np.outer(r, kh))
    return np.trapezoid(integrand, x=log_kh, axis=-1) / (2 * np.pi**2)


def plot_correlation_function_with_bias_comparison(
    k: npt.ArrayLike,
    pk: npt.ArrayLike,
    bias: Sequence[float],
    logc: float = 0,
    minimize_ringing: bool = False,
    **kwargs,
):
    """
    Compare correlation functions computed with different bias values.

    In FFTLog, the kernel of the Bessel function is given by its Mellin transform,
    which is defined in the interval (-μ, 3/2). Since μ=1/2 here, we can test
    different values of the biased transform in the range (-3/2, 1/2).
    """
    fig, (ax, ax_residue) = plt.subplots(2, 1, **kwargs)
    k = np.asarray(k)
    pk = np.asarray(pk)
    for q in bias:
        kernel = BesselJKernel(mu=0.5, bias=q)
        grid = Grid.from_k(
            k=k, kernel=kernel, minimize_ringing=minimize_ringing, logc=logc
        )
        grid.inverse(k ** (3 / 2) * pk)
        xi_fftlog = grid.ar / (2 * np.pi * grid.r) ** (3 / 2)
        xi_direct_integration = get_xi_direct_integration(k, grid.r, pk)
        residue = np.abs(xi_fftlog / xi_direct_integration) - 1
        ax.semilogx(grid.r, xi_fftlog, label=rf"$q={q}$")
        ax_residue.loglog(grid.r, residue, label=rf"$q={q}$")

    ax.semilogx(grid.r, xi_direct_integration, label="Direct Integration")
    for _ax, ylabel in zip((ax, ax_residue), (r"$\xi(r)$", "Relative error")):
        _ax.set_xlabel("$r \\, \\rm{[Mpc/h]}$")
        _ax.set_ylabel(ylabel)
        _ax.legend()

    fig.tight_layout()
    return fig, (ax, ax_residue)


# ============================================================================
# Minimizing Ringing
# ============================================================================


def plot_correlation_function_with_optimal_center_comparison(
    k: npt.ArrayLike,
    pk: npt.ArrayLike,
    bias: float = 0,
    logc: float = 0,
    **kwargs,
):
    """
    Compare correlation functions with and without ringing minimization.

    Setting minimize_ringing=True reduces oscillations by optimizing the
    log-center parameter.
    """
    fig, (ax, ax_residue) = plt.subplots(2, 1, **kwargs)
    k = np.asarray(k)
    pk = np.asarray(pk)
    for minimize_ringing in (False, True):
        kernel = BesselJKernel(mu=0.5, bias=bias)
        grid = Grid.from_k(
            k=k, kernel=kernel, minimize_ringing=minimize_ringing, logc=logc
        )
        grid.inverse(k ** (3 / 2) * pk)
        xi_fftlog = grid.ar / (2 * np.pi * grid.r) ** (3 / 2)
        xi_direct_integration = get_xi_direct_integration(k, grid.r, pk)
        residue = np.abs(xi_fftlog / xi_direct_integration) - 1
        logc_label = grid.fftlog.logc if not minimize_ringing else r"\text{optimal}"
        label = rf"$\log (k_0 r_0) = {logc_label}$"
        ax.semilogx(grid.r, xi_fftlog, label=label)
        ax_residue.loglog(grid.r, residue, label=label)

    ax.semilogx(grid.r, xi_direct_integration, label="Direct Integration")
    for _ax, ylabel in zip((ax, ax_residue), (r"$\xi(r)$", "Relative error")):
        _ax.set_xlabel("$r \\, \\rm{[Mpc/h]}$")
        _ax.set_ylabel(ylabel)
        _ax.legend()

    fig.tight_layout()
    return fig, (ax, ax_residue)


# ============================================================================
# Projecting Observables on the Sphere
# ============================================================================


def get_nbar(z: npt.NDArray, mu: float, sigma: float):
    """Get normalized redshift distribution."""
    chi = (z - mu) / sigma
    return np.exp(-(chi**2)) / np.sqrt(2 * np.pi) / sigma


def compare_tracer_cls_with_camb(
    params: camb.CAMBparams,
    k: npt.NDArray,
    lmax: int,
    zmax: float,
    tracer_bias: float,
    sigma_relative: float,
    minimize_ringing: bool = False,
    fftlog_bias: float = 0,
    plot_kernel_ells: Sequence[int] | None = None,
    **kwargs,
):
    """
    Compare angular power spectra computed with FFTLog vs CAMB.

    Many cosmological observables are naturally defined in harmonic space.
    The angular cross-correlation is related to the primordial power spectrum
    via line-of-sight integrals involving spherical Bessel functions.
    """
    background = camb.get_background(params)
    rmax = background.comoving_radial_distance(zmax)
    rmin = rmax * (k[0] / k[-1])
    npoints = k.shape[0]
    dlog = (k[-1] / k[0]) / (npoints - 1)
    kc = np.sqrt(k[0] * k[-1])
    rc = np.sqrt(rmax * rmin)
    logc = np.log(kc * rc)
    ells = np.arange(lmax + 1)
    kernel = SphericalBesselJKernel(ell=ells.reshape(-1, 1), bias=fftlog_bias)
    grid = Grid.from_k(k=k, kernel=kernel, minimize_ringing=minimize_ringing, logc=logc)
    zmin = background.redshift_at_comoving_radial_distance(grid.r[0])
    z_linear = np.linspace(zmin, zmax, npoints)
    chi_linear = background.comoving_radial_distance(z_linear)
    z = np.interp(grid.r, chi_linear, z_linear)
    breakpoint()
    sigma = (zmax - zmin) * sigma_relative
    zmid = 0.5 * (zmin + zmax)
    nbar = get_nbar(z, mu=zmid, sigma=sigma)
    source_window = GaussianSourceWindow(
        redshift=zmid, sigma=sigma, source_type="counts", bias=tracer_bias
    )
    params = update_params(
        params,
        kmax=k[-1],
        lmax=lmax,
        redshifts=z,
        dlog=dlog,
        windows=[source_window],
    )
    breakpoint()
    results = camb.get_results(params)
    breakpoint()
    transfer = results.get_matter_transfer_data()
    transfer_k = transfer.transfer_data[camb.model.Transfer_kh - 1, :, 0] * params.h
    transfer_matter_of_z = (
        transfer.transfer_data[camb.model.Transfer_tot - 1, 0, :] * transfer_k[0] ** 2
    )
    breakpoint()
    jacobian = results.h_of_z(z)
    source_chi_of_z = nbar * transfer_matter_of_z * jacobian
    chi_of_z = results.comoving_radial_distance(z)
    source_chi = np.interp(grid.r, chi_of_z, source_chi_of_z)
    kernel = grid.forward(source_chi) / k
    integrand = kernel**2 * params.scalar_power(grid.k) / (k**3 / (2 * np.pi**2))
    cls_fftlog = 4 * np.pi * np.trapezoid(integrand, x=np.log(grid.k), axis=-1)
    cls_camb = results.get_source_cls_dict(lmax=lmax, raw_cl=True)["W1xW1"]

    plot_kernels = plot_kernel_ells is not None
    nplots = 2 + int(plot_kernels)
    fig, axs = plt.subplots(nplots, 1, figsize=(4, 3 * nplots))
    if plot_kernels:
        ax_kernel, ax_plot, ax_residue = axs.ravel()
        plot_kernel_ells_arr = np.asarray(plot_kernel_ells).reshape(-1, 1, 1)
        kchi = np.outer(k, grid.r)
        bessel = scipy.special.spherical_jn(plot_kernel_ells_arr, kchi)
        kernel_quadrature = np.trapezoid(bessel * source_chi, x=grid.r, axis=-1)
        for i, ell in enumerate(plot_kernel_ells):
            ax_kernel.semilogx(
                grid.k,
                np.abs(kernel_quadrature[i, ...]),
                label=rf"$\ell={ell}$ quadrature",
            )
            ax_kernel.set(xlabel=r"$k$", ylabel=r"$\Delta(k)$")

        ax_kernel.legend()
    else:
        ax_plot, ax_residue = axs.ravel()

    ax_plot.semilogy(ells, cls_fftlog, label="FFTLog")
    ax_plot.semilogy(ells, cls_camb, label="CAMB")
    ax_plot.set(xlabel=r"$\ell$", ylabel=r"$C_\ell$")
    ax_plot.legend()

    residue = cls_fftlog - cls_camb
    ax_residue.semilogy(ells, residue)
    ax_residue.set(xlabel=r"$\ell$", ylabel="Relative error")
    fig.tight_layout()
    return fig, axs


# ============================================================================
# Main Execution
# ============================================================================

if __name__ == "__main__":
    print("FFTLog Tutorial")
    print("=" * 80)
    print()

    print("1. Plotting matter power spectrum...")
    fig1 = plot_matter_power_spectrum()
    plt.savefig("matter_power_spectrum.png")
    print("   Saved: matter_power_spectrum.png")

    print("\n2. Comparing correlation functions with different bias values...")
    fig2, _ = plot_correlation_function_with_bias_comparison(
        kh, pk[0, :], bias=[-0.5, 0, 0.5], logc=0, minimize_ringing=False
    )
    plt.savefig("correlation_function_bias_comparison.png")
    print("   Saved: correlation_function_bias_comparison.png")

    print("\n3. Comparing with and without ringing minimization...")
    fig3, _ = plot_correlation_function_with_optimal_center_comparison(
        kh,
        pk[0, :],
        bias=0,
        logc=0,
    )
    plt.savefig("correlation_function_ringing_comparison.png")
    print("   Saved: correlation_function_ringing_comparison.png")

    print("\n4. Comparing tracer angular power spectra (FFTLog vs CAMB)...")
    k = np.geomspace(minkh, maxkh, min(npoints, 256)) / h
    fig4, _ = compare_tracer_cls_with_camb(
        params,
        k,
        lmax,
        zmax=1,
        tracer_bias=1,
        sigma_relative=0.01,
        minimize_ringing=False,
        plot_kernel_ells=[50],
        **cosmology_params,
    )
    plt.savefig("tracer_cls_comparison.png")
    print("   Saved: tracer_cls_comparison.png")

    print("\n" + "=" * 80)
    print("Tutorial complete! All figures saved.")
    print("\nTo display the figures interactively, uncomment plt.show() below:")
    # plt.show()
