import jax.numpy as jnp

from gemlib.distributions.continuous_markov import (
    EventList,
    _one_hot_expand_state,
    _total_flux,
    compute_state,
)

INCIDENCE_MATRIX = jnp.array([[-1, 0], [1, -1], [0, 1]])


def rate_fn(_, state):
    si_rate = 0.25 * state[:, 1] / jnp.sum(state, axis=-1)
    ir_rate = 0.25 * state[:, 2]
    return si_rate, ir_rate


initial_state = jnp.array([[90, 10, 0]])


def test_one_hot_expand_state(evaltest):
    condensed_state = jnp.array([9, 1, 0])
    result = evaltest(lambda: _one_hot_expand_state(condensed_state))
    assert jnp.all(result.sum(axis=0) == condensed_state)
    assert result.shape == (10, 3)


def test_total_flux(evaltest):
    result = evaltest(
        lambda: _total_flux(
            transition_rates=rate_fn(2, initial_state),
            state=initial_state,
            incidence_matrix=INCIDENCE_MATRIX,
        )
    )
    assert result.shape == (2, 1)
    assert jnp.all(result >= 0.0)


def test_compute_state_output_shape(evaltest):
    """Testing it produces the expected shape on a valid input"""
    initial_state = jnp.array([[99, 1, 0]])
    events = EventList(
        time=jnp.array([1.0, 2.0]),
        transition=jnp.array([0, 1]),
        unit=jnp.array([0, 0]),
    )
    result = evaltest(
        lambda: compute_state(
            incidence_matrix=INCIDENCE_MATRIX,
            initial_state=initial_state,
            event_list=events,
        )
    )
    assert result.shape == (2, 1, 3)


def test_compute_state_output_value(evaltest):
    """Testing it produces the expected output on a valid input"""
    initial_state = jnp.array([[100, 0, 0]])
    events = EventList(
        time=jnp.array([1.0, 2.0]),
        transition=jnp.array([0, 1]),
        unit=jnp.array([0, 0]),
    )
    expected_values = jnp.array([[[100, 0, 0]], [[99, 1, 0]]])
    result = evaltest(
        lambda: compute_state(
            incidence_matrix=INCIDENCE_MATRIX,
            initial_state=initial_state,
            event_list=events,
        )
    )
    assert result.shape == expected_values.shape
    assert jnp.allclose(result, expected_values)


def test_compute_state_closed(evaltest):
    """Testing valid output when closed argument is true"""
    initial_state = jnp.array([[100, 0, 0]])
    events = EventList(
        time=jnp.array([1.0, 2.0]),
        transition=jnp.array([0, 1]),
        unit=jnp.array([0, 0]),
    )
    expected_values = jnp.array([[[100, 0, 0]], [[99, 1, 0]], [[99, 0, 1]]])
    result = evaltest(
        lambda: compute_state(
            incidence_matrix=INCIDENCE_MATRIX,
            initial_state=initial_state,
            event_list=events,
            include_final_state=True,
        )
    )
    assert jnp.allclose(result, expected_values)


def test_compute_state_multiple_batches(evaltest):
    """Testing valid output shape and values when passing
    multiple batches."""
    initial_state = jnp.array([[100, 0, 0], [50, 10, 5]])
    events = EventList(
        time=jnp.array([1.0, 1.3, 2.4, 3.3]),
        transition=jnp.array([0, 1, 1, 0]),
        unit=jnp.array([0, 1, 0, 1]),
    )
    expected_values = jnp.array(
        [
            [[100, 0, 0], [50, 10, 5]],
            [[99, 1, 0], [50, 10, 5]],
            [[99, 1, 0], [50, 9, 6]],
            [[99, 0, 1], [50, 9, 6]],
        ]
    )
    results = evaltest(
        lambda: compute_state(
            incidence_matrix=INCIDENCE_MATRIX,
            initial_state=initial_state,
            event_list=events,
        )
    )
    assert results.shape == expected_values.shape
    assert jnp.allclose(results, expected_values)
