diff --git a/numpyro/contrib/einstein/steinvi.py b/numpyro/contrib/einstein/steinvi.py index 424d09f79..1d41be393 100644 --- a/numpyro/contrib/einstein/steinvi.py +++ b/numpyro/contrib/einstein/steinvi.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections import namedtuple +from copy import deepcopy import functools from functools import partial from itertools import chain @@ -20,8 +21,7 @@ get_parameter_transform, ) from numpyro.contrib.funsor import config_enumerate, enum -from numpyro.distributions import Distribution, Normal -from numpyro.distributions.constraints import real +from numpyro.distributions import Distribution from numpyro.distributions.transforms import IdentityTransform from numpyro.infer.autoguide import AutoGuide from numpyro.infer.util import _guess_max_plate_nesting, transform_fn @@ -102,6 +102,38 @@ def __init__( enum=True, **static_kwargs, ): + if isinstance(guide, AutoGuide): + not_comptaible_guides = [ + "AutoIAFNormal", + "AutoBNAFNormal", + "AutoDAIS", + "AutoSemiDAIS", + "AutoSurrogateLikelihoodDAIS", + ] + guide_name = guide.__class__.__name__ + assert guide_name not in not_comptaible_guides, ( + f"SteinVI currently not compatible with {guide_name}. " + f"If you have a use case, feel free to open an issue." + ) + + init_loc_error_message = ( + "SteinVI is not compatible with init_to_feasible, init_to_value, " + "and init_to_uniform with radius=0. If you have a use case, " + "feel free to open an issue." + ) + if isinstance(guide.init_loc_fn, partial): + init_fn_name = guide.init_loc_fn.func.__name__ + if init_fn_name == "init_to_uniform": + assert ( + guide.init_loc_fn.keywords.get("radius", None) != 0 + ), init_loc_error_message + else: + init_fn_name = guide.init_loc_fn.__name__ + assert init_fn_name not in [ + "init_to_feasible", + "init_to_value", + ], init_loc_error_message + self._inference_model = model self.model = model self.guide = guide @@ -112,7 +144,7 @@ def __init__( ) self.kernel_fn = kernel_fn self.static_kwargs = static_kwargs - self.num_particles = num_stein_particles + self.num_stein_particles = num_stein_particles self.loss_temperature = loss_temperature self.repulsion_temperature = repulsion_temperature self.enum = enum @@ -167,48 +199,21 @@ def _calc_particle_info(self, uparams, num_particles, start_index=0): start_index = end_index return res, end_index - def _find_init_params(self, particle_seed, inner_guide, inner_guide_trace): - def extract_info(site): - nonlocal particle_seed - name = site["name"] - value = site["value"] - constraint = site["kwargs"].get("constraint", real) - transform = get_parameter_transform(site) - if ( - isinstance(inner_guide, AutoGuide) - and "_".join((inner_guide.prefix, "loc")) in name - ): - site_key, particle_seed = random.split(particle_seed) - unconstrained_shape = transform.inverse_shape(value.shape) - init_value = jnp.expand_dims( - transform.inv(value), 0 - ) + Normal( # Add gaussian noise - scale=0.1 - ).sample( - particle_seed, (self.num_particles, *unconstrained_shape) - ) - init_value = transform(init_value) - - else: - site_fn = site["fn"] - site_args = site["args"] - site_key, particle_seed = random.split(particle_seed) + def _find_init_params(self, particle_seed, inner_guide, model_args, model_kwargs): + def local_trace(key): + guide = deepcopy(inner_guide) - def _reinit(seed): - with handlers.seed(rng_seed=seed): - return site_fn(*site_args) + with handlers.seed(rng_seed=key), handlers.trace() as mixture_trace: + guide(*model_args, **model_kwargs) - init_value = vmap(_reinit)( - random.split(particle_seed, self.num_particles) - ) - return init_value, constraint + init_params = { + name: site["value"] + for name, site in mixture_trace.items() + if site.get("type") == "param" + } + return init_params - init_params = { - name: extract_info(site) - for name, site in inner_guide_trace.items() - if site.get("type") == "param" - } - return init_params + return vmap(local_trace)(random.split(particle_seed, self.num_stein_particles)) def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs): # 0. Separate model and guide parameters, since only guide parameters are updated using Stein @@ -352,7 +357,7 @@ def _update_force(attr_force, rep_force, jac): vmap(single_particle_grad)( stein_particles, attractive_force, repulsive_force ) - / self.num_particles + / self.num_stein_particles ) # 5. Decompose the monolithic particle forces back to concrete parameter values @@ -372,19 +377,25 @@ def init(self, rng_key: KeyArray, *args, **kwargs): :param kwargs: Keyword arguments to the model / guide. :return: initial :data:`SteinVIState` """ - rng_key, kernel_seed, model_seed, guide_seed = random.split(rng_key, 4) - model_init = handlers.seed(self.model, model_seed) - guide_init = handlers.seed(self.guide, guide_seed) - guide_trace = handlers.trace(guide_init).get_trace( - *args, **kwargs, **self.static_kwargs + + rng_key, kernel_seed, model_seed, guide_seed, particle_seed = random.split( + rng_key, 5 ) + + model_init = handlers.seed(self.model, model_seed) model_trace = handlers.trace(model_init).get_trace( *args, **kwargs, **self.static_kwargs ) - rng_key, particle_seed = random.split(rng_key) + guide_init_params = self._find_init_params( - particle_seed, self.guide, guide_trace + particle_seed, self.guide, args, kwargs ) + + guide_init = handlers.seed(self.guide, guide_seed) + guide_trace = handlers.trace(guide_init).get_trace( + *args, **kwargs, **self.static_kwargs + ) + params = {} transforms = {} inv_transforms = {} @@ -415,7 +426,7 @@ def init(self, rng_key: KeyArray, *args, **kwargs): "particle_transform", IdentityTransform() ) if site["name"] in guide_init_params: - pval, _ = guide_init_params[site["name"]] + pval = guide_init_params[site["name"]] if self.non_mixture_params_fn(site["name"]): pval = tree_map(lambda x: x[0], pval) else: diff --git a/test/contrib/einstein/test_steinvi.py b/test/contrib/einstein/test_steinvi.py index 1c0214fd2..14115db20 100644 --- a/test/contrib/einstein/test_steinvi.py +++ b/test/contrib/einstein/test_steinvi.py @@ -2,11 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 from collections import namedtuple +from functools import partial import string import numpy as np -from numpy.ma.testutils import assert_array_approx_equal import numpy.random as nrandom +from numpy.testing import assert_allclose import pytest from jax import random @@ -25,14 +26,19 @@ import numpyro.distributions as dist from numpyro.distributions import Bernoulli, Normal, Poisson from numpyro.distributions.transforms import AffineTransform -from numpyro.infer import SVI, Trace_ELBO +from numpyro.infer import Trace_ELBO, init_to_mean, init_to_value from numpyro.infer.autoguide import ( + AutoBNAFNormal, + AutoDAIS, AutoDelta, AutoDiagonalNormal, + AutoIAFNormal, AutoLaplaceApproximation, AutoLowRankMultivariateNormal, AutoMultivariateNormal, AutoNormal, + AutoSemiDAIS, + AutoSurrogateLikelihoodDAIS, ) from numpyro.infer.initialization import ( init_to_feasible, @@ -108,23 +114,13 @@ def model(features, labels): return true_coefs, (data, labels), model -######################################## -# Stein Exterior (Smoke tests) -######################################## - - @pytest.mark.parametrize("kernel", KERNELS) -@pytest.mark.parametrize( - "init_loc_fn", - (init_to_uniform(), init_to_sample(), init_to_median(), init_to_feasible()), -) -@pytest.mark.parametrize("auto_guide", (AutoDelta, AutoNormal)) @pytest.mark.parametrize("problem", (uniform_normal, regression)) -def test_steinvi_smoke(kernel, auto_guide, init_loc_fn, problem): +def test_kernel_smoke(kernel, problem): true_coefs, data, model = problem() stein = SteinVI( model, - auto_guide(model, init_loc_fn=init_loc_fn), + AutoNormal(model), Adam(1e-1), kernel, ) @@ -136,33 +132,65 @@ def test_steinvi_smoke(kernel, auto_guide, init_loc_fn, problem): ######################################## -@pytest.mark.parametrize("kernel", KERNELS) @pytest.mark.parametrize( - "init_loc_fn", - (init_to_uniform(), init_to_sample(), init_to_median(), init_to_feasible()), + "auto_guide", + [ + AutoIAFNormal, + AutoBNAFNormal, + AutoSemiDAIS, + AutoSurrogateLikelihoodDAIS, + AutoDAIS, + ], ) -@pytest.mark.parametrize("auto_guide", (AutoDelta, AutoNormal)) # add transforms -@pytest.mark.parametrize("problem", (regression, uniform_normal)) -def test_get_params(kernel, auto_guide, init_loc_fn, problem): - _, data, model = problem() - guide, optim, elbo = ( - auto_guide(model, init_loc_fn=init_loc_fn), - Adam(1e-1), - Trace_ELBO(), - ) - - stein = SteinVI(model, guide, optim, kernel) - stein_params = stein.get_params(stein.init(random.PRNGKey(0), *data)) +def test_incompatible_autoguide(auto_guide): + def model(): + return + + if auto_guide.__name__ == "AutoSurrogateLikelihoodDAIS": + guide = auto_guide(model, model) + elif auto_guide.__name__ == "AutoSemiDAIS": + guide = auto_guide(model, model, model) + else: + guide = auto_guide(model) + + try: + SteinVI( + model, + guide, + Adam(1.0), + RBFKernel(), + num_stein_particles=1, + ) + pytest.fail() + except AssertionError: + return - svi = SVI(model, guide, optim, elbo) - svi_params = svi.get_params(svi.init(random.PRNGKey(0), *data)) - assert svi_params.keys() == stein_params.keys() - for name, svi_param in svi_params.items(): - assert ( - stein_params[name].shape - == np.repeat(svi_param[None, ...], stein.num_particles, axis=0).shape +@pytest.mark.parametrize( + "init_loc", + [ + init_to_value, + init_to_feasible, + partial(init_to_value), + partial(init_to_feasible), + partial(init_to_uniform, radius=0), + ], +) +def test_incompatible_init_locs(init_loc): + def model(): + return + + try: + SteinVI( + model, + AutoDelta(model, init_loc_fn=init_loc), + Adam(1.0), + RBFKernel(), + num_stein_particles=1, ) + pytest.fail() + except AssertionError: + return @pytest.mark.parametrize( @@ -180,38 +208,42 @@ def test_get_params(kernel, auto_guide, init_loc_fn, problem): "init_loc_fn", [ init_to_uniform, - init_to_feasible, init_to_median, + init_to_mean, init_to_sample, ], ) @pytest.mark.parametrize("num_particles", [1, 2, 10]) -def test_auto_guide(auto_class, init_loc_fn, num_particles): +def test_init_auto_guide(auto_class, init_loc_fn, num_particles): latent_dim = 3 def model(obs): - a = numpyro.sample("a", Normal(0, 1)) + a = numpyro.sample("a", Normal(0, 1).expand((latent_dim,)).to_event(1)) return numpyro.sample("obs", Bernoulli(logits=a), obs=obs) obs = Bernoulli(0.5).sample(random.PRNGKey(0), (10, latent_dim)) rng_key = random.PRNGKey(0) guide_key, stein_key = random.split(rng_key) - inner_guide = auto_class(model, init_loc_fn=init_loc_fn()) - with handlers.seed(rng_seed=guide_key), handlers.trace() as inner_guide_tr: - inner_guide(obs) + guide = auto_class(model, init_loc_fn=init_loc_fn()) steinvi = SteinVI( model, - auto_class(model, init_loc_fn=init_loc_fn()), + guide, Adam(1.0), RBFKernel(), num_stein_particles=num_particles, ) state = steinvi.init(stein_key, obs) + init_params = steinvi.get_params(state) + inner_guide = auto_class(model, init_loc_fn=init_loc_fn()) + + with handlers.seed(rng_seed=guide_key), handlers.trace() as inner_guide_tr: + inner_guide(obs) + for name, site in inner_guide_tr.items(): if site.get("type") == "param": assert name in init_params @@ -223,9 +255,43 @@ def model(obs): assert np.alltrue(init_value != np.zeros(expected_shape)) assert np.unique(init_value).shape == init_value.reshape(-1).shape elif "scale" in name: - assert_array_approx_equal(init_value, np.full(expected_shape, 0.1)) - else: - assert_array_approx_equal(init_value, np.full(expected_shape, 0.0)) + assert_allclose(init_value[init_value != 0.0], 0.1, rtol=1e-6) + + +@pytest.mark.parametrize("num_particles", [1, 2, 10]) +def test_init_custom_guide(num_particles): + latent_dim = 3 + + def guide(obs): + aloc = numpyro.param( + "aloc", lambda rng_key: Normal().sample(rng_key, (latent_dim,)) + ) + numpyro.sample("a", Normal(aloc, 1).to_event(1)) + + def model(obs): + a = numpyro.sample("a", Normal(0, 1).expand((latent_dim,)).to_event(1)) + return numpyro.sample("obs", Bernoulli(logits=a), obs=obs) + + obs = Bernoulli(0.5).sample(random.PRNGKey(0), (10, latent_dim)) + + rng_key = random.PRNGKey(0) + guide_key, stein_key = random.split(rng_key) + + steinvi = SteinVI( + model, + guide, + Adam(1.0), + RBFKernel(), + num_stein_particles=num_particles, + ) + init_params = steinvi.get_params(steinvi.init(stein_key, obs)) + + init_value = init_params["aloc"] + expected_shape = (num_particles, latent_dim) + + assert expected_shape == init_value.shape + assert np.alltrue(init_value != np.zeros(expected_shape)) + assert np.unique(init_value).shape == init_value.reshape(-1).shape @pytest.mark.parametrize("length", [1, 2, 3, 6])