jax==0.8.0
jaxdecomp==0.2.9
tensorflow
orbax-checkpoint==0.11.28
optax==0.2.6
jaxopt
numpy
matplotlib
setuptools>=70.1.1
