import jax.numpy as jnp
import numpy as np
from jax import random

from gemlib.distributions.discrete_markov import (
    compute_state,
    discrete_markov_log_prob,
    discrete_markov_simulation,
    make_transition_prob_matrix_fn,
)

INCIDENCE_MATRIX = incidence_matrix = np.array([[-1, 0], [1, -1], [0, 1]])
TIME_DELTA = 1.0


def test_compute_state_output_shape(evaltest):
    """Testing it produces the expected shape on a valid input"""
    initial_state = jnp.array([[99, 1, 0]])
    events = jnp.array([[[2.0, 0.0]], [[0.0, 1.0]]])
    result = evaltest(
        lambda: compute_state(initial_state, events, INCIDENCE_MATRIX)
    )
    assert result.shape == (2, 1, 3)


def test_compute_state_output_value(evaltest):
    """Testing it produces the expected output on a valid input"""
    initial_state = jnp.array([[100, 0, 0]])
    events = jnp.array([[[1, 0]], [[0, 1]]])
    expected_values = jnp.array([[[100, 0, 0]], [[99, 1, 0]]])
    result = evaltest(
        lambda: compute_state(initial_state, events, INCIDENCE_MATRIX)
    )
    assert result.shape == expected_values.shape
    assert jnp.allclose(result, expected_values)


def test_compute_state_closed(evaltest):
    """Testing valid output when closed argument is true"""
    initial_state = jnp.array([[100, 0, 0]])
    events = jnp.array([[[1, 0]], [[0, 1]]])
    expected_values = jnp.array([[[100, 0, 0]], [[99, 1, 0]], [[99, 0, 1]]])
    result = evaltest(
        lambda: compute_state(
            initial_state, events, INCIDENCE_MATRIX, closed=True
        )
    )
    assert jnp.allclose(result, expected_values)


def test_compute_state_multiple_batches(evaltest):
    """Testing valid output shape and values when passing
    multiple batches."""
    initial_state = jnp.array([[100, 0, 0], [50, 10, 5]])
    events = jnp.array([[[1, 0], [0, 2]], [[0, 1], [1, 0]]])
    expected_values = jnp.array(
        [[[100, 0, 0], [50, 10, 5]], [[99, 1, 0], [50, 8, 7]]]
    )
    results = evaltest(
        lambda: compute_state(initial_state, events, INCIDENCE_MATRIX)
    )
    assert results.shape == expected_values.shape
    assert jnp.allclose(results, expected_values)


def test_compute_state_all_zeroes(evaltest):
    """Testing when events are all zeroes"""
    initial_state = jnp.array([[10, 9, 8]])
    events = jnp.zeros((4, 1, 2))
    expected_values = jnp.array(
        [[[10, 9, 8]], [[10, 9, 8]], [[10, 9, 8]], [[10, 9, 8]]]
    )
    result = evaltest(
        lambda: compute_state(initial_state, events, INCIDENCE_MATRIX)
    )
    assert result.shape == expected_values.shape
    assert jnp.allclose(result, expected_values)


def test_transition_prob_matrix(evaltest):
    """Testing it produces the expected shape on a single valid input"""

    def transition_rate_fn(_t, _state):
        return (jnp.array([0.2]), jnp.array([0.1]))

    state = jnp.array([[10, 1, 0]])
    expected_values = jnp.array(
        [
            [
                [0.8187308, 0.18126923, 0.0],
                [0.0, 0.9048374, 0.09516257],
                [0.0, 0.0, 1.0],
            ]
        ]
    )

    tpmf = make_transition_prob_matrix_fn(
        transition_rate_fn, TIME_DELTA, INCIDENCE_MATRIX
    )
    results = evaltest(lambda: tpmf(0, state))
    assert results.shape == expected_values.shape
    assert jnp.allclose(results, expected_values)


def test_transition_prob_matrix_batched(evaltest):
    """Testing correctness on a batched valid input"""

    def transition_rate_fn(_t, _state):
        rates = (jnp.array([0.2, 0.3, 0.1]), jnp.array([0.1, 0.4, 0.3]))
        return rates

    state = jnp.array([[10, 1, 0], [15, 2, 0], [8, 2, 0]])
    expected_values = jnp.array(
        [
            [
                [0.8187308, 0.18126923, 0.0],
                [0.0, 0.9048374, 0.09516257],
                [0.0, 0.0, 1.0],
            ],
            [
                [0.7408182, 0.2591818, 0.0],
                [0.0, 0.67032003, 0.32967997],
                [0.0, 0.0, 1.0],
            ],
            [
                [0.90483737, 0.09516263, 0.0],
                [0.0, 0.7408182, 0.2591818],
                [0.0, 0.0, 1.0],
            ],
        ]
    )
    tpmf = make_transition_prob_matrix_fn(
        transition_rate_fn, TIME_DELTA, INCIDENCE_MATRIX
    )

    results = evaltest(lambda: tpmf(0, state))

    assert results.shape == expected_values.shape
    assert jnp.allclose(results, expected_values)


def test_transition_prob_matrix_all_zeroes(evaltest):
    """Testing when state is all zeroes"""

    def transition_rate_fn(_t, _state):
        return (jnp.array([0.0]), jnp.array([0.0]))

    state = jnp.array([[0, 0, 0]])
    expected_values = jnp.expand_dims(jnp.eye(3), axis=0)
    tpmf = make_transition_prob_matrix_fn(
        transition_rate_fn, TIME_DELTA, INCIDENCE_MATRIX
    )
    results = evaltest(lambda: tpmf(0, state))
    assert results.shape == expected_values.shape
    assert jnp.allclose(results, expected_values)

    # def test_markov_simulation_shape(evaltest):
    """Testing we get expected shape on valid input"""

    def transition_rate_fn(_t, _state):
        return (jnp.array([0.2]), jnp.array([0.1]))

    state = jnp.array([10, 1, 0])
    tpmf = make_transition_prob_matrix_fn(
        transition_rate_fn, TIME_DELTA, INCIDENCE_MATRIX
    )
    key = random.key(0)
    times, all_events = evaltest(
        lambda: discrete_markov_simulation(
            tpmf, state, start=0.0, end=2.0, time_step=1.0, seed=key
        )
    )
    assert times.shape == (2,)
    assert all_events.shape == (2, 1, 3, 3)

    # def test_markov_simulation_batch_shape(evaltest):
    """Testing we get expected shape on batched input"""

    def transition_rate_fn(_t, _state):
        return (jnp.array([0.2] * 3), jnp.array([0.1] * 3))

    states = jnp.array([[10, 1, 0], [20, 0, 0], [100, 5, 0]])
    tpmf = make_transition_prob_matrix_fn(
        transition_rate_fn, TIME_DELTA, INCIDENCE_MATRIX
    )
    key = random.key(0)
    times, all_events = evaltest(
        lambda: discrete_markov_simulation(
            tpmf, states, start=0.0, end=4.0, time_step=1.0, seed=key
        )
    )
    assert times.shape == (4,)
    assert all_events.shape == (4, 3, 3, 3)

    # def test_markov_simulation_all_zeroes(evaltest):
    """Testing we get expected values when transition
    rate is 0."""

    def transition_rate_fn(_t, _state):
        return (jnp.array([0.0]), jnp.array([0.0]))

    state = jnp.array([[10, 1, 0]])
    tpmf = make_transition_prob_matrix_fn(
        transition_rate_fn, TIME_DELTA, INCIDENCE_MATRIX
    )
    key = random.key(0)
    times, all_events = evaltest(
        lambda: discrete_markov_simulation(
            tpmf, state, start=0.0, end=3.0, time_step=1.0, seed=key
        )
    )
    state_matrix = jnp.diag(state[0])
    expected_values = jnp.broadcast_to(state_matrix, all_events.shape)
    assert jnp.allclose(all_events, expected_values)


def test_discrete_markov_log_prob(evaltest):
    init_state = jnp.array([[1, 0, 0]])

    def transition_rate_fn(_t, _state):
        return (jnp.array([1.0]), jnp.array([1.0]))

    init_step = 0.0
    time_delta = 1.0

    tpmf = make_transition_prob_matrix_fn(
        transition_rate_fn, TIME_DELTA, INCIDENCE_MATRIX
    )

    events_1 = jnp.array([[[1, 0]], [[0, 1]]])

    lp = evaltest(
        lambda: discrete_markov_log_prob(
            events_1, init_state, init_step, time_delta, tpmf, INCIDENCE_MATRIX
        )
    )

    assert np.isfinite(lp)

    events_1 = jnp.array([[[0, 1]], [[1, 0]]])

    lp = evaltest(
        lambda: discrete_markov_log_prob(
            events_1, init_state, init_step, time_delta, tpmf, INCIDENCE_MATRIX
        )
    )

    assert ~np.isfinite(lp)


def test_discrete_markov_log_prob_batches(evaltest):
    init_state = jnp.array([[1, 0, 0], [0, 1, 0]])

    def transition_rate_fn(_t, _state):
        return (jnp.array([1.0]), jnp.array([1.0]))

    init_step = 0.0
    time_delta = 1.0

    tpmf = make_transition_prob_matrix_fn(
        transition_rate_fn, TIME_DELTA, INCIDENCE_MATRIX
    )

    events_1 = jnp.array([[[1, 0], [0, 0]], [[0, 0], [0, 1]]])

    lp = evaltest(
        lambda: discrete_markov_log_prob(
            events_1, init_state, init_step, time_delta, tpmf, INCIDENCE_MATRIX
        )
    )

    assert np.isfinite(lp)

    events_1 = jnp.array([[[1, 0], [0, 0]], [[0, 1], [0, 1]]])

    lp = evaltest(
        lambda: discrete_markov_log_prob(
            events_1, init_state, init_step, time_delta, tpmf, INCIDENCE_MATRIX
        )
    )

    assert np.isfinite(lp)

    events_2 = jnp.array([[[1, 0], [0, 0]], [[0, 1], [1, 0]]])

    lp = evaltest(
        lambda: discrete_markov_log_prob(
            events_2, init_state, init_step, time_delta, tpmf, INCIDENCE_MATRIX
        )
    )

    assert ~np.isfinite(lp)


def test_discrete_markov_log_prob_time_dependent(evaltest):
    init_state = jnp.array([[1, 0, 0]])

    def transition_rate_fn(t, _state):
        return ((t > 0) * jnp.array([1.0]), (t > 0) * jnp.array([1.0]))

    init_step = 0.0
    time_delta = 1.0

    tpmf = make_transition_prob_matrix_fn(
        transition_rate_fn, TIME_DELTA, INCIDENCE_MATRIX
    )

    events_1 = jnp.array([[[1, 0]], [[0, 1]]])

    lp = evaltest(
        lambda: discrete_markov_log_prob(
            events_1, init_state, init_step, time_delta, tpmf, INCIDENCE_MATRIX
        )
    )

    assert ~np.isfinite(lp)

    events_2 = jnp.array([[[0, 0]], [[1, 0]]])

    lp = evaltest(
        lambda: discrete_markov_log_prob(
            events_2, init_state, init_step, time_delta, tpmf, INCIDENCE_MATRIX
        )
    )

    assert np.isfinite(lp)
