"""Tests for mathematical operaetors and functions"""

import numpy as np
import pytest

from gemlib.math import cumsum, multiply_no_nan


@pytest.fixture
def tensor3d():
    return np.arange(1, 25).reshape((2, 3, 4))


def test_cumsum_simple(tensor3d):
    # Test basic cumsum on inner dimension
    np.testing.assert_array_equal(
        cumsum(tensor3d, axis=-1), np.cumsum(tensor3d, axis=-1)
    )

    # Test basic cumsum on middle dimension
    np.testing.assert_array_equal(
        cumsum(tensor3d, axis=-2), np.cumsum(tensor3d, axis=-2)
    )


@pytest.mark.parametrize("axis", [0, 1, -1])
def test_cumsum_exclusive(tensor3d, axis):
    # Compute the required slice
    ndims = tensor3d.ndim
    pos_axis = ndims + axis if axis < 0 else axis  # Count axis from left
    slice_idx = tuple(
        slice(None, -1) if i == pos_axis else slice(None) for i in range(ndims)
    )
    expected = np.pad(
        np.cumsum(tensor3d, axis=axis)[slice_idx],
        pad_width=[(1, 0) if i == pos_axis else (0, 0) for i in range(ndims)],
    )
    np.testing.assert_array_equal(
        cumsum(tensor3d, axis=axis, exclusive=True), expected
    )


def test_cumsum_reverse(tensor3d):
    # Test reverse cumsum on inner dimension
    expected = np.cumsum(np.flip(tensor3d, axis=-1), axis=-1)
    np.testing.assert_array_equal(
        cumsum(tensor3d, axis=-1, reverse=True), expected
    )


def test_multiply_no_nan_zero_y():
    x = np.array([-np.inf, np.nan, 1, +np.inf])
    y = np.zeros(4)
    expected = np.array([0, 0, 0, 0])
    np.testing.assert_array_equal(multiply_no_nan(x, y), expected)


def test_multiply_no_nan_nonzero_y():
    x = np.array([-np.inf, np.nan, 1, +np.inf])
    y = np.array([0.0, 1.0, 2.0, 3.0])
    expected = np.array([0, np.nan, 2, +np.inf])
    np.testing.assert_array_equal(multiply_no_nan(x, y), expected)


def test_multiply_no_nan_scalar():
    x = np.nan
    y = 0
    expected = 0
    np.testing.assert_array_equal(multiply_no_nan(x, y), expected)


def test_multiply_no_nan_scalar_nonzero_y():
    x = np.nan
    y = 1
    expected = np.nan
    np.testing.assert_array_equal(multiply_no_nan(x, y), expected)
