"""Test the adaptive random walk metropolis sampling algorithm"""

import jax
import numpy as np
import pytest
import tensorflow_probability.substrates.jax as tfp

from .adaptive_random_walk_metropolis import adaptive_rwmh
from .mcmc_sampler import mcmc

tfd = tfp.distributions

NUM_SAMPLES = 1000


def get_seed():
    return jax.random.key(42)


def split_seed(seed, n):
    return tfp.random.split_seed(seed, n=n)


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_adaptive_rwmh_univariate_shapes(dtype, evaltest):
    distribution = tfd.Normal(loc=dtype(0.0), scale=dtype(1.0))

    kernel = adaptive_rwmh(initial_scale=dtype(2.38))

    (cs, ks) = evaltest(lambda: kernel.init(distribution.log_prob, dtype(0.1)))
    (cs1, ks1), info = evaltest(
        lambda: kernel.step(distribution.log_prob, (cs, ks), seed=get_seed())
    )

    assert jax.tree.structure(cs) == jax.tree.structure(cs1)
    assert jax.tree.structure(ks) == jax.tree.structure(ks1)


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_adaptive_rwmh_bivariate_shapes(dtype, evaltest):
    def log_prob(x1, x2):
        x1_lp = tfd.Normal(loc=dtype(0.0), scale=dtype(1.0)).log_prob(x1)
        x2_lp = tfd.Normal(loc=dtype(1.0), scale=dtype(2.0)).log_prob(x2)
        return x1_lp + x2_lp

    kernel = adaptive_rwmh(initial_scale=dtype(2.38))

    (cs, ks) = evaltest(lambda: kernel.init(log_prob, [dtype(0.1), dtype(1.1)]))
    (cs1, ks1), info = evaltest(
        lambda: kernel.step(log_prob, (cs, ks), seed=get_seed())
    )

    assert jax.tree.structure(cs) == jax.tree.structure(cs1)
    assert jax.tree.structure(ks) == jax.tree.structure(ks1)


@pytest.mark.parametrize("evaltest", ["jit_compile"], indirect=True)
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_adaptive_rwmh_univariate_mcmc(dtype, evaltest):
    distribution = tfd.Normal(loc=dtype(0.0), scale=dtype(1.0))

    kernel = adaptive_rwmh(initial_scale=dtype(2.38))

    samples, info = evaltest(
        lambda: mcmc(
            2000,
            sampling_algorithm=kernel,
            target_density_fn=distribution.log_prob,
            initial_position=dtype(0.1),
            seed=get_seed(),
        )
    )

    accept_eps = 0.1
    assert np.abs(np.mean(info.is_accepted[1000:]) - 0.44) < accept_eps

    mean_eps = 0.2
    assert np.abs(np.mean(samples[1000:]) - 0.0) < mean_eps

    var_eps = 0.05
    assert np.abs(np.std(samples[1000:]) - 1.0) < var_eps


@pytest.mark.parametrize("evaltest", ["jit_compile"], indirect=True)
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_adaptive_rwmh_multivariate_mcmc(dtype, evaltest):
    def log_prob(x1, x2, x3):
        x1_lp = tfd.Normal(loc=dtype(0.0), scale=dtype(0.1)).log_prob(x1)
        x2_lp = tfd.Normal(loc=dtype(1.0), scale=dtype(0.2)).log_prob(x2)
        x3_lp = tfd.Normal(loc=dtype(2.0), scale=dtype(0.4)).log_prob(x3)

        return x1_lp + x2_lp + x3_lp

    kernel = adaptive_rwmh(initial_scale=dtype(1.5))

    samples, info = evaltest(
        lambda: mcmc(
            2000,
            sampling_algorithm=kernel,
            target_density_fn=log_prob,
            initial_position=[dtype(0.1), dtype(-0.1), dtype(0.5)],
            seed=get_seed(),
        )
    )

    accept_eps = 0.12
    assert np.abs(np.mean(info.is_accepted[1000:]) - 0.234) < accept_eps

    print("means: ", [np.mean(x[1000:]) for x in samples])
    mean_eps = 0.2
    assert all(
        np.abs(np.mean(x[1000:]) - y) / (y + 1) < mean_eps
        for x, y in zip(samples, [0.0, 1.0, 2.0], strict=True)
    )

    var_eps = 0.05
    assert all(
        np.abs(np.std(x[1000:]) - y) < var_eps
        for x, y in zip(samples, [0.1, 0.2, 0.4], strict=True)
    )
