Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify the Oryx examples #468

Merged
merged 1 commit into from
Jan 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 14 additions & 15 deletions docs/examples/howto_use_oryx.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Oryx is a probabilistic programming library written in JAX, it is thus natively

We reproduce the [example in Oryx's documentation](https://www.tensorflow.org/probability/oryx/notebooks/probabilistic_programming#case_study_bayesian_neural_network) and train a Bayesian Neural Network (BNN) on the iris dataset:

```{code-cell} python
```{code-cell} ipython3
from sklearn import datasets

iris = datasets.load_iris()
Expand All @@ -28,7 +28,7 @@ num_features = features.shape[-1]
num_classes = len(iris.target_names)
```

```{code-cell} python
```{code-cell} ipython3
:tags: [hide-input]

print(f"Number of features: {num_features}")
Expand All @@ -38,7 +38,7 @@ print(f"Number of data points: {features.shape[0]}")

Oryx's approach, like Aesara's, is to implement probabilistic models as generative models and then apply transformations to get the log-probability density function. We begin with implementing a dense layer with normal prior probability on the weights and use the function `random_variable` to define random variables:

```{code-cell} python
```{code-cell} ipython3
import jax
from oryx.core.ppl import random_variable

Expand Down Expand Up @@ -68,7 +68,7 @@ def dense(dim_out, activation=jax.nn.relu):

We now use this layer to build a multi-layer perceptron. The `nest` function is used to create "scope tags" that allows in this context to re-use our `dense` layer multiple times without name collision in the dictionary that will contain the parameters:

```{code-cell} python
```{code-cell} ipython3
from oryx.core.ppl import nest


Expand All @@ -88,7 +88,7 @@ def mlp(hidden_sizes, num_classes):

Finally, we model the labels as categorical random variables:

```{code-cell} python
```{code-cell} ipython3
import functools

def predict(mlp):
Expand All @@ -101,21 +101,20 @@ def predict(mlp):
return forward
```


We can now build the BNN and sample an initial position for the inference algorithm using `joint_sample`:

```{code-cell} python
```{code-cell} ipython3
import jax.numpy as jnp
from oryx.core.ppl import joint_sample


bnn = mlp([200, 200], num_classes)
bnn = mlp([50, 50], num_classes)
initial_weights = joint_sample(bnn)(jax.random.PRNGKey(0), jnp.ones(num_features))

print(initial_weights.keys())
```

```{code-cell} python
```{code-cell} ipython3
:tags: [hide-input]

num_parameters = sum([layer.size for layer in jax.tree_util.tree_flatten(initial_weights)[0]])
Expand All @@ -124,7 +123,7 @@ print(f"Number of parameters in the model: {num_parameters}")

To sample from this model we will need to obtain its joint distribution log-probability using `joint_log_prob`:

```{code-cell} python
```{code-cell} ipython3
from oryx.core.ppl import joint_log_prob

def logdensity_fn(weights):
Expand All @@ -133,7 +132,7 @@ def logdensity_fn(weights):

We can now run the window adaptation to get good values for the parameters of the NUTS algorithm:

```{code-cell} python
```{code-cell} ipython3
%%time
import blackjax

Expand All @@ -144,7 +143,7 @@ last_state, kernel, _ = adapt.run(rng_key, initial_weights, 100)

and sample from the model's posterior distribution:

```{code-cell} python
```{code-cell} ipython3
:tags: [hide-cell]

def inference_loop(rng_key, kernel, initial_state, num_samples):
Expand All @@ -158,15 +157,15 @@ def inference_loop(rng_key, kernel, initial_state, num_samples):
return states, infos
```

```{code-cell} python
```{code-cell} ipython3
%%time

states, infos = inference_loop(rng_key, kernel, last_state, 100)
```

We can now use our samples to take an estimate of the accuracy that is averaged over the posterior distribution. We use `intervene` to "inject" the posterior values of the weights instead of sampling from the prior distribution:

```{code-cell} python
```{code-cell} ipython3
from oryx.core.ppl import intervene

posterior_weights = states.position
Expand All @@ -180,7 +179,7 @@ output_logits = jax.vmap(
output_probs = jax.nn.softmax(output_logits)
```

```{code-cell} python
```{code-cell} ipython3
:tags: [hide-input]

print('Average sample accuracy:', (
Expand Down