From c6851b80d9efa12c030cdf436367e2aea37eb5ab Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev <50752571+ordabayevy@users.noreply.github.com> Date: Wed, 15 Mar 2023 23:43:00 -0400 Subject: [PATCH] PyTorch Lightning example (#3189) * Bump to version 1.5.2 (#2755) * PyTorch Lightning example * fixes * fix test * update comments * fix pip install pyro-ppl * address comments * add svi_lightning to toctree --------- Co-authored-by: Fritz Obermeyer <fritz.obermeyer@gmail.com> --- examples/svi_horovod.py | 2 +- examples/svi_lightning.py | 116 ++++++++++++++++++++++++++++++ setup.py | 1 + tests/common.py | 8 +++ tests/test_examples.py | 9 +++ tutorial/source/index.rst | 1 + tutorial/source/svi_lightning.rst | 17 +++++ 7 files changed, 153 insertions(+), 1 deletion(-) create mode 100644 examples/svi_lightning.py create mode 100644 tutorial/source/svi_lightning.rst diff --git a/examples/svi_horovod.py b/examples/svi_horovod.py index 361cba86cf..f43f73feed 100644 --- a/examples/svi_horovod.py +++ b/examples/svi_horovod.py @@ -12,7 +12,7 @@ # https://horovod.readthedocs.io/en/stable # # This assumes you have installed horovod, e.g. via -# pip install pyro[horovod] +# pip install pyro-ppl[horovod] # For detailed instructions see # https://horovod.readthedocs.io/en/stable/install.html # On my mac laptop I was able to install horovod with diff --git a/examples/svi_lightning.py b/examples/svi_lightning.py new file mode 100644 index 0000000000..2079c11e27 --- /dev/null +++ b/examples/svi_lightning.py @@ -0,0 +1,116 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +# Distributed training via Pytorch Lightning. +# +# This tutorial demonstrates how to distribute SVI training across multiple +# machines (or multiple GPUs on one or more machines) using the PyTorch Lightning +# library. PyTorch Lightning enables data-parallel training by aggregating stochastic +# gradients at each step of training. We focus on integration between PyTorch Lightning and Pyro. +# For further details on distributed computing with PyTorch Lightning, see +# https://lightning.ai/docs/pytorch/latest +# +# This assumes you have installed pytorch lightning, e.g. via +# pip install pyro-ppl[lightning] + +import argparse + +import pytorch_lightning as pl +import torch + +import pyro +import pyro.distributions as dist +from pyro.infer import Trace_ELBO +from pyro.infer.autoguide import AutoNormal +from pyro.nn import PyroModule + + +# We define a model as usual, with no reference to Pytorch Lightning. +# This model is data parallel and supports subsampling. +class Model(PyroModule): + def __init__(self, size): + super().__init__() + self.size = size + + def forward(self, covariates, data=None): + coeff = pyro.sample("coeff", dist.Normal(0, 1)) + bias = pyro.sample("bias", dist.Normal(0, 1)) + scale = pyro.sample("scale", dist.LogNormal(0, 1)) + + # Since we'll use a distributed dataloader during training, we need to + # manually pass minibatches of (covariates,data) that are smaller than + # the full self.size. In particular we cannot rely on pyro.plate to + # automatically subsample, since that would lead to all workers drawing + # identical subsamples. + with pyro.plate("data", self.size, len(covariates)): + loc = bias + coeff * covariates + return pyro.sample("obs", dist.Normal(loc, scale), obs=data) + + +# We define an ELBO loss, a PyTorch optimizer, and a training step in our PyroLightningModule. +# Note that we are using a PyTorch optimizer instead of a Pyro optimizer and +# we are using ``training_step`` instead of Pyro's SVI machinery. +class PyroLightningModule(pl.LightningModule): + def __init__(self, loss_fn: pyro.infer.elbo.ELBOModule, lr: float): + super().__init__() + self.loss_fn = loss_fn + self.model = loss_fn.model + self.guide = loss_fn.guide + self.lr = lr + self.predictive = pyro.infer.Predictive( + self.model, guide=self.guide, num_samples=1 + ) + + def forward(self, *args): + return self.predictive(*args) + + def training_step(self, batch, batch_idx): + """Training step for Pyro training.""" + loss = self.loss_fn(*batch) + # Logging to TensorBoard by default + self.log("train_loss", loss) + return loss + + def configure_optimizers(self): + """Configure an optimizer.""" + return torch.optim.Adam(self.loss_fn.parameters(), lr=self.lr) + + +def main(args): + # Create a model, synthetic data, a guide, and a lightning module. + pyro.set_rng_seed(args.seed) + pyro.settings.set(module_local_params=True) + model = Model(args.size) + covariates = torch.randn(args.size) + data = model(covariates) + guide = AutoNormal(model) + loss_fn = Trace_ELBO()(model, guide) + training_plan = PyroLightningModule(loss_fn, args.learning_rate) + + # Create a dataloader. + dataset = torch.utils.data.TensorDataset(covariates, data) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size) + + # All relevant parameters need to be initialized before ``configure_optimizer`` is called. + # Since we used AutoNormal guide our parameters have not be initialized yet. + # Therefore we initialize the model and guide by running one mini-batch through the loss. + mini_batch = dataset[: args.batch_size] + loss_fn(*mini_batch) + + # Run stochastic variational inference using PyTorch Lightning Trainer. + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit(training_plan, train_dataloaders=dataloader) + + +if __name__ == "__main__": + assert pyro.__version__.startswith("1.8.4") + parser = argparse.ArgumentParser( + description="Distributed training via PyTorch Lightning" + ) + parser.add_argument("--size", default=1000000, type=int) + parser.add_argument("--batch_size", default=100, type=int) + parser.add_argument("--learning_rate", default=0.01, type=float) + parser.add_argument("--seed", default=20200723, type=int) + parser = pl.Trainer.add_argparse_args(parser) + args = parser.parse_args() + main(args) diff --git a/setup.py b/setup.py index f96d7074ba..6689ca5758 100644 --- a/setup.py +++ b/setup.py @@ -137,6 +137,7 @@ "yapf", ], "horovod": ["horovod[pytorch]>=0.19"], + "lightning": ["pytorch_lightning"], "funsor": [ # This must be a released version when Pyro is released. # "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@7bb52d0eae3046d08a20d1b288544e1a21b4f461", diff --git a/tests/common.py b/tests/common.py index 9968baf136..f2beffdc71 100644 --- a/tests/common.py +++ b/tests/common.py @@ -68,6 +68,14 @@ def wrapper(*args, **kwargs): horovod is None, reason="horovod is not available" ) +try: + import pytorch_lightning +except ImportError: + pytorch_lightning = None +requires_lightning = pytest.mark.skipif( + pytorch_lightning is None, reason="pytorch lightning is not available" +) + try: import funsor except ImportError: diff --git a/tests/test_examples.py b/tests/test_examples.py index 931a61ac14..8e62a7f770 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -14,6 +14,7 @@ requires_cuda, requires_funsor, requires_horovod, + requires_lightning, xfail_param, ) @@ -110,6 +111,10 @@ "sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide auto", "sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide easy", "svi_horovod.py --num-epochs=2 --size=400 --no-horovod", + pytest.param( + "svi_lightning.py --max_epochs=2 --size=400 --accelerator cpu --devices 1", + marks=[requires_lightning], + ), "toy_mixture_model_discrete_enumeration.py --num-steps=1", "sparse_regression.py --num-steps=100 --num-data=100 --num-dimensions 11", "vae/ss_vae_M2.py --num-epochs=1", @@ -177,6 +182,10 @@ "sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --cuda", "sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 --cuda", "svi_horovod.py --num-epochs=2 --size=400 --cuda --no-horovod", + pytest.param( + "svi_lightning.py --max_epochs=2 --size=400 --accelerator gpu --devices 1", + marks=[requires_lightning], + ), "vae/vae.py --num-epochs=1 --cuda", "vae/ss_vae_M2.py --num-epochs=1 --cuda", "vae/ss_vae_M2.py --num-epochs=1 --aux-loss --cuda", diff --git a/tutorial/source/index.rst b/tutorial/source/index.rst index 9212b09979..5d4c0cc12c 100644 --- a/tutorial/source/index.rst +++ b/tutorial/source/index.rst @@ -96,6 +96,7 @@ List of Tutorials prior_predictive jit svi_horovod + svi_lightning .. toctree:: :maxdepth: 1 diff --git a/tutorial/source/svi_lightning.rst b/tutorial/source/svi_lightning.rst new file mode 100644 index 0000000000..0685fb621d --- /dev/null +++ b/tutorial/source/svi_lightning.rst @@ -0,0 +1,17 @@ +Example: distributed training via PyTorch Lightning +=================================================== + +This script passes argparse arguments to PyTorch Lightning ``Trainer`` automatically_, for example:: + + $ python examples/svi_lightning.py --accelerator gpu --devices 2 --max_epochs 100 --strategy ddp + +.. _automatically: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-in-python-scripts + +`View svi_lightning.py on github`__ + +.. _github: https://github.com/pyro-ppl/pyro/blob/dev/examples/svi_lightning.py + +__ github_ + +.. literalinclude:: ../../examples/svi_lightning.py + :language: python