From 4e37df33343e6a52c6cadce4dd64424f8b84fe5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ola=20R=C3=B8nning?= Date: Tue, 15 Aug 2023 14:52:50 +0200 Subject: [PATCH] Bug/steinvi reinit (#1626) * added separate guide for reinitialization. * added test case for reinit. --- numpyro/contrib/einstein/steinvi.py | 3 ++- test/contrib/einstein/test_steinvi.py | 23 ++++++++++++++++++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/numpyro/contrib/einstein/steinvi.py b/numpyro/contrib/einstein/steinvi.py index 1d41be393..6013f2f44 100644 --- a/numpyro/contrib/einstein/steinvi.py +++ b/numpyro/contrib/einstein/steinvi.py @@ -137,6 +137,7 @@ def __init__( self._inference_model = model self.model = model self.guide = guide + self._init_guide = deepcopy(guide) self.optim = optim self.stein_loss = SteinLoss( # TODO: @OlaRonning handle enum elbo_num_particles=num_elbo_particles, @@ -388,7 +389,7 @@ def init(self, rng_key: KeyArray, *args, **kwargs): ) guide_init_params = self._find_init_params( - particle_seed, self.guide, args, kwargs + particle_seed, self._init_guide, args, kwargs ) guide_init = handlers.seed(self.guide, guide_seed) diff --git a/test/contrib/einstein/test_steinvi.py b/test/contrib/einstein/test_steinvi.py index 14115db20..4af95a237 100644 --- a/test/contrib/einstein/test_steinvi.py +++ b/test/contrib/einstein/test_steinvi.py @@ -10,7 +10,7 @@ from numpy.testing import assert_allclose import pytest -from jax import random +from jax import numpy as jnp, random import numpyro from numpyro import handlers @@ -193,6 +193,27 @@ def model(): return +def test_stein_reinit(): + num_particles = 4 + + def model(): + numpyro.sample("x", Normal(0, 1.0)) + + stein = SteinVI( + model, + AutoDelta(model), + Adam(1.0), + RBFKernel(), + num_stein_particles=num_particles, + ) + + for i in range(2): + with handlers.seed(rng_seed=i): + params = stein.get_params(stein.init(numpyro.prng_key())) + xs = params["x_auto_loc"] + assert jnp.unique(xs).shape == xs.shape + + @pytest.mark.parametrize( "auto_class", [