-
-
Notifications
You must be signed in to change notification settings - Fork 983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make ReparamMessenger aware of InitMessenger, simplifying initialization #2876
Changes from all commits
1394105
af641c0
c640879
9f89a55
d0dd8cf
f6479cd
9f3a9f0
507a483
ebfc347
9583d4a
a4ee8c3
7427a10
a7d10a2
940e946
1bc2dce
d803fad
60796d0
9d2721f
c7565d3
09d63f4
c8977c6
b2b2879
de983ed
c60abe0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
as an initial constrained value for a guide estimate. | ||
""" | ||
import functools | ||
from typing import Callable, Optional | ||
|
||
import torch | ||
from torch.distributions import transform_to | ||
|
@@ -41,7 +42,9 @@ def init_to_feasible(site=None): | |
|
||
value = site["fn"].sample().detach() | ||
t = transform_to(site["fn"].support) | ||
return t(torch.zeros_like(t.inv(value))) | ||
value = t(torch.zeros_like(t.inv(value))) | ||
value._pyro_custom_init = False | ||
return value | ||
|
||
|
||
def init_to_sample(site=None): | ||
|
@@ -51,16 +54,30 @@ def init_to_sample(site=None): | |
if site is None: | ||
return init_to_sample | ||
|
||
return site["fn"].sample().detach() | ||
value = site["fn"].sample().detach() | ||
value._pyro_custom_init = False | ||
return value | ||
|
||
|
||
def init_to_median(site=None, num_samples=15): | ||
def init_to_median( | ||
site=None, | ||
num_samples=15, | ||
*, | ||
fallback: Optional[Callable] = init_to_feasible, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added |
||
): | ||
""" | ||
Initialize to the prior median; fallback to a feasible point if median is | ||
undefined. | ||
Initialize to the prior median; fallback to ``fallback`` (defaults to | ||
:func:`init_to_feasible`) if mean is undefined. | ||
|
||
:param callable fallback: Fallback init strategy, for sites not specified | ||
in ``values``. | ||
:raises ValueError: If ``fallback=None`` and no value for a site is given | ||
in ``values``. | ||
""" | ||
if site is None: | ||
return functools.partial(init_to_median, num_samples=num_samples) | ||
return functools.partial( | ||
init_to_median, num_samples=num_samples, fallback=fallback | ||
) | ||
|
||
# The median undefined for multivariate distributions. | ||
if _is_multivariate(site["fn"]): | ||
|
@@ -73,18 +90,31 @@ def init_to_median(site=None, num_samples=15): | |
raise ValueError | ||
if hasattr(site["fn"], "_validate_sample"): | ||
site["fn"]._validate_sample(value) | ||
value._pyro_custom_init = False | ||
return value | ||
except (RuntimeError, ValueError): | ||
# Fall back to feasible point. | ||
return init_to_feasible(site) | ||
pass | ||
if fallback is not None: | ||
return fallback(site) | ||
raise ValueError(f"No init strategy specified for site {repr(site['name'])}") | ||
|
||
|
||
def init_to_mean(site=None): | ||
def init_to_mean( | ||
site=None, | ||
*, | ||
fallback: Optional[Callable] = init_to_median, | ||
): | ||
""" | ||
Initialize to the prior mean; fallback to median if mean is undefined. | ||
Initialize to the prior mean; fallback to ``fallback`` (defaults to | ||
:func:`init_to_median`) if mean is undefined. | ||
|
||
:param callable fallback: Fallback init strategy, for sites not specified | ||
in ``values``. | ||
:raises ValueError: If ``fallback=None`` and no value for a site is given | ||
in ``values``. | ||
""" | ||
if site is None: | ||
return init_to_mean | ||
return functools.partial(init_to_mean, fallback=fallback) | ||
|
||
try: | ||
# Try .mean() method. | ||
|
@@ -93,42 +123,62 @@ def init_to_mean(site=None): | |
raise ValueError | ||
if hasattr(site["fn"], "_validate_sample"): | ||
site["fn"]._validate_sample(value) | ||
value._pyro_custom_init = False | ||
return value | ||
except (NotImplementedError, ValueError): | ||
# Fall back to a median. | ||
# This is required for distributions with infinite variance, e.g. Cauchy. | ||
return init_to_median(site) | ||
# This may happen for distributions with infinite variance, e.g. Cauchy. | ||
pass | ||
if fallback is not None: | ||
return fallback(site) | ||
raise ValueError(f"No init strategy specified for site {repr(site['name'])}") | ||
|
||
|
||
def init_to_uniform(site=None, radius=2): | ||
def init_to_uniform( | ||
site: Optional[dict] = None, | ||
radius: float = 2.0, | ||
): | ||
""" | ||
Initialize to a random point in the area ``(-radius, radius)`` of | ||
unconstrained domain. | ||
|
||
:param float radius: specifies the range to draw an initial point in the unconstrained domain. | ||
:param float radius: specifies the range to draw an initial point in the | ||
unconstrained domain. | ||
""" | ||
if site is None: | ||
return functools.partial(init_to_uniform, radius=radius) | ||
|
||
value = site["fn"].sample().detach() | ||
t = transform_to(site["fn"].support) | ||
return t(torch.rand_like(t.inv(value)) * (2 * radius) - radius) | ||
value = t(torch.rand_like(t.inv(value)) * (2 * radius) - radius) | ||
value._pyro_custom_init = False | ||
return value | ||
|
||
|
||
def init_to_value(site=None, values={}): | ||
def init_to_value( | ||
site: Optional[dict] = None, | ||
values: dict = {}, | ||
*, | ||
fallback: Optional[Callable] = init_to_uniform, | ||
): | ||
""" | ||
Initialize to the value specified in ``values``. We defer to | ||
:func:`init_to_uniform` strategy for sites which do not appear in ``values``. | ||
Initialize to the value specified in ``values``. Fallback to ``fallback`` | ||
(defaults to :func:`init_to_uniform`) strategy for sites not appearing in | ||
``values``. | ||
|
||
:param dict values: dictionary of initial values keyed by site name. | ||
:param callable fallback: Fallback init strategy, for sites not specified | ||
in ``values``. | ||
:raises ValueError: If ``fallback=None`` and no value for a site is given | ||
in ``values``. | ||
""" | ||
if site is None: | ||
return functools.partial(init_to_value, values=values) | ||
return functools.partial(init_to_value, values=values, fallback=fallback) | ||
|
||
if site["name"] in values: | ||
return values[site["name"]] | ||
else: | ||
return init_to_uniform(site) | ||
if fallback is not None: | ||
return fallback(site) | ||
raise ValueError(f"No init strategy specified for site {repr(site['name'])}") | ||
|
||
|
||
class _InitToGenerated: | ||
|
@@ -180,7 +230,7 @@ def __init__(self, init_fn): | |
super().__init__() | ||
|
||
def _pyro_sample(self, msg): | ||
if msg["done"] or msg["is_observed"] or type(msg["fn"]).__name__ == "_Subsample": | ||
if msg["value"] is not None or type(msg["fn"]).__name__ == "_Subsample": | ||
return | ||
with torch.no_grad(), helpful_support_errors(msg): | ||
value = self.init_fn(msg) | ||
|
@@ -194,4 +244,8 @@ def _pyro_sample(self, msg): | |
"{} provided invalid shape for site {}:\nexpected {}\nactual {}" | ||
.format(self.init_fn, msg["name"], msg["value"].shape, value.shape)) | ||
msg["value"] = value | ||
msg["done"] = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @eb8680 can you confirm this is ok? I think we want to avoid setting |
||
|
||
def _pyro_get_init_messengers(self, msg): | ||
if msg["value"] is None: | ||
msg["value"] = [] | ||
msg["value"].append(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added fallback logic to a few init strategies (hoping to port these to NumPyro and combine with pyro-ppl/numpyro#1058), and set the fallback to
None
here, which would error if the new initialization logic were to fail.