"""Describes a DiscreteTimeStateTransitionModel."""

from __future__ import annotations

from collections.abc import Callable, Sequence

import jax
import jax.numpy as jnp
import numpy as np
import tensorflow_probability.substrates.jax as tfp

from gemlib.distributions.discrete_markov import (
    compute_state,
    discrete_markov_log_prob,
    discrete_markov_simulation,
    make_transition_prob_matrix_fn,
)
from gemlib.func_util import maybe_combine_fn
from gemlib.math import cumsum_np
from gemlib.prng_util import sanitize_key
from gemlib.tensor_util import broadcast_fn_to
from gemlib.util import (
    batch_gather,
    transition_coords,
)

Array = jax.Array
ArrayLike = jax.typing.ArrayLike
tfd = tfp.distributions


class DiscreteTimeStateTransitionModel(tfd.Distribution):
    """Discrete-time state transition model

    A discrete-time state transition model assumes a population of
    individuals is divided into a number of mutually exclusive states,
    where transitions between states occur according to a Markov process.
    Such models are commonly found in epidemiological and ecological
    applications, where rapid implementation and modification is necessary.

    This class provides a programmable implementation of the discrete-time
    state transition model, compatible with TensorFlow Probability.


    Example
    -------
    A homogeneously mixing SIR model implementation::

        import jax.numpy as jnp
        from gemlib.distributions import DiscreteTimeStateTransitionModel

        # Initial state, counts per compartment (S, I, R), for one
        #   population
        initial_state = jnp.array([[99, 1, 0]], jnp.float32)

        # Note that the incidence_matrix is treated as a static parameter,
        # and so is supplied as a numpy ndarray, not a jax Array.
        incidence_matrix = np.array(
            [
                [-1, 0],
                [1, -1],
                [0, 1],
            ],
            dtype=np.float32,
        )


        def si_rate(t, state):
            return 0.28 * state[:, 1] / jnp.sum(state, axis=-1)


        def ir_rate(t, state):
            return 0.14


        # Instantiate model
        sir = DiscreteTimeStateTransitionModel(
            transition_rate_fn=(si_rate, ir_rate),
            incidence_matrix=incidence_matrix,
            initial_state=initial_state,
            num_steps=100,
        )

        # One realisation of the epidemic process
        sim = sir.sample(seed=0)
    """

    def __init__(
        self,
        transition_rate_fn: Sequence[Callable[[float, ArrayLike], ArrayLike]]
        | Callable[[float, ArrayLike], tuple[ArrayLike, ...]],
        incidence_matrix: np.typing.NDArray,
        initial_state: ArrayLike,
        num_steps: int,
        initial_step: int = 0,
        time_delta: float = 1.0,
        validate_args: bool = False,
        allow_nan_stats: bool = True,
        name: str = "DiscreteTimeStateTransitionModel",
    ):
        """Initialise a discrete-time state transition model.

        Args:
          transition_rate_fn: Either a sequence of ``R`` callables
            ``(t: float, state: ArrayLike) -> ArrayLike``, each returning a
            Array of rates of shape ``(N,)``, or a callable
            ``(t: float, state: ArrayLike) -> tuple[ArrayLike, ...]``,
            returning a tuple of ``R`` vectors of shape ``(N,)``.
            Here ``N`` is the number of units in the simulation.
            In the first (preferred) form, each callable returns the respective
            transition rate.
            In the second form, the single callable returns ``R`` transition
            rates.
            **Note**: the second form will be
            deprecated in future releases of ``gemlib``.
          incidence_matrix: incidence matrix of shape ``(S, R)``
            for ``S`` states and ``R`` transitions between states.
          initial_state: Initial state of the model, of shape
            ``(N,S)`` containing the counts in each of ``N`` units and
            ``S`` states.
            We require ``initial_state.shape[-1]==incidence_matrix.shape[-2]``
          num_steps: the number of time steps simulated by the model.
          initial_step: time ``t`` at the start of the simulation.
          time_delta: the duration of the discretized time step.
        """
        parameters = dict(locals())

        self._incidence_matrix = np.asarray(incidence_matrix)

        self._source_states = _compute_source_states(self._incidence_matrix)

        initial_state = jnp.asarray(initial_state)
        self._transition_prob_matrix_fn = make_transition_prob_matrix_fn(
            broadcast_fn_to(
                maybe_combine_fn(transition_rate_fn),
                jnp.shape(initial_state)[:-1],
            ),
            time_delta,
            self._incidence_matrix,
        )
        super().__init__(
            dtype=initial_state.dtype,
            reparameterization_type=tfd.FULLY_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            name=name,
        )

    @property
    def transition_rate_fn(self):
        return self._parameters["transition_rate_fn"]

    @property
    def incidence_matrix(self):
        return self._parameters["incidence_matrix"]

    @property
    def initial_state(self):
        return self._parameters["initial_state"]

    @property
    def initial_step(self):
        return self._parameters["initial_step"]

    @property
    def source_states(self):
        return self._source_states

    @property
    def time_delta(self):
        return self._parameters["time_delta"]

    @property
    def num_steps(self):
        return self._parameters["num_steps"]

    @property
    def num_units(self):
        return jnp.array(self._parameters["initial_state"]).shape[-2]

    @property
    def num_states(self):
        return jnp.array(self._parameters["initial_state"]).shape[-1]

    def _batch_shape(self):
        return []

    def _event_shape(self):
        shape = (
            int(self.num_steps),  # T
            jnp.shape(self.initial_state)[-2],  # N
            jnp.shape(self.incidence_matrix)[-1],  # S
        )
        return shape

    def compute_state(
        self, events: Array, include_final_state: bool = False
    ) -> jax.Array:
        """Computes a state timeseries from a sequence of transition events.

        Args:
            events: an array of events of shape ``(T,N,R)`` where
                ``T`` is the number of timesteps,
                ``N=self.num_units`` the number of units and
                ``R=self.incidence_matrix.shape[0]``
                the number of transitions.

            include_final_state: If ``False`` (default)
                the result does not include the final state.
                Otherwise the result includes the final state.

        Returns:
            An array of shape ``(T, N, S)`` if
                ``include_final_state==False`` or ``(T+1, N, S)``
                if ``include_final_state==True``, where ``S=self.num_states``,
                giving the number of individuals in each state at each time
                point for each unit.
        """
        return compute_state(
            incidence_matrix=self._incidence_matrix,
            initial_state=self.initial_state,
            events=events,
            closed=include_final_state,
        )

    def transition_prob_matrix(self, events: None | ArrayLike) -> Array:
        """Compute the Markov transition probability matrix.

        Args:
            events: None, or an array of shape ``(T, N, R)``, where
              ``T`` is the number of timesteps,
              ``N=self.num_units` the number of units and
              ``R=self.incidence_matrix.shape[0]``
              the number of transitions.

        Returns:
            Transition probabilty matrix. If ``events`` is None,
            this matrix is of shape ``(N, S, S)``, transition probability
            matrix associated with the initial state
            (``self.initial_state``), of shape ``(N, S)``, where
            ``S=self.num_states`` is the number of distinct states.
            Otherwise, this matrix is of shape ``(T, N, S, S)``,
            representing the transition probability matrix at each timestep.
            of shape ``(T, N, S, S)``.
        """
        if events is None:
            return self._transition_prob_matrix_fn(
                self.initial_step, self.initial_state
            )

        state = self.compute_state(events)
        times = jnp.arange(
            self.initial_step,
            self.initial_step + self.time_delta * self.num_steps,
            self.time_delta,
        )

        return jax.vmap(self._transition_prob_matrix_fn)(times, state)

    def _sample_n(self, n: int, seed: int | Array | None = None) -> Array:
        """Runs `n` simulations of the epidemic model.

        Args:
            n: number of simulations to run.
            seed: an integer or a JAX PRNG key.

        Returns:
            an array of shape ``(n, T, N, R)``
            for ``n`` samples, ``T=self.num_steps`` time steps,
            ``N=self.num_units`` units and ``R=self.transition_matrix.shape[0]``
            transitions, containing the
            number of events for each timestep, unit and transition.
        """
        key = sanitize_key(seed)
        keys = jax.random.split(sanitize_key(key), num=n)

        def one_sample(key):
            _, events = discrete_markov_simulation(
                transition_prob_matrix_fn=self._transition_prob_matrix_fn,
                state=self.initial_state,
                start=self.initial_step,
                end=self.initial_step + self.num_steps * self.time_delta,
                time_step=self.time_delta,
                seed=key,
            )
            return events

        sim = jax.vmap(one_sample)(keys)
        indices = transition_coords(self._incidence_matrix)

        return batch_gather(sim, indices)

    def _log_prob(self, y):
        y = jnp.asarray(y)

        batch_shape = y.shape[:-3]
        y_flat = y.reshape((-1,) + y.shape[-3:])

        def one_log_prob(y):
            return discrete_markov_log_prob(
                events=y,
                init_state=jnp.asarray(self.initial_state, dtype=y.dtype),
                init_step=jnp.asarray(self.initial_step, dtype=y.dtype),
                time_delta=jnp.asarray(self.time_delta, dtype=y.dtype),
                transition_prob_matrix_fn=self._transition_prob_matrix_fn,
                incidence_matrix=self._incidence_matrix,
            )

        log_probs = jax.vmap(one_log_prob)(y_flat)
        return log_probs.reshape(batch_shape)


# Unclear why we don't just use gemlib.utils.transition_coords
def _compute_source_states(
    incidence_matrix: np.ndarray, dtype=np.int32
) -> np.ndarray:
    """Computes the indices of the source states for each
       transition in a state transition model.

    Args:
        incidence_matrix (jax.Array): incidence matrix of shape ``(S, R)``
            for ``S`` states and ``R`` transitions between states.

    Returns:
        jax.Array: an array of shape ``(R,)`` containing indices of source
            states.
    """
    incidence_matrix = np.transpose(incidence_matrix)
    source_states = np.sum(
        cumsum_np(
            np.clip(-incidence_matrix, a_min=0, a_max=1),
            axis=-1,
            reverse=True,
            exclusive=True,
        ),
        axis=-1,
    )
    return source_states.astype(dtype)
