"""Test sparse distance"""

import jax.numpy as jnp
import numpy as np
import pytest
import scipy.spatial.distance as spsp

from .sp_dist import pdist, sparse_pdist


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_pdist(dtype):
    coords = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=dtype)

    expected = np.sqrt(
        np.array(
            [
                [0.0, 1.0, 1.0, 2.0],
                [1.0, 0.0, 2.0, 1.0],
                [1.0, 2.0, 0.0, 1.0],
                [2.0, 1.0, 1.0, 0.0],
            ],
            dtype=dtype,
        )
    )

    actual = pdist(coords, coords)

    np.testing.assert_array_equal(actual, expected)


def test_sparse_pdist(coords):
    N = coords.shape[-2]
    CHUNK_SIZE = 32
    MAX_DIST = 0.101

    def include_fn(x):
        return jnp.less(0, x) & jnp.less(x, MAX_DIST)

    sparse_coords = sparse_pdist(
        coords, include_fn=include_fn, chunk_size=CHUNK_SIZE
    )

    expected_dist = spsp.squareform(spsp.pdist(coords))
    is_nonzero = np.less(0, expected_dist) & np.less(expected_dist, MAX_DIST)
    expected_nnz = np.sum(is_nonzero)

    assert sparse_coords.shape == (N, N)
    assert sparse_coords.nse == expected_nnz
