diff --git a/tutorial/source/custom_objectives.ipynb b/tutorial/source/custom_objectives_training.ipynb similarity index 90% rename from tutorial/source/custom_objectives.ipynb rename to tutorial/source/custom_objectives_training.ipynb index dacc7a7414..4ad07c9b41 100644 --- a/tutorial/source/custom_objectives.ipynb +++ b/tutorial/source/custom_objectives_training.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Custom SVI Objectives\n", + "# Customizing SVI objectives and training loops\n", "\n", "Pyro provides support for various optimization-based approaches to Bayesian inference, with `Trace_ELBO` serving as the basic implementation of SVI (stochastic variational inference).\n", "See the [docs](http://docs.pyro.ai/en/dev/inference_algos.html#module-pyro.infer.svi) for more information on the various SVI implementations and SVI \n", @@ -13,7 +13,7 @@ "and [III](http://pyro.ai/examples/svi_part_iii.html) for background on SVI.\n", "\n", "In this tutorial we show how advanced users can modify and/or augment the variational\n", - "objectives (alternatively: loss functions) provided by Pyro to support special use cases." + "objectives (alternatively: loss functions) and the training step implementation provided by Pyro to support special use cases." ] }, { @@ -57,19 +57,28 @@ "- `SVI.step()` zeros gradients between gradient steps\n", "\n", "If we want more control, we can directly manipulate the differentiable loss method of \n", - "the various `ELBO` classes. For example, (assuming we know all the parameters in advance) \n", - "this is equivalent to the previous code snippet:\n", + "the various `ELBO` classes. For example, this optimization loop:\n", + "```python\n", + "svi = pyro.infer.SVI(model, guide, optimizer, loss=pyro.infer.Trace_ELBO())\n", + "for i in range(n_iter):\n", + " loss = svi.step(X_train, y_train)\n", + "```\n", + "is equivalent to this low-level pattern:\n", "\n", "```python\n", - "# define optimizer and loss function\n", - "optimizer = torch.optim.Adam(my_parameters, {\"lr\": 0.001, \"betas\": (0.90, 0.999)})\n", - "loss_fn = pyro.infer.Trace_ELBO().differentiable_loss\n", - "# compute loss\n", - "loss = loss_fn(model, guide, model_and_guide_args)\n", - "loss.backward()\n", - "# take a step and zero the parameter gradients\n", - "optimizer.step()\n", - "optimizer.zero_grad()\n", + "loss_fn = lambda model, guide: pyro.infer.Trace_ELBO().differentiable_loss(model, guide, X_train, y_train)\n", + "with pyro.poutine.trace(param_only=True) as param_capture:\n", + " loss = loss_fn(model, guide)\n", + "params = set(site[\"value\"].unconstrained()\n", + " for site in param_capture.trace.nodes.values())\n", + "optimizer = torch.optim.Adam(params, lr=0.001, betas=(0.90, 0.999))\n", + "for i in range(n_iter):\n", + " # compute loss\n", + " loss = loss_fn(model, guide)\n", + " loss.backward()\n", + " # take a step and zero the parameter gradients\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", "```\n", "\n", "## Example: Custom Regularizer\n", @@ -106,6 +115,7 @@ "+ optimizer = pyro.optim.Adam({\"lr\": 0.001, \"betas\": (0.90, 0.999)}, {\"clip_norm\": 10.0})\n", "```\n", "\n", + "Further variants of gradient clipping can also be implemented manually by modifying the low-level pattern described above.\n", "\n", "## Example: Scaling the Loss\n", "\n",