"""Functions for chain binomial simulation."""

from collections.abc import Callable, Iterable

import jax
import jax.numpy as jnp
import numpy as np
import tensorflow_probability.substrates.jax as tfp

from gemlib.math import multiply_no_nan
from gemlib.util import transition_coords_tuple

tfd = tfp.distributions

__all__ = [
    "make_transition_prob_matrix_fn",
    "compute_state",
    "discrete_markov_log_prob",
    "discrete_markov_simulation",
]


# scatter to transition matrix
def _scatter_to_transition_matrix(
    rates: Iterable[jax.Array],
    rate_coords: tuple[tuple[int, int], ...],
    num_states: int,
) -> jax.Array:
    """Build an un-normalised Markov transition rate matrix from rates and
       coordinates

    Args:
        rates (Iterable[jax.Array]): An iterable (tuple, list, ...) of ``R``
            jax.Arrays containing transition rates.
        rate_coords (tuple[tuple[int, int],...]): A tuple of
            ``(from_state, to_state)`` coordinate pairs for each of the ``R``
            rates.
        num_states (int): The number of possible states ``S``.

    Returns:
        jax.Array: An array of shape ``rates[0].shape + (S, S)`` representing
            the un-normalised transition rate matrix.
    """

    # The each shape is (...,S) for  S compartments
    # We're going to scatter the rates into a (S,S,...) data structure
    # to avoid strided scatter
    matrix_shape = (num_states, num_states) + rates[0].shape

    output = jnp.zeros(matrix_shape, dtype=rates[0].dtype)

    # We iterate over each rate vector and coordinate pair, inserting the
    #   transition rates into the coord-th slice of the right-hand two "S"
    #   dimensions of ``output``.rate_coords
    for rate, coord in zip(rates, rate_coords, strict=True):
        output = output.at[coord[-2], coord[-1], ...].set(
            rate,
            indices_are_sorted=True,
            unique_indices=True,
            mode="promise_in_bounds",
        )

    # Transform from (S,S,...) to (...,S,S)
    return jnp.moveaxis(output, [0, 1], [-2, -1])


def _approx_expm(rates: jax.Array) -> jax.Array:
    """Approximates a full Markov transition matrix from a rate matrix
       through first-order approximation of the matrix exponential

    Args:
        rates (jax.Array): un-normalised square rate matrix (i.e. diagonal zero)
            or batch of such matrices. Accepts shapes ``(B, S, S)``
            or ``(S, S)``

    Returns:
        jax.Array: Approximation to Markov transition matrix of same shape as
            ``rates``
    """

    total_rates = jnp.sum(rates, axis=-1, keepdims=True)

    prob = 1.0 - jnp.exp(-total_rates)

    partial_matrix = multiply_no_nan(prob / total_rates, rates)

    diagonal_values = 1.0 - jnp.sum(partial_matrix, axis=-1)
    batched_dimensions = 3
    if rates.ndim == batched_dimensions:
        # Batched inputs.
        mask = jnp.eye(rates.shape[-1], dtype=bool)
        mask = jnp.broadcast_to(mask, rates.shape)
    else:
        # Single input.
        mask = jnp.eye(rates.shape[-1], dtype=bool)

    return jnp.where(mask, diagonal_values[..., None], partial_matrix)


def make_transition_prob_matrix_fn(
    transition_rate_fn: Callable[[float, jax.Array], tuple[jax.Array, ...]],
    time_delta: float,
    incidence_matrix: np.ndarray,
) -> Callable[[float, jax.Array], jax.Array]:
    """Generates a function for computing the discrete-time
    Markov transition matrix over a given time interval.

    Args:
        transition_rate_fn (Callable[[float, jax.Array], tuple[jax.Array, ...]])
            : Function which takes a scalar time ``t`` and ``state`` of shape
            ``(N, S)`` to transition rates as a tuple of ``R`` arrays of
            shape ``(N,)``.
        time_delta (float): Size of the time step.
        incidence_matrix (jax.Array): Array of shape ``(S, R)`` describing
            allowed transitions between compartments.

    Returns:
        Callable[[float, jax.Array], jax.Array]: A function that takes a scalar
            ``time`` and state of shape ``(N,S)`` to a Markov transition matrix
            of shape ``(N, S, S)``.
    """

    rate_coords = transition_coords_tuple(incidence_matrix)

    def fn(t: float, state: jax.Array) -> jax.Array:
        rates = transition_rate_fn(t, state)

        rate_matrix = _scatter_to_transition_matrix(
            rates, rate_coords, state.shape[-1]
        )
        # ``rate_matrix`` is a tensor of shape
        # ``(N, S, S)`` where  ``rate_matrix[n, i, j]``
        # gives the rate for the transition from state ``i`` to
        # state ``j`` in unit ``n``, for ``i``!=``j``.

        transition_matrix = _approx_expm(rate_matrix * time_delta)
        return transition_matrix

    return fn


@jax.custom_vjp
def _multinomial_log_prob(total_count, probs, counts):
    log_unnorm_prob = jnp.sum(multiply_no_nan(jnp.log(probs), counts), axis=-1)
    neg_log_normalizer = tfp.math.log_combinations(total_count, counts)
    output = log_unnorm_prob + neg_log_normalizer
    return output


# Forward pass
def _multinomial_log_prob_fwd(total_count, probs, counts):
    log_unnorm_prob = jnp.sum(multiply_no_nan(jnp.log(probs), counts), axis=-1)
    neg_log_normalizer = tfp.math.log_combinations(total_count, counts)
    output = log_unnorm_prob + neg_log_normalizer
    return output, (probs, counts)


# Backward pass
def _multinomial_log_prob_bwd(res, g):
    probs, counts = res
    counts = jnp.broadcast_to(counts, probs.shape)
    g = jnp.expand_dims(g, axis=-1)
    grad_probs = jnp.nan_to_num(1.0 / probs) * counts * g
    return (None, grad_probs, None)


# Registering the custom forward and backward functions for
# _multinomial_log_prob
_multinomial_log_prob.defvjp(
    _multinomial_log_prob_fwd, _multinomial_log_prob_bwd
)


def compute_state(
    initial_state: jax.Array,
    events: jax.Array,
    incidence_matrix: np.ndarray,
    closed: bool = False,
):
    """Compute the state array at multiple points in time from the initial state
       and event array.

    Args:
        initial_state (jax.Array): Initial state with shape ``(N, S)``.
        events (jax.Array): Time series of events with shape ``(T, N, R)``.
        incidence_matrix (jax.Array): a matrix with shape ``(S, R)`` describing
            how transitions update the state.
        closed (bool): if ``True``, return state at ``0..T``, otherwise
            ``0..T-1``.

    Returns:
        jax.Array: Array of shape ``(T, N, S)`` if ``closed=False`` or
            ``(T+1, N, S)`` if ``closed=True``, describing the state of the
            system at times ``0..T`` or ``0..T-1``, respectively.
    """
    increments = jnp.einsum("...tmr,sr->...tms", events, incidence_matrix)

    if not closed:
        padding = jnp.zeros_like(increments[..., :1, :, :])
        cum_increments = jnp.concatenate(
            (padding, jnp.cumsum(increments[..., :-1, :, :], axis=-3)), axis=-3
        )
    else:
        padding = jnp.zeros_like(increments[..., :1, :, :])
        cum_increments = jnp.concatenate(
            (padding, jnp.cumsum(increments, axis=-3)), axis=-3
        )
    state = cum_increments + jnp.expand_dims(initial_state, axis=-3)
    return state


def chain_binomial_propagate(
    transition_matrix_fn: Callable[[float, jax.Array], jax.Array],
) -> Callable[[float, jax.Array, jax.Array], tuple[jax.Array, jax.Array]]:
    """Propagates the state of a population according to discrete time dynamics.

    Args:
        transition_matrix_fn: a function (t: float, state: Array) -> Array
            which takes a time ``t`` and state of shape ``(N, S)``, and
            returns a Markov transition probability matrix of shape
            ``(N, S, S)``.
            A suitable function is generated by make_transition_prob_matrix_fn.

    Returns:
        A function (t: float, state: jax.Array, seed: int) ->
            tuple[jax.Array, jax.Array] that propagates a state of shape
            ``(N, S)`` at time ``t``, returning an ``events`` matrix of shape
            ``(N, S, S)`` and the state at the next time step.
    """

    def propagate_fn(t, state, seed):
        markov_transition_matrix = transition_matrix_fn(t, state)
        num_states = markov_transition_matrix.shape[-1]
        prev_probs = jnp.zeros_like(markov_transition_matrix[..., :, 0])
        counts = jnp.zeros(
            markov_transition_matrix.shape[:-1] + (0,),
            dtype=markov_transition_matrix.dtype,
        )

        total_count = state.astype(markov_transition_matrix.dtype)

        # Generates num_states -1 independent random number generator keys.
        keys = jax.random.split(seed, num_states - 1)

        for i in range(num_states - 1):
            probs = markov_transition_matrix[..., :, i]
            binom = tfd.Binomial(
                total_count=total_count,
                probs=jnp.clip(probs / (1.0 - prev_probs + 1e-10), 0.0, 1.0),
            )
            sample = binom.sample(seed=keys[i])
            counts = jnp.concatenate(
                [counts, sample[..., jnp.newaxis]], axis=-1
            )
            total_count -= sample
            prev_probs += probs

        # Final state
        counts = jnp.concatenate(
            [counts, total_count[..., jnp.newaxis]], axis=-1
        )

        # Aggregate new state
        new_state = jnp.sum(counts, axis=-2)

        return counts, new_state

    return propagate_fn


def discrete_markov_simulation(
    transition_prob_matrix_fn: Callable[[float, jax.Array], jax.Array],
    state: jax.Array,
    start: float,
    end: float,
    time_step: float,
    seed: jax.Array,
):
    """Simulates from a discrete time Markov state transition model using
    multinomial sampling across rows of the transition matrix.

    Args:
        transition_prob_matrix_fn (Callable[[float, jax.Array], jax.Array]):
            A function that takes a scalar ``time`` and state of shape
            ``(N, S)`` to a Markov transition matrix of shape ``(N, S, S)``.
            Here ``N`` is the number of units and ``S`` the dimension of the
            state of each unit.
        state (jax.Array): Initial state of shape ``(N, S)``.
        start (float): Start time of simulation.
        end (float): End time of simulation.
        time_step (float): Duration of discrete timestep.
        seed (jax.Array) : Random seed for sampling.

    Returns:
        tuple[jax.Array, jax.Array]: A vector of shape ``(T,)`` of the times
            of the start of each timestep, and an array of shape
            ``(T, N, S, S)`` containing the transitions within each timestep.
    """

    state = jnp.asarray(state)
    if state.ndim == 1:
        state = state[None, ...]

    propagate = chain_binomial_propagate(transition_prob_matrix_fn)

    times = jnp.arange(start, end, time_step, dtype=state.dtype)
    keys = jax.random.split(seed, times.shape[0])

    def scan_fn(state, elems):
        t, current_key = elems
        event_counts, new_state = propagate(t, state, current_key)
        new_state = new_state.astype(state.dtype)
        return new_state, event_counts

    final_state, all_events = jax.lax.scan(scan_fn, state, (times, keys))

    return times, all_events


def discrete_markov_log_prob(
    events: jax.Array,
    init_state: jax.Array,
    init_step: float,
    time_delta: float,
    transition_prob_matrix_fn: Callable[[float, jax.Array], jax.Array],
    incidence_matrix: np.ndarray,
) -> jax.Array:
    """Calculates an unnormalised log_prob function for a discrete time epidemic
       model.

    Args:
        events (jax.Array): a ``(T, N, R)`` array of transition events,
            where the simulation has ``T`` times, ``N`` units and ``R``
            transitions.
        init_state (jax.Array): a vector of shape ``(N, S)``, giving the initial
            state of the epidemic, where ``S`` is the number of states.
        init_step (float): the initial time.
        time_delta (float): the duration of the discrete time step.
        transition_prob_matrix_fn (Callable[[float, jax.Array], jax.Array]):
            A function that takes a scalar time and state of shape ``(N, S)``
            to a Markov transition matrix of shape ``(N, S, S)``
        incidence_matrix (np.ndarray): a matrix with shape ``(S, R)`` describing
            how transitions update the state.
    Returns:
        jax.Array: A scalar log probability for these events given the initial
            state and model.
    """

    num_times = events.shape[-3]
    num_states = init_state.shape[-1]

    # Construct the state at all timepoints, of shape ``(T, N, S)``
    state_timeseries = compute_state(init_state, events, incidence_matrix)

    times = init_step + time_delta * jnp.arange(num_times)

    transition_prob_matrix = jax.vmap(transition_prob_matrix_fn)(
        times, state_timeseries
    )

    rate_coords = transition_coords_tuple(incidence_matrix)

    # Calculate ``event_matrix```` of shape ``(T, N, S, S)``, where
    # ``event_matrix[t, n, i, j]```` is the number of items moving from state
    # ``i```` to ``j`` at timestep ``t`` and in unit ``n``.

    event_matrix = _scatter_to_transition_matrix(
        jnp.unstack(events, axis=-1), rate_coords, num_states
    )

    # Setting the diagonal
    diagonal = state_timeseries - jnp.sum(event_matrix, axis=-1)
    idx = jnp.arange(event_matrix.shape[-1])
    event_matrix = event_matrix.at[..., idx, idx].set(diagonal)

    logp = _multinomial_log_prob(
        state_timeseries, transition_prob_matrix, event_matrix
    )
    return jnp.sum(logp)
