"""Test left censored event times Metropolis-Hastings"""

from typing import NamedTuple

import jax
import numpy as np
import pytest

from .left_censored_events_mh import left_censored_events_mh

Tensor = np.typing.NDArray


class ExampleEventsInitCond(NamedTuple):
    seir_events: Tensor
    seir_initial_state: Tensor


@pytest.fixture
def seed():
    return jax.random.key(42)


def test_shape(evaltest, seir_metapop_example, seed):
    """Test that shapes do not get altered in the kernel"""
    example_state = ExampleEventsInitCond(
        seir_metapop_example["events"],
        seir_metapop_example["initial_conditions"],
    )

    kernel = left_censored_events_mh(
        incidence_matrix=seir_metapop_example["incidence_matrix"],
        transition_index=1,
        max_timepoint=7,
        max_events=10,
        events_varname="seir_events",
        initial_conditions_varname="seir_initial_state",
        name="test_left_censored_events_mh",
    )

    def tlp(*_):
        return np.asarray(0.0, dtype=example_state.seir_events.dtype)

    cs, ks = kernel.init(tlp, example_state)
    print("cs: ", cs)
    (new_cs, new_ks), info = evaltest(lambda: kernel.step(tlp, (cs, ks), seed))

    assert cs.position.seir_events.shape == example_state.seir_events.shape
    assert (
        cs.position.seir_initial_state.shape
        == example_state.seir_initial_state.shape
    )
    assert jax.tree.structure(ks) == jax.tree.structure(new_ks)


def test_mcmc(evaltest, seir_metapop_example, seed):
    """Test that multiple invocations of the kernel result in valid
    event timeseries, and also definitely move around.
    """
    example_state = ExampleEventsInitCond(
        seir_metapop_example["events"],
        seir_metapop_example["initial_conditions"],
    )

    kernel = left_censored_events_mh(
        incidence_matrix=seir_metapop_example["incidence_matrix"],
        transition_index=1,
        max_timepoint=7,
        max_events=10,
        events_varname="seir_events",
        initial_conditions_varname="seir_initial_state",
        name="test_left_censored_events_mh",
    )

    def tlp(*_):
        return np.asarray(0.0, dtype=seir_metapop_example["events"].dtype)

    initial_chain_and_kernel_state = kernel.init(tlp, example_state)

    def one_step(chain_and_kernel_state, seed):
        new_chain_state, _ = kernel.step(tlp, chain_and_kernel_state, seed)
        return new_chain_state, None

    seeds = jax.random.split(seed, num=100)
    (cs, ks), _ = evaltest(
        lambda: jax.lax.scan(
            one_step, init=initial_chain_and_kernel_state, xs=seeds
        ),
    )

    samples = cs.position

    assert np.all(samples.seir_events >= 0.0)
    assert np.all(samples.seir_initial_state >= 0.0)

    assert (
        np.mean((samples.seir_events[1:] - samples.seir_events[:-1]) ** 2) > 0.0
    )
    assert (
        np.mean(
            (samples.seir_initial_state[1:] - samples.seir_initial_state[:-1])
            ** 2
        )
        > 0.0
    )
