jax>=0.7.0
jaxlib>=0.7.0
flax>=0.11.2
numpy>=1.21.0

[all]
jraphx[dev,docs,examples,test]

[dev]
pytest>=7.0
pytest-xdist
black>=22.0
isort>=5.10
mypy>=0.990
ruff>=0.1.0

[docs]
sphinx==5.1.1
sphinx-autodoc-typehints==1.19.2
sphinx-copybutton>=0.5.2

[examples]
optax>=0.2.5
torch-geometric>=2.0.0
grain>=0.2.12

[test]
pytest>=7.0
pytest-xdist
optax>=0.2.5
