Skip to content

Commit

Permalink
Add documentation site and two example notebooks.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 595462010
  • Loading branch information
ColCarroll authored and The bayeux Authors committed Jan 3, 2024
1 parent 6d88459 commit d5e570f
Show file tree
Hide file tree
Showing 8 changed files with 1,183 additions and 159 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/build_docs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: Build docs

on:
push:
branches:
- main

jobs:
build:
strategy:
matrix:
python-version: [ 3.11 ]
os: [ ubuntu-latest ]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install .[docs]
- name: Build docs
run: |
mkdocs build
mkdocs build # twice, see https://github.com/patrick-kidger/pytkdocs_tweaks
- name: Upload docs
uses: actions/upload-artifact@v2
with:
name: docs
path: site # where `mkdocs build` puts the built site
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,6 @@ poetry.lock

# PyCharm
.idea

# mkdocs static files
site/
273 changes: 120 additions & 153 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,112 +5,147 @@
[![Unittests](https://github.com/jax-ml/bayeux/actions/workflows/pytest_and_autopublish.yml/badge.svg)](https://github.com/jax-ml/bayeux/actions/workflows/pytest_and_autopublish.yml)
[![PyPI version](https://badge.fury.io/py/bayeux_ml.svg)](https://badge.fury.io/py/bayeux_ml)

The goal of `bayeux` is to allow users to write a model in JAX and use
best-in-class Bayesian inference methods on it. The API aims to be simple, self
descriptive, and helpful. The user is required to supply a (possibly
unnormalized) log density, along with a single
[pytree](https://jax.readthedocs.io/en/latest/pytrees.html), such that the log
density of that point is a finite scalar.
`bayeux` lets you write a probabilistic model in JAX and immediately have access to state-of-the-art inference methods. The API aims to be **simple**, **self descriptive**, and **helpful**. Simply provide a log density function (which doesn't even have to be normalized), along with a single point (specified as a [pytree](https://jax.readthedocs.io/en/latest/pytrees.html)) where that log density is finite. Then let `bayeux` do the rest!

## Installation

```bash
pip install bayeux-ml
```
## Quickstart

We define a model by providing a log density in JAX. This could be defined using a probabilistic programming language (PPL) like numpyro, PyMC, TFP, distrax, oryx, coix, or directly in JAX.

```python
import bayeux as bx
import jax

normal_model = bx.Model(
normal_density = bx.Model(
log_density=lambda x: -x*x,
test_point=1.)
```

Already, we can optimize this density, by supplying just a `jax.PRNGKey`:
```python
params, state = normal_model.optimize.jaxopt_lbfgs(seed=jax.random.PRNGKey(0))
seed = jax.random.PRNGKey(0)
```

In a similar way, we can run MCMC:
## Simple
Every inference algorithm in `bayeux` will (try to) run with just a seed as an argument:

```python
idata = normal_model.mcmc.numpyro_nuts(seed=jax.random.PRNGKey(0))
opt_results = normal_density.optimize.optax_adam(seed=seed)
# OR!
idata = normal_density.mcmc.numpyro_nuts(seed=seed)
# OR!
surrogate_posterior, loss = normal_density.vi.tfp_factored_surrogate_posterior(seed=seed)
```

A few things to note:

- The `model.mcmc` namespace tab completes with the available MCMC algorithms,
each of which will sample with sensible defaults after supplying a seed.
Running `print(model.mcmc)` or `print(model.optimize)` will list available
methods.
- The return value for MCMC is an `arviz.InferenceData` object, which hooks
into the [ArviZ library](https://python.arviz.org/) for analysis of the
sampling. More on InferenceData
[here](https://python.arviz.org/en/stable/getting_started/XarrayforArviZ.html#xarray-for-arviz).
- The return value for optimization is a `namedtuple` with fields `params` and
`state`. The `params` will be the optimization results over a batch of
particles (given by the `num_particles` argument). Given the variety of
diagnostics for the provided optimization algorithms, the user may need to
consult documentation for the given library to interpret the `state`.

In case we need to constrain some of the arguments, the `bx.Model` class accepts
an optional `transform_fn` argument. This should be an invertible JAX function
of a pytree real number into the support of the `log_density`. For example,
An (only rarely) optional third argument to `bx.Model` is `transform_fn`, which maps a real number to the support of the distribution. The [oryx](https://github.com/jax-ml/oryx) library is used to automatically compute the inverse and Jacobian determinants for changes of variables, but the user can supply these if known.

```python
half_normal_model = bx.Model(
log_density=lambda x: -x*x,
test_point=1.,
transform_fn=jnp.exp)
half_normal_density = bx.Model(
lambda x: -x*x,
test_point=1.,
transform_fn=jax.nn.softplus)
```

## Using with TFP on JAX
## Self descriptive

```python
import numpy as np
Since `bayeux` is built on top of other fantastic libraries, it tries not to get in the way of them. Each algorithm has a `.get_kwargs()` method that tells you how it will be called, and what functions are being called:

## Generate linear data
np.random.seed(0)

ndims = 5
ndata = 100
X = np.random.randn(ndata, ndims)
w_ = np.random.randn(ndims) # hidden
noise_ = 0.1 * np.random.randn(ndata) # hidden
```python
normal_density.optimize.jaxopt_bfgs.get_kwargs()

{jaxopt._src.bfgs.BFGS: {'value_and_grad': False,
'has_aux': False,
'maxiter': 500,
'tol': 0.001,
'stepsize': 0.0,
'linesearch': 'zoom',
'linesearch_init': 'increase',
'condition': None,
'maxls': 30,
'decrease_factor': None,
'increase_factor': 1.5,
'max_stepsize': 1.0,
'min_stepsize': 1e-06,
'implicit_diff': True,
'implicit_diff_solve': None,
'jit': True,
'unroll': 'auto',
'verbose': False},
'extra_parameters': {'chain_method': 'vectorized',
'num_particles': 8,
'num_iters': 1000,
'apply_transform': True}}
```

y_obs = X.dot(w_) + noise_
If you pass an argument into `.get_kwargs()`, this will also tell you what will be passed on to the actual algorithms.

## Write a joint distribution in TFP and condition on the data
@tfd.JointDistributionCoroutineAutoBatched
def tfd_model():
sigma = yield tfd.HalfNormal(1, name='sigma')
w = yield tfd.Sample(tfd.Normal(0, sigma), sample_shape=ndims, name='w')
yield tfd.Normal(jnp.einsum('...jk,...k->...j', X, w), 0.1, name='y')
```
normal_density.mcmc.blackjax_nuts.get_kwargs(
num_chains=5,
target_acceptance_rate=0.99)
tfd_model = tfd_model.experimental_pin(y=y_obs)
test_point = tfd_model.sample_unpinned(seed=jax.random.PRNGKey(1))
transform = lambda pt: pt._replace(sigma=jnp.exp(pt.sigma))
{<blackjax.adaptation.window_adaptation.window_adaptation: {'is_mass_matrix_diagonal': True,
'initial_step_size': 1.0,
'target_acceptance_rate': 0.99,
'progress_bar': False,
'algorithm': blackjax.mcmc.nuts.nuts},
blackjax.mcmc.nuts.nuts: {'max_num_doublings': 10,
'divergence_threshold': 1000,
'integrator': blackjax.mcmc.integrators.velocity_verlet,
'step_size': 0.01},
'extra_parameters': {'chain_method': 'vectorized',
'num_chains': 5,
'num_draws': 500,
'num_adapt_draws': 500,
'return_pytree': False}}
```

## Sample the model with bayeux
model = bx.Model(tfd_model.unnormalized_log_prob, test_point, transform_fn=transform)
A full list of available algorithms and how to call them can be seen with

idata = model.mcmc.numpyro_nuts(seed=jax.random.PRNGKey(2))
```python
print(normal_density)

mcmc
.blackjax_hmc
.blackjax_nuts
.blackjax_hmc_pathfinder
.blackjax_nuts_pathfinder
.numpyro_hmc
.numpyro_nuts
optimize
.jaxopt_bfgs
.jaxopt_gradient_descent
.jaxopt_lbfgs
.jaxopt_nonlinear_cg
.optax_adabelief
.optax_adafactor
.optax_adagrad
.optax_adam
.optax_adamw
.optax_adamax
.optax_amsgrad
.optax_fromage
.optax_lamb
.optax_lion
.optax_noisy_sgd
.optax_novograd
.optax_radam
.optax_rmsprop
.optax_sgd
.optax_sm3
.optax_yogi
vi
.tfp_factored_surrogate_posterior

## Analyze with arviz
az.summary(idata)
```

| | mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat |
|:------|-------:|------:|---------:|----------:|------------:|----------:|-----------:|-----------:|--------:|
| sigma | 0.673 | 0.245 | 0.323 | 1.136 | 0.003 | 0.002 | 9016 | 5410 | 1 |
| w[0] | 0.372 | 0.01 | 0.353 | 0.391 | 0 | 0 | 9402 | 6572 | 1 |
| w[1] | -0.035 | 0.011 | -0.055 | -0.015 | 0 | 0 | 10148 | 6196 | 1 |
| w[2] | 1.094 | 0.01 | 1.075 | 1.114 | 0 | 0 | 11007 | 6085 | 1 |
| w[3] | -0.234 | 0.01 | -0.254 | -0.217 | 0 | 0 | 9604 | 6456 | 1 |
| w[4] | -0.339 | 0.01 | -0.359 | -0.32 | 0 | 0 | 10838 | 6539 | 1 |

## Helpful features
## Helpful

### Debug mode

Each sampler has a `debug` method, which checks for common problems:
Algorithms come with a built-in `debug` mode that attempts to fail quickly and in a manner that might help debug problems quickly. The signature for `debug` accepts `verbosity` and `catch_exceptions` arguments, as well as a `kwargs` dictionary that the user plans to pass to the algorithm itself.

```python
normal_model.mcmc.numpyro_nuts.debug(seed=jax.random.PRNGKey(0))
normal_density.mcmc.numpyro_nuts.debug(seed=seed)

Checking test_point shape ✓
Computing test point log density ✓
Expand All @@ -125,16 +160,16 @@ Computing gradients of transformed log density ✓
True
```

You can additionally pass higher verbosity for more information, or keywords
that you plan to pass to the sampler. Here is a badly specified model:

Here is an example of a bad model with a higher verbosity:
```python
import jax.numpy as jnp

bad_model = bx.Model(
log_density=jnp.sqrt,
test_point=-1.)

model.mcmc.blackjax_nuts.debug(jax.random.PRNGKey(0),
verbosity=3, kwargs={"num_chains": 17})
bad_model.mcmc.blackjax_nuts.debug(jax.random.PRNGKey(0),
verbosity=3, kwargs={"num_chains": 17})

Checking test_point shape ✓
Test point has shape
Expand All @@ -148,17 +183,18 @@ Array(nan, dtype=float32, weak_type=True)

Loading keyword arguments...
Keyword arguments are
{<function window_adaptation at 0x7fa4da751d80>: {'algorithm': <class 'blackjax.mcmc.nuts.nuts'>,
{<function window_adaptation at 0x77feef9308b0>: {'algorithm': <class 'blackjax.mcmc.nuts.nuts'>,
'initial_step_size': 1.0,
'is_mass_matrix_diagonal': True,
'progress_bar': False,
'target_acceptance_rate': 0.8},
'extra_parameters': {'chain_method': 'vectorized',
'num_adapt_draws': 500,
'num_chains': 17,
'num_draws': 500},
'num_draws': 500,
'return_pytree': False},
<class 'blackjax.mcmc.nuts.nuts'>: {'divergence_threshold': 1000,
'integrator': <function velocity_verlet at 0x7fa4f0bbfac0>,
'integrator': <function velocity_verlet at 0x77feefbf4b80>,
'max_num_doublings': 10,
'step_size': 0.01}}
✓✓✓✓✓✓✓✓✓✓
Expand Down Expand Up @@ -190,10 +226,6 @@ Transformed state log density has shape
(17,)
✓✓✓✓✓✓✓✓✓✓

Comparing transformed log density to untransformed ×
Log density mismatch of up to nan
××××××××××

Computing gradients of transformed log density ×
The gradient contains NaNs! Initial gradients has shape
(17,)
Expand All @@ -202,70 +234,5 @@ The gradient contains NaNs! Initial gradients has shape
False
```

Note that a verbosity of 0 will just return a boolean of whether the model seems
ok to run. The goal is to detect all possible problems before starting an
inference run -- please report errors that are not caught!

### Keyword inspection

Since `bayeux` aims to connect model specification with inference algorithms,
there may be many functions from different libraries that are called. A user can
inspect the functions and keywords via the `get_kwargs` argument:

```python
normal_model.mcmc.blackjax_nuts_pathfinder.get_kwargs()

{<function blackjax.adaptation.pathfinder_adaptation.pathfinder_adaptation>: {'algorithm': blackjax.mcmc.nuts.nuts,
'initial_step_size': 1.0,
'target_acceptance_rate': 0.8},
blackjax.mcmc.nuts.nuts: {'divergence_threshold': 1000,
'integrator': <function blackjax.mcmc.integrators.velocity_verlet>,
'max_num_doublings': 10,
'step_size': 0.01},
'extra_parameters': {'chain_method': 'vectorized',
'num_adapt_draws': 500,
'num_chains': 8,
'num_draws': 500}}
```

Note that some of the keys describe functions, but the calling conventions of
libraries are diverse enough that this is not quite dynamically generated
(though the defaults are automatically pulled from the libraries via
`inspect.getsignature`). Keywords can be overridden here, and all keywords
passed to the sampler will be just passed to `.get_kwargs`, so you can check
beforehand what arguments are being used, or save them for repeatability.

```python
normal_model.mcmc.numpyro_hmc.get_kwargs(target_accept_prob=0.99)

{numpyro.infer.hmc.HMC: {'adapt_mass_matrix': True,
'adapt_step_size': True,
'dense_mass': False,
'find_heuristic_step_size': False,
'forward_mode_differentiation': False,
'init_strategy': <function numpyro.infer.initialization.init_to_uniform>,
'inverse_mass_matrix': None,
'kinetic_fn': None,
'model': None,
'num_steps': None,
'regularize_mass_matrix': True,
'step_size': 1.0,
'target_accept_prob': 0.99,
'trajectory_length': 6.283185307179586},
numpyro.infer.mcmc.MCMC: {'chain_method': 'vectorized',
'jit_model_args': False,
'num_chains': 8,
'num_samples': 1000,
'num_warmup': 500,
'postprocess_fn': None,
'progress_bar': True,
'thinning': 1}}
```

Note that *every* subkey matching a name gets replaced: for example, in
`blackjax`, multiple functions accept the same keyword arguments, and there are
reasonable reasons to want them to be different, but that's not possible here
yet. Also, some subkeys may not be honored. For example, `step_size` may get
adapted, and will overwrite the user-provided value.

*This is not an officially supported Google product.*
Loading

0 comments on commit d5e570f

Please sign in to comment.