Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ReparamMessenger aware of InitMessenger, simplifying initialization #2876

Merged
merged 24 commits into from
Jun 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 1 addition & 22 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,13 +734,11 @@ def _heuristic(self, haar, **options):
assert isinstance(init_values, dict)
assert "auxiliary" in init_values, \
".heuristic() did not define auxiliary value"
if haar:
haar.user_to_aux(init_values)
logger.info("Heuristic init: {}".format(", ".join(
"{}={:0.3g}".format(k, v.item())
for k, v in sorted(init_values.items())
if v.numel() == 1)))
return init_to_value(values=init_values)
return init_to_value(values=init_values, fallback=None)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added fallback logic to a few init strategies (hoping to port these to NumPyro and combine with pyro-ppl/numpyro#1058), and set the fallback to None here, which would error if the new initialization logic were to fail.


def _concat_series(self, samples, trace, forecast=0):
"""
Expand Down Expand Up @@ -1142,25 +1140,6 @@ def reparam(self, model):

return model

def user_to_aux(self, samples):
"""
Convert from user-facing samples to auxiliary samples, in-place.
"""
# Transform to Haar coordinates.
for name, dim in self.dims.items():
x = samples.pop(name)
x = biject_to(self.supports[name]).inv(x)
x = HaarTransform(dim=dim, flip=True)(x)
samples[name + "_haar"] = x

if self.split:
# Split into low- and high-frequency parts.
splits = [self.split, self.duration - self.split]
for name, dim in self.dims.items():
x0, x1 = samples.pop(name + "_haar").split(splits, dim=dim)
samples[name + "_haar_split_0"] = x0
samples[name + "_haar_split_1"] = x1

def aux_to_user(self, samples):
"""
Convert from auxiliary samples to user-facing samples, in-place.
Expand Down
104 changes: 79 additions & 25 deletions pyro/infer/autoguide/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
as an initial constrained value for a guide estimate.
"""
import functools
from typing import Callable, Optional

import torch
from torch.distributions import transform_to
Expand Down Expand Up @@ -41,7 +42,9 @@ def init_to_feasible(site=None):

value = site["fn"].sample().detach()
t = transform_to(site["fn"].support)
return t(torch.zeros_like(t.inv(value)))
value = t(torch.zeros_like(t.inv(value)))
value._pyro_custom_init = False
return value


def init_to_sample(site=None):
Expand All @@ -51,16 +54,30 @@ def init_to_sample(site=None):
if site is None:
return init_to_sample

return site["fn"].sample().detach()
value = site["fn"].sample().detach()
value._pyro_custom_init = False
return value


def init_to_median(site=None, num_samples=15):
def init_to_median(
site=None,
num_samples=15,
*,
fallback: Optional[Callable] = init_to_feasible,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added fallback kwargs to a few init strategies mainly to make it easier for users to catch unintended initialization.

):
"""
Initialize to the prior median; fallback to a feasible point if median is
undefined.
Initialize to the prior median; fallback to ``fallback`` (defaults to
:func:`init_to_feasible`) if mean is undefined.

:param callable fallback: Fallback init strategy, for sites not specified
in ``values``.
:raises ValueError: If ``fallback=None`` and no value for a site is given
in ``values``.
"""
if site is None:
return functools.partial(init_to_median, num_samples=num_samples)
return functools.partial(
init_to_median, num_samples=num_samples, fallback=fallback
)

# The median undefined for multivariate distributions.
if _is_multivariate(site["fn"]):
Expand All @@ -73,18 +90,31 @@ def init_to_median(site=None, num_samples=15):
raise ValueError
if hasattr(site["fn"], "_validate_sample"):
site["fn"]._validate_sample(value)
value._pyro_custom_init = False
return value
except (RuntimeError, ValueError):
# Fall back to feasible point.
return init_to_feasible(site)
pass
if fallback is not None:
return fallback(site)
raise ValueError(f"No init strategy specified for site {repr(site['name'])}")


def init_to_mean(site=None):
def init_to_mean(
site=None,
*,
fallback: Optional[Callable] = init_to_median,
):
"""
Initialize to the prior mean; fallback to median if mean is undefined.
Initialize to the prior mean; fallback to ``fallback`` (defaults to
:func:`init_to_median`) if mean is undefined.

:param callable fallback: Fallback init strategy, for sites not specified
in ``values``.
:raises ValueError: If ``fallback=None`` and no value for a site is given
in ``values``.
"""
if site is None:
return init_to_mean
return functools.partial(init_to_mean, fallback=fallback)

try:
# Try .mean() method.
Expand All @@ -93,42 +123,62 @@ def init_to_mean(site=None):
raise ValueError
if hasattr(site["fn"], "_validate_sample"):
site["fn"]._validate_sample(value)
value._pyro_custom_init = False
return value
except (NotImplementedError, ValueError):
# Fall back to a median.
# This is required for distributions with infinite variance, e.g. Cauchy.
return init_to_median(site)
# This may happen for distributions with infinite variance, e.g. Cauchy.
pass
if fallback is not None:
return fallback(site)
raise ValueError(f"No init strategy specified for site {repr(site['name'])}")


def init_to_uniform(site=None, radius=2):
def init_to_uniform(
site: Optional[dict] = None,
radius: float = 2.0,
):
"""
Initialize to a random point in the area ``(-radius, radius)`` of
unconstrained domain.

:param float radius: specifies the range to draw an initial point in the unconstrained domain.
:param float radius: specifies the range to draw an initial point in the
unconstrained domain.
"""
if site is None:
return functools.partial(init_to_uniform, radius=radius)

value = site["fn"].sample().detach()
t = transform_to(site["fn"].support)
return t(torch.rand_like(t.inv(value)) * (2 * radius) - radius)
value = t(torch.rand_like(t.inv(value)) * (2 * radius) - radius)
value._pyro_custom_init = False
return value


def init_to_value(site=None, values={}):
def init_to_value(
site: Optional[dict] = None,
values: dict = {},
*,
fallback: Optional[Callable] = init_to_uniform,
):
"""
Initialize to the value specified in ``values``. We defer to
:func:`init_to_uniform` strategy for sites which do not appear in ``values``.
Initialize to the value specified in ``values``. Fallback to ``fallback``
(defaults to :func:`init_to_uniform`) strategy for sites not appearing in
``values``.

:param dict values: dictionary of initial values keyed by site name.
:param callable fallback: Fallback init strategy, for sites not specified
in ``values``.
:raises ValueError: If ``fallback=None`` and no value for a site is given
in ``values``.
"""
if site is None:
return functools.partial(init_to_value, values=values)
return functools.partial(init_to_value, values=values, fallback=fallback)

if site["name"] in values:
return values[site["name"]]
else:
return init_to_uniform(site)
if fallback is not None:
return fallback(site)
raise ValueError(f"No init strategy specified for site {repr(site['name'])}")


class _InitToGenerated:
Expand Down Expand Up @@ -180,7 +230,7 @@ def __init__(self, init_fn):
super().__init__()

def _pyro_sample(self, msg):
if msg["done"] or msg["is_observed"] or type(msg["fn"]).__name__ == "_Subsample":
if msg["value"] is not None or type(msg["fn"]).__name__ == "_Subsample":
return
with torch.no_grad(), helpful_support_errors(msg):
value = self.init_fn(msg)
Expand All @@ -194,4 +244,8 @@ def _pyro_sample(self, msg):
"{} provided invalid shape for site {}:\nexpected {}\nactual {}"
.format(self.init_fn, msg["name"], msg["value"].shape, value.shape))
msg["value"] = value
msg["done"] = True
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eb8680 can you confirm this is ok? I think we want to avoid setting done here so as to enable subsequent reparametrization.


def _pyro_get_init_messengers(self, msg):
if msg["value"] is None:
msg["value"] = []
msg["value"].append(self)
21 changes: 16 additions & 5 deletions pyro/infer/reparam/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ def reparam_guide():
def __init__(self, guide):
self.guide = guide

def __call__(self, name, fn, obs):
assert obs is None, "PosteriorReparam does not support observe statements"
def apply(self, msg):
name = msg["name"]
fn = msg["fn"]
value = msg["value"]
is_observed = msg["is_observed"]

# Compute a guide distribution, either static or dependent.
guide_dist = self.guide
Expand All @@ -68,8 +71,16 @@ def __call__(self, name, fn, obs):
# handling of traced sites than the crude _do_not_trace flag below.
raise NotImplementedError("ConjugateReparam inference supports only reparameterized "
"distributions, but got {}".format(type(fn)))
value = pyro.sample("{}_updated".format(name), fn,
infer={"is_auxiliary": True, "_do_not_trace": True})
value = pyro.sample(
f"{name}_updated",
fn,
obs=value,
infer={
"is_observed": is_observed,
"is_auxiliary": True,
"_do_not_trace": True,
},
)

# Compute importance weight. Let p(z) be the original fn, q(z|x) be
# the guide, and u(z) be the conjugate_updated distribution. Then
Expand All @@ -91,4 +102,4 @@ def __call__(self, name, fn, obs):

# Return an importance-weighted point estimate.
new_fn = dist.Delta(value, log_density=log_density, event_dim=fn.event_dim)
return new_fn, value
return {"fn": new_fn, "value": value, "is_observed": True}
61 changes: 46 additions & 15 deletions pyro/infer/reparam/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@ def __init__(self, init=None, trans=None, obs=None):
self.trans = trans
self.obs = obs

def __call__(self, name, fn, obs):
def apply(self, msg):
name = msg["name"]
fn = msg["fn"]
value = msg["value"]
is_observed = msg["is_observed"]

fn, event_dim = self._unwrap(fn)
assert isinstance(fn, (dist.LinearHMM, dist.IndependentHMM))
if fn.duration is None:
Expand All @@ -69,38 +74,64 @@ def __call__(self, name, fn, obs):

# Unwrap IndependentHMM.
if isinstance(fn, dist.IndependentHMM):
if obs is not None:
obs = obs.transpose(-1, -2).unsqueeze(-1)
hmm, obs = self(name, fn.base_dist.to_event(1), obs)
indep_value = None
if value is not None:
indep_value = value.transpose(-1, -2).unsqueeze(-1)
msg = self.apply({
"name": name,
"fn": fn.base_dist.to_event(1),
"value": indep_value,
"is_observed": is_observed,
})
hmm = msg["fn"]
hmm = dist.IndependentHMM(hmm.to_event(-1))
if obs is not None:
obs = obs.squeeze(-1).transpose(-1, -2)
return hmm, obs
if msg["value"] is not indep_value:
value = msg["value"].squeeze(-1).transpose(-1, -2)
return {"fn": hmm, "value": value, "is_observed": is_observed}

# Reparameterize the initial distribution as conditionally Gaussian.
init_dist = fn.initial_dist
if self.init is not None:
init_dist, _ = self.init("{}_init".format(name),
self._wrap(init_dist, event_dim - 1), None)
msg = self.init.apply({
"name": f"{name}_init",
"fn": self._wrap(init_dist, event_dim - 1),
"value": None,
"is_observed": False,
})
init_dist = msg["fn"]
init_dist = init_dist.to_event(1 - init_dist.event_dim)

# Reparameterize the transition distribution as conditionally Gaussian.
trans_dist = fn.transition_dist
if self.trans is not None:
if trans_dist.batch_shape[-1] != fn.duration:
trans_dist = trans_dist.expand(trans_dist.batch_shape[:-1] + (fn.duration,))
trans_dist, _ = self.trans("{}_trans".format(name),
self._wrap(trans_dist, event_dim), None)
trans_dist = trans_dist.expand(
trans_dist.batch_shape[:-1] + (fn.duration,)
)
msg = self.trans.apply({
"name": f"{name}_trans",
"fn": self._wrap(trans_dist, event_dim),
"value": None,
"is_observed": False,
})
trans_dist = msg["fn"]
trans_dist = trans_dist.to_event(1 - trans_dist.event_dim)

# Reparameterize the observation distribution as conditionally Gaussian.
obs_dist = fn.observation_dist
if self.obs is not None:
if obs_dist.batch_shape[-1] != fn.duration:
obs_dist = obs_dist.expand(obs_dist.batch_shape[:-1] + (fn.duration,))
obs_dist, obs = self.obs("{}_obs".format(name),
self._wrap(obs_dist, event_dim), obs)
msg = self.obs.apply({
"name": f"{name}_obs",
"fn": self._wrap(obs_dist, event_dim),
"value": value,
"is_observed": is_observed,
})
obs_dist = msg["fn"]
obs_dist = obs_dist.to_event(1 - obs_dist.event_dim)
value = msg["value"]
is_observed = msg["is_observed"]

# Reparameterize the entire HMM as conditionally Gaussian.
hmm = dist.GaussianHMM(init_dist, fn.transition_matrix, trans_dist,
Expand All @@ -111,4 +142,4 @@ def __call__(self, name, fn, obs):
if fn.transforms:
hmm = dist.TransformedDistribution(hmm, fn.transforms)

return hmm, obs
return {"fn": hmm, "value": value, "is_observed": is_observed}
Loading