From 72e4053d79acd163b42bce6fbcc3222231806533 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Wed, 1 May 2024 08:34:08 -0400 Subject: [PATCH 1/6] exclude sample sites with "~" --- numpyro/infer/mcmc.py | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 8ffd38871..b49e1e973 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -188,11 +188,17 @@ def _sample_fn_nojit_args(state, sampler, args, kwargs): return (sampler.sample(state[0], args, kwargs),) -def _collect_fn(collect_fields): - @cached_by(_collect_fn, collect_fields) +def _collect_fn(collect_fields, remove_sites): + @cached_by(_collect_fn, collect_fields, remove_sites) def collect(x): if collect_fields: - return attrgetter(*collect_fields)(x[0]) + fields = list(attrgetter(*collect_fields)(x[0])) + # fields[0] is guaranteed to be `sample_field` + sample_sites = fields[0].copy() + for site in remove_sites: + del sample_sites[site] + fields[0] = sample_sites + return fields else: return x[0] @@ -419,7 +425,7 @@ def _get_cached_init_state(self, rng_key, args, kwargs): except TypeError: return None - def _single_chain_mcmc(self, init, args, kwargs, collect_fields): + def _single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites): rng_key, init_state, init_params = init # Check if _sample_fn is None, then we need to initialize the sampler. if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None): @@ -452,7 +458,7 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields): upper_idx, sample_fn, init_val, - transform=_collect_fn(collect_fields), + transform=_collect_fn(collect_fields, remove_sites), progbar=self.progress_bar, return_last_val=True, thinning=self.thinning, @@ -626,18 +632,25 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): " as `num_chains`." ) assert isinstance(extra_fields, (tuple, list)) - collect_fields = tuple( - set( - (self._sample_field,) - + tuple(self._default_fields) - + tuple(extra_fields) - ) - ) + + collect_fields = {} # we use a dictionary to ensure `sample_field` always comes first + remove_sites = [] + for field_name in ( + (self._sample_field,) + tuple(self._default_fields) + tuple(extra_fields) + ): + if field_name.startswith(f"~{self._sample_field}."): + remove_sites.append(field_name[len(self._sample_field) + 2 :]) + else: + collect_fields[field_name] = None + collect_fields = tuple(collect_fields.keys()) + remove_sites = tuple(remove_sites) + partial_map_fn = partial( self._single_chain_mcmc, args=args, kwargs=kwargs, collect_fields=collect_fields, + remove_sites=remove_sites, ) map_args = (rng_key, init_state, init_params) if self.num_chains == 1: From 9e6b587c1a333c854c3a4a87351b2493e0a4cc3c Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Wed, 1 May 2024 08:36:30 -0400 Subject: [PATCH 2/6] handle repeat remove_sites --- numpyro/infer/mcmc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index b49e1e973..566ec75bf 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -634,16 +634,16 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): assert isinstance(extra_fields, (tuple, list)) collect_fields = {} # we use a dictionary to ensure `sample_field` always comes first - remove_sites = [] + remove_sites = {} for field_name in ( (self._sample_field,) + tuple(self._default_fields) + tuple(extra_fields) ): if field_name.startswith(f"~{self._sample_field}."): - remove_sites.append(field_name[len(self._sample_field) + 2 :]) + remove_sites[(field_name[len(self._sample_field) + 2 :])] = None else: collect_fields[field_name] = None collect_fields = tuple(collect_fields.keys()) - remove_sites = tuple(remove_sites) + remove_sites = tuple(remove_sites.keys()) partial_map_fn = partial( self._single_chain_mcmc, From 85f8be5cd9c76d6982925d70a07c3c14123ca328 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Wed, 1 May 2024 08:50:35 -0400 Subject: [PATCH 3/6] test exclude sites --- test/infer/test_mcmc.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index fa1fea460..2d1d2ba36 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -1186,3 +1186,16 @@ def model(data): mcmc.run(rng_key, data, extra_fields=("num_steps",)) num_steps_list = np.array(mcmc.get_extra_fields()["num_steps"]) assert all(step == num_steps for step in num_steps_list) + + +@pytest.mark.parametrize("remove_sites", [("~z.a", "~z.b"), ("~z.a", "~z.a")]) +def test_remove_sites(remove_sites): + def model(): + numpyro.sample("a", dist.Normal()) + numpyro.sample("b", dist.Normal()) + + mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10) + mcmc.run(random.PRNGKey(0)) + samps = mcmc.get_samples() + + assert all(site not in samps for site in remove_sites) From ae302b3c837244e53a71f66d3abba122a4a791d3 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Wed, 1 May 2024 09:50:11 -0400 Subject: [PATCH 4/6] fix test case and len(1) collect_fields edge case --- numpyro/infer/mcmc.py | 5 +++-- test/infer/test_mcmc.py | 9 +++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 566ec75bf..d11e943e5 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -192,13 +192,14 @@ def _collect_fn(collect_fields, remove_sites): @cached_by(_collect_fn, collect_fields, remove_sites) def collect(x): if collect_fields: - fields = list(attrgetter(*collect_fields)(x[0])) + fields = attrgetter(*collect_fields)(x[0]) + fields = [fields] if len(collect_fields) == 1 else list(fields) # fields[0] is guaranteed to be `sample_field` sample_sites = fields[0].copy() for site in remove_sites: del sample_sites[site] fields[0] = sample_sites - return fields + return fields[0] if len(collect_fields) == 1 else fields else: return x[0] diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index 2d1d2ba36..480aba80d 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -1188,14 +1188,15 @@ def model(data): assert all(step == num_steps for step in num_steps_list) +@pytest.mark.parametrize("kernel_cls", [NUTS, BarkerMH]) @pytest.mark.parametrize("remove_sites", [("~z.a", "~z.b"), ("~z.a", "~z.a")]) -def test_remove_sites(remove_sites): +def test_remove_sites(kernel_cls, remove_sites): def model(): numpyro.sample("a", dist.Normal()) numpyro.sample("b", dist.Normal()) - mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10) - mcmc.run(random.PRNGKey(0)) + mcmc = MCMC(kernel_cls(model), num_warmup=10, num_samples=10) + mcmc.run(random.PRNGKey(0), extra_fields=remove_sites) samps = mcmc.get_samples() - assert all(site not in samps for site in remove_sites) + assert all([site[3:] not in samps for site in remove_sites]) From 604f1f7b8270511f93e6a2d3f353cc4d9e00afc1 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Thu, 2 May 2024 09:20:06 -0400 Subject: [PATCH 5/6] add dict check, switch to list, add documentation --- numpyro/infer/mcmc.py | 47 +++++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index d11e943e5..16bea4766 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -193,13 +193,18 @@ def _collect_fn(collect_fields, remove_sites): def collect(x): if collect_fields: fields = attrgetter(*collect_fields)(x[0]) - fields = [fields] if len(collect_fields) == 1 else list(fields) - # fields[0] is guaranteed to be `sample_field` - sample_sites = fields[0].copy() - for site in remove_sites: - del sample_sites[site] - fields[0] = sample_sites - return fields[0] if len(collect_fields) == 1 else fields + + if remove_sites != (): + fields = [fields] if len(collect_fields) == 1 else list(fields) + assert isinstance(fields[0], dict) + + sample_sites = fields[0].copy() + for site in remove_sites: + sample_sites.pop(site) + fields[0] = sample_sites + fields = fields[0] if len(collect_fields) == 1 else fields + + return fields else: return x[0] @@ -563,7 +568,8 @@ def warmup( These are typically the arguments needed by the `model`. :param extra_fields: Extra fields (aside from :meth:`~numpyro.infer.mcmc.MCMCKernel.default_fields`) from the state object (e.g. :data:`numpyro.infer.hmc.HMCState` for HMC) to collect during - the MCMC run. + the MCMC run. Exclude sample sites from collection with "~`sampler.sample_field`.`sample_site`". + e.g. "~z.a" will prevent site "a" from being collected if you're using the NUTS sampler. :type extra_fields: tuple or list :param bool collect_warmup: Whether to collect samples from the warmup phase. Defaults to `False`. @@ -598,7 +604,9 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): :param extra_fields: Extra fields (aside from `"z"`, `"diverging"`) from the state object (e.g. :data:`numpyro.infer.hmc.HMCState` for HMC) to be collected during the MCMC run. Note that subfields can be accessed using dots, e.g. - `"adapt_state.step_size"` can be used to collect step sizes at each step. + `"adapt_state.step_size"` can be used to collect step sizes at each step. Exclude sample sites from + collection with "~`sampler.sample_field`.`sample_site`". e.g. "~z.a" will prevent site "a" from + being collected if you're using the NUTS sampler. :type extra_fields: tuple or list of str :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn` provided to the kernel. If the kernel is @@ -634,24 +642,23 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): ) assert isinstance(extra_fields, (tuple, list)) - collect_fields = {} # we use a dictionary to ensure `sample_field` always comes first - remove_sites = {} - for field_name in ( - (self._sample_field,) + tuple(self._default_fields) + tuple(extra_fields) - ): + field_names = set(tuple(self._default_fields) + tuple(extra_fields)) + field_names.discard(self._sample_field) + remove_sites = [] + collect_fields = [self._sample_field] + + for field_name in field_names: if field_name.startswith(f"~{self._sample_field}."): - remove_sites[(field_name[len(self._sample_field) + 2 :])] = None + remove_sites.append(field_name[len(self._sample_field) + 2 :]) else: - collect_fields[field_name] = None - collect_fields = tuple(collect_fields.keys()) - remove_sites = tuple(remove_sites.keys()) + collect_fields.append(field_name) partial_map_fn = partial( self._single_chain_mcmc, args=args, kwargs=kwargs, - collect_fields=collect_fields, - remove_sites=remove_sites, + collect_fields=tuple(collect_fields), + remove_sites=tuple(remove_sites), ) map_args = (rng_key, init_state, init_params) if self.num_chains == 1: From 32eeb3ee3c0cd5973776768f57848ef7205bee88 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Thu, 2 May 2024 10:10:59 -0400 Subject: [PATCH 6/6] back to dict solution --- numpyro/infer/mcmc.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 16bea4766..9b159187a 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -642,23 +642,24 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): ) assert isinstance(extra_fields, (tuple, list)) - field_names = set(tuple(self._default_fields) + tuple(extra_fields)) - field_names.discard(self._sample_field) - remove_sites = [] - collect_fields = [self._sample_field] - - for field_name in field_names: + collect_fields = {} + remove_sites = {} + for field_name in ( + (self._sample_field,) + tuple(self._default_fields) + tuple(extra_fields) + ): if field_name.startswith(f"~{self._sample_field}."): - remove_sites.append(field_name[len(self._sample_field) + 2 :]) + remove_sites[(field_name[len(self._sample_field) + 2 :])] = None else: - collect_fields.append(field_name) + collect_fields[field_name] = None + collect_fields = tuple(collect_fields.keys()) + remove_sites = tuple(remove_sites.keys()) partial_map_fn = partial( self._single_chain_mcmc, args=args, kwargs=kwargs, - collect_fields=tuple(collect_fields), - remove_sites=tuple(remove_sites), + collect_fields=collect_fields, + remove_sites=remove_sites, ) map_args = (rng_key, init_state, init_params) if self.num_chains == 1: