jax>=0.6.0
flax<0.12.0,>=0.10.6
hydra-core>=1.3.2
optax>=0.2.4
pandas>=2.3.2
pyarrow>=21.0.0
progress-table>=3.1.2
rax>=0.4.0
torch>=2.7.0

[cuda12]
jax[cuda12]>=0.6.0
