Skip to content

Commit

Permalink
Predictive fix when deterministic sites are present (pyro-ppl#1789)
Browse files Browse the repository at this point in the history
* added custom effect handler for predictive

* added test, fixed predictive_substitute

* fixed typo, removed unneeded custom substitute calls

* removed custom effect handler, improved readability

* reverted formatting of imports

* added conditional arg for handling deterministic sites to predictive

* changed arg name to exclude_deterministic

* updated exclude_deterministic description

* changed substitute to condition in infer_discrete _predctive workflow

---------

Co-authored-by: kylejcaron <[email protected]>
  • Loading branch information
2 people authored and OlaRonning committed May 6, 2024
1 parent 1702a72 commit 275f079
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 4 deletions.
25 changes: 21 additions & 4 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ def _predictive(
return_sites=None,
infer_discrete=False,
parallel=True,
exclude_deterministic: bool = True,
model_args=(),
model_kwargs={},
):
Expand All @@ -774,7 +775,7 @@ def _predictive(
posterior_samples,
)
prototype_trace = trace(
seed(substitute(masked_model, prototype_sample), subkey)
seed(condition(masked_model, prototype_sample), subkey)
).get_trace(*model_args, **model_kwargs)
first_available_dim = -_guess_max_plate_nesting(prototype_trace) - 1

Expand All @@ -795,9 +796,20 @@ def single_prediction(val):
**model_kwargs,
)
else:
model_trace = trace(
seed(substitute(masked_model, samples), rng_key)
).get_trace(*model_args, **model_kwargs)

def _samples_wo_deterministic(msg):
return (
samples.get(msg["name"]) if msg["type"] != "deterministic" else None
)

substituted_model = (
substitute(masked_model, substitute_fn=_samples_wo_deterministic)
if exclude_deterministic
else substitute(masked_model, samples)
)
model_trace = trace(seed(substituted_model, rng_key)).get_trace(
*model_args, **model_kwargs
)
pred_samples = {name: site["value"] for name, site in model_trace.items()}

if return_sites is not None:
Expand Down Expand Up @@ -870,6 +882,7 @@ class Predictive(object):
+ set `batch_ndims=1` to get predictions from a one dimensional batch of the guide and parameters
with shapes `(num_samples x batch_size x ...)`
:param exclude_deterministic: indicates whether to ignore deterministic sites from the posterior samples.
:return: dict of samples from the predictive distribution.
Expand Down Expand Up @@ -907,6 +920,7 @@ def __init__(
infer_discrete: bool = False,
parallel: bool = False,
batch_ndims: Optional[int] = None,
exclude_deterministic: bool = True,
):
if posterior_samples is None and num_samples is None:
raise ValueError(
Expand Down Expand Up @@ -967,6 +981,7 @@ def __init__(
self.parallel = parallel
self.batch_ndims = batch_ndims
self._batch_shape = batch_shape
self.exclude_deterministic = exclude_deterministic

def _call_with_params(self, rng_key, params, args, kwargs):
posterior_samples = self.posterior_samples
Expand All @@ -983,6 +998,7 @@ def _call_with_params(self, rng_key, params, args, kwargs):
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
exclude_deterministic=self.exclude_deterministic,
)
model = substitute(self.model, self.params)
return _predictive(
Expand All @@ -995,6 +1011,7 @@ def _call_with_params(self, rng_key, params, args, kwargs):
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
exclude_deterministic=self.exclude_deterministic,
)

def __call__(self, rng_key, *args, **kwargs):
Expand Down
39 changes: 39 additions & 0 deletions test/infer/test_infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,22 @@ def model(data=None):
return model, data, true_probs


def linear_regression():
N = 800
X = dist.Normal(0, 1).sample(random.PRNGKey(0), (N,))
y = 1.5 + X * 0.7

def model(X, y=None):
alpha = numpyro.sample("alpha", dist.Normal(0.0, 5))
beta = numpyro.sample("beta", dist.Normal(0.0, 1.0))
sigma = numpyro.sample("sigma", dist.Exponential(1.0))
with numpyro.plate("plate", len(X)):
mu = numpyro.deterministic("mu", alpha + X * beta)
numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)

return model, X, y


@pytest.mark.parametrize("parallel", [True, False])
def test_predictive(parallel):
model, data, true_probs = beta_bernoulli()
Expand All @@ -74,6 +90,29 @@ def test_predictive(parallel):
assert_allclose(obs.mean(0), true_probs, rtol=0.1)


@pytest.mark.parametrize("parallel", [True, False])
def test_predictive_with_deterministic(parallel):
"""Tests that the default behavior when predicting from models with
deterministic sites doesn't lead to static deterministic sites in the predictive.
"""
n_preds = 400
model, X, y = linear_regression()
mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=100)
mcmc.run(random.PRNGKey(0), X=X, y=y)
samples = mcmc.get_samples()
predictive = Predictive(model, samples, parallel=parallel)
# change the input (X) shape to make sure the deterministic shape changes
predictive_samples = predictive(random.PRNGKey(1), X=X[:n_preds])
assert predictive_samples.keys() == {"mu", "obs"}

predictive.return_sites = ["beta", "mu", "obs"]
# change the input (X) shape to make sure the deterministic shape changes
predictive_samples = predictive(random.PRNGKey(1), X=X[:n_preds])
# check shapes
assert predictive_samples["mu"].shape == (100,) + X[:n_preds].shape
assert predictive_samples["obs"].shape == (100,) + X[:n_preds].shape


def test_predictive_with_guide():
data = jnp.array([1] * 8 + [0] * 2)

Expand Down

0 comments on commit 275f079

Please sign in to comment.