Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tutorials using normalizing flows #3302

Merged
merged 15 commits into from
Jan 14, 2024
46 changes: 46 additions & 0 deletions pyro/contrib/zuko.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch import Size, Tensor

import pyro


class Zuko2Pyro(pyro.distributions.TorchDistribution):
francois-rozet marked this conversation as resolved.
Show resolved Hide resolved
r"""Wraps a Zuko (or PyTorch) distribution as a Pyro distribution."""
francois-rozet marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, dist: torch.distributions.Distribution):
self.dist = dist
self.cache = {}

@property
def has_rsample(self) -> bool:
return self.dist.has_rsample

@property
def event_shape(self) -> Size:
return self.dist.event_shape

@property
def batch_shape(self) -> Size:
return self.dist.batch_shape

def __call__(self, shape: Size = ()) -> Tensor:
if hasattr(self.dist, "rsample_and_log_prob"): # fast sampling + scoring
x, self.cache[x] = self.dist.rsample_and_log_prob(shape)
elif self.has_rsample:
x = self.dist.rsample(shape)
else:
x = self.dist.sample(shape)

return x

def log_prob(self, x: Tensor) -> Tensor:
if x in self.cache:
return self.cache[x]
francois-rozet marked this conversation as resolved.
Show resolved Hide resolved
else:
return self.dist.log_prob(x)

def expand(self, *args, **kwargs):
return Zuko2Pyro(self.dist.expand(*args, **kwargs))
56 changes: 56 additions & 0 deletions tests/contrib/test_zuko.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0


import pytest
import torch

import pyro
from pyro.contrib.zuko import Zuko2Pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam


@pytest.mark.parametrize("multivariate", [True, False])
def test_Zuko2Pyro(multivariate: bool):
francois-rozet marked this conversation as resolved.
Show resolved Hide resolved
# Distribution
if multivariate:
normal = torch.distributions.MultivariateNormal
mu = torch.zeros(3)
sigma = torch.eye(3)
else:
normal = torch.distributions.Normal
mu = torch.zeros(())
sigma = torch.ones(())

dist = normal(mu, sigma)

# Sample
x1 = pyro.sample("x1", Zuko2Pyro(dist))

assert x1.shape == dist.event_shape

# Sample within plate
with pyro.plate("data", 4):
x2 = pyro.sample("x2", Zuko2Pyro(dist))

assert x2.shape == (4, *dist.event_shape)

# SVI
def model():
pyro.sample("a", Zuko2Pyro(dist))

with pyro.plate("data", 4):
pyro.sample("b", Zuko2Pyro(dist))

def guide():
mu_ = pyro.param("mu", mu)
sigma_ = pyro.param("sigma", sigma)

pyro.sample("a", Zuko2Pyro(normal(mu_, sigma_)))

with pyro.plate("data", 4):
pyro.sample("b", Zuko2Pyro(normal(mu_, sigma_)))

svi = SVI(model, guide, optim=Adam({"lr": 1e-3}), loss=Trace_ELBO())
svi.step()
4 changes: 3 additions & 1 deletion tutorial/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ List of Tutorials
jit
svi_horovod
svi_lightning
svi_flow_guide

.. toctree::
:maxdepth: 1
Expand All @@ -106,7 +107,8 @@ List of Tutorials
vae
ss-vae
cvae
normalizing_flows_i
normalizing_flows_intro
vae_flow_prior
dmm
air
cevae
Expand Down
Loading
Loading