"""Test partially censored events move for DiscreteTimeStateTransitionModel"""

import jax
import jax.numpy as jnp
import numpy as np
import pytest

from gemlib.mcmc.discrete_time_state_transition_model.move_events import (
    move_events,
)


def get_seed():
    return jax.random.key(42)


@pytest.mark.parametrize("evaltest", ["eager", "jit_compile"], indirect=True)
def test_move_events(evaltest, seir_metapop_example):
    def tlp(_):
        return jnp.array(0.0, dtype=seir_metapop_example["events"].dtype)

    kernel = move_events(
        seir_metapop_example["incidence_matrix"],
        transition_index=1,
        num_units=2,
        delta_max=4,
        count_max=10,
    )

    (cs, ks) = kernel.init(
        target_log_prob_fn=tlp,
        position=seir_metapop_example["events"],
        initial_conditions=seir_metapop_example["initial_conditions"],
    )

    (new_cs, new_ks), info = kernel.step(
        tlp, (cs, ks), get_seed(), seir_metapop_example["initial_conditions"]
    )

    assert cs.position.shape == seir_metapop_example["events"].shape
    assert new_cs.position.shape == seir_metapop_example["events"].shape
    assert jax.tree.map(lambda _1, _2: True, ks, new_ks)

    seeds = jax.random.split(get_seed(), num=100)

    def one_step(chain_and_kernel_state, seed):
        new_chain_state, _ = kernel.step(
            tlp,
            chain_and_kernel_state,
            seed,
            seir_metapop_example["initial_conditions"],
        )
        return new_chain_state, _

    init_chain_and_state = kernel.init(
        tlp,
        seir_metapop_example["events"],
        seir_metapop_example["initial_conditions"],
    )
    (cs, ks), info = evaltest(
        lambda: jax.lax.scan(
            one_step,
            init=init_chain_and_state,
            xs=seeds,
        )
    )

    samples = cs.position
    assert np.all(np.greater_equal(samples, 0.0))
    assert np.mean((samples[1:] - samples[:-1]) ** 2) > 0.0
