"""Test ContinuousTimeStateTransitionModel"""

import jax
import jax.numpy as jnp
import numpy as np
import pytest
import tensorflow_probability.substrates.jax as tfp

from gemlib.distributions.continuous_time_state_transition_model import (
    ContinuousTimeStateTransitionModel,
    EventList,
)

tfd = tfp.distributions

INCIDENCE_MATRIX = np.array([[-1, 0], [1, -1], [0, 1]])
NUM_STEPS = 1999
MIN_EPIDEMIC_LEN = 10


@pytest.fixture
def example_ilm():
    """A simple event list with 4 units, SIR model"""
    return {
        "incidence_matrix": np.array(
            [[-1, 0], [1, -1], [0, 1]], dtype=np.float32
        ),
        "event_list": EventList(
            time=np.array(
                [0.4, 1.3, 1.5, 1.9, 2.3, np.inf, np.inf], dtype=np.float32
            ),
            transition=np.array([0, 0, 1, 1, 1, 2, 2], dtype=np.int32),
            unit=np.array([1, 2, 0, 2, 1, 0, 0], dtype=np.int32),
        ),
        "initial_conditions": np.array(
            [[0, 1, 0], [1, 0, 0], [1, 0, 0], [1, 0, 0]], dtype=np.float32
        ),
    }


def seed():
    return jax.random.key(1010)


def test_simple_sir(evaltest, homogeneous_sir_params):
    model_params = homogeneous_sir_params()
    model = ContinuousTimeStateTransitionModel(**model_params)

    sim = evaltest(lambda: model.sample(seed=seed))

    assert sim is not None
    assert sim.time.shape[0] == model_params["num_steps"]
    assert sim.transition.shape[0] == model_params["num_steps"]
    assert sim.unit.shape[0] == model_params["num_steps"]
    assert model.batch_shape == []


def test_two_unit_sir(evaltest, two_unit_sir_params):
    model_params = two_unit_sir_params()

    model = ContinuousTimeStateTransitionModel(**model_params)

    sim = evaltest(lambda: model.sample(seed=seed))
    evaltest(lambda: model.log_prob(sim))


def test_simple_sir_loglik(evaltest, example_ilm):
    """Test loglikelihood function"""
    # epi constants
    incidence_matrix = example_ilm["incidence_matrix"]
    initial_population = example_ilm["initial_conditions"]

    def rate_fn(_1, _2):
        si_rate = 0.5
        ir_rate = 0.7

        return si_rate, ir_rate

    # create an instance of the model
    epi_model = ContinuousTimeStateTransitionModel(
        transition_rate_fn=rate_fn,
        incidence_matrix=incidence_matrix,
        initial_state=initial_population,
        num_steps=100,
        initial_time=0.0,
    )

    log_lik = evaltest(lambda: epi_model.log_prob(example_ilm["event_list"]))

    # hand calculated log likelihood
    actual_loglik = -7.256319192936088

    np.testing.assert_almost_equal(log_lik, desired=actual_loglik, decimal=5)


def test_loglik_mle(evaltest, cont_time_homogeneous_sir_example):
    # Create sample
    actuals = list(cont_time_homogeneous_sir_example["true_params"].values())

    def opt_fn(log_rate_parameters):
        """Return negative log likelihood"""

        beta, gamma = jnp.unstack(jnp.exp(log_rate_parameters))
        model_params = cont_time_homogeneous_sir_example["model_params_fn"](
            beta, gamma
        )
        model = ContinuousTimeStateTransitionModel(**model_params)

        return -model.log_prob(cont_time_homogeneous_sir_example["draw"])

    initial_parameters = np.array(
        [-2.0, -2.0], dtype=cont_time_homogeneous_sir_example["dtype"]
    )
    opt = evaltest(
        lambda: tfp.optimizer.nelder_mead_minimize(
            opt_fn,
            initial_vertex=initial_parameters,
            func_tolerance=1e-6,
            position_tolerance=1e-3,
        )
    )
    assert opt.converged
    np.testing.assert_allclose(
        np.exp(opt.position), actuals, atol=0.2, rtol=0.1
    )


def test_batch_sample(evaltest, homogeneous_sir_params):
    """Test batched.sample method"""

    model = ContinuousTimeStateTransitionModel(
        **homogeneous_sir_params(0.5, 0.14)
    )

    sample_shape = (3, 2)
    sims = evaltest(lambda: model.sample(sample_shape=sample_shape, seed=seed))

    expected_shape = sample_shape + (model.num_steps,)

    assert sims.time.shape == expected_shape
    assert sims.transition.shape == expected_shape
    assert sims.unit.shape == expected_shape


def test_batch_log_prob(evaltest, homogeneous_sir_params):
    """Test operability of batched log prob method"""

    model = ContinuousTimeStateTransitionModel(
        **homogeneous_sir_params(0.5, 0.14)
    )

    sample_shape = (3, 2)
    sims = evaltest(lambda: model.sample(sample_shape=sample_shape, seed=seed))
    log_probs = evaltest(lambda: model.log_prob(sims))

    assert log_probs.shape == sample_shape
