-
-
Notifications
You must be signed in to change notification settings - Fork 984
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Bump to version 1.5.2 (#2755) * PyTorch Lightning example * fixes * fix test * update comments * fix pip install pyro-ppl * address comments * add svi_lightning to toctree --------- Co-authored-by: Fritz Obermeyer <[email protected]>
- Loading branch information
1 parent
9afb089
commit c6851b8
Showing
7 changed files
with
153 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# Distributed training via Pytorch Lightning. | ||
# | ||
# This tutorial demonstrates how to distribute SVI training across multiple | ||
# machines (or multiple GPUs on one or more machines) using the PyTorch Lightning | ||
# library. PyTorch Lightning enables data-parallel training by aggregating stochastic | ||
# gradients at each step of training. We focus on integration between PyTorch Lightning and Pyro. | ||
# For further details on distributed computing with PyTorch Lightning, see | ||
# https://lightning.ai/docs/pytorch/latest | ||
# | ||
# This assumes you have installed pytorch lightning, e.g. via | ||
# pip install pyro-ppl[lightning] | ||
|
||
import argparse | ||
|
||
import pytorch_lightning as pl | ||
import torch | ||
|
||
import pyro | ||
import pyro.distributions as dist | ||
from pyro.infer import Trace_ELBO | ||
from pyro.infer.autoguide import AutoNormal | ||
from pyro.nn import PyroModule | ||
|
||
|
||
# We define a model as usual, with no reference to Pytorch Lightning. | ||
# This model is data parallel and supports subsampling. | ||
class Model(PyroModule): | ||
def __init__(self, size): | ||
super().__init__() | ||
self.size = size | ||
|
||
def forward(self, covariates, data=None): | ||
coeff = pyro.sample("coeff", dist.Normal(0, 1)) | ||
bias = pyro.sample("bias", dist.Normal(0, 1)) | ||
scale = pyro.sample("scale", dist.LogNormal(0, 1)) | ||
|
||
# Since we'll use a distributed dataloader during training, we need to | ||
# manually pass minibatches of (covariates,data) that are smaller than | ||
# the full self.size. In particular we cannot rely on pyro.plate to | ||
# automatically subsample, since that would lead to all workers drawing | ||
# identical subsamples. | ||
with pyro.plate("data", self.size, len(covariates)): | ||
loc = bias + coeff * covariates | ||
return pyro.sample("obs", dist.Normal(loc, scale), obs=data) | ||
|
||
|
||
# We define an ELBO loss, a PyTorch optimizer, and a training step in our PyroLightningModule. | ||
# Note that we are using a PyTorch optimizer instead of a Pyro optimizer and | ||
# we are using ``training_step`` instead of Pyro's SVI machinery. | ||
class PyroLightningModule(pl.LightningModule): | ||
def __init__(self, loss_fn: pyro.infer.elbo.ELBOModule, lr: float): | ||
super().__init__() | ||
self.loss_fn = loss_fn | ||
self.model = loss_fn.model | ||
self.guide = loss_fn.guide | ||
self.lr = lr | ||
self.predictive = pyro.infer.Predictive( | ||
self.model, guide=self.guide, num_samples=1 | ||
) | ||
|
||
def forward(self, *args): | ||
return self.predictive(*args) | ||
|
||
def training_step(self, batch, batch_idx): | ||
"""Training step for Pyro training.""" | ||
loss = self.loss_fn(*batch) | ||
# Logging to TensorBoard by default | ||
self.log("train_loss", loss) | ||
return loss | ||
|
||
def configure_optimizers(self): | ||
"""Configure an optimizer.""" | ||
return torch.optim.Adam(self.loss_fn.parameters(), lr=self.lr) | ||
|
||
|
||
def main(args): | ||
# Create a model, synthetic data, a guide, and a lightning module. | ||
pyro.set_rng_seed(args.seed) | ||
pyro.settings.set(module_local_params=True) | ||
model = Model(args.size) | ||
covariates = torch.randn(args.size) | ||
data = model(covariates) | ||
guide = AutoNormal(model) | ||
loss_fn = Trace_ELBO()(model, guide) | ||
training_plan = PyroLightningModule(loss_fn, args.learning_rate) | ||
|
||
# Create a dataloader. | ||
dataset = torch.utils.data.TensorDataset(covariates, data) | ||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size) | ||
|
||
# All relevant parameters need to be initialized before ``configure_optimizer`` is called. | ||
# Since we used AutoNormal guide our parameters have not be initialized yet. | ||
# Therefore we initialize the model and guide by running one mini-batch through the loss. | ||
mini_batch = dataset[: args.batch_size] | ||
loss_fn(*mini_batch) | ||
|
||
# Run stochastic variational inference using PyTorch Lightning Trainer. | ||
trainer = pl.Trainer.from_argparse_args(args) | ||
trainer.fit(training_plan, train_dataloaders=dataloader) | ||
|
||
|
||
if __name__ == "__main__": | ||
assert pyro.__version__.startswith("1.8.4") | ||
parser = argparse.ArgumentParser( | ||
description="Distributed training via PyTorch Lightning" | ||
) | ||
parser.add_argument("--size", default=1000000, type=int) | ||
parser.add_argument("--batch_size", default=100, type=int) | ||
parser.add_argument("--learning_rate", default=0.01, type=float) | ||
parser.add_argument("--seed", default=20200723, type=int) | ||
parser = pl.Trainer.add_argparse_args(parser) | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -96,6 +96,7 @@ List of Tutorials | |
prior_predictive | ||
jit | ||
svi_horovod | ||
svi_lightning | ||
|
||
.. toctree:: | ||
:maxdepth: 1 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
Example: distributed training via PyTorch Lightning | ||
=================================================== | ||
|
||
This script passes argparse arguments to PyTorch Lightning ``Trainer`` automatically_, for example:: | ||
|
||
$ python examples/svi_lightning.py --accelerator gpu --devices 2 --max_epochs 100 --strategy ddp | ||
|
||
.. _automatically: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-in-python-scripts | ||
|
||
`View svi_lightning.py on github`__ | ||
|
||
.. _github: https://github.com/pyro-ppl/pyro/blob/dev/examples/svi_lightning.py | ||
|
||
__ github_ | ||
|
||
.. literalinclude:: ../../examples/svi_lightning.py | ||
:language: python |