"""Continuous time state transition model"""

from __future__ import annotations

import math as pymath
from collections.abc import Callable, Sequence

import jax
import jax.numpy as jnp
import numpy as np
import tensorflow_probability.substrates.jax as tfp

from gemlib.distributions.continuous_markov import (
    EventList,
    compute_state,
    continuous_markov_simulation,
    continuous_time_log_likelihood,
)
from gemlib.func_util import maybe_combine_fn
from gemlib.prng_util import sanitize_key
from gemlib.tensor_util import broadcast_fn_to

# aliasing for convenience
Array = jax.Array
ArrayLike = jax.typing.ArrayLike
NDArray = np.typing.NDArray
tfd = tfp.distributions


class ContinuousTimeStateTransitionModel(tfd.Distribution):
    """Continuous time state transition model.

    Example:
        A homogeneously mixing SIR model implementation::

            import jax.numpy as jnp
            from gemlib.distributions import ContinuousTimeStateTransitionModel

            # Initial state, counts per compartment (S, I, R), for one
            #   population
            initial_state = jnp.array([[99, 1, 0]], jnp.float32)

            # Note that the incidence_matrix is treated as a static parameter,
            # and so is supplied as a numpy ndarray, not a jax Array.
            incidence_matrix = np.array(
                [
                    [-1, 0],
                    [1, -1],
                    [0, 1],
                ],
                dtype=np.float32,
            )


            def si_rate(t, state):
                return 0.28 * state[:, 1] / jnp.sum(state, axis=-1)


            def ir_rate(t, state):
                return 0.14


            # Instantiate model
            sir = ContinuousTimeStateTransitionModel(
                transition_rate_fn=(si_rate, ir_rate),
                incidence_matrix=incidence_matrix,
                initial_state=initial_state,
                num_steps=100,
            )

            # One realisation of the epidemic process
            sim = sir.sample(seed=0)

            # Compute the log probability of observing `sim`
            sir.log_prob(sim)
    """

    def __init__(
        self,
        transition_rate_fn: Sequence[Callable[[float, ArrayLike], ArrayLike]]
        | Callable[[float, ArrayLike], tuple[ArrayLike, ...]],
        incidence_matrix: NDArray,
        initial_state: Array,
        num_steps: int,
        initial_time: float = 0.0,
        validate_args: bool = False,
        allow_nan_stats: bool = True,
        name: str = "ContinuousTimeStateTransitionModel",
    ):
        """
        Initializes a ContinuousTimeStateTransitionModel object.

        Args:
          transition_rate_fn: Either a list of callables of the form
            :code:`fn(t: float, state: Tensor) -> Tensor` or a Python callable
            of the form :code:`fn(t: float, state: Tensor) -> tuple(Tensor,...)`
            .  In the first
            (preferred) form, each callable in the list  corresponds to the
            respective transition in :code:`incidence_matrix`.  In the second
            form, the callable should return a :code:`tuple` of transition rate
            tensors corresponding to transitions in :code:`incidence_matrix`.
            **Note**: the second form will be deprecated in future releases of
            :code:`gemlib`.
          incidence_matrix: Matrix representing the incidence of transitions
                            between states.
          initial_state: A :code:`[N, S]` tensor containing the initial state of
            the population of :code:`N` units in :code:`S` epidemiological
            classes.
          num_steps: the number of markov jumps for a single iteration.
            initial_time: Initial time of the model. Defaults to 0.0.
          name: Name of the model. Defaults to
                  "ContinuousTimeStateTransitionModel".

        """
        parameters = dict(locals())

        self._incidence_matrix = jnp.asarray(incidence_matrix)
        self._initial_state = initial_state
        self._initial_time = jnp.asarray(initial_time)

        self._transition_rate_fn = maybe_combine_fn(transition_rate_fn)

        dtype = EventList(
            time=self._initial_time.dtype,
            transition=jnp.int32,
            unit=jnp.int32,
        )

        super().__init__(
            dtype=dtype,
            reparameterization_type=tfd.FULLY_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            name=name,
        )

    @property
    def transition_rate_fn(self):
        """Transition rate function for the model."""
        return self._parameters["transition_rate_fn"]

    @property
    def incidence_matrix(self):
        """Incidence matrix for the model."""
        return self._parameters["incidence_matrix"]

    @property
    def initial_state(self):
        """Initial state of the model."""
        return self._parameters["initial_state"]

    @property
    def num_steps(self):
        """Number of events to simulate."""
        return self._parameters["num_steps"]

    @property
    def initial_time(self):
        """Initial wall clock for the model. Sets the time scale."""
        return self._parameters["initial_time"]

    def compute_state(
        self, event_list: EventList, include_final_state: bool = False
    ) -> jax.Array:
        """Compute state timeseries given an event list

        Args:
            event_list: the event list, assumed to be sorted by time.
            include_final_state: should the final state be included in the
                returned timeseries?  If `True`, then the time dimension of
                the returned tensor will be 1 greater than the length of the
                event list.  If `False` (default) these will be equal.

        Returns:
          A `[T, N, S]` tensor where `T` is the number of events, `N` is the
          number of units, and `S` is the number of states.
        """
        return compute_state(
            self.incidence_matrix,
            self.initial_state,
            event_list,
            include_final_state,
        )

    def _sample_n(self, n: int, seed: int | Array | None = None) -> EventList:
        """
        Samples n outcomes from the continuous time state transition model.

        Args:
            n (int): The number of realisations of the Markov process to sample
                     (currently ignored).
            seed (int, optional): The seed value for random number generation.
                                  Defaults to None.

        Returns:
            Sample(s) from the continuous time state transition model.
        """

        key = sanitize_key(seed)
        keys = jax.random.split(key, n)

        def one_sample(seed):
            return continuous_markov_simulation(
                transition_rate_fn=broadcast_fn_to(
                    self._transition_rate_fn,
                    self._initial_state.shape[:-1],
                ),
                incidence_matrix=self._incidence_matrix,
                initial_state=self._initial_state,
                initial_time=self._initial_time,
                num_markov_jumps=self.num_steps,
                seed=seed,
            )

        outcome = jax.vmap(one_sample)(keys)

        return outcome

    def _log_prob(self, value: EventList) -> Array:
        """
        Computes the log probability of the given outcomes.

        Args:
            value (EventList): an EventList object representing the
                                   outcomes.

        Returns:
            float: The log probability of the given outcomes.
        """
        value = jax.tree_util.tree_map(lambda x: jnp.asarray(x), value)

        batch_shape = value.time.shape[:-1]
        flat_shape = (
            pymath.prod(batch_shape),
            value.time.shape[-1],
        )

        value = jax.tree_util.tree_map(lambda x: x.reshape(flat_shape), value)

        def one_log_prob(x):
            return continuous_time_log_likelihood(
                transition_rate_fn=broadcast_fn_to(
                    self._transition_rate_fn,
                    self._initial_state.shape[:-1],
                ),
                incidence_matrix=self.incidence_matrix,
                initial_state=self.initial_state,
                initial_time=self.initial_time,
                event_list=x,
            )

        log_probs = jax.vmap(one_log_prob)(value)

        return jnp.reshape(log_probs, batch_shape)

    def _event_shape(self) -> EventList:
        return EventList(
            time=(self.num_steps,),
            transition=(self.num_steps,),
            unit=(self.num_steps,),
        )

    def _batch_shape(self) -> tuple:
        return ()
