# Dependency imports

import jax
import jax.numpy as jnp
import numpy as np
import tensorflow_probability.substrates.jax as tfp

from gemlib.distributions.discrete_time_state_transition_model import (
    DiscreteTimeStateTransitionModel,
)


def test_simple_sir(evaltest, homogeneous_sir_params):
    model_params = homogeneous_sir_params()
    model = DiscreteTimeStateTransitionModel(**model_params)
    sim = evaltest(lambda: model.sample(seed=0))

    # sim.shape should be (num_steps, num_pops, num_events)
    assert sim is not None
    assert sim.shape == (
        model_params["num_steps"],
        model_params["initial_state"].shape[0],
        model_params["incidence_matrix"].shape[1],
    )


def test_simple_sir_jax(evaltest, homogeneous_sir_params_jax):
    model_params = homogeneous_sir_params_jax()
    model = DiscreteTimeStateTransitionModel(**model_params)
    sim = evaltest(lambda: model.sample(seed=0))

    # sim.shape should be (num_steps, num_pops, num_events)
    assert sim is not None
    assert sim.shape == (
        model_params["num_steps"],
        model_params["initial_state"].shape[0],
        model_params["incidence_matrix"].shape[1],
    )


def test_two_unit_sir(evaltest, two_unit_sir_params):
    model_params = two_unit_sir_params()
    model = DiscreteTimeStateTransitionModel(**model_params)
    sim = evaltest(lambda: model.sample(seed=0))

    assert sim is not None
    assert sim.shape == (
        model_params["num_steps"],
        model_params["initial_state"].shape[0],
        model_params["incidence_matrix"].shape[1],
    )


def test_log_prob_and_grads(evaltest, homogeneous_sir_params):
    model_params = homogeneous_sir_params()

    model = DiscreteTimeStateTransitionModel(**model_params)

    eventlist = model.sample(seed=0)

    eventlist = jnp.zeros_like(eventlist)

    lp_and_grads = evaltest(
        lambda: jax.value_and_grad(model.log_prob)(eventlist)
    )

    assert lp_and_grads[0].dtype == model_params["initial_state"].dtype
    assert lp_and_grads[1].dtype == model_params["initial_state"].dtype

    # Check that all gradients wrt events are not defined

    assert lp_and_grads[1].shape == eventlist.shape
    # Note grads are mostly NaN here


def test_log_prob(evaltest, discrete_two_unit_sir_example):
    model_params = discrete_two_unit_sir_example["model_params"]

    model = DiscreteTimeStateTransitionModel(**model_params)

    lp = evaltest(lambda: model.log_prob(discrete_two_unit_sir_example["draw"]))

    actual_mean = discrete_two_unit_sir_example["log_prob"]

    assert jnp.abs(lp - actual_mean) / actual_mean < 1.0e-6  # noqa: PLR2004


def test_transition_prob_matrix(evaltest, two_unit_sir_params):
    model_params = two_unit_sir_params()
    model = DiscreteTimeStateTransitionModel(**model_params)
    events = evaltest(lambda: model.sample(seed=0))
    result = evaltest(lambda: model.transition_prob_matrix(events))

    assert result is not None
    assert result.shape == (
        model_params["num_steps"],
        model_params["initial_state"].shape[0],
        model_params["initial_state"].shape[1],
        model_params["initial_state"].shape[1],
    )


def test_model_constraints(evaltest, homogeneous_sir_params):
    num_sim = 50
    model_params = homogeneous_sir_params()

    model = DiscreteTimeStateTransitionModel(**model_params)

    eventlist = evaltest(lambda: model.sample(sample_shape=num_sim, seed=0))
    ts = evaltest(lambda: model.compute_state(eventlist))

    # Check states are all positive.
    assert (jnp.asarray(eventlist) >= 0).all()

    # Check dS/dt + dI/dt + dR/dt = 0 at each time point
    finite_diffs = jnp.sum(
        (ts[:, 1:, ...] - ts[:, :-1, ...]) / model.time_delta
    )

    assert jnp.abs(finite_diffs) < 1e-06  # noqa: PLR2004


def test_log_prob_mle(
    evaltest, two_unit_sir_params, discrete_two_unit_sir_example
):
    events = discrete_two_unit_sir_example["draw"]

    pars = jnp.array(
        [
            discrete_two_unit_sir_example["true_params"]["beta"],
            discrete_two_unit_sir_example["true_params"]["gamma"],
        ],
        dtype=discrete_two_unit_sir_example["dtype"],
    )

    def logp(pars):
        beta, gamma = jnp.unstack(jnp.exp(pars))
        model_params = two_unit_sir_params(
            beta=beta,
            phi=discrete_two_unit_sir_example["true_params"]["phi"],
            gamma=gamma,
        )
        model = DiscreteTimeStateTransitionModel(**model_params)
        return -model.log_prob(events)

    optim_results = evaltest(
        lambda: tfp.optimizer.nelder_mead_minimize(
            logp,
            initial_vertex=jax.random.uniform(
                key=jax.random.key(0), shape=2, minval=-1, maxval=1
            ),
            func_tolerance=1e-7,
        )
    )

    assert bool(optim_results.converged)
    np.testing.assert_allclose(
        np.exp(optim_results.position),
        pars,
        rtol=0.08,
        atol=0.08,
    )
