"""Utility functions for model implementation code."""

import jax.numpy as jnp
import numpy as np
from jax import Array
from jax.typing import ArrayLike

from gemlib.math import cumsum_np


def batch_gather(arr: ArrayLike, indices: ArrayLike) -> Array:
    """Batched gather of ``indices`` from the right-most dimensions of ``arr``.

    This function gathers elements on the right-most ``indices`` of ``arr``.
    If ``arr`` has dimension ``m`` and shape
    ``(a_1, a_2, ..., a_{m-d}, a_{m-d+1}, ..., a_{m})``
    and ``indices`` has dimension ``n`` and shape ``(b_1, ..., b_{n-1}, d)``,
    then the output has shape ``(a_1, ..., a_{m-d}, b_1, ..., b_{n})``,
    where ``output[i_1, ... i_{m-d}, j_1, ... , j_{n-1}] =
    a[i_1, ..., i_{m-d}, b[j_1, ..., j_{n-1}, 1], ...,
    b[j_1, ..., j_{n-1}, d]]``

    This function is equivalent to
    `` indices = jnp.movejax.Arrayaxis(jnp.asarray(indices, dtype=jnp.int32),
                                       -1, 0)
       return  arr[..., *indices]``
    but avoids fancy indexing.

    Args:
        arr (jax.Array): an ``n``-dimensional array.
        indices (jax.Array): a array of coordinates into the rightmost
             `indices.shape[-1]` dimensions of `arr`.

    Returns:
        jax.Array: Gathered values of dimension
            ``arr.ndim - indices.shape[-1] + indices.ndim - 1``
    """

    arr = jnp.asarray(arr)
    indices = jnp.asarray(indices, dtype=jnp.int32)
    index_dims = indices.shape[-1]
    flat_shape = arr.shape[:-index_dims] + (-1,)
    flat_arr = jnp.reshape(arr, flat_shape)

    stride_shape = (arr.shape[-index_dims:])[1:] + (1,)
    strides = jnp.cumprod(jnp.array(stride_shape)[::-1])[::-1]

    flat_indices = jnp.dot(indices, strides)
    # Alternatively, flat_indices = jnp.ravel_multi_indices(indices,
    #  arr.shape[:-index_dims])

    return jnp.take(flat_arr, flat_indices, axis=-1)


def transition_coords(
    incidence_matrix: np.ndarray, dtype=np.int32
) -> np.ndarray:
    """Compute coordinates of transitions in an incidence matrix or batch
    of incidence matrices.

    For each ``incidence_matrix``, return the indices of the
    source and destination states.

    Note: this requires ``incidence_matrix``
    being that for a state transition model, with a single negative (source)
    and positive (destination) entry in each column, and all other entries
    ``0``.


    Args:
        incidence_matrix(np.ndarray): an array of shape ``(..., S, R)``
            containing a single or a batch of incidence matrices,
            each representing ``R`` transitions within ``S`` states.

        dtype (type): Data type of returned array.

    Returns:
       np.ndarray: An array of shape ``(..., R, 2)``,
            where the ``[..., r, 0]`` entries is the coordinate of the source
            for the r-th rate, and ``[..., r, 1]`` entry the destination
            for the r-th rate.

    """

    is_src_dest = np.stack(
        [incidence_matrix < 0, incidence_matrix > 0], axis=-1
    )

    coords = np.sum(
        cumsum_np(
            is_src_dest.astype(dtype),
            reverse=True,
            exclusive=True,
            axis=-3,
        ),
        axis=-3,
    )
    return coords.astype(dtype)


def transition_coords_tuple(
    incidence_matrix: np.ndarray, dtype=jnp.int32
) -> np.ndarray:
    return tuple(
        map(tuple, transition_coords(incidence_matrix, dtype).tolist())
    )


def states_from_transition_idx(
    transition_index: int,
    incidence_matrix: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
    """Return source and destination state indices from an incidence matrix
    for a particular transition index.

    Given the index of a transition in ``incidence_matrix``, return
    the indices of the source and destination states.

    Note: this requires the ``incidence_matrix``
    to have one negative (source)  and one positive (destination) entry in
    each column, and all other entries ``0``.

    Args:
        transition_index (int): the index (column) of the event in
            ``incidence_matrix``.
        incidence_matrix (np.ndarray): an array of shape ``(..., S, R)``
            containing a single or a batch of incidence matrices,
            each representing ``R`` transitions within ``S`` states.


    Returns:
        tuple[np.ndarray, np.ndarray]: A pair of integer arrays of shape
            ``(..., R)``, giving the indices of the source and destination
            states for each rate.

    """
    coords = transition_coords(incidence_matrix)[..., transition_index, :]

    return coords[..., 0], coords[..., 1]
