"""Mathematical operators and functions"""

import jax
import jax.numpy as jnp
import numpy as np

__all__ = ["cumsum"]


def cumsum(
    x: np.typing.ArrayLike,
    axis: int = 0,
    exclusive: bool = False,
    reverse: bool = False,
) -> np.typing.ArrayLike:
    """Compute the cumulative sum of :code:`x` along :code:`axis`

    Args:
      x: input array
      axis: Axis along which to compute the cumulative sum
      exclusive: if :code:`True`, perform exclusive cumsum
      reverse: if :code:`True`, perform a reverse cumsum

    Returns:
      An array of the same type and shape as :code:`x`
    """

    x = jnp.asarray(x)
    if axis < 0:  # Turn neg axis into pos axis
        axis = x.ndim + axis

    if reverse is True:
        x = jnp.flip(x, axis=axis)

    if exclusive is False:
        return jnp.cumsum(x, axis)

    zeros_slice = tuple(
        slice(0, 1) if i == axis else slice(None) for i in range(x.ndim)
    )

    return jnp.roll(jnp.cumsum(x, axis), 1, axis=axis).at[zeros_slice].set(0)


def cumsum_np(
    x: np.typing.ArrayLike,
    axis: int = 0,
    exclusive: bool = False,
    reverse: bool = False,
) -> np.typing.ArrayLike:
    """Compute the cumulative sum of :code:`x` along :code:`axis`

    Args:
      x: input array
      axis: Axis along which to compute the cumulative sum
      exclusive: if :code:`True`, perform exclusive cumsum
      reverse: if :code:`True`, perform a reverse cumsum

    Returns:
      An array of the same type and shape as :code:`x`
    """

    x = np.asarray(x)
    if axis < 0:  # Turn neg axis into pos axis
        axis = x.ndim + axis

    if reverse is True:
        x = np.flip(x, axis=axis)

    if exclusive is False:
        return np.cumsum(x, axis)

    zeros_slice = tuple(
        slice(0, 1) if i == axis else slice(None) for i in range(x.ndim)
    )

    res = np.roll(np.cumsum(x, axis), 1, axis=axis)
    res[zeros_slice] = 0
    return res


def multiply_no_nan(x: jax.Array, y: jax.Array) -> jax.Array:
    """Implements the equivalent function to tf.multiply_no_nan.
        Appears to compile to the same HLO.

    Args:
        x (jax.Array): First argument
        y (jax.Array): Second argument

    Returns:
        jax.Array: The elementwise product of x and y, except where y is zero
            return zero.
    """
    prod = x * y
    return jnp.where(y == 0, jnp.zeros_like(prod), prod)
