"""Test for Chain Binomial Rippler Kernel"""

import jax
import pytest
import tensorflow_probability.substrates.jax as tfp

from gemlib.distributions import DiscreteTimeStateTransitionModel

from .chain_binomial_rippler import CBRKernel

tfd = tfp.distributions

MIN_EVENTS = 10


@pytest.fixture
def seed():
    return jax.random.key(42)


def test_cbr_kernel(evaltest, discrete_two_unit_sir_example, seed):
    model = DiscreteTimeStateTransitionModel(
        **discrete_two_unit_sir_example["model_params"]
    )

    def obs_process(current_events):
        return tfd.Independent(
            distribution=tfd.Binomial(
                total_count=current_events[..., 1], probs=0.5
            ),
            reinterpreted_batch_ndims=1,
        )

    observed_cases = obs_process(discrete_two_unit_sir_example["draw"]).sample(
        seed=seed
    )

    def tlp_fn(current_events):
        return obs_process(current_events).log_prob(observed_cases).sum()

    kernel = CBRKernel(tlp_fn, model=model)
    pkr = evaltest(
        lambda: kernel.bootstrap_results(discrete_two_unit_sir_example["draw"])
    )
    samples, results = kernel.one_step(
        discrete_two_unit_sir_example["draw"], pkr, seed
    )
