"""Compute a sparse distance matrix given coordinates"""

from __future__ import annotations

from collections.abc import Callable
from functools import partial

import jax
import jax.experimental.sparse as jsp
import jax.numpy as jnp
import numpy as np
from tqdm import tqdm

__all__ = ["pdist", "sparse_pdist"]

Array = jax.Array
ArrayLike = jax.typing.ArrayLike
BooleanArray = np.typing.NDArray[np.bool]


def pdist(a: ArrayLike, b: ArrayLike) -> Array:
    """Compute the Euclidean distance between a and b

    Args:
        a: a :code:`[N, D]` tensor of coordinates
        b: a :code:`[M, D]` tensor of coordinates

    Returns:
        A :code:`[N, M]` matrix of Euclidean distances between
        coordinates.
    """
    delta = a[..., np.newaxis, :] - b[np.newaxis, ...]
    sqdist = jnp.sum(delta * delta, axis=-1)

    return jnp.sqrt(sqdist)


def include_all(x: np.typing.ArrayLike):
    x_ = jnp.asarray(x)
    return jnp.full(x_.shape, True)


@partial(jax.jit, static_argnums=2)
def _pdist_indices_mask(a, b, include_fn):
    values = pdist(a, b).flatten()
    mask = include_fn(values)

    flat_index = jnp.cumsum(mask) - 1
    row_idx = flat_index // a.shape[0]
    col_idx = flat_index % b.shape[0]
    indices = jnp.stack([row_idx, col_idx], axis=-1)
    return values, indices, mask


def compress_distance(a, b, include_fn):
    """Return a sparse tensor containing all distances
    between :code:`a` and :code:`b` less than :code:`max_dist`.
    """
    values, indices, mask = _pdist_indices_mask(a, b, include_fn)

    return values[mask], indices[mask]


def sparse_pdist(
    coords: ArrayLike,
    include_fn: Callable[[ArrayLike], BooleanArray] = include_all,
    chunk_size: int | None = None,
) -> Array:
    """Compute a sparse distance matrix

    Compute a sparse Euclidean distance matrix between all pairs of
    :code:`coords` such that the distance is less than :code:`max_dist`.

    Args:
        coords: a ``[N, D]`` array of coordinates
        include_fn: a callable that takes a float representing the distance
                    between two points, and returns :code:`True` if the
                    distance should be included as a "non-zero" element
                    of the returned sparse matrix.
        batch_size: If memory is limited, compute the distances in batches
                    of ``[batch_size, N]`` stripes.

    Returns:
        A sparse tensor of Euclidean distances satisfying `include_fn`.

    Example:

        >>> import numpy as np
        >>> from gemlib.spatial import sparse_pdist
        >>> coords = np.random.uniform(size=(1000, 2))
        >>> d_sparse = sparse_pdist(coords, max_dist=0.01, batch_size=200)
        >>> d_sparse
        SparseTensor(indices=tf.Tensor(
        [[  0   0]
         [  1   1]
         [  2   2]
         ...
         [997 997]
         [998 998]
         [999 999]], shape=(1316, 2), dtype=int64), values=tf.Tensor(
        [0.00000000e+00 2.22044605e-16 0.00000000e+00 ... 0.00000000e+00
         0.00000000e+00 0.00000000e+00], shape=(1316,), dtype=float64),
        dense_shape=tf.Tensor([1000 1000], shape=(2,), dtype=int64))

    """
    coords = np.asarray(coords)
    num_coords = coords.shape[-2]

    if chunk_size is None:
        chunk_size = num_coords

    cpu = jax.devices("cpu")[0]
    values_accum = []
    indices_accum = []

    for i in tqdm(
        range(0, num_coords, chunk_size),
        unit="rows",
        unit_scale=chunk_size,
        miniters=1,
    ):
        j = np.minimum(i + chunk_size, num_coords)
        values, indices = compress_distance(coords[i:j], coords, include_fn)
        values_accum.append(jax.device_put(values, cpu))
        indices_accum.append(jax.device_put(indices, cpu))

    res = jsp.BCOO(
        (jnp.concatenate(values_accum, 0), jnp.concatenate(indices_accum, 0)),
        shape=(num_coords, num_coords),
        indices_sorted=True,
        unique_indices=True,
    )

    return res
