jax==0.7.1
flax==0.12.0
optax==0.2.6
transformers==4.56.2
