Metadata-Version: 2.1
Name: eqxvision
Version: 0.1.1
Summary: Root package info.
Author-email: Contributing Authors <aditya.91.singh@gmail.com>
Requires-Python: >=3.7
Description-Content-Type: text/markdown
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3 :: Only
Requires-Dist: equinox
Requires-Dist: jax
Requires-Dist: jaxlib
Requires-Dist: torch >=1.7.1
Requires-Dist: torchvision >=0.10.0
Requires-Dist: jinja2==3.0.3 ; extra == "docs"
Requires-Dist: Markdown>=3.3 ; extra == "docs"
Requires-Dist: MarkupSafe>=1.1 ; extra == "docs"
Requires-Dist: mkdocs>=1.2 ; extra == "docs"
Requires-Dist: mkdocs-autorefs>=0.3.1 ; extra == "docs"
Requires-Dist: pymdown-extensions>=6.3 ; extra == "docs"
Requires-Dist: mkdocs==1.3.0 ; extra == "docs"
Requires-Dist: mkdocs-autorefs ; extra == "docs"
Requires-Dist: mkdocs_include_exclude_files==0.0.1 ; extra == "docs"
Requires-Dist: mkdocs-material==7.3.6 ; extra == "docs"
Requires-Dist: mkdocs-material-extensions ; extra == "docs"
Requires-Dist: mkdocstrings==0.17.0 ; extra == "docs"
Requires-Dist: mkdocstrings-python ; extra == "docs"
Requires-Dist: mkdocstrings-python-legacy ; extra == "docs"
Requires-Dist: mknotebooks==0.7.1 ; extra == "docs"
Requires-Dist: pymdown-extensions==9.4 ; extra == "docs"
Requires-Dist: pytkdocs_tweaks==0.0.5 ; extra == "docs"
Requires-Dist: pre-commit ; extra == "test"
Requires-Dist: bluepy ; extra == "test"
Requires-Dist: pytest ; extra == "test"
Requires-Dist: torch >=1.7.1 ; extra == "test"
Requires-Dist: torchvision >=0.10.0 ; extra == "test"
Project-URL: Bug Tracker, https://github.com/paganpasta/eqxvision/issues
Project-URL: Homepage, https://github.com/paganpasta/eqxvision
Provides-Extra: docs
Provides-Extra: test

# Eqxvision

Eqxvision is a package of popular computer vision model architectures built using [Equinox](https://docs.kidger.site/equinox/).

## Installation

Use the package manager [pip](https://pip.pypa.io/en/stable/) to install eqxvision.

```bash
pip install eqxvision
```

*requires:* `python>=3.7`

## Usage
???+ Example
    Importing and doing a forward pass is as simple as
    ```python
    import jax
    import jax.random as jr
    import equinox as eqx
    from eqxvision.models import alexnet
    
    @eqx.filter_jit
    def forward(net, images, key):
        keys = jax.random.split(key, images.shape[0])
        output = jax.vmap(net, axis_name=('batch'))(images, key=keys)
        ...
        
    net = alexnet(num_classes=1000)
    
    images = jr.uniform(jr.PRNGKey(0), shape=(1,3,224,224))
    output = forward(net, images, jr.PRNGKey(0))
    ```

## What's New?
- `[Experimental]`Now supports loading PyTorch weights from `torchvision` for models **without** BatchNorm

    !!! note
        Due to slight differences in the implementation of underlying operations,
        the output can differ for pretrained versions of the network.
       
## Tips
- Better to use `@equinox.jit_filter` instead of `@jax.jit`
- Advisable to use `jax.{v,p}map` with `axis_name='batch'` for all models
- Don't forget to switch to `inference` mode for evaluations
- Wrap with `eqx.filter(net, eqx.is_array)` for `Optax` initialisation.



## Contributing
Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.

Please make sure to update tests as appropriate.

## Acknowledgements
- [Equinox](https://github.com/patrick-kidger/equinox)
- [Patrick Kidger](https://github.com/patrick-kidger)
- [Torchvision](https://pytorch.org/vision/stable/index.html)

## License
[MIT](https://choosealicense.com/licenses/mit/)
