import jax
import numpy as np

from gemlib.spatial import sparse_pdist


N = 188361
CHUNK_SIZE = 512
MAX_DIST = 0.01

coords = np.random.uniform(size=(N, 2))

with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
    sp_m = sparse_pdist(
        coords, include_fn=lambda x: x < MAX_DIST, chunk_size=CHUNK_SIZE
    )
