#!/usr/bin/env python
"""Debug script for test_simple_gaussian_transform failure."""

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import spherical_jn

from fftloggin.cosmology import RadialIntegrator

# Recreate the test setup
chi = np.logspace(0, 3, 4096)
n = 64
chi0 = 100.0
sigma = 30.0
s = np.exp(-(((chi - chi0) / sigma) ** 2))

# Test multiple ells
ells = [10, 100]
colors = ["blue", "green", "red"]

print(f"chi range: [{chi.min():.3f}, {chi.max():.3f}]")
print(f"chi0 (center): {chi0}")
print(f"s max value: {s.max():.6f} at chi={chi[np.argmax(s)]:.3f}")
print()

# Store results for each ell
results = {}
for ell in ells:
    integrator = RadialIntegrator(chi, s, ell, n=n, recenter=True, lowring=True)
    result = integrator.result
    k = integrator.k
    print(integrator.chi_mask.astype(int).sum() / len(chi))
    mask = integrator.chi_mask[::-1]
    kchi = np.outer(k[mask], chi)
    bessel = spherical_jn(ell, kchi)
    result_quad = np.trapezoid(s * bessel, x=chi, axis=-1)

    results[ell] = {
        "chi": integrator.chi[integrator.chi_mask],
        "s_resampled": integrator.s_resampled[integrator.chi_mask],
        "result": result[mask],
        "result_quad": result_quad,
        "k": k[mask],
        "peak_idx": np.argmax(np.abs(result[mask])),
    }

    print(f"ell = {ell}:")
    print(f"  k range: [{k.min():.6f}, {k.max():.6f}]")
    print(f"  result max: {np.abs(result).max():.6e}")
    print(f"  result peak index: {results[ell]['peak_idx']}")
    print(f"  result peak k value: {k[results[ell]['peak_idx']]:.6f}")
    print()

# Create plots
nplots = len(ells) + 1
fig, axes = plt.subplots(nplots, 1, figsize=(4, 3 * nplots))

# Plot 1: Source function s vs chi
axes[0].plot(chi, s, "k-", linewidth=2)
axes[0].axvline(chi0, color="r", linestyle="--", label=f"chi0 = {chi0}", alpha=0.7)
axes[0].set_xlabel("chi (comoving distance)", fontsize=12)
axes[0].set_ylabel("s (source function)", fontsize=12)
axes[0].set_title("Source Function (Gaussian Window)", fontsize=14)
axes[0].set_xscale("log")
axes[0].grid(True, alpha=0.3)

# Plot 2: Results vs k for all ells
for i, ell in enumerate(ells):
    k = results[ell]["k"]
    chi_resampled = results[ell]["chi"]
    s_resampled = results[ell]["s_resampled"]
    result = results[ell]["result"]
    result_quad = results[ell]["result_quad"]
    peak_idx = results[ell]["peak_idx"]

    axes[0].plot(
        chi_resampled,
        s_resampled,
        ls="--",
        linewidth=4,
        color=colors[i],
        label=f"ell = {ell}, npts = {len(chi_resampled)}",
        alpha=0.8,
    )
    axes[i + 1].plot(
        k,
        np.abs(result),
        ls="-",
        linewidth=2,
        color=colors[i],
        label=f"ell = {ell} (FFTLog)",
        alpha=0.8,
    )
    axes[i + 1].plot(
        k,
        np.abs(result_quad),
        ls="--",
        linewidth=2,
        color=colors[i],
        label=f"ell = {ell} (Quadrature)",
        alpha=0.8,
    )
    axes[i + 1].plot(
        k[peak_idx],
        np.abs(result[peak_idx]),
        "o",
        color=colors[i],
        markersize=8,
        alpha=0.8,
    )

    axes[i + 1].set_xlabel("k (wavenumber)", fontsize=12)
    axes[i + 1].set_ylabel("|Δ_ell(k)| (radial integral)", fontsize=12)
    axes[i + 1].set_title("Radial Integral Results for Different ells", fontsize=14)
    axes[i + 1].set_xscale("log")
    # axes[i + 1].set_yscale("log")
    axes[i + 1].grid(True, alpha=0.3)
    axes[i + 1].legend()

axes[0].legend()
plt.tight_layout()
plt.savefig("debug_gaussian_test.png", dpi=150, bbox_inches="tight")
print("Plot saved to: debug_gaussian_test.png")
plt.show()
