"""Tests Brownian Bridge kernel"""

import os
import pickle as pkl

import jax
import jax.numpy as jnp
import numpy as np
import pytest
import tensorflow_probability.substrates.jax as tfp

from gemlib.distributions.brownian import BrownianMotion

from .brownian_bridge_kernel import UncalibratedBrownianBridgeKernel

tfd = tfp.distributions


@pytest.fixture
def seed():
    return jax.random.key(42)


@pytest.fixture(params=[np.float32, np.float64])
def brownian_example(request):
    """Fixture from model below"""
    dtype = request.param
    # fmt: off
    return {
        'mu0': np.array(1.078298782437224, dtype=dtype),
        'mu': np.array(
            [
                1.05857034,  0.32002898,  0.37228629,  0.40525237,  0.12359111,
                0.09761669, -0.06408588,  0.13988239,  0.99996464,  0.92174459,
                0.89632451,  0.76809755,  0.61848128,  0.34516299,  0.43396905,
                0.80968081,  0.95935981,  1.36020218,  1.78247587,  1.48679809,
                1.92589046,  2.10760496,  1.71412132,  1.77749278,  1.88419343,
                2.78635343,  2.7826425 ,  2.76047156,  2.9203929 ,  3.45938202,
                3.46593501,  3.72072366,  4.29276331,  4.52024769,  4.03456866,
                3.67338132,  3.89255809,  4.36232553,  4.36436465,  4.41965069,
                4.10126142,  3.99169247,  3.86904207,  3.80634771,  3.95730443,
                4.15291629,  4.23671593,  4.91920361,  4.85009102,  4.04911164,
                3.83959911,  4.1658448 ,  4.22685046,  4.3965696 ,  4.29211524,
                4.25564851,  4.20041381,  3.77908754,  3.98601494,  3.8937047 ,
                4.31104904,  4.2890428 ,  4.47919182,  4.23664283,  4.2335522 ,
                3.95874123,  3.66550923,  3.7419551 ,  3.58068457,  3.5154879 ,
                3.51269999,  3.73871861,  3.98074317,  3.79272018,  3.34503227,
                3.08081016,  3.2689756 ,  2.9434675 ,  3.05988444,  3.47211118,
                3.1298929 ,  2.90523247,  2.77559176,  3.16931728,  2.73967911,
                2.72032434,  2.57052955,  2.93432295,  2.21730476,  2.5978165 ,
                2.51272339,  3.22263945,  2.99662432,  3.3611697 ,  3.15646481,
                2.89784185,  2.34103054,  2.25303897,  2.82420474
             ],
            dtype=dtype,
        ),
        'y': np.array(
            [
                1.,   2.,   4.,   1.,   1.,   0.,   2.,   1.,   0.,   2.,   1.,
                2.,   1.,   1.,   2.,   2.,   1.,   2.,   1.,   7.,   4.,   5.,
                4.,   3.,  11.,   8.,  16.,  22.,  13.,  20.,  44.,  25.,  44.,
                74., 103.,  59.,  41.,  32.,  70.,  82.,  70.,  72.,  62.,  50.,
                37.,  57.,  55.,  66., 128., 142.,  67.,  49.,  57.,  74.,  79.,
                73.,  65.,  69.,  48.,  41.,  60.,  76.,  74.,  67.,  62.,  60.,
                58.,  35.,  47.,  48.,  25.,  37.,  38.,  44.,  30.,  22.,  19.,
                33.,  22.,  19.,  35.,  25.,  15.,  24.,  21.,  13.,  17.,  14.,
                18.,  12.,   9.,  13.,  27.,  29.,  30.,  19.,  22.,  13.,   9.,
                23.
             ],
            dtype=dtype,
        )
    }
    # fmt: on
    dir_path = os.path.dirname(os.path.realpath(__file__))
    with open(os.path.join(dir_path, "bb_fixture.pkl"), "rb") as f:
        return pkl.load(f)


def test_simple_brownian_motion(evaltest, brownian_example, seed):
    dtype = brownian_example["y"].dtype
    x = np.arange(0.0, 10.0, 0.1, dtype=dtype)
    Y = BrownianMotion(x)
    y = Y.sample(seed=seed)

    kernel = tfp.mcmc.MetropolisHastings(
        inner_kernel=UncalibratedBrownianBridgeKernel(
            Y.log_prob,
            index_points=x,
            span=90,
            scale=1.0,
        )
    )

    samples, results = evaltest(
        lambda: tfp.mcmc.sample_chain(
            num_results=10000,
            kernel=kernel,
            current_state=y,
            seed=seed,
        )
    )

    print("Acceptance rate:", results.is_accepted.mean())

    np.testing.assert_allclose(0.0, np.mean(samples[:, 0]), atol=1.0, rtol=0.1)
    np.testing.assert_allclose(10.0, np.var(samples[:, -1]), atol=1.0, rtol=0.1)


def test_poisson_with_brownian_mean(evaltest, brownian_example, seed):
    dtype = brownian_example["mu"].dtype
    index_points = np.arange(0.0, 10.0, 0.1, dtype=dtype)
    print("test dtype:", dtype)

    def model():
        mu0 = tfd.Normal(
            loc=jnp.asarray(1.0, dtype=dtype),
            scale=jnp.asarray(1.0, dtype=dtype),
        )

        def mu(mu0):
            return BrownianMotion(index_points, x0=mu0)

        def y(mu):
            rate = jnp.concatenate([jnp.zeros(1, dtype=dtype), mu], axis=-1)
            return tfd.Independent(
                tfd.Poisson(rate=jnp.exp(rate)),
                reinterpreted_batch_ndims=1,
            )

        return tfd.JointDistributionNamed({"mu0": mu0, "mu": mu, "y": y})

    model = model()

    def logp(mu):
        return model.log_prob(
            {
                "mu0": brownian_example["mu0"],
                "mu": mu,
                "y": brownian_example["y"],
            }
        )

    mcmc_kernel = tfp.mcmc.MetropolisHastings(
        inner_kernel=UncalibratedBrownianBridgeKernel(
            logp,
            index_points=index_points,
            span=5,
            scale=1.0,
            left=True,
            right=True,
        )
    )

    samples, results = evaltest(
        lambda: tfp.mcmc.sample_chain(
            num_results=5000,
            kernel=mcmc_kernel,
            current_state=jnp.full(
                brownian_example["mu"].shape, 3.0, dtype=dtype
            ),
            seed=seed,
        )
    )
    print("Acceptance rate:", results.is_accepted.mean())

    np.testing.assert_allclose(
        0.0,
        0.0,
        rtol=1.5,
        atol=2.0,
    )
