numpy>=1.21.0
scipy>=1.7.0

[all]
jax>=0.4.0
jaxlib>=0.4.0
torch>=2.0.0
optax>=0.1.0
flax>=0.7.0
chex>=0.1.0

[dev]
pytest>=7.0
pytest-benchmark>=4.0
black>=23.0
isort>=5.0
mypy>=1.0

[examples]
matplotlib>=3.5.0
jupyter>=1.0.0
seaborn>=0.11.0

[jax]
jax>=0.4.0
jaxlib>=0.4.0
optax>=0.1.0
flax>=0.7.0
chex>=0.1.0

[torch]
torch>=2.0.0
