Metadata-Version: 2.1
Name: bayeux-ml
Version: 0.1.13
Summary: Stitching together probabilistic models and inference.
Keywords: 
Author-email: bayeux authors <bayeux@google.com>
Requires-Python: >=3.9
Description-Content-Type: text/markdown
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Intended Audience :: Science/Research
Requires-Dist: jax>=0.4.6
Requires-Dist: tensorflow-probability[jax]>=0.19.0
Requires-Dist: oryx>=0.2.7
Requires-Dist: arviz
Requires-Dist: optax
Requires-Dist: optimistix
Requires-Dist: blackjax
Requires-Dist: flowmc>=0.3.0
Requires-Dist: numpyro
Requires-Dist: jaxopt
Requires-Dist: pymc
Requires-Dist: pytest ; extra == "dev"
Requires-Dist: pytest-xdist ; extra == "dev"
Requires-Dist: pylint>=2.6.0 ; extra == "dev"
Requires-Dist: pyink ; extra == "dev"
Requires-Dist: mkdocs==1.5.3 ; extra == "docs"
Requires-Dist: mkdocs-material==9.5.11 ; extra == "docs"
Requires-Dist: pymdown-extensions==10.7 ; extra == "docs"
Requires-Dist: mkdocstrings==0.24.1 ; extra == "docs"
Requires-Dist: mknotebooks==0.8.0 ; extra == "docs"
Project-URL: changelog, https://github.com/jax-ml/bayeux/blob/main/CHANGELOG.md
Project-URL: documentation, https://jax-ml.github.io/bayeux
Project-URL: homepage, https://github.com/jax-ml/bayeux
Project-URL: repository, https://github.com/jax-ml/bayeux
Provides-Extra: dev
Provides-Extra: docs

# Bayeux

*Stitching together models and samplers*

[![Unittests](https://github.com/jax-ml/bayeux/actions/workflows/pytest_and_autopublish.yml/badge.svg)](https://github.com/jax-ml/bayeux/actions/workflows/pytest_and_autopublish.yml)
[![PyPI version](https://badge.fury.io/py/bayeux_ml.svg)](https://badge.fury.io/py/bayeux_ml)

`bayeux` lets you write a probabilistic model in JAX and immediately have access to state-of-the-art inference methods. The API aims to be **simple**, **self descriptive**, and **helpful**. Simply provide a log density function (which doesn't even have to be normalized), along with a single point (specified as a [pytree](https://jax.readthedocs.io/en/latest/pytrees.html)) where that log density is finite. Then let `bayeux` do the rest!

## Installation

```bash
pip install bayeux-ml
```
## Quickstart

We define a model by providing a log density in JAX. This could be defined using a probabilistic programming language (PPL) like [numpyro](examples/numpyro_and_bayeux), [PyMC](examples/pymc_and_bayeux), [TFP](examples/tfp_and_bayeux), distrax, oryx, coix, or directly in JAX.

```python
import bayeux as bx
import jax

normal_density = bx.Model(
  log_density=lambda x: -x*x,
  test_point=1.)

seed = jax.random.key(0)

opt_results = normal_density.optimize.optax_adam(seed=seed)
# OR!
idata = normal_density.mcmc.numpyro_nuts(seed=seed)
# OR!
surrogate_posterior, loss = normal_density.vi.tfp_factored_surrogate_posterior(seed=seed)
```

## Read more

* [Defining models](inference)
* [Inspecting models](inspecting)
* [Testing and debugging](debug_mode)
* Also see `bayeux` integration with [numpyro](examples/numpyro_and_bayeux), [PyMC](examples/pymc_and_bayeux), and [TFP](examples/tfp_and_bayeux)!


*This is not an officially supported Google product.*

