"""Right-censored events MCMC test"""

import jax
import jax.numpy as jnp
import numpy as np
import pytest

from .right_censored_events_mh import right_censored_events_mh


@pytest.fixture
def seed():
    return jax.random.PRNGKey(42)


def test_right_censored_events_mh(evaltest, seir_metapop_example, seed):
    events = seir_metapop_example["events"]
    initial_conditions = seir_metapop_example["initial_conditions"]

    def tlp(_):
        return jnp.array(0.0, dtype=events.dtype)

    kernel = right_censored_events_mh(
        incidence_matrix=seir_metapop_example["incidence_matrix"],
        transition_index=0,
        count_max=10,
        t_range=(events.shape[0] - 7, events.shape[0]),
    )

    cs, ks = kernel.init(tlp, events, initial_conditions=initial_conditions)
    (new_cs, new_ks), info = kernel.step(
        tlp, (cs, ks), seed, initial_conditions=initial_conditions
    )

    assert cs.position.shape == events.shape
    assert new_cs.position.shape == events.shape
    assert jax.tree.structure(ks) == jax.tree.structure(new_ks)

    seeds = jax.random.split(seed, num=100)

    def one_step(chain_and_kernel_state, seed):
        new_chain_state, _ = kernel.step(
            tlp,
            chain_and_kernel_state,
            seed,
            initial_conditions=initial_conditions,
        )
        return new_chain_state, None

    (cs, ks), _ = evaltest(
        lambda: jax.lax.scan(
            one_step,
            init=kernel.init(
                tlp, events, initial_conditions=initial_conditions
            ),
            xs=seeds,
        ),
    )

    samples = cs.position

    assert np.all(samples >= 0.0)
    assert np.mean((samples[1:] - samples[:-1]) ** 2) > 0.0
