import marimo

__generated_with = "0.17.0"
app = marimo.App(width="medium")


@app.cell
def _():
    import marimo as mo
    return (mo,)


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
    # Using FFTLog

    In this tutorial, we will reproduce the example in appendix B of [[astro-ph/9905191]](https://arxiv.org/abs/astro-ph/9905191). See also the [webpage](https://jila.colorado.edu/~ajsh/FFTLog/).

    ## Running the notebook

    If you use the `uv` package manager, you can run the notebook from the main directory as:

    ```shell
    uv run --group tutorial marimo edit notebooks/tutorial.py
    ```

    See the [`uv` docs](https://docs.astral.sh/uv/guides/integration/jupyter/#using-jupyter-within-a-project) for more details.
    """
    )
    return


@app.cell
def _():
    from typing import Sequence

    import camb
    import camb.model
    import matplotlib.pyplot as plt
    import numpy as np
    import numpy.typing as npt
    import scipy.special
    import scipy.interpolate
    from camb.sources import (
        GaussianSourceWindow,
        SourceWindow,
    )

    from fftloggin import FFTLog, Grid
    from fftloggin.grids import infer_logc, get_other_array
    from fftloggin.kernels import BesselJKernel, SphericalBesselJKernel

    # magic command not supported in marimo; please file an issue to add support
    # %config InlineBackend.figure_format = 'retina'
    plt.rcParams["figure.dpi"] = 300
    plt.rcParams["savefig.dpi"] = 300
    return (
        BesselJKernel,
        FFTLog,
        GaussianSourceWindow,
        Grid,
        Sequence,
        SourceWindow,
        SphericalBesselJKernel,
        camb,
        np,
        npt,
        plt,
        scipy,
    )


@app.cell(hide_code=True)
def _(mo):
    mo.md(r"""We use CAMB to generate the matter power spectrum:""")
    return


@app.cell
def _(Sequence, SourceWindow, camb, np, plt):
    h = 0.675
    minkh = 1e-4
    maxkh = 1
    npoints = 512
    lmax = 2000
    dlog = np.log(maxkh / minkh) / (npoints - 1)
    k_per_logint = int(1 / dlog)

    redshift_slices = [0, 1, 2]

    cosmology_params = {
        "H0": 100 * h,
        "ombh2": 0.022,
        "omch2": 0.122,
        "mnu": 0,
        "omk": 0,
        "tau": 0.06,
        "As": 2e-9,
        "ns": 0.965,
    }

    def update_params(
        params: camb.CAMBparams,
        lmax: int | None = None,
        kmax: float | None = None,
        dlog: float | None = None,
        redshifts: Sequence[float] | None = None,
        non_linear: bool = False,
        windows: Sequence[SourceWindow] | None = None,
    ) -> camb.CAMBparams:
        if lmax:
            params.set_for_lmax(lmax=lmax, nonlinear=non_linear)

        if redshifts is not None:
            params.set_matter_power(
                redshifts=redshifts,
                kmax=kmax,
                k_per_logint=int(1 / dlog),
                nonlinear=non_linear,
            )

        if windows is not None:
            params.SourceWindows = windows
        return params

    params = camb.set_params(**cosmology_params)
    params = update_params(
        params,
        kmax=maxkh * h,
        dlog=dlog,
        redshifts=redshift_slices,
        non_linear=False,
        windows=None,
        lmax=lmax,
    )

    _results = camb.get_results(params)

    kh, _, pk = _results.get_matter_power_spectrum(
        minkh=minkh, maxkh=maxkh, npoints=npoints
    )
    _fig, _ax = plt.subplots()
    for _i, redshift in enumerate(redshift_slices):
        _ax.loglog(kh, pk[_i, :], label=f"$z={redshift}$")
    _ax.set_xlabel("$k \\, \\rm{[h/Mpc]}$")
    _ax.set_ylabel("$P(k) \\, \\rm{[Mpc/h]^3}$")
    _ax.set_title("Matter Power Spectrum")
    _ax.legend()
    _ax.grid()
    _fig.tight_layout()
    _fig
    return h, kh, lmax, maxkh, minkh, params, pk, update_params


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
    ## The Hankel transform

    Converting the power spectrum into the real-space correlation function $\xi(r)$ amounts to taking the Fourier transform,

    $$ \xi(r) = \int d^3k P(k)e^{i \bm{k} \cdot \bm{r}}$$

    The radial-angular split collapses the plane wave $e^{i \bm{k} \cdot \bm{r}}$ into a volume factor and a spherical Bessel function,

    $$
    \begin{align*}
    \xi(r) &= \frac{1}{2\pi^2}\int dk k^2 P(k) j_0(kr) \\
    &= \frac{1}{2\pi^2}\int dk k^2 P(k) \frac{\sin (kr)}{kr}.
    \end{align*}
    $$

    The inverse Hankel transform that the code computes has a slightly different definition. For a transform pair $a(r), \tilde{a}(k)$,

    $$
    \begin{align*}
    \tilde{a}(k) &= \int J_\mu(kr)a(r)kdr,\\
    a(r) &= \int J_\mu(kr)a(k)rdk,
    \end{align*}
    $$

    where $J_\mu(\cdot)$ is the Bessel function of order $\mu$, which is related to the spherical bessel function $j_\mu$ by

    $$j_\mu(x) = \sqrt{\frac{\pi}{2x}}J_{\mu + 1/2}(x).$$

    Recasting the expression of the correlation function as an explicit Hankel transform,

    $$
    \begin{align*}
    \xi(r) &= \frac{1}{2\pi^2}\int k^2 P(k) \sqrt{\frac{\pi}{2kr}}J_{1/2}(kr)dk \\
    &= \frac{1}{(2\pi r)^{3/2}}\int k^{3/2} P(k) J_{1/2}(kr)rdk\\
    &= \frac{1}{(2\pi r)^{3/2}} \rm{IFHT} [k^{3/2} P(k), \mu=1/2]
    \end{align*}
    $$

    In the cell below, we compute the correlation function $\xi(r)$ via the inverse Hankel transform using the FFTLog algorithm. We also compute the integration directly for comparison.
    """
    )
    return


@app.cell
def _(np, scipy):
    def get_xi_direct_integration(kh: np.ndarray, r: np.ndarray, pk: np.ndarray):
        log_kh = np.log(kh)
        integrand = pk * kh**3 * scipy.special.spherical_jn(0, np.outer(r, kh))
        return np.trapezoid(integrand, x=log_kh, axis=-1) / (2 * np.pi**2)
    return (get_xi_direct_integration,)


@app.cell
def _(mo):
    mo.md(r"""In FFTLog, the kernel of the Bessel function is given by its Mellin transform, which is defined in the interval $(-\mu, 3/2)$. Since $mu=1/2$ here, we can test different values of the biased transform in the range $(-3/2, 1/2)$.""")
    return


@app.cell
def _(
    BesselJKernel,
    FFTLog,
    Sequence,
    get_xi_direct_integration,
    kh,
    np,
    npt,
    pk,
    plt,
):
    def plot_correlation_function_with_bias_comparison(
        k: npt.ArrayLike,
        pk: npt.ArrayLike,
        bias: Sequence[float],
        logc: float = 0,
        minimize_ringing: bool = False,
        **kwargs,
    ):
        fig, (ax, ax_residue) = plt.subplots(2, 1, **kwargs)
        k = np.asarray(k)
        pk = np.asarray(pk)
        for q in bias:
            kernel = BesselJKernel(mu=0.5)
            fftlog = FFTLog.from_array(
                k, kernel=kernel, bias=q, minimize_ringing=minimize_ringing, logc=logc
            )
            grid = fftlog.create_grid(k=k)
            xi_fftlog_data = fftlog.inverse(k ** (3 / 2) * pk)
            xi_fftlog = xi_fftlog_data / (2 * np.pi * grid.r) ** (3 / 2)
            xi_direct_integration = get_xi_direct_integration(k, grid.r, pk)
            residue = np.abs(xi_fftlog / xi_direct_integration - 1)
            ax.semilogx(grid.r, xi_fftlog, label=rf"$q={q}$")
            ax_residue.loglog(grid.r, residue, label=rf"$q={q}$")

        ax.semilogx(grid.r, xi_direct_integration, label="Direct Integration")
        for _ax, ylabel in zip((ax, ax_residue), (r"$\xi(r)$", "Relative error")):
            _ax.set_xlabel("$r \\, \\rm{[Mpc/h]}$")
            _ax.set_ylabel(ylabel)
            _ax.legend()

        fig.tight_layout()
        return fig, (ax, ax_residue)

    _fig, _ = plot_correlation_function_with_bias_comparison(
        kh, pk[0, :], bias=[-0.5, 0, 0.5], logc=0, minimize_ringing=False
    )
    _fig
    return


@app.cell
def _():
    return


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
    ## Minimizing ringing

    Let's try to reduce the oscillations by setting `minimizing_ringing=True`:
    """
    )
    return


@app.cell
def _(BesselJKernel, FFTLog, get_xi_direct_integration, kh, np, npt, pk, plt):
    def plot_correlation_function_with_optimal_center_comparison(
        k: npt.ArrayLike,
        pk: npt.ArrayLike,
        bias: float = 0,
        logc: float = 0,
        **kwargs,
    ):
        fig, (ax, ax_residue) = plt.subplots(2, 1, **kwargs)
        k = np.asarray(k)
        pk = np.asarray(pk)
        for minimize_ringing in (False, True):
            kernel = BesselJKernel(mu=0.5)
            fftlog = FFTLog.from_array(
                k, kernel=kernel, bias=bias, minimize_ringing=minimize_ringing, logc=logc
            )
            grid = fftlog.create_grid(k=k)
            xi_fftlog_data = fftlog.inverse(k ** (3 / 2) * pk)
            xi_fftlog = xi_fftlog_data / (2 * np.pi * grid.r) ** (3 / 2)
            xi_direct_integration = get_xi_direct_integration(k, grid.r, pk)
            residue = np.abs(xi_fftlog / xi_direct_integration - 1)
            logc_label = fftlog.logc if not minimize_ringing else r"\text{optimal}"
            label = rf"$\log (k_0 r_0) = {logc_label}$"
            ax.semilogx(grid.r, xi_fftlog, label=label)
            ax_residue.loglog(grid.r, residue, label=label)

        ax.semilogx(grid.r, xi_direct_integration, label="Direct Integration")
        for _ax, ylabel in zip((ax, ax_residue), (r"$\xi(r)$", "Relative error")):
            _ax.set_xlabel("$r \\, \\rm{[Mpc/h]}$")
            _ax.set_ylabel(ylabel)
            _ax.legend()

        fig.tight_layout()
        return fig, (ax, ax_residue)

    _fig, _ = plot_correlation_function_with_optimal_center_comparison(
        kh,
        pk[0, :],
        bias=0,
        logc=0,
    )
    _fig
    return


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
    ## Projecting observables on the sphere

    Many cosmological observables, from CMB temperature and polarization to weak lensing, are more naturally defined in harmonic space. The angular cross-correlation of two observables is related to the primordial power spectrum by

    $$
    C^{ij}_{\ell} = \int_0^\infty dk k^2 \Delta^i_\ell (k) \Delta^j_\ell (k) \mathcal{P}(k),
    $$

    The quantity $\Delta^i(k)$ is generally reffered as the kernel for the observable $i$, and is defined by a line of sight integral

    $$
    \Delta^i_\ell (k) = \int_0^\infty d\chi W^i(\chi) \mathcal{T}(\chi, k) j_\ell(k\chi),
    $$

    where $W^i(\chi)$ is a radial window function, and $\mathcal{T}(\chi, k)$ is the transfer function that characterizes the cosmological evolution of wavemodes related to that tracer. We also observe the presence of the spherical Bessel function $j_\ell(k\chi)$, which suggests recasting the expression as a Hankel transform that can be efficiently calculated with the FFTLog algorithm. One obstacle to doing so is the dependence of the transfer function $\mathcal{T}(\chi, k)$ on both $\chi$ and $k$. 

    In linear theory and in the absence of anisotropic stress, the transfer functions factor into $\mathcal{T}(\chi, k)=\mathcal{T}_{\chi}(\chi) \mathcal{T}_{k}(k)$. We may then cast the line-of-sight integral into the form expected by FFTLog,

    $$
    \begin{align*}
    \Delta^i_\ell (k) &= \mathcal{T}_{k}(k) \int_0^\infty W^i(\chi) \mathcal{T}_{\chi}(\chi) j_\ell(k\chi) d\chi \\
    &= \mathcal{T}_{k}(k) \int_0^\infty W^i(\chi) \mathcal{T}_{\chi}(\chi) \sqrt{\frac{\pi}{2k\chi}} J_{\ell + 1/2}(k\chi) d\chi \\
    &= \sqrt{\frac{\pi}{2}}k^{-3/2}\mathcal{T}_{k}(k) \int_0^\infty \frac{W^i(\chi) \mathcal{T}_{\chi}(\chi)}{\chi^{1/2}} J_{\ell + 1/2}(k\chi) kd\chi \\
    &= \sqrt{\frac{\pi}{2}}k^{-3/2}{T}_{k}(k) \text{FHT} \left[\frac{W^i(\chi) \mathcal{T}_{\chi}(\chi)}{\chi^{1/2}}, \mu = \ell + 1/2 \right]
    \end{align*}
    $$
    """
    )
    return


@app.cell
def _(camb, np, params):
    _background = camb.get_background(params)
    _chimin, _chimax = _background.comoving_radial_distance(np.array([0.002, 4]))
    print(f"chi_min: {_chimin:.1e}, chi_max: {_chimax:.1e}, L: {_chimax / _chimin:.1e}")
    return


@app.cell
def _(
    GaussianSourceWindow,
    SourceWindow,
    camb,
    h,
    lmax,
    maxkh,
    minkh,
    np,
    npt,
    params,
    plt,
    update_params,
):
    def get_nbar(z: npt.NDArray, mu: float, sigma: float):
        """
        Gaussian window function
        """
        chi = (z - mu) / sigma
        return np.exp(-(chi**2)) / np.sqrt(2 * np.pi) / sigma

    def get_source_term(
        params: camb.CAMBparams,
        k: npt.NDArray,
        zmin: float,
        zmax: float,
        lmax: int,
        tracer_bias: float,
        sigma_relative: float,
    ) -> tuple[np.ndarray, np.ndarray, np.ndarray, SourceWindow]:
        # Define redshift array for CAMB calculations
        # CAMB allows at most 256 points
        z = np.linspace(zmin, zmax, 256)

        # Define source window function: Gaussian centered at zmid
        zmid = 0.5 * (zmin + zmax)
        zrange = zmax - zmin
        sigma = sigma_relative * zrange
        nbar = get_nbar(z, mu=zmid, sigma=sigma)
        source_window = GaussianSourceWindow(
            redshift=zmid, sigma=sigma, source_type="counts", bias=tracer_bias
        )

        # Get transfer functions from CAMB
        npoints = k.shape[0]
        dlog = (np.log(k[-1]) - np.log(k[0])) / (npoints - 1)
        params = update_params(
            params,
            kmax=k[-1],
            lmax=lmax,
            redshifts=z,
            dlog=dlog,
            windows=[source_window],
        )
        results = camb.get_results(params)
        chi = results.comoving_radial_distance(z)
        transfer = results.get_matter_transfer_data()

        # Get values of k used for interpolating the transfer functions
        transfer_k = transfer.transfer_data[camb.model.Transfer_kh - 1, :, 0] * params.h
        # CAMB returns T(k, z) / k**2.
        # In linear theory, T(k, z) factorizes
        transfer_matter_of_z = (
            transfer.transfer_data[camb.model.Transfer_tot - 1, 0, :]
            # * transfer_k[0] ** 2
        )

        # Compute source term: W(z) * T(z) * dz/dchi = W(z) * T(z) * H(z)
        sigma_8 = results.get_sigma8()[::-1]
        growth_function = sigma_8 / sigma_8[0]
        jacobian = results.h_of_z(z)
        source_term = nbar * growth_function * jacobian
        return z, chi, source_term, source_window

    _fftlog_npoints = 2048
    k = np.geomspace(minkh, maxkh, _fftlog_npoints) * h
    z, chi, source_term, source_window_camb = get_source_term(
        params,
        k,
        zmin=2e-4,
        zmax=0.2,
        lmax=lmax,
        tracer_bias=1,
        sigma_relative=0.2,
    )
    _fig, _ax = plt.subplots()
    _ax.semilogx(chi, source_term)
    _ax.set(xlabel=r"$\chi$", ylabel=r"$S(\chi)$")
    _fig.tight_layout()
    _fig
    return chi, k, source_term, z


@app.cell
def _(
    FFTLog,
    Grid,
    SphericalBesselJKernel,
    chi,
    k,
    lmax,
    np,
    npt,
    plt,
    source_term,
    z,
):
    def compute_kernel_with_fftlog(
        k: npt.NDArray,
        z: npt.NDArray,
        chi: npt.NDArray,
        source_term: npt.NDArray,
        ells: npt.NDArray,
        fftlog_bias: float = 0.1,
        p: float = 0.68,
        minimize_ringing: bool = False,
        recenter: bool = False,
    ) -> tuple[np.ndarray, np.ndarray, np.ndarray, FFTLog, Grid]:
        if recenter:
            # We want to determine the best center for our r grid
            # We compute the (alpha, 1 - alpha) quantiles of the normalized source_term
            # And take rc to be the geometric mean of them
            source_term_cdf = 0.5 * np.cumsum(
                np.diff(chi) * (source_term[1:] + source_term[:-1])
            )
            source_term_cdf /= source_term_cdf[-1]  # Normalize
            alpha = (1 - p) / 2
            imin, imax = np.searchsorted(
                source_term_cdf, [min(alpha, 1 - alpha), max(alpha, 1 - alpha)]
            )
            rc = np.sqrt(chi[imin] * chi[imax])

        else:
            # Keep the center as-is, pad values to the left and right
            rc = np.sqrt(chi[0] * chi[-1])

        # Creating FFTLog object
        kernel = SphericalBesselJKernel(ell=ells.reshape(-1, 1))
        kc = np.sqrt(k[0] * k[-1])
        logc = np.log(kc * rc)
        n = k.shape[-1]
        dlog = np.log(k[-1] / k[0]) / (n - 1)
        fftlog = FFTLog(
            kernel=kernel,
            dlog=dlog,
            n=n,
            bias=fftlog_bias,
            minimize_ringing=minimize_ringing,
            logc=np.log(ells + 0.5).reshape(-1, 1),
        )
        j = np.arange(n)
        jc = (n - 1) / 2
        r = rc * np.exp(dlog * (j - jc))
        grid = fftlog.create_grid(r=r)
        mask = (grid.r >= chi[0]) & (grid.r <= chi[-1])
        # Interpolate source term at the FFTLog grid
        # Depending on dlog and n, FFTLog grid may contain
        # artificially large values of chi (larger than the conformal age of the universe!)
        # Therefore we just set the source term to zero for these extrapolated values
        source_term_interpolated = np.interp(grid.r, chi, source_term, left=0, right=0)
        argument = source_term_interpolated
        kernel_data = fftlog.forward(argument)
        # Remember that FFTLog computes \int s(chi) jl(k chi) k dchi
        kernel = kernel_data / grid.k
        return mask, argument, kernel, fftlog, grid

    ells = np.arange(lmax + 1)
    mask_chi, _argument, kernel, fftlog, grid = compute_kernel_with_fftlog(
        k,
        z,
        chi,
        source_term,
        ells,
        fftlog_bias=0.01,
        p=0.68,
        minimize_ringing=False,
        recenter=True,
    )
    print(f"{mask_chi.astype(int).sum()} / {fftlog.n} points inside physical range for chi")

    _fig, _ax = plt.subplots()
    _ax.scatter(grid.r[mask_chi], _argument[mask_chi], c="tab:blue", s=20)
    _ax.scatter(grid.r[~mask_chi], _argument[~mask_chi], c="red", label="Padded values", s=20)
    _ax.axvline(grid.rcenter, color="k", ls="--", label=r"$\chi_*$")
    _ax.set(
        xlabel=r"$\chi$",
        ylabel=r"$S(\chi) / \sqrt{\chi}$",
        title="Source function renormalized and padded",
        xscale="log"
    )
    _ax.legend()
    _fig.tight_layout()
    _fig
    return grid, kernel, mask_chi


@app.cell
def _(lmax, mo):
    selected_ell = mo.ui.slider(
        start=0, stop=lmax, step=1, label=r"Select value of $\ell$ for plotting:"
    )
    return (selected_ell,)


@app.cell
def _(mo, selected_ell):
    mo.hstack([selected_ell, mo.md(f"Selected: {selected_ell.value}")])
    return


@app.cell
def _(Grid, grid, kernel, mask_chi, np, npt, plt, selected_ell):
    def _plot_kernel(
        kernel: npt.NDArray,
        grid: Grid,
        selected_ell: int,
        mask: npt.NDArray | None = None,
        plot_maximum: bool = False,
    ):
        fig, ax = plt.subplots()
        if mask is None:
            mask = np.ones((grid.k.shape[-1])).astype(bool)
        if grid.k.ndim > 1:
            k = grid.k[selected_ell, mask]
        else:
            k = grid.k[mask]
        y = kernel[selected_ell, mask]
        imax = np.argmax(y)
        print(y[imax] * k[imax])
        ax.semilogx(
            k,
            y,
            label=rf"$\ell = {selected_ell}$",
        )
        kmax = (selected_ell + 0.5) / grid.rcenter
        if plot_maximum:
            ax.axvline(k[imax], color="red", label=rf"$i = {imax}$")
        ax.axvline(kmax, color="k", ls="--", label=r"$k_*=(\ell + 1/2) / \chi_*$")
        ax.set(xlabel=r"$k$", ylabel=r"$\Delta_\ell (k)$")
        ax.legend()
        fig.tight_layout()
        return fig, ax

    mask_k = mask_chi[::-1]
    _fig, _ = _plot_kernel(kernel, grid, selected_ell.value, mask=None)
    _fig
    return


@app.cell
def _(chi, grid, np, npt, plt, scipy, source_term, z):
    def compute_kernel_with_quadrature(
        k: npt.NDArray,
        z: npt.NDArray,
        chi: npt.NDArray,
        source_term: npt.NDArray,
        ells: npt.NDArray,
        method: str = "naive",
        nzeros: int = 5,
        npoints: int | None = None,
    ):
        if method == "naive":
            kchi = np.outer(k, chi)
            bessel = scipy.special.spherical_jn(ells.reshape(-1, 1, 1), kchi)
            return np.trapezoid(bessel * source_term, x=chi, axis=-1)
        else:
            if npoints is None:
                npoints = k.shape[-1]
            source_term_interpolator = scipy.interpolate.CubicSpline(
                chi, source_term, extrapolate=True
            )
            kernel = []
            for ell in ells:
                # We integrate directly in x = k * kchi
                # \int S(x / k) * jl(x) dx / k
                xmax = scipy.special.jn_zeros(ell, nzeros)[-1]
                print(ell, xmax)
                x = np.linspace(0, xmax, npoints)
                bessel = scipy.special.spherical_jn(ell, x)
                chi_from_x = x / k[:, None]
                source_term_x = source_term_interpolator(chi_from_x)
                kernel_ell = np.trapezoid(source_term_x * bessel, x=x, axis=-1) / k
                kernel.append(kernel_ell)

            return np.stack(kernel, axis=0)

    def _plot_kernel(
        k: npt.NDArray,
        kernel: npt.NDArray,
        ells: npt.NDArray,
        plot_maximum: bool = False,
    ):
        assert kernel.shape == (*ells.shape, *k.shape), (
            f"{kernel.shape} {ells.shape} {k.shape}"
        )
        fig, ax = plt.subplots()
        for i, ell in enumerate(ells):
            y = kernel[i, ...]
            imax = np.argmax(y)
            print(y[imax] * k[imax])
            ax.semilogx(k, y, label=rf"$\ell = {ell}$")
            if plot_maximum:
                ax.axvline(k[imax], color="red", label=rf"$i = {imax}$")
        ax.set(xlabel=r"$k$", ylabel=r"$\Delta_\ell (k)$")
        ax.legend()
        fig.tight_layout()
        return fig, ax

    _ells = np.array([159])
    _kernel_quadrature_naive = compute_kernel_with_quadrature(
        grid.k, z, chi, source_term, _ells, method="naive", nzeros=5
    )
    _fig, _ = _plot_kernel(grid.k, _kernel_quadrature_naive, _ells)
    _fig
    return


@app.cell
def _(Sequence, camb, grid, kernel, np, npt, results):
    def compare_tracer_cls_with_camb(
        params: camb.CAMBparams,
        k: npt.NDArray,
        zmin: float,
        zmax: float,
        lmax: int,
        tracer_bias: float,
        sigma_relative: float,
        p: float = 0.68,
        minimize_ringing: bool = False,
        fftlog_npoints: int | None = None,
        fftlog_bias: float = 0,
        plot_kernel_ells: Sequence[int] | None = None,
        **kwargs,
    ):
        assert np.isfinite(kernel).all()
        pk = results.get_matter_power_interpolator(
            nonlinear=False, hubble_units=False, k_hunit=False
        )
        integrand = (
            kernel** 2
            * pk.P(0, grid.k)  # / (k**3 / (2 * np.pi**2))#params.scalar_power(grid.k)
        )
        cls_fftlog = 4 * np.pi * np.trapezoid(integrand, x=np.log(grid.k), axis=-1)
        cls_camb = results.get_source_cls_dict(lmax=lmax, raw_cl=True)["W1xW1"]
    return


@app.cell
def _():
    return


@app.cell
def _():
    return


if __name__ == "__main__":
    app.run()
