From c6851b80d9efa12c030cdf436367e2aea37eb5ab Mon Sep 17 00:00:00 2001
From: Yerdos Ordabayev <50752571+ordabayevy@users.noreply.github.com>
Date: Wed, 15 Mar 2023 23:43:00 -0400
Subject: [PATCH] PyTorch Lightning example (#3189)

* Bump to version 1.5.2 (#2755)

* PyTorch Lightning example

* fixes

* fix test

* update comments

* fix pip install pyro-ppl

* address comments

* add svi_lightning to toctree

---------

Co-authored-by: Fritz Obermeyer <fritz.obermeyer@gmail.com>
---
 examples/svi_horovod.py           |   2 +-
 examples/svi_lightning.py         | 116 ++++++++++++++++++++++++++++++
 setup.py                          |   1 +
 tests/common.py                   |   8 +++
 tests/test_examples.py            |   9 +++
 tutorial/source/index.rst         |   1 +
 tutorial/source/svi_lightning.rst |  17 +++++
 7 files changed, 153 insertions(+), 1 deletion(-)
 create mode 100644 examples/svi_lightning.py
 create mode 100644 tutorial/source/svi_lightning.rst

diff --git a/examples/svi_horovod.py b/examples/svi_horovod.py
index 361cba86cf..f43f73feed 100644
--- a/examples/svi_horovod.py
+++ b/examples/svi_horovod.py
@@ -12,7 +12,7 @@
 #   https://horovod.readthedocs.io/en/stable
 #
 # This assumes you have installed horovod, e.g. via
-#   pip install pyro[horovod]
+#   pip install pyro-ppl[horovod]
 # For detailed instructions see
 #   https://horovod.readthedocs.io/en/stable/install.html
 # On my mac laptop I was able to install horovod with
diff --git a/examples/svi_lightning.py b/examples/svi_lightning.py
new file mode 100644
index 0000000000..2079c11e27
--- /dev/null
+++ b/examples/svi_lightning.py
@@ -0,0 +1,116 @@
+# Copyright Contributors to the Pyro project.
+# SPDX-License-Identifier: Apache-2.0
+
+# Distributed training via Pytorch Lightning.
+#
+# This tutorial demonstrates how to distribute SVI training across multiple
+# machines (or multiple GPUs on one or more machines) using the PyTorch Lightning
+# library. PyTorch Lightning enables data-parallel training by aggregating stochastic
+# gradients at each step of training. We focus on integration between PyTorch Lightning and Pyro.
+# For further details on distributed computing with PyTorch Lightning, see
+#   https://lightning.ai/docs/pytorch/latest
+#
+# This assumes you have installed pytorch lightning, e.g. via
+#   pip install pyro-ppl[lightning]
+
+import argparse
+
+import pytorch_lightning as pl
+import torch
+
+import pyro
+import pyro.distributions as dist
+from pyro.infer import Trace_ELBO
+from pyro.infer.autoguide import AutoNormal
+from pyro.nn import PyroModule
+
+
+# We define a model as usual, with no reference to Pytorch Lightning.
+# This model is data parallel and supports subsampling.
+class Model(PyroModule):
+    def __init__(self, size):
+        super().__init__()
+        self.size = size
+
+    def forward(self, covariates, data=None):
+        coeff = pyro.sample("coeff", dist.Normal(0, 1))
+        bias = pyro.sample("bias", dist.Normal(0, 1))
+        scale = pyro.sample("scale", dist.LogNormal(0, 1))
+
+        # Since we'll use a distributed dataloader during training, we need to
+        # manually pass minibatches of (covariates,data) that are smaller than
+        # the full self.size. In particular we cannot rely on pyro.plate to
+        # automatically subsample, since that would lead to all workers drawing
+        # identical subsamples.
+        with pyro.plate("data", self.size, len(covariates)):
+            loc = bias + coeff * covariates
+            return pyro.sample("obs", dist.Normal(loc, scale), obs=data)
+
+
+# We define an ELBO loss, a PyTorch optimizer, and a training step in our PyroLightningModule.
+# Note that we are using a PyTorch optimizer instead of a Pyro optimizer and
+# we are using ``training_step`` instead of Pyro's SVI machinery.
+class PyroLightningModule(pl.LightningModule):
+    def __init__(self, loss_fn: pyro.infer.elbo.ELBOModule, lr: float):
+        super().__init__()
+        self.loss_fn = loss_fn
+        self.model = loss_fn.model
+        self.guide = loss_fn.guide
+        self.lr = lr
+        self.predictive = pyro.infer.Predictive(
+            self.model, guide=self.guide, num_samples=1
+        )
+
+    def forward(self, *args):
+        return self.predictive(*args)
+
+    def training_step(self, batch, batch_idx):
+        """Training step for Pyro training."""
+        loss = self.loss_fn(*batch)
+        # Logging to TensorBoard by default
+        self.log("train_loss", loss)
+        return loss
+
+    def configure_optimizers(self):
+        """Configure an optimizer."""
+        return torch.optim.Adam(self.loss_fn.parameters(), lr=self.lr)
+
+
+def main(args):
+    # Create a model, synthetic data, a guide, and a lightning module.
+    pyro.set_rng_seed(args.seed)
+    pyro.settings.set(module_local_params=True)
+    model = Model(args.size)
+    covariates = torch.randn(args.size)
+    data = model(covariates)
+    guide = AutoNormal(model)
+    loss_fn = Trace_ELBO()(model, guide)
+    training_plan = PyroLightningModule(loss_fn, args.learning_rate)
+
+    # Create a dataloader.
+    dataset = torch.utils.data.TensorDataset(covariates, data)
+    dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size)
+
+    # All relevant parameters need to be initialized before ``configure_optimizer`` is called.
+    # Since we used AutoNormal guide our parameters have not be initialized yet.
+    # Therefore we initialize the model and guide by running one mini-batch through the loss.
+    mini_batch = dataset[: args.batch_size]
+    loss_fn(*mini_batch)
+
+    # Run stochastic variational inference using PyTorch Lightning Trainer.
+    trainer = pl.Trainer.from_argparse_args(args)
+    trainer.fit(training_plan, train_dataloaders=dataloader)
+
+
+if __name__ == "__main__":
+    assert pyro.__version__.startswith("1.8.4")
+    parser = argparse.ArgumentParser(
+        description="Distributed training via PyTorch Lightning"
+    )
+    parser.add_argument("--size", default=1000000, type=int)
+    parser.add_argument("--batch_size", default=100, type=int)
+    parser.add_argument("--learning_rate", default=0.01, type=float)
+    parser.add_argument("--seed", default=20200723, type=int)
+    parser = pl.Trainer.add_argparse_args(parser)
+    args = parser.parse_args()
+    main(args)
diff --git a/setup.py b/setup.py
index f96d7074ba..6689ca5758 100644
--- a/setup.py
+++ b/setup.py
@@ -137,6 +137,7 @@
             "yapf",
         ],
         "horovod": ["horovod[pytorch]>=0.19"],
+        "lightning": ["pytorch_lightning"],
         "funsor": [
             # This must be a released version when Pyro is released.
             # "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@7bb52d0eae3046d08a20d1b288544e1a21b4f461",
diff --git a/tests/common.py b/tests/common.py
index 9968baf136..f2beffdc71 100644
--- a/tests/common.py
+++ b/tests/common.py
@@ -68,6 +68,14 @@ def wrapper(*args, **kwargs):
     horovod is None, reason="horovod is not available"
 )
 
+try:
+    import pytorch_lightning
+except ImportError:
+    pytorch_lightning = None
+requires_lightning = pytest.mark.skipif(
+    pytorch_lightning is None, reason="pytorch lightning is not available"
+)
+
 try:
     import funsor
 except ImportError:
diff --git a/tests/test_examples.py b/tests/test_examples.py
index 931a61ac14..8e62a7f770 100644
--- a/tests/test_examples.py
+++ b/tests/test_examples.py
@@ -14,6 +14,7 @@
     requires_cuda,
     requires_funsor,
     requires_horovod,
+    requires_lightning,
     xfail_param,
 )
 
@@ -110,6 +111,10 @@
     "sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide auto",
     "sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide easy",
     "svi_horovod.py --num-epochs=2 --size=400 --no-horovod",
+    pytest.param(
+        "svi_lightning.py --max_epochs=2 --size=400 --accelerator cpu --devices 1",
+        marks=[requires_lightning],
+    ),
     "toy_mixture_model_discrete_enumeration.py  --num-steps=1",
     "sparse_regression.py --num-steps=100 --num-data=100 --num-dimensions 11",
     "vae/ss_vae_M2.py --num-epochs=1",
@@ -177,6 +182,10 @@
     "sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --cuda",
     "sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 --cuda",
     "svi_horovod.py --num-epochs=2 --size=400 --cuda --no-horovod",
+    pytest.param(
+        "svi_lightning.py --max_epochs=2 --size=400 --accelerator gpu --devices 1",
+        marks=[requires_lightning],
+    ),
     "vae/vae.py --num-epochs=1 --cuda",
     "vae/ss_vae_M2.py --num-epochs=1 --cuda",
     "vae/ss_vae_M2.py --num-epochs=1 --aux-loss --cuda",
diff --git a/tutorial/source/index.rst b/tutorial/source/index.rst
index 9212b09979..5d4c0cc12c 100644
--- a/tutorial/source/index.rst
+++ b/tutorial/source/index.rst
@@ -96,6 +96,7 @@ List of Tutorials
    prior_predictive
    jit
    svi_horovod
+   svi_lightning
 
 .. toctree::
    :maxdepth: 1
diff --git a/tutorial/source/svi_lightning.rst b/tutorial/source/svi_lightning.rst
new file mode 100644
index 0000000000..0685fb621d
--- /dev/null
+++ b/tutorial/source/svi_lightning.rst
@@ -0,0 +1,17 @@
+Example: distributed training via PyTorch Lightning
+===================================================
+
+This script passes argparse arguments to PyTorch Lightning ``Trainer`` automatically_, for example::
+
+    $ python examples/svi_lightning.py --accelerator gpu --devices 2 --max_epochs 100 --strategy ddp
+
+.. _automatically: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-in-python-scripts
+
+`View svi_lightning.py on github`__
+
+.. _github: https://github.com/pyro-ppl/pyro/blob/dev/examples/svi_lightning.py
+
+__ github_
+
+.. literalinclude:: ../../examples/svi_lightning.py
+    :language: python