numpy>=1.21.0
scipy>=1.7.0
matplotlib>=3.5.0
pandas>=1.3.0
sympy>=1.9.0
numba>=0.56.0
joblib>=1.1.0
torch>=1.12.0
jax>=0.3.0
jaxlib>=0.3.0
optax>=0.1.0
flax>=0.5.0
torch-geometric>=2.1.0
cupy-cuda12x>=12.0.0

[dev]
pytest>=7.0.0
pytest-cov>=4.0.0
flake8>=5.0.0
black>=22.0.0
mypy>=0.991
pre-commit>=2.20.0

[docs]
sphinx>=5.0.0
sphinx-rtd-theme>=1.0.0
myst-parser>=0.18.0

[gpu]
cupy-cuda12x>=12.0.0
torch>=1.12.0
jax[cuda12]>=0.3.0

[ml]
torch>=1.12.0
jax>=0.3.0
jaxlib>=0.3.0
optax>=0.1.0
flax>=0.5.0
torch-geometric>=2.1.0
cupy-cuda12x>=12.0.0
