diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 916545cda..7d57cbbd0 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -761,6 +761,7 @@ def _predictive( return_sites=None, infer_discrete=False, parallel=True, + exclude_deterministic: bool = True, model_args=(), model_kwargs={}, ): @@ -774,7 +775,7 @@ def _predictive( posterior_samples, ) prototype_trace = trace( - seed(substitute(masked_model, prototype_sample), subkey) + seed(condition(masked_model, prototype_sample), subkey) ).get_trace(*model_args, **model_kwargs) first_available_dim = -_guess_max_plate_nesting(prototype_trace) - 1 @@ -795,9 +796,20 @@ def single_prediction(val): **model_kwargs, ) else: - model_trace = trace( - seed(substitute(masked_model, samples), rng_key) - ).get_trace(*model_args, **model_kwargs) + + def _samples_wo_deterministic(msg): + return ( + samples.get(msg["name"]) if msg["type"] != "deterministic" else None + ) + + substituted_model = ( + substitute(masked_model, substitute_fn=_samples_wo_deterministic) + if exclude_deterministic + else substitute(masked_model, samples) + ) + model_trace = trace(seed(substituted_model, rng_key)).get_trace( + *model_args, **model_kwargs + ) pred_samples = {name: site["value"] for name, site in model_trace.items()} if return_sites is not None: @@ -870,6 +882,7 @@ class Predictive(object): + set `batch_ndims=1` to get predictions from a one dimensional batch of the guide and parameters with shapes `(num_samples x batch_size x ...)` + :param exclude_deterministic: indicates whether to ignore deterministic sites from the posterior samples. :return: dict of samples from the predictive distribution. @@ -907,6 +920,7 @@ def __init__( infer_discrete: bool = False, parallel: bool = False, batch_ndims: Optional[int] = None, + exclude_deterministic: bool = True, ): if posterior_samples is None and num_samples is None: raise ValueError( @@ -967,6 +981,7 @@ def __init__( self.parallel = parallel self.batch_ndims = batch_ndims self._batch_shape = batch_shape + self.exclude_deterministic = exclude_deterministic def _call_with_params(self, rng_key, params, args, kwargs): posterior_samples = self.posterior_samples @@ -983,6 +998,7 @@ def _call_with_params(self, rng_key, params, args, kwargs): parallel=self.parallel, model_args=args, model_kwargs=kwargs, + exclude_deterministic=self.exclude_deterministic, ) model = substitute(self.model, self.params) return _predictive( @@ -995,6 +1011,7 @@ def _call_with_params(self, rng_key, params, args, kwargs): parallel=self.parallel, model_args=args, model_kwargs=kwargs, + exclude_deterministic=self.exclude_deterministic, ) def __call__(self, rng_key, *args, **kwargs): diff --git a/test/infer/test_infer_util.py b/test/infer/test_infer_util.py index 4412916b8..aecf7ca2c 100644 --- a/test/infer/test_infer_util.py +++ b/test/infer/test_infer_util.py @@ -53,6 +53,22 @@ def model(data=None): return model, data, true_probs +def linear_regression(): + N = 800 + X = dist.Normal(0, 1).sample(random.PRNGKey(0), (N,)) + y = 1.5 + X * 0.7 + + def model(X, y=None): + alpha = numpyro.sample("alpha", dist.Normal(0.0, 5)) + beta = numpyro.sample("beta", dist.Normal(0.0, 1.0)) + sigma = numpyro.sample("sigma", dist.Exponential(1.0)) + with numpyro.plate("plate", len(X)): + mu = numpyro.deterministic("mu", alpha + X * beta) + numpyro.sample("obs", dist.Normal(mu, sigma), obs=y) + + return model, X, y + + @pytest.mark.parametrize("parallel", [True, False]) def test_predictive(parallel): model, data, true_probs = beta_bernoulli() @@ -74,6 +90,29 @@ def test_predictive(parallel): assert_allclose(obs.mean(0), true_probs, rtol=0.1) +@pytest.mark.parametrize("parallel", [True, False]) +def test_predictive_with_deterministic(parallel): + """Tests that the default behavior when predicting from models with + deterministic sites doesn't lead to static deterministic sites in the predictive. + """ + n_preds = 400 + model, X, y = linear_regression() + mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=100) + mcmc.run(random.PRNGKey(0), X=X, y=y) + samples = mcmc.get_samples() + predictive = Predictive(model, samples, parallel=parallel) + # change the input (X) shape to make sure the deterministic shape changes + predictive_samples = predictive(random.PRNGKey(1), X=X[:n_preds]) + assert predictive_samples.keys() == {"mu", "obs"} + + predictive.return_sites = ["beta", "mu", "obs"] + # change the input (X) shape to make sure the deterministic shape changes + predictive_samples = predictive(random.PRNGKey(1), X=X[:n_preds]) + # check shapes + assert predictive_samples["mu"].shape == (100,) + X[:n_preds].shape + assert predictive_samples["obs"].shape == (100,) + X[:n_preds].shape + + def test_predictive_with_guide(): data = jnp.array([1] * 8 + [0] * 2)