"""Test event time samplers"""

# ruff: noqa: PLR2004

import jax
import numpy as np
import pytest

from gemlib.distributions.discrete_markov import compute_state
from gemlib.mcmc.discrete_time_state_transition_model.move_events_impl import (
    UncalibratedEventTimesUpdate,
    discrete_move_events_proposal,
    events_state_count_bounding_fn,
)


@pytest.fixture
def seed():
    return jax.random.key(42)


def test_discrete_move_events_proposal(evaltest, sir_metapop_example, seed):
    num_units = 2

    proposal = discrete_move_events_proposal(
        incidence_matrix=sir_metapop_example["incidence_matrix"],
        target_transition_id=1,
        num_units=num_units,
        delta_max=4,
        count_max=10,
        initial_conditions=sir_metapop_example["initial_conditions"],
        events=sir_metapop_example["events"],
        count_bounding_fn=events_state_count_bounding_fn(10),
        name="foo",
    )

    sample = evaltest(lambda: proposal.sample(seed=seed))
    lp = evaltest(lambda: proposal.log_prob(sample))

    assert sample._fields == ("unit", "timepoint", "delta", "event_count")
    assert sample.unit.shape == (num_units,)
    assert sample.timepoint.shape == (num_units,)
    assert sample.delta.shape == (num_units,)
    assert sample.event_count.shape == (num_units,)

    for k, v in sample._asdict().items():
        assert v.dtype == np.int32, f"Field `{k}` is not int32"

    assert lp.shape == ()
    assert lp.dtype == sir_metapop_example["incidence_matrix"].dtype


@pytest.mark.parametrize("evaltest", ["eager", "jit_compile"], indirect=True)
def test_uncalibrated_event_time_update(evaltest, sir_metapop_example, seed):
    def tlp(_):
        return np.array(0.0, sir_metapop_example["events"].dtype)

    kernel = UncalibratedEventTimesUpdate(
        target_log_prob_fn=tlp,
        incidence_matrix=sir_metapop_example["incidence_matrix"],
        initial_conditions=sir_metapop_example["initial_conditions"],
        target_transition_id=1,
        delta_max=4,
        num_units=1,
        count_max=10,
    )

    # Test structures are consistent
    pkr = evaltest(
        lambda: kernel.bootstrap_results(sir_metapop_example["events"])
    )
    next_events, results = evaltest(
        lambda: kernel.one_step(
            sir_metapop_example["events"],
            pkr,
            seed=seed,
        )
    )
    assert jax.tree.structure(pkr) == jax.tree.structure(results)
    assert jax.tree.structure(
        sir_metapop_example["events"]
    ) == jax.tree.structure(next_events)

    # Test that multiple invocations of the kernel do not
    # allow the state to go negative, even if the tlp is not
    # there to guide the sampler.
    seeds = jax.random.split(seed, num=100)

    def one_step(state, seed):
        new_state = kernel.one_step(*state, seed=seed)
        return new_state, None

    (samples, results), _ = evaltest(
        lambda: jax.lax.scan(
            one_step,
            init=(sir_metapop_example["events"], pkr),
            xs=seeds,
        )
    )
    sampled_states = compute_state(
        initial_state=sir_metapop_example["initial_conditions"],
        events=samples,
        incidence_matrix=sir_metapop_example["incidence_matrix"],
        closed=True,
    )

    assert np.all(samples >= 0.0)
    assert np.all(sampled_states >= 0.0)
