diff --git a/docs/examples/howto_use_oryx.md b/docs/examples/howto_use_oryx.md index 581b3b0b3..c63ed99f3 100644 --- a/docs/examples/howto_use_oryx.md +++ b/docs/examples/howto_use_oryx.md @@ -19,7 +19,7 @@ Oryx is a probabilistic programming library written in JAX, it is thus natively We reproduce the [example in Oryx's documentation](https://www.tensorflow.org/probability/oryx/notebooks/probabilistic_programming#case_study_bayesian_neural_network) and train a Bayesian Neural Network (BNN) on the iris dataset: -```{code-cell} python +```{code-cell} ipython3 from sklearn import datasets iris = datasets.load_iris() @@ -28,7 +28,7 @@ num_features = features.shape[-1] num_classes = len(iris.target_names) ``` -```{code-cell} python +```{code-cell} ipython3 :tags: [hide-input] print(f"Number of features: {num_features}") @@ -38,7 +38,7 @@ print(f"Number of data points: {features.shape[0]}") Oryx's approach, like Aesara's, is to implement probabilistic models as generative models and then apply transformations to get the log-probability density function. We begin with implementing a dense layer with normal prior probability on the weights and use the function `random_variable` to define random variables: -```{code-cell} python +```{code-cell} ipython3 import jax from oryx.core.ppl import random_variable @@ -68,7 +68,7 @@ def dense(dim_out, activation=jax.nn.relu): We now use this layer to build a multi-layer perceptron. The `nest` function is used to create "scope tags" that allows in this context to re-use our `dense` layer multiple times without name collision in the dictionary that will contain the parameters: -```{code-cell} python +```{code-cell} ipython3 from oryx.core.ppl import nest @@ -88,7 +88,7 @@ def mlp(hidden_sizes, num_classes): Finally, we model the labels as categorical random variables: -```{code-cell} python +```{code-cell} ipython3 import functools def predict(mlp): @@ -101,21 +101,20 @@ def predict(mlp): return forward ``` - We can now build the BNN and sample an initial position for the inference algorithm using `joint_sample`: -```{code-cell} python +```{code-cell} ipython3 import jax.numpy as jnp from oryx.core.ppl import joint_sample -bnn = mlp([200, 200], num_classes) +bnn = mlp([50, 50], num_classes) initial_weights = joint_sample(bnn)(jax.random.PRNGKey(0), jnp.ones(num_features)) print(initial_weights.keys()) ``` -```{code-cell} python +```{code-cell} ipython3 :tags: [hide-input] num_parameters = sum([layer.size for layer in jax.tree_util.tree_flatten(initial_weights)[0]]) @@ -124,7 +123,7 @@ print(f"Number of parameters in the model: {num_parameters}") To sample from this model we will need to obtain its joint distribution log-probability using `joint_log_prob`: -```{code-cell} python +```{code-cell} ipython3 from oryx.core.ppl import joint_log_prob def logdensity_fn(weights): @@ -133,7 +132,7 @@ def logdensity_fn(weights): We can now run the window adaptation to get good values for the parameters of the NUTS algorithm: -```{code-cell} python +```{code-cell} ipython3 %%time import blackjax @@ -144,7 +143,7 @@ last_state, kernel, _ = adapt.run(rng_key, initial_weights, 100) and sample from the model's posterior distribution: -```{code-cell} python +```{code-cell} ipython3 :tags: [hide-cell] def inference_loop(rng_key, kernel, initial_state, num_samples): @@ -158,7 +157,7 @@ def inference_loop(rng_key, kernel, initial_state, num_samples): return states, infos ``` -```{code-cell} python +```{code-cell} ipython3 %%time states, infos = inference_loop(rng_key, kernel, last_state, 100) @@ -166,7 +165,7 @@ states, infos = inference_loop(rng_key, kernel, last_state, 100) We can now use our samples to take an estimate of the accuracy that is averaged over the posterior distribution. We use `intervene` to "inject" the posterior values of the weights instead of sampling from the prior distribution: -```{code-cell} python +```{code-cell} ipython3 from oryx.core.ppl import intervene posterior_weights = states.position @@ -180,7 +179,7 @@ output_logits = jax.vmap( output_probs = jax.nn.softmax(output_logits) ``` -```{code-cell} python +```{code-cell} ipython3 :tags: [hide-input] print('Average sample accuracy:', (