Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom loss documentation #3122

Merged
merged 3 commits into from
Jul 31, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Custom SVI Objectives\n",
"# Customizing SVI objectives and training loops\n",
"\n",
"Pyro provides support for various optimization-based approaches to Bayesian inference, with `Trace_ELBO` serving as the basic implementation of SVI (stochastic variational inference).\n",
"See the [docs](http://docs.pyro.ai/en/dev/inference_algos.html#module-pyro.infer.svi) for more information on the various SVI implementations and SVI \n",
Expand All @@ -13,7 +13,7 @@
"and [III](http://pyro.ai/examples/svi_part_iii.html) for background on SVI.\n",
"\n",
"In this tutorial we show how advanced users can modify and/or augment the variational\n",
"objectives (alternatively: loss functions) provided by Pyro to support special use cases."
"objectives (alternatively: loss functions) and the training step implementation provided by Pyro to support special use cases."
]
},
{
Expand Down Expand Up @@ -57,19 +57,28 @@
"- `SVI.step()` zeros gradients between gradient steps\n",
"\n",
"If we want more control, we can directly manipulate the differentiable loss method of \n",
"the various `ELBO` classes. For example, (assuming we know all the parameters in advance) \n",
"this is equivalent to the previous code snippet:\n",
"the various `ELBO` classes. For example, this optimization loop:\n",
"```python\n",
"svi = pyro.infer.SVI(model, guide, optimizer, loss=pyro.infer.Trace_ELBO())\n",
"for i in range(n_iter):\n",
" loss = svi.step(X_train, y_train)\n",
"```\n",
"is equivalent to this low-level pattern:\n",
"\n",
"```python\n",
"# define optimizer and loss function\n",
"optimizer = torch.optim.Adam(my_parameters, {\"lr\": 0.001, \"betas\": (0.90, 0.999)})\n",
"loss_fn = pyro.infer.Trace_ELBO().differentiable_loss\n",
"# compute loss\n",
"loss = loss_fn(model, guide, model_and_guide_args)\n",
"loss.backward()\n",
"# take a step and zero the parameter gradients\n",
"optimizer.step()\n",
"optimizer.zero_grad()\n",
"loss_fn = lambda model, guide: pyro.infer.Trace_ELBO().differentiable_loss(model, guide, X_train, y_train)\n",
"with pyro.poutine.trace(param_only=True) as param_capture:\n",
" loss = loss_fn(model, guide)\n",
"params = set(site[\"value\"].unconstrained()\n",
" for site in param_capture.trace.nodes.values())\n",
"optimizer = torch.optim.Adam(params, lr=0.001, betas=(0.90, 0.999))\n",
"for i in range(n_iter):\n",
" # compute loss\n",
" loss = loss_fn(model, guide)\n",
" loss.backward()\n",
" # take a step and zero the parameter gradients\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"```\n",
"\n",
"## Example: Custom Regularizer\n",
Expand Down Expand Up @@ -106,6 +115,7 @@
"+ optimizer = pyro.optim.Adam({\"lr\": 0.001, \"betas\": (0.90, 0.999)}, {\"clip_norm\": 10.0})\n",
"```\n",
"\n",
"Further variants of gradient clipping can also be implemented manually by modifying the low-level pattern described above.\n",
"\n",
"## Example: Scaling the Loss\n",
"\n",
Expand Down