"""ODE solver with gemlib state transition model interface"""

from __future__ import annotations

from collections.abc import Callable
from typing import NamedTuple

import jax.numpy as jnp
import numpy as np
import tensorflow_probability.substrates.jax as tfp
from jax import Array
from jax.typing import ArrayLike

from gemlib.func_util import maybe_combine_fn
from gemlib.tensor_util import broadcast_fn_to
from gemlib.util import batch_gather, transition_coords

__all__ = ["ode_model"]


def _total_flux(
    transition_rates: ArrayLike, state: ArrayLike, incidence_matrix: ArrayLike
) -> Array:
    """Multiplies `transition_rates` by source `state`s to return
       the total flux along transitions given `state`.

    Args
    ----
        transition_rates: a `[R,N]` tensor of per-unit transition rates
                          for `R` transitions and `N` aggregation units.
        state: a `[N, S]` tensor of `N` aggregation units and `S` states.
        incidence_matrix: a `[S, R]` matrix describing the change in `S` for
                          each transition `R`.

    Returns
    -------
    A [R,N] tensor of total flux along each transition, taking into account the
    availability of units in the source state.
    """

    source_state_idx = transition_coords(np.array(incidence_matrix))[:, 0]
    source_states = batch_gather(
        state, indices=source_state_idx[:, jnp.newaxis]
    )
    transition_rates = jnp.stack(transition_rates, axis=-1)

    return jnp.einsum("...nr,...nr->...rn", transition_rates, source_states)


class ODEResults(NamedTuple):
    times: Array
    states: Array


def ode_model(
    transition_rate_fn: list[Callable[[float, ArrayLike], Array]]
    | Callable[[float, ArrayLike], tuple[Array]],
    incidence_matrix: ArrayLike,
    initial_state: ArrayLike,
    num_steps: int | None = None,
    initial_time: float = 0.0,
    time_delta: float = 1.0,
    times: ArrayLike | None = None,
    solver: str = "DormandPrince",
    solver_kwargs: dict | None = None,
) -> ODEResults:
    """Solve a system of differential equations

    Args:
      transition_rate_fn: Either a list of callables of the form
        :code:`fn(t: float, state: Tensor) -> Tensor` or a Python callable
        of the form :code:`fn(t: float, state: Tensor) -> tuple(Tensor,...)`
        .  In the first
        (preferred) form, each callable in the list  corresponds to the
        respective transition in :code:`incidence_matrix`.  In the second
        form, the callable should return a :code:`tuple` of transition rate
        tensors corresponding to transitions in :code:`incidence_matrix`.
        **Note**: the second form will be deprecated in future releases of
        :code:`gemlib`.
      incidence_matrix: a :code:`[S, R]` matrix describing the change in
        :code:`S` resulting from transitions :code:`R`.
      initial_state: a :code:`[...,N, S]` (batched) tensor with the state
        values for :code:`N` units and :code:`S` states.
      num_steps: python integer representing the size of the time step to be
        used.
      initial_time: an offset giving the time of the first time step in the
        model.
      time_delta: the size of the time step to be used.
      times: a 1-D tensor of times for which the ODE solutions are required.
      solver: a string giving the ODE solver method to use.  Can be "rk45"
        (default) or "BDF".  See the `TensorFlow Probability
        documentation`_ for details.
      solver_kwargs: a dictionary of keyword argument to supply to the
        solver. See the solver documentation for details.
      validate_args: check that the values of the parameters supplied to the
        constructor are all within the domain of the ODE function
      name: the name of this distribution.

    .. _TensorFlow Probability documentation:
           https://www.tensorflow.org/probability/api_docs/python/tfp/math/ode
    """

    if (num_steps is not None) and (times is not None):
        raise ValueError("Must specify exactly one of `num_steps` or `times`")

    if num_steps is not None:
        times = jnp.arange(initial_time, time_delta * num_steps, time_delta)
    elif times is not None:
        times = jnp.asarray(times)
    else:
        raise ValueError("Must specify either `num_steps` or `times`")

    transition_rate_fn = maybe_combine_fn(transition_rate_fn)

    if solver_kwargs is None:
        solver_kwargs = {}

    if solver == "DormandPrince":
        solver_fn = tfp.math.ode.DormandPrince(**solver_kwargs)
    elif solver == "BDF":
        solver_fn = tfp.math.ode.BDF(**solver_kwargs)
    else:
        raise ValueError("`solver` must be one of 'DormandPrince' or 'BDF'")

    def derivs(t, state):
        rates = broadcast_fn_to(transition_rate_fn, initial_state.shape[:-1])(
            t, state
        )
        flux = _total_flux(rates, state, incidence_matrix)
        derivs = jnp.linalg.matmul(incidence_matrix, flux)
        return derivs.T

    solver_results = solver_fn.solve(
        ode_fn=derivs,
        initial_time=initial_time,
        initial_state=initial_state,
        solution_times=times,
    )

    return ODEResults(
        times=solver_results.times,
        states=solver_results.states,
    )
