Skip to content

Commit

Permalink
Add inference utilities to transform between unconstrained and constr…
Browse files Browse the repository at this point in the history
…ained space

Improve and simplify constrain_fn and unconstrain_fn implementation

Add missing doctstrings

Constrain/unconstrain functions now always consider param sites

Fix syntax for lint tests

Fix syntax for lint tests

Fix syntax for lint tests
  • Loading branch information
aymgal committed Mar 30, 2023
1 parent a7267d9 commit 19c003a
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 0 deletions.
44 changes: 44 additions & 0 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ def substitute_fn(site):
if site["type"] == "sample":
with helpful_support_errors(site):
return biject_to(site["fn"].support)(params[site["name"]])
elif site["type"] == "param":
constraint = site["kwargs"].pop("constraint", constraints.real)
with helpful_support_errors(site):
return biject_to(constraint)(params[site["name"]])
else:
return params[site["name"]]

Expand All @@ -193,6 +197,42 @@ def substitute_fn(site):
}


def get_transforms(model, model_args, model_kwargs, params):
"""
(EXPERIMENTAL INTERFACE) Retrieve (inverse) transforms via biject_to()
given a NumPyro model. This function supports 'param' sites.
NB: Parameter values are only used to retrieve the model trace.
:param model: a callable containing NumPyro primitives.
:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: dictionary of values keyed by site names.
:return: `dict` of transformation keyed by site names.
"""
substituted_model = substitute(model, data=params)
transforms, _, _, _ = _get_model_transforms(
substituted_model, model_args, model_kwargs
)
return transforms


def unconstrain_fn(model, model_args, model_kwargs, params):
"""
(EXPERIMENTAL INTERFACE) Given a NumPyro model and a dict of parameters,
this function applies the right transformation to convert parameter values
from constrained space to unconstrained space.
:param model: a callable containing NumPyro primitives.
:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: dictionary of constrained values keyed by site
names.
:return: `dict` of transformation keyed by site names.
"""
transforms = get_transforms(model, model_args, model_kwargs, params)
return transform_fn(transforms, params, invert=True)


def _unconstrain_reparam(params, site):
name = site["name"]
if name in params:
Expand Down Expand Up @@ -449,6 +489,10 @@ def _get_model_transforms(model, model_args=(), model_kwargs=None):
for arg in args:
if not isinstance(getattr(support, arg), (int, float)):
replay_model = True
elif v["type"] == "param":
constraint = v["kwargs"].pop("constraint", constraints.real)
with helpful_support_errors(v, raise_warnings=True):
inv_transforms[k] = biject_to(constraint)
elif v["type"] == "deterministic":
replay_model = True
return inv_transforms, replay_model, has_enumerate_support, model_trace
Expand Down
47 changes: 47 additions & 0 deletions test/infer/test_infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
log_likelihood,
potential_energy,
transform_fn,
unconstrain_fn,
)
import numpyro.optim as optim

Expand Down Expand Up @@ -220,6 +221,52 @@ def model():
assert_allclose(actual_potential_energy, expected_potential_energy)


def test_constrain_unconstrain():
x_prior = dist.HalfNormal(2)
y_prior = dist.LogNormal(scale=3.0) # transformed distribution
z_constraint = constraints.positive

def model():
numpyro.sample("x", x_prior)
numpyro.sample("y", y_prior)
numpyro.param("z", init_value=2.0, constraint=z_constraint)

params = {"x": jnp.array(-5.0), "y": jnp.array(7.0), "z": jnp.array(3.0)}
model = handlers.seed(model, random.PRNGKey(0))
inv_transforms = {
"x": biject_to(x_prior.support),
"y": biject_to(y_prior.support),
"z": biject_to(z_constraint),
}
expected_constrained_samples = partial(transform_fn, inv_transforms)(params)
transforms = {
"x": biject_to(x_prior.support).inv,
"y": biject_to(y_prior.support).inv,
"z": biject_to(z_constraint).inv,
}
expected_unconstrained_samples = partial(transform_fn, transforms)(
expected_constrained_samples
)

actual_constrained_samples = constrain_fn(model, (), {}, params)
actual_unconstrained_samples = unconstrain_fn(
model, (), {}, actual_constrained_samples
)

assert_allclose(expected_constrained_samples["x"], actual_constrained_samples["x"])
assert_allclose(expected_constrained_samples["y"], actual_constrained_samples["y"])
assert_allclose(expected_constrained_samples["z"], actual_constrained_samples["z"])
assert_allclose(
expected_unconstrained_samples["x"], actual_unconstrained_samples["x"]
)
assert_allclose(
expected_unconstrained_samples["y"], actual_unconstrained_samples["y"]
)
assert_allclose(
expected_unconstrained_samples["z"], actual_unconstrained_samples["z"]
)


def test_model_with_mask_false():
def model():
x = numpyro.sample("x", dist.Normal())
Expand Down

0 comments on commit 19c003a

Please sign in to comment.