diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 8ffd38871..9b159187a 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -188,11 +188,23 @@ def _sample_fn_nojit_args(state, sampler, args, kwargs): return (sampler.sample(state[0], args, kwargs),) -def _collect_fn(collect_fields): - @cached_by(_collect_fn, collect_fields) +def _collect_fn(collect_fields, remove_sites): + @cached_by(_collect_fn, collect_fields, remove_sites) def collect(x): if collect_fields: - return attrgetter(*collect_fields)(x[0]) + fields = attrgetter(*collect_fields)(x[0]) + + if remove_sites != (): + fields = [fields] if len(collect_fields) == 1 else list(fields) + assert isinstance(fields[0], dict) + + sample_sites = fields[0].copy() + for site in remove_sites: + sample_sites.pop(site) + fields[0] = sample_sites + fields = fields[0] if len(collect_fields) == 1 else fields + + return fields else: return x[0] @@ -419,7 +431,7 @@ def _get_cached_init_state(self, rng_key, args, kwargs): except TypeError: return None - def _single_chain_mcmc(self, init, args, kwargs, collect_fields): + def _single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites): rng_key, init_state, init_params = init # Check if _sample_fn is None, then we need to initialize the sampler. if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None): @@ -452,7 +464,7 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields): upper_idx, sample_fn, init_val, - transform=_collect_fn(collect_fields), + transform=_collect_fn(collect_fields, remove_sites), progbar=self.progress_bar, return_last_val=True, thinning=self.thinning, @@ -556,7 +568,8 @@ def warmup( These are typically the arguments needed by the `model`. :param extra_fields: Extra fields (aside from :meth:`~numpyro.infer.mcmc.MCMCKernel.default_fields`) from the state object (e.g. :data:`numpyro.infer.hmc.HMCState` for HMC) to collect during - the MCMC run. + the MCMC run. Exclude sample sites from collection with "~`sampler.sample_field`.`sample_site`". + e.g. "~z.a" will prevent site "a" from being collected if you're using the NUTS sampler. :type extra_fields: tuple or list :param bool collect_warmup: Whether to collect samples from the warmup phase. Defaults to `False`. @@ -591,7 +604,9 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): :param extra_fields: Extra fields (aside from `"z"`, `"diverging"`) from the state object (e.g. :data:`numpyro.infer.hmc.HMCState` for HMC) to be collected during the MCMC run. Note that subfields can be accessed using dots, e.g. - `"adapt_state.step_size"` can be used to collect step sizes at each step. + `"adapt_state.step_size"` can be used to collect step sizes at each step. Exclude sample sites from + collection with "~`sampler.sample_field`.`sample_site`". e.g. "~z.a" will prevent site "a" from + being collected if you're using the NUTS sampler. :type extra_fields: tuple or list of str :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn` provided to the kernel. If the kernel is @@ -626,18 +641,25 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): " as `num_chains`." ) assert isinstance(extra_fields, (tuple, list)) - collect_fields = tuple( - set( - (self._sample_field,) - + tuple(self._default_fields) - + tuple(extra_fields) - ) - ) + + collect_fields = {} + remove_sites = {} + for field_name in ( + (self._sample_field,) + tuple(self._default_fields) + tuple(extra_fields) + ): + if field_name.startswith(f"~{self._sample_field}."): + remove_sites[(field_name[len(self._sample_field) + 2 :])] = None + else: + collect_fields[field_name] = None + collect_fields = tuple(collect_fields.keys()) + remove_sites = tuple(remove_sites.keys()) + partial_map_fn = partial( self._single_chain_mcmc, args=args, kwargs=kwargs, collect_fields=collect_fields, + remove_sites=remove_sites, ) map_args = (rng_key, init_state, init_params) if self.num_chains == 1: diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index fa1fea460..480aba80d 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -1186,3 +1186,17 @@ def model(data): mcmc.run(rng_key, data, extra_fields=("num_steps",)) num_steps_list = np.array(mcmc.get_extra_fields()["num_steps"]) assert all(step == num_steps for step in num_steps_list) + + +@pytest.mark.parametrize("kernel_cls", [NUTS, BarkerMH]) +@pytest.mark.parametrize("remove_sites", [("~z.a", "~z.b"), ("~z.a", "~z.a")]) +def test_remove_sites(kernel_cls, remove_sites): + def model(): + numpyro.sample("a", dist.Normal()) + numpyro.sample("b", dist.Normal()) + + mcmc = MCMC(kernel_cls(model), num_warmup=10, num_samples=10) + mcmc.run(random.PRNGKey(0), extra_fields=remove_sites) + samps = mcmc.get_samples() + + assert all([site[3:] not in samps for site in remove_sites])