From bf7446e3a2a21f5bdea3fe65b73892351360f6f0 Mon Sep 17 00:00:00 2001 From: OlaRonning Date: Thu, 10 Feb 2022 21:49:45 +0100 Subject: [PATCH 1/2] change particle info to allow nested params. --- numpyro/contrib/einstein/steinvi.py | 17 +++++++++++------ test/contrib/einstein/test_stein.py | 25 ++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/numpyro/contrib/einstein/steinvi.py b/numpyro/contrib/einstein/steinvi.py index cb5be0e2a..0206db4bf 100644 --- a/numpyro/contrib/einstein/steinvi.py +++ b/numpyro/contrib/einstein/steinvi.py @@ -111,16 +111,21 @@ def _param_size(self, param): return sum(map(self._param_size, param)) return param.size - def _calc_particle_info(self, uparams, num_particles): + def _calc_particle_info(self, uparams, num_particles, start_index=0): uparam_keys = list(uparams.keys()) uparam_keys.sort() - start_index = 0 res = {} for k in uparam_keys: - end_index = start_index + self._param_size(uparams[k]) // num_particles - res[k] = (start_index, end_index) + if isinstance(uparams[k], dict): + res_sub, end_index = self._calc_particle_info( + uparams[k], num_particles, start_index + ) + res[k] = res_sub + else: + end_index = start_index + self._param_size(uparams[k]) // num_particles + res[k] = (start_index, end_index) start_index = end_index - return res + return res, end_index def _find_init_params(self, particle_seed, inner_guide, inner_guide_trace): def extract_info(site): @@ -180,7 +185,7 @@ def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs): stein_particles, unravel_pytree, unravel_pytree_batched = batch_ravel_pytree( stein_uparams, nbatch_dims=1 ) - particle_info = self._calc_particle_info( + particle_info, _ = self._calc_particle_info( stein_uparams, stein_particles.shape[0] ) diff --git a/test/contrib/einstein/test_stein.py b/test/contrib/einstein/test_stein.py index 8605d8040..f2eeb5d0c 100644 --- a/test/contrib/einstein/test_stein.py +++ b/test/contrib/einstein/test_stein.py @@ -382,12 +382,35 @@ def test_calc_particle_info(num_params, num_particles): expected_pinfo = dict(zip(string.ascii_lowercase[:num_params], expected_start_end)) stein = SteinVI(id, id, Adam(1.0), Trace_ELBO(), RBFKernel()) - pinfo = stein._calc_particle_info(uparams, num_particles) + pinfo, _ = stein._calc_particle_info(uparams, num_particles) for k in pinfo.keys(): assert pinfo[k] == expected_pinfo[k], f"Failed for seed {seed}" +def test_calc_particle_info_nested(): + num_params = 3 + num_particles = 10 + seed = random.PRNGKey(42) + sizes = Poisson(5).sample(seed, (100, nrandom.randint(1, 10))) + 1 + uparam = tuple(np.empty(tuple(size)) for size in sizes) + uparams = { + string.ascii_lowercase[i]: { + string.ascii_lowercase[j]: uparam for j in range(num_params) + } + for i in range(num_params) + } + + stein = SteinVI(id, id, Adam(1.0), Trace_ELBO(), RBFKernel()) + pinfo, _ = stein._calc_particle_info(uparams, num_particles) + start = 0 + tot_size = sum(map(lambda size: size.prod(), sizes)) // num_particles + for val in pinfo.values(): + for v in val.values(): + assert v == (start, start + tot_size) + start += tot_size + + ######################################## # Stein Kernels ######################################## From 1ae354b86c50867a11a8028556d1c464b9a15253 Mon Sep 17 00:00:00 2001 From: OlaRonning Date: Thu, 10 Feb 2022 23:02:16 +0100 Subject: [PATCH 2/2] fixed unassigned end_index in `_calc_particle_info` --- numpyro/contrib/einstein/steinvi.py | 1 + 1 file changed, 1 insertion(+) diff --git a/numpyro/contrib/einstein/steinvi.py b/numpyro/contrib/einstein/steinvi.py index 0206db4bf..6ec2af5eb 100644 --- a/numpyro/contrib/einstein/steinvi.py +++ b/numpyro/contrib/einstein/steinvi.py @@ -115,6 +115,7 @@ def _calc_particle_info(self, uparams, num_particles, start_index=0): uparam_keys = list(uparams.keys()) uparam_keys.sort() res = {} + end_index = start_index for k in uparam_keys: if isinstance(uparams[k], dict): res_sub, end_index = self._calc_particle_info(