From 295c7df894cff5037f806bc2a93a3e45e0d0d5f4 Mon Sep 17 00:00:00 2001 From: OlaRonning Date: Wed, 21 Jun 2023 11:49:11 +0200 Subject: [PATCH 1/8] changed `steinvi` to use `init_loc_fn` for all particles. --- numpyro/contrib/einstein/steinvi.py | 43 ++++++++++++++++----------- test/contrib/einstein/test_steinvi.py | 2 +- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/numpyro/contrib/einstein/steinvi.py b/numpyro/contrib/einstein/steinvi.py index 424d09f79..084ab6596 100644 --- a/numpyro/contrib/einstein/steinvi.py +++ b/numpyro/contrib/einstein/steinvi.py @@ -20,11 +20,15 @@ get_parameter_transform, ) from numpyro.contrib.funsor import config_enumerate, enum -from numpyro.distributions import Distribution, Normal +from numpyro.distributions import Distribution, Normal, biject_to from numpyro.distributions.constraints import real from numpyro.distributions.transforms import IdentityTransform from numpyro.infer.autoguide import AutoGuide -from numpyro.infer.util import _guess_max_plate_nesting, transform_fn +from numpyro.infer.util import ( + _guess_max_plate_nesting, + helpful_support_errors, + transform_fn, +) from numpyro.optim import _NumPyroOptim from numpyro.util import fori_collect, ravel_pytree @@ -112,7 +116,7 @@ def __init__( ) self.kernel_fn = kernel_fn self.static_kwargs = static_kwargs - self.num_particles = num_stein_particles + self.num_stein_particles = num_stein_particles self.loss_temperature = loss_temperature self.repulsion_temperature = repulsion_temperature self.enum = enum @@ -170,24 +174,29 @@ def _calc_particle_info(self, uparams, num_particles, start_index=0): def _find_init_params(self, particle_seed, inner_guide, inner_guide_trace): def extract_info(site): nonlocal particle_seed + nonlocal inner_guide_trace name = site["name"] - value = site["value"] constraint = site["kwargs"].get("constraint", real) transform = get_parameter_transform(site) - if ( - isinstance(inner_guide, AutoGuide) - and "_".join((inner_guide.prefix, "loc")) in name + if isinstance(inner_guide, AutoGuide) and name.endswith( + "_".join(("", inner_guide.prefix, "loc")) ): site_key, particle_seed = random.split(particle_seed) - unconstrained_shape = transform.inverse_shape(value.shape) - init_value = jnp.expand_dims( - transform.inv(value), 0 - ) + Normal( # Add gaussian noise - scale=0.1 - ).sample( - particle_seed, (self.num_particles, *unconstrained_shape) + sample_site = inner_guide_trace[ + name.replace("_".join(("", inner_guide.prefix, "loc")), "") + ].copy() + sample_site_shape = sample_site["kwargs"]["sample_shape"] + sample_site["kwargs"]["sample_shape"] = ( + self.num_stein_particles, + *sample_site_shape, ) - init_value = transform(init_value) + sample_site["kwargs"]["rng_key"] = site_key + sample_site["value"] = None + init_value_particles = inner_guide.init_loc_fn(sample_site) + with helpful_support_errors(sample_site): + sample_transform = biject_to(sample_site["fn"].support) + init_value_particles = sample_transform.inv(init_value_particles) + init_value = transform(init_value_particles) else: site_fn = site["fn"] @@ -199,7 +208,7 @@ def _reinit(seed): return site_fn(*site_args) init_value = vmap(_reinit)( - random.split(particle_seed, self.num_particles) + random.split(particle_seed, self.num_stein_particles) ) return init_value, constraint @@ -352,7 +361,7 @@ def _update_force(attr_force, rep_force, jac): vmap(single_particle_grad)( stein_particles, attractive_force, repulsive_force ) - / self.num_particles + / self.num_stein_particles ) # 5. Decompose the monolithic particle forces back to concrete parameter values diff --git a/test/contrib/einstein/test_steinvi.py b/test/contrib/einstein/test_steinvi.py index 1c0214fd2..15f543945 100644 --- a/test/contrib/einstein/test_steinvi.py +++ b/test/contrib/einstein/test_steinvi.py @@ -161,7 +161,7 @@ def test_get_params(kernel, auto_guide, init_loc_fn, problem): for name, svi_param in svi_params.items(): assert ( stein_params[name].shape - == np.repeat(svi_param[None, ...], stein.num_particles, axis=0).shape + == np.repeat(svi_param[None, ...], stein.num_stein_particles, axis=0).shape ) From c84ab89cf85c8aa9abc7af9f87fc90edd9e45e7f Mon Sep 17 00:00:00 2001 From: OlaRonning Date: Mon, 26 Jun 2023 10:12:19 +0200 Subject: [PATCH 2/8] removed unused import --- numpyro/contrib/einstein/steinvi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/contrib/einstein/steinvi.py b/numpyro/contrib/einstein/steinvi.py index 084ab6596..2ad241168 100644 --- a/numpyro/contrib/einstein/steinvi.py +++ b/numpyro/contrib/einstein/steinvi.py @@ -20,7 +20,7 @@ get_parameter_transform, ) from numpyro.contrib.funsor import config_enumerate, enum -from numpyro.distributions import Distribution, Normal, biject_to +from numpyro.distributions import Distribution, biject_to from numpyro.distributions.constraints import real from numpyro.distributions.transforms import IdentityTransform from numpyro.infer.autoguide import AutoGuide From bec064051442b01cbbb1946513f1cf561589bdec Mon Sep 17 00:00:00 2001 From: OlaRonning Date: Wed, 5 Jul 2023 22:21:10 +0200 Subject: [PATCH 3/8] sketched `_find_init_params` --- .../einstein/mixture_guide_predictive.py | 4 +- numpyro/contrib/einstein/steinvi.py | 120 +++++++++--------- numpyro/infer/util.py | 4 +- test/contrib/einstein/test_steinvi.py | 91 +++++++++++-- 4 files changed, 145 insertions(+), 74 deletions(-) diff --git a/numpyro/contrib/einstein/mixture_guide_predictive.py b/numpyro/contrib/einstein/mixture_guide_predictive.py index 3eca7513b..00e4ad50b 100644 --- a/numpyro/contrib/einstein/mixture_guide_predictive.py +++ b/numpyro/contrib/einstein/mixture_guide_predictive.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from functools import partial -from typing import Callable, Dict, Optional, Sequence +from typing import Callable, Dict, Optional, Sequence, Set from jax import numpy as jnp, random, tree_map, vmap from jax.tree_util import tree_flatten @@ -40,7 +40,7 @@ def __init__( params: Dict, guide_sites: Sequence, num_samples: Optional[int] = None, - return_sites: Optional[Sequence[str]] = None, + return_sites: Optional[Sequence[str] | Set[str]] = None, mixture_assignment_sitename="mixture_assignments", ): self.model_predictive = partial( diff --git a/numpyro/contrib/einstein/steinvi.py b/numpyro/contrib/einstein/steinvi.py index 2ad241168..1d41be393 100644 --- a/numpyro/contrib/einstein/steinvi.py +++ b/numpyro/contrib/einstein/steinvi.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections import namedtuple +from copy import deepcopy import functools from functools import partial from itertools import chain @@ -20,15 +21,10 @@ get_parameter_transform, ) from numpyro.contrib.funsor import config_enumerate, enum -from numpyro.distributions import Distribution, biject_to -from numpyro.distributions.constraints import real +from numpyro.distributions import Distribution from numpyro.distributions.transforms import IdentityTransform from numpyro.infer.autoguide import AutoGuide -from numpyro.infer.util import ( - _guess_max_plate_nesting, - helpful_support_errors, - transform_fn, -) +from numpyro.infer.util import _guess_max_plate_nesting, transform_fn from numpyro.optim import _NumPyroOptim from numpyro.util import fori_collect, ravel_pytree @@ -106,6 +102,38 @@ def __init__( enum=True, **static_kwargs, ): + if isinstance(guide, AutoGuide): + not_comptaible_guides = [ + "AutoIAFNormal", + "AutoBNAFNormal", + "AutoDAIS", + "AutoSemiDAIS", + "AutoSurrogateLikelihoodDAIS", + ] + guide_name = guide.__class__.__name__ + assert guide_name not in not_comptaible_guides, ( + f"SteinVI currently not compatible with {guide_name}. " + f"If you have a use case, feel free to open an issue." + ) + + init_loc_error_message = ( + "SteinVI is not compatible with init_to_feasible, init_to_value, " + "and init_to_uniform with radius=0. If you have a use case, " + "feel free to open an issue." + ) + if isinstance(guide.init_loc_fn, partial): + init_fn_name = guide.init_loc_fn.func.__name__ + if init_fn_name == "init_to_uniform": + assert ( + guide.init_loc_fn.keywords.get("radius", None) != 0 + ), init_loc_error_message + else: + init_fn_name = guide.init_loc_fn.__name__ + assert init_fn_name not in [ + "init_to_feasible", + "init_to_value", + ], init_loc_error_message + self._inference_model = model self.model = model self.guide = guide @@ -171,53 +199,21 @@ def _calc_particle_info(self, uparams, num_particles, start_index=0): start_index = end_index return res, end_index - def _find_init_params(self, particle_seed, inner_guide, inner_guide_trace): - def extract_info(site): - nonlocal particle_seed - nonlocal inner_guide_trace - name = site["name"] - constraint = site["kwargs"].get("constraint", real) - transform = get_parameter_transform(site) - if isinstance(inner_guide, AutoGuide) and name.endswith( - "_".join(("", inner_guide.prefix, "loc")) - ): - site_key, particle_seed = random.split(particle_seed) - sample_site = inner_guide_trace[ - name.replace("_".join(("", inner_guide.prefix, "loc")), "") - ].copy() - sample_site_shape = sample_site["kwargs"]["sample_shape"] - sample_site["kwargs"]["sample_shape"] = ( - self.num_stein_particles, - *sample_site_shape, - ) - sample_site["kwargs"]["rng_key"] = site_key - sample_site["value"] = None - init_value_particles = inner_guide.init_loc_fn(sample_site) - with helpful_support_errors(sample_site): - sample_transform = biject_to(sample_site["fn"].support) - init_value_particles = sample_transform.inv(init_value_particles) - init_value = transform(init_value_particles) - - else: - site_fn = site["fn"] - site_args = site["args"] - site_key, particle_seed = random.split(particle_seed) + def _find_init_params(self, particle_seed, inner_guide, model_args, model_kwargs): + def local_trace(key): + guide = deepcopy(inner_guide) - def _reinit(seed): - with handlers.seed(rng_seed=seed): - return site_fn(*site_args) + with handlers.seed(rng_seed=key), handlers.trace() as mixture_trace: + guide(*model_args, **model_kwargs) - init_value = vmap(_reinit)( - random.split(particle_seed, self.num_stein_particles) - ) - return init_value, constraint + init_params = { + name: site["value"] + for name, site in mixture_trace.items() + if site.get("type") == "param" + } + return init_params - init_params = { - name: extract_info(site) - for name, site in inner_guide_trace.items() - if site.get("type") == "param" - } - return init_params + return vmap(local_trace)(random.split(particle_seed, self.num_stein_particles)) def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs): # 0. Separate model and guide parameters, since only guide parameters are updated using Stein @@ -381,19 +377,25 @@ def init(self, rng_key: KeyArray, *args, **kwargs): :param kwargs: Keyword arguments to the model / guide. :return: initial :data:`SteinVIState` """ - rng_key, kernel_seed, model_seed, guide_seed = random.split(rng_key, 4) - model_init = handlers.seed(self.model, model_seed) - guide_init = handlers.seed(self.guide, guide_seed) - guide_trace = handlers.trace(guide_init).get_trace( - *args, **kwargs, **self.static_kwargs + + rng_key, kernel_seed, model_seed, guide_seed, particle_seed = random.split( + rng_key, 5 ) + + model_init = handlers.seed(self.model, model_seed) model_trace = handlers.trace(model_init).get_trace( *args, **kwargs, **self.static_kwargs ) - rng_key, particle_seed = random.split(rng_key) + guide_init_params = self._find_init_params( - particle_seed, self.guide, guide_trace + particle_seed, self.guide, args, kwargs ) + + guide_init = handlers.seed(self.guide, guide_seed) + guide_trace = handlers.trace(guide_init).get_trace( + *args, **kwargs, **self.static_kwargs + ) + params = {} transforms = {} inv_transforms = {} @@ -424,7 +426,7 @@ def init(self, rng_key: KeyArray, *args, **kwargs): "particle_transform", IdentityTransform() ) if site["name"] in guide_init_params: - pval, _ = guide_init_params[site["name"]] + pval = guide_init_params[site["name"]] if self.non_mixture_params_fn(site["name"]): pval = tree_map(lambda x: x[0], pval) else: diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 4a343510b..1e71bd1ad 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -4,7 +4,7 @@ from collections import namedtuple from contextlib import contextmanager from functools import partial -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Set import warnings import numpy as np @@ -899,7 +899,7 @@ def __init__( guide: Optional[Callable] = None, params: Optional[Dict] = None, num_samples: Optional[int] = None, - return_sites: Optional[List[str]] = None, + return_sites: Optional[List[str] | Set[str]] = None, infer_discrete: bool = False, parallel: bool = False, batch_ndims: Optional[int] = None, diff --git a/test/contrib/einstein/test_steinvi.py b/test/contrib/einstein/test_steinvi.py index 15f543945..81714deea 100644 --- a/test/contrib/einstein/test_steinvi.py +++ b/test/contrib/einstein/test_steinvi.py @@ -2,11 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 from collections import namedtuple +from functools import partial import string import numpy as np -from numpy.ma.testutils import assert_array_approx_equal import numpy.random as nrandom +from numpy.testing import assert_allclose import pytest from jax import random @@ -25,14 +26,19 @@ import numpyro.distributions as dist from numpyro.distributions import Bernoulli, Normal, Poisson from numpyro.distributions.transforms import AffineTransform -from numpyro.infer import SVI, Trace_ELBO +from numpyro.infer import SVI, Trace_ELBO, init_to_mean, init_to_value from numpyro.infer.autoguide import ( + AutoBNAFNormal, + AutoDAIS, AutoDelta, AutoDiagonalNormal, + AutoIAFNormal, AutoLaplaceApproximation, AutoLowRankMultivariateNormal, AutoMultivariateNormal, AutoNormal, + AutoSemiDAIS, + AutoSurrogateLikelihoodDAIS, ) from numpyro.infer.initialization import ( init_to_feasible, @@ -165,6 +171,67 @@ def test_get_params(kernel, auto_guide, init_loc_fn, problem): ) +@pytest.mark.parametrize( + "auto_guide", + [ + AutoIAFNormal, + AutoBNAFNormal, + AutoSemiDAIS, + AutoSurrogateLikelihoodDAIS, + AutoDAIS, + ], +) +def test_incompatiable_autoguide(auto_guide): + def model(): + return + + if auto_guide.__name__ == "AutoSurrogateLikelihoodDAIS": + guide = auto_guide(model, model) + elif auto_guide.__name__ == "AutoSemiDAIS": + guide = auto_guide(model, model, model) + else: + guide = auto_guide(model) + + try: + SteinVI( + model, + guide, + Adam(1.0), + RBFKernel(), + num_stein_particles=1, + ) + pytest.fail() + except AssertionError: + return + + +@pytest.mark.parametrize( + "init_loc", + [ + init_to_value, + init_to_feasible, + partial(init_to_value), + partial(init_to_feasible), + partial(init_to_uniform, radius=0), + ], +) +def test_incompatible_init_locs(init_loc): + def model(): + return + + try: + SteinVI( + model, + AutoDelta(model, init_loc_fn=init_loc), + Adam(1.0), + RBFKernel(), + num_stein_particles=1, + ) + pytest.fail() + except AssertionError: + return + + @pytest.mark.parametrize( "auto_class", [ @@ -180,8 +247,8 @@ def test_get_params(kernel, auto_guide, init_loc_fn, problem): "init_loc_fn", [ init_to_uniform, - init_to_feasible, init_to_median, + init_to_mean, init_to_sample, ], ) @@ -190,28 +257,32 @@ def test_auto_guide(auto_class, init_loc_fn, num_particles): latent_dim = 3 def model(obs): - a = numpyro.sample("a", Normal(0, 1)) + a = numpyro.sample("a", Normal(0, 1).expand((latent_dim,)).to_event(1)) return numpyro.sample("obs", Bernoulli(logits=a), obs=obs) obs = Bernoulli(0.5).sample(random.PRNGKey(0), (10, latent_dim)) rng_key = random.PRNGKey(0) guide_key, stein_key = random.split(rng_key) - inner_guide = auto_class(model, init_loc_fn=init_loc_fn()) - with handlers.seed(rng_seed=guide_key), handlers.trace() as inner_guide_tr: - inner_guide(obs) + guide = auto_class(model, init_loc_fn=init_loc_fn()) steinvi = SteinVI( model, - auto_class(model, init_loc_fn=init_loc_fn()), + guide, Adam(1.0), RBFKernel(), num_stein_particles=num_particles, ) state = steinvi.init(stein_key, obs) + init_params = steinvi.get_params(state) + inner_guide = auto_class(model, init_loc_fn=init_loc_fn()) + + with handlers.seed(rng_seed=guide_key), handlers.trace() as inner_guide_tr: + inner_guide(obs) + for name, site in inner_guide_tr.items(): if site.get("type") == "param": assert name in init_params @@ -223,9 +294,7 @@ def model(obs): assert np.alltrue(init_value != np.zeros(expected_shape)) assert np.unique(init_value).shape == init_value.reshape(-1).shape elif "scale" in name: - assert_array_approx_equal(init_value, np.full(expected_shape, 0.1)) - else: - assert_array_approx_equal(init_value, np.full(expected_shape, 0.0)) + assert_allclose(init_value[init_value != 0.0], 0.1, rtol=1e-6) @pytest.mark.parametrize("length", [1, 2, 3, 6]) From 414d2bdfe3058c61bb13e8eef6bbe5d2adb7c794 Mon Sep 17 00:00:00 2001 From: OlaRonning Date: Wed, 5 Jul 2023 22:28:32 +0200 Subject: [PATCH 4/8] | requires >=python3.10 --- numpyro/contrib/einstein/mixture_guide_predictive.py | 2 +- numpyro/infer/util.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/numpyro/contrib/einstein/mixture_guide_predictive.py b/numpyro/contrib/einstein/mixture_guide_predictive.py index 00e4ad50b..f3d29dd04 100644 --- a/numpyro/contrib/einstein/mixture_guide_predictive.py +++ b/numpyro/contrib/einstein/mixture_guide_predictive.py @@ -40,7 +40,7 @@ def __init__( params: Dict, guide_sites: Sequence, num_samples: Optional[int] = None, - return_sites: Optional[Sequence[str] | Set[str]] = None, + return_sites: Optional[Sequence[str]] = None, mixture_assignment_sitename="mixture_assignments", ): self.model_predictive = partial( diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 1e71bd1ad..9835ea65e 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -899,7 +899,7 @@ def __init__( guide: Optional[Callable] = None, params: Optional[Dict] = None, num_samples: Optional[int] = None, - return_sites: Optional[List[str] | Set[str]] = None, + return_sites: Optional[List[str]] = None, infer_discrete: bool = False, parallel: bool = False, batch_ndims: Optional[int] = None, From 24bc0e068f1a5e6921f8c689545643e59abbb61d Mon Sep 17 00:00:00 2001 From: OlaRonning Date: Wed, 5 Jul 2023 22:32:23 +0200 Subject: [PATCH 5/8] removed unused imports --- numpyro/contrib/einstein/mixture_guide_predictive.py | 2 +- numpyro/infer/util.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/numpyro/contrib/einstein/mixture_guide_predictive.py b/numpyro/contrib/einstein/mixture_guide_predictive.py index f3d29dd04..3eca7513b 100644 --- a/numpyro/contrib/einstein/mixture_guide_predictive.py +++ b/numpyro/contrib/einstein/mixture_guide_predictive.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from functools import partial -from typing import Callable, Dict, Optional, Sequence, Set +from typing import Callable, Dict, Optional, Sequence from jax import numpy as jnp, random, tree_map, vmap from jax.tree_util import tree_flatten diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 9835ea65e..4a343510b 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -4,7 +4,7 @@ from collections import namedtuple from contextlib import contextmanager from functools import partial -from typing import Callable, Dict, List, Optional, Set +from typing import Callable, Dict, List, Optional import warnings import numpy as np From c157170c6500732e053568103dd157c66afabe98 Mon Sep 17 00:00:00 2001 From: OlaRonning Date: Thu, 6 Jul 2023 08:29:08 +0200 Subject: [PATCH 6/8] reduced smoke test (took long to run). added custom guide test. --- test/contrib/einstein/test_steinvi.py | 54 ++++++++++++++++++++------- 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/test/contrib/einstein/test_steinvi.py b/test/contrib/einstein/test_steinvi.py index 81714deea..9955f3ded 100644 --- a/test/contrib/einstein/test_steinvi.py +++ b/test/contrib/einstein/test_steinvi.py @@ -114,23 +114,13 @@ def model(features, labels): return true_coefs, (data, labels), model -######################################## -# Stein Exterior (Smoke tests) -######################################## - - @pytest.mark.parametrize("kernel", KERNELS) -@pytest.mark.parametrize( - "init_loc_fn", - (init_to_uniform(), init_to_sample(), init_to_median(), init_to_feasible()), -) -@pytest.mark.parametrize("auto_guide", (AutoDelta, AutoNormal)) @pytest.mark.parametrize("problem", (uniform_normal, regression)) -def test_steinvi_smoke(kernel, auto_guide, init_loc_fn, problem): +def test_kernel_smoke(kernel, problem): true_coefs, data, model = problem() stein = SteinVI( model, - auto_guide(model, init_loc_fn=init_loc_fn), + AutoNormal(model), Adam(1e-1), kernel, ) @@ -181,7 +171,7 @@ def test_get_params(kernel, auto_guide, init_loc_fn, problem): AutoDAIS, ], ) -def test_incompatiable_autoguide(auto_guide): +def test_incompatible_autoguide(auto_guide): def model(): return @@ -253,7 +243,7 @@ def model(): ], ) @pytest.mark.parametrize("num_particles", [1, 2, 10]) -def test_auto_guide(auto_class, init_loc_fn, num_particles): +def test_init_auto_guide(auto_class, init_loc_fn, num_particles): latent_dim = 3 def model(obs): @@ -297,6 +287,42 @@ def model(obs): assert_allclose(init_value[init_value != 0.0], 0.1, rtol=1e-6) +@pytest.mark.parametrize("num_particles", [1, 2, 10]) +def test_init_custom_guide(num_particles): + latent_dim = 3 + + def guide(obs): + aloc = numpyro.param( + "aloc", lambda rng_key: Normal().sample(rng_key, (latent_dim,)) + ) + numpyro.sample("a", Normal(aloc, 1).to_event(1)) + + def model(obs): + a = numpyro.sample("a", Normal(0, 1).expand((latent_dim,)).to_event(1)) + return numpyro.sample("obs", Bernoulli(logits=a), obs=obs) + + obs = Bernoulli(0.5).sample(random.PRNGKey(0), (10, latent_dim)) + + rng_key = random.PRNGKey(0) + guide_key, stein_key = random.split(rng_key) + + steinvi = SteinVI( + model, + guide, + Adam(1.0), + RBFKernel(), + num_stein_particles=num_particles, + ) + init_params = steinvi.get_params(steinvi.init(stein_key, obs)) + + init_value = init_params["aloc"] + expected_shape = (num_particles, latent_dim) + + assert expected_shape == init_value.shape + assert np.alltrue(init_value != np.zeros(expected_shape)) + assert np.unique(init_value).shape == init_value.reshape(-1).shape + + @pytest.mark.parametrize("length", [1, 2, 3, 6]) @pytest.mark.parametrize("depth", [1, 3, 5]) @pytest.mark.parametrize("t", [list, tuple]) # add dict, set From d77b34cc9274f1df107d2fb765eca589746cc99f Mon Sep 17 00:00:00 2001 From: OlaRonning Date: Thu, 6 Jul 2023 10:13:29 +0200 Subject: [PATCH 7/8] removed tests covered by new test cases. --- test/contrib/einstein/test_steinvi.py | 29 --------------------------- 1 file changed, 29 deletions(-) diff --git a/test/contrib/einstein/test_steinvi.py b/test/contrib/einstein/test_steinvi.py index 9955f3ded..6f8b1d4ff 100644 --- a/test/contrib/einstein/test_steinvi.py +++ b/test/contrib/einstein/test_steinvi.py @@ -132,35 +132,6 @@ def test_kernel_smoke(kernel, problem): ######################################## -@pytest.mark.parametrize("kernel", KERNELS) -@pytest.mark.parametrize( - "init_loc_fn", - (init_to_uniform(), init_to_sample(), init_to_median(), init_to_feasible()), -) -@pytest.mark.parametrize("auto_guide", (AutoDelta, AutoNormal)) # add transforms -@pytest.mark.parametrize("problem", (regression, uniform_normal)) -def test_get_params(kernel, auto_guide, init_loc_fn, problem): - _, data, model = problem() - guide, optim, elbo = ( - auto_guide(model, init_loc_fn=init_loc_fn), - Adam(1e-1), - Trace_ELBO(), - ) - - stein = SteinVI(model, guide, optim, kernel) - stein_params = stein.get_params(stein.init(random.PRNGKey(0), *data)) - - svi = SVI(model, guide, optim, elbo) - svi_params = svi.get_params(svi.init(random.PRNGKey(0), *data)) - assert svi_params.keys() == stein_params.keys() - - for name, svi_param in svi_params.items(): - assert ( - stein_params[name].shape - == np.repeat(svi_param[None, ...], stein.num_stein_particles, axis=0).shape - ) - - @pytest.mark.parametrize( "auto_guide", [ From 6c2129f4b848082bebdd3c49ced2723e15d4e7ea Mon Sep 17 00:00:00 2001 From: OlaRonning Date: Thu, 6 Jul 2023 10:24:53 +0200 Subject: [PATCH 8/8] fixed imports. --- test/contrib/einstein/test_steinvi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/contrib/einstein/test_steinvi.py b/test/contrib/einstein/test_steinvi.py index 6f8b1d4ff..14115db20 100644 --- a/test/contrib/einstein/test_steinvi.py +++ b/test/contrib/einstein/test_steinvi.py @@ -26,7 +26,7 @@ import numpyro.distributions as dist from numpyro.distributions import Bernoulli, Normal, Poisson from numpyro.distributions.transforms import AffineTransform -from numpyro.infer import SVI, Trace_ELBO, init_to_mean, init_to_value +from numpyro.infer import Trace_ELBO, init_to_mean, init_to_value from numpyro.infer.autoguide import ( AutoBNAFNormal, AutoDAIS,