numpy
typing_extensions
saiunit==0.0.16

[cpu]
jax[cpu]

[cuda12]
jax[cuda12]

[testing]
pytest

[tpu]
jax[tpu]
