Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change coordinatization of AutoMultivariateNormal #2963

Merged
merged 6 commits into from
Nov 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions pyro/infer/autoguide/guides.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,7 @@ class AutoMultivariateNormal(AutoContinuous):
(unconstrained transformed) latent variable.
"""

scale_constraint = constraints.softplus_positive
scale_tril_constraint = constraints.softplus_lower_cholesky
martinjankowiak marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1):
Expand All @@ -874,27 +875,31 @@ def _setup_prototype(self, *args, **kwargs):
super()._setup_prototype(*args, **kwargs)
# Initialize guide params
self.loc = nn.Parameter(self._init_loc())
self.scale = PyroParam(
torch.full_like(self.loc, self._init_scale), self.scale_constraint
)
self.scale_tril = PyroParam(
eye_like(self.loc, self.latent_dim) * self._init_scale,
self.scale_tril_constraint,
eye_like(self.loc, self.latent_dim), self.scale_tril_constraint
)

def get_base_dist(self):
return dist.Normal(
torch.zeros_like(self.loc), torch.zeros_like(self.loc)
torch.zeros_like(self.loc), torch.ones_like(self.loc)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, what a bug!? It is surprised to me that we didn't catch this earlier.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, apparently nobody uses AutoMultivariateNormal with NeuTraReparam. My motivation for this PR is to write a tutorial about autoguides, so hopefully they'll get more exposure

).to_event(1)

def get_transform(self, *args, **kwargs):
return dist.transforms.LowerCholeskyAffine(self.loc, scale_tril=self.scale_tril)
scale_tril = self.scale[..., None] * self.scale_tril
return dist.transforms.LowerCholeskyAffine(self.loc, scale_tril=scale_tril)

def get_posterior(self, *args, **kwargs):
"""
Returns a MultivariateNormal posterior distribution.
"""
return dist.MultivariateNormal(self.loc, scale_tril=self.scale_tril)
scale_tril = self.scale[..., None] * self.scale_tril
return dist.MultivariateNormal(self.loc, scale_tril=scale_tril)

def _loc_scale(self, *args, **kwargs):
return self.loc, self.scale_tril.diag()
return self.loc, self.scale * self.scale_tril.diag()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does scale need an unsqueeze here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope, scale is already the correct shape

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i see this is only for marginals



class AutoDiagonalNormal(AutoContinuous):
Expand Down Expand Up @@ -937,7 +942,7 @@ def _setup_prototype(self, *args, **kwargs):

def get_base_dist(self):
return dist.Normal(
torch.zeros_like(self.loc), torch.zeros_like(self.loc)
torch.zeros_like(self.loc), torch.ones_like(self.loc)
).to_event(1)

def get_transform(self, *args, **kwargs):
Expand Down Expand Up @@ -1167,15 +1172,23 @@ def laplace_approximation(self, *args, **kwargs):
loss = guide_trace.log_prob_sum() - model_trace.log_prob_sum()

H = hessian(loss, self.loc)
cov = H.inverse()
loc = self.loc
scale_tril = torch.linalg.cholesky(cov)
with torch.no_grad():
loc = self.loc.detach()
cov = H.inverse()
scale = cov.diagonal().sqrt()
cov /= scale[:, None]
cov /= scale[None, :]
scale_tril = torch.linalg.cholesky(cov)

gaussian_guide = AutoMultivariateNormal(self.model)
gaussian_guide._setup_prototype(*args, **kwargs)
# Set loc, scale_tril parameters as computed above.
gaussian_guide.loc = loc
gaussian_guide.scale_tril = scale_tril
# Set detached loc, scale, scale_tril parameters as computed above.
del gaussian_guide.loc
del gaussian_guide.scale
del gaussian_guide.scale_tril
gaussian_guide.register_buffer("loc", loc)
gaussian_guide.register_buffer("scale", scale)
gaussian_guide.register_buffer("scale_tril", scale_tril)
return gaussian_guide


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def compute_target(self, N):
) * self.target_auto_diag_cov[n + 1]

def test_multivariatate_normal_auto(self):
self.do_test_auto(3, reparameterized=True, n_steps=10001)
self.do_test_auto(3, reparameterized=True, n_steps=1001)

def do_test_auto(self, N, reparameterized, n_steps):
logger.debug("\nGoing to do AutoGaussianChain test...")
Expand All @@ -70,20 +70,21 @@ def do_test_auto(self, N, reparameterized, n_steps):
)

# TODO speed up with parallel num_particles > 1
adam = optim.Adam({"lr": 0.001, "betas": (0.95, 0.999)})
svi = SVI(self.model, self.guide, adam, loss=Trace_ELBO())
adam = optim.Adam({"lr": 0.01, "betas": (0.95, 0.999)})
elbo = Trace_ELBO(num_particles=100, vectorize_particles=True)
svi = SVI(self.model, self.guide, adam, elbo)

for k in range(n_steps):
loss = svi.step(reparameterized)
assert np.isfinite(loss), loss

if k % 1000 == 0 and k > 0 or k == n_steps - 1:
if k % 100 == 0 and k > 0 or k == n_steps - 1:
logger.debug(
"[step {}] guide mean parameter: {}".format(
k, self.guide.loc.detach().cpu().numpy()
)
)
L = self.guide.scale_tril
L = self.guide.scale_tril * self.guide.scale[:, None]
diag_cov = torch.mm(L, L.t()).diag()
logger.debug(
"[step {}] auto_diag_cov: {}".format(
Expand Down
35 changes: 15 additions & 20 deletions tests/infer/reparam/test_neutra.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
from pyro import optim
from pyro.distributions.transforms import ComposeTransform
from pyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from pyro.infer.autoguide import AutoIAFNormal
from pyro.infer.autoguide import (
AutoDiagonalNormal,
AutoIAFNormal,
AutoMultivariateNormal,
)
from pyro.infer.mcmc.util import initialize_model
from pyro.infer.reparam import NeuTraReparam
from tests.common import assert_close, xfail_param
from tests.common import assert_close

from .util import check_init_reparam

Expand All @@ -31,25 +35,23 @@ def dirichlet_categorical(data):
return p_latent


@pytest.mark.parametrize("jit", [False, True])
@pytest.mark.parametrize(
"jit",
[
False,
xfail_param(True, reason="https://github.com/pyro-ppl/pyro/issues/2292"),
],
"Guide",
[AutoDiagonalNormal, AutoMultivariateNormal, AutoIAFNormal],
)
def test_neals_funnel_smoke(jit):
def test_neals_funnel_smoke(Guide, jit):
dim = 10

guide = AutoIAFNormal(neals_funnel)
guide = Guide(neals_funnel)
svi = SVI(neals_funnel, guide, optim.Adam({"lr": 1e-10}), Trace_ELBO())
for _ in range(1000):
for _ in range(10):
svi.step(dim)

neutra = NeuTraReparam(guide.requires_grad_(False))
model = neutra.reparam(neals_funnel)
nuts = NUTS(model, jit_compile=jit)
mcmc = MCMC(nuts, num_samples=50, warmup_steps=50)
nuts = NUTS(model, jit_compile=jit, ignore_jit_warnings=True)
mcmc = MCMC(nuts, num_samples=10, warmup_steps=10)
mcmc.run(dim)
samples = mcmc.get_samples()
# XXX: `MCMC.get_samples` adds a leftmost batch dim to all sites, not uniformly at -max_plate_nesting-1;
Expand All @@ -65,14 +67,7 @@ def test_neals_funnel_smoke(jit):
"model, kwargs",
[
(neals_funnel, {"dim": 10}),
(
dirichlet_categorical,
{
"data": torch.ones(
10,
)
},
),
(dirichlet_categorical, {"data": torch.ones(10)}),
],
)
def test_reparam_log_joint(model, kwargs):
Expand Down