From 512c80f6d35f97d041d804cdd99b7eef1aff83c7 Mon Sep 17 00:00:00 2001 From: martin jankowiak Date: Sun, 19 Jun 2022 16:36:56 -0400 Subject: [PATCH 01/11] initial commit of autosldais --- numpyro/infer/autoguide.py | 173 +++++++++++++++++++++++++++++++++++ test/infer/test_autoguide.py | 71 ++++++++++++++ 2 files changed, 244 insertions(+) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 5738c5666..b23def813 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -64,6 +64,7 @@ "AutoIAFNormal", "AutoDelta", "AutoSemiDAIS", + "AutoSurrogateLikelihoodDAIS", ] @@ -865,6 +866,178 @@ def _single_sample(_rng_key): return _single_sample(rng_key) +class AutoSurrogateLikelihoodDAIS(AutoDAIS): + """ + This implementation of :class:`AutoSurrogateLikelihoodDAIS` provides a + mini-batchable family of variational distributions as described in [1]. + It combines a surrogate likelihood with Differentiable Annealed + Importance Sampling (DAIS) [1, 2]. It is not applicable to models with + local latent variables. The surrogate likelihood is provided by the user. + Unlike :class:`AutoDAIS`, it *can* be used in conjunction with data subsampling. + + **Reference:** + + 1. *Surrogate likelihoods for variational annealed importance sampling*, + Martin Jankowiak, Du Phan + 2. *MCMC Variational Inference via Uncorrected Hamiltonian Annealing*, + Tomas Geffner, Justin Domke + 3. *Differentiable Annealed Importance Sampling and the Perils of Gradient Noise*, + Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse + + Usage:: + + guide = AutoSurrogateLikelihoodDAIS(model, surrogate_model) + svi = SVI(model, guide, ...) + + :param callable model: A NumPyro model. + :param str prefix: A prefix that will be prefixed to all param internal sites. + :param int K: A positive integer that controls the number of HMC steps used. + Defaults to 4. + :param str base_dist: Controls whether the base Normal variational distribution + is parameterized by a "diagonal" covariance matrix or a full-rank covariance + matrix parameterized by a lower-triangular "cholesky" factor. Defaults to "diagonal". + :param float eta_init: The initial value of the step size used in HMC. Defaults + to 0.01. + :param float eta_max: The maximum value of the learnable step size used in HMC. + Defaults to 0.1. + :param float gamma_init: The initial value of the learnable damping factor used + during partial momentum refreshments in HMC. Defaults to 0.9. + :param callable init_loc_fn: A per-site initialization function. + See :ref:`init_strategy` section for available functions. + :param float init_scale: Initial scale for the standard deviation of + the base variational distribution for each (unconstrained transformed) + latent variable. Defaults to 0.1. + """ + + def __init__( + self, + model, + surrogate_model, + *, + K=8, + eta_init=0.01, + eta_max=0.1, + gamma_init=0.9, + prefix="auto", + base_dist="diagonal", + init_loc_fn=init_to_uniform, + init_scale=0.1, + ): + super().__init__( + model, + K=K, + eta_init=eta_init, + eta_max=eta_max, + gamma_init=gamma_init, + prefix=prefix, + init_loc_fn=init_loc_fn, + init_scale=init_scale, + base_dist=base_dist, + ) + + self.surrogate_model = surrogate_model + + def _setup_prototype(self, *args, **kwargs): + AutoContinuous._setup_prototype(self, *args, **kwargs) + + rng_key = numpyro.prng_key() + + with numpyro.handlers.block(): + (_, self._surrogate_potential_fn, _, _) = initialize_model( + rng_key, + self.surrogate_model, + init_strategy=self.init_loc_fn, + dynamic_args=False, + model_args=(), + model_kwargs={}, + ) + + def _sample_latent(self, *args, **kwargs): + def blocked_surrogate_model(x): + x_unpack = self._unpack_latent(x) + with numpyro.handlers.block(hide_fn=lambda site: site["type"] != "param"): + return -self._surrogate_potential_fn(x_unpack) + + eta0 = numpyro.param( + "{}_eta0".format(self.prefix), + self.eta_init, + constraint=constraints.interval(0, self.eta_max), + ) + eta_coeff = numpyro.param("{}_eta_coeff".format(self.prefix), 0.0) + + gamma = numpyro.param( + "{}_gamma".format(self.prefix), + self.gamma_init, + constraint=constraints.interval(0, 1), + ) + betas = numpyro.param( + "{}_beta_increments".format(self.prefix), + jnp.ones(self.K), + constraint=constraints.positive, + ) + betas = jnp.cumsum(betas) + betas = betas / betas[-1] # K-dimensional with betas[-1] = 1 + + mass_matrix = numpyro.param( + "{}_mass_matrix".format(self.prefix), + jnp.ones(self.latent_dim), + constraint=constraints.positive, + ) + inv_mass_matrix = 0.5 / mass_matrix + + init_z_loc = numpyro.param("{}_z_0_loc".format(self.prefix), self._init_latent) + + if self.base_dist == "diagonal": + init_z_scale = numpyro.param( + "{}_z_0_scale".format(self.prefix), + jnp.full(self.latent_dim, self._init_scale), + constraint=constraints.positive, + ) + base_z_dist = dist.Normal(init_z_loc, init_z_scale).to_event() + else: + scale_tril = numpyro.param( + "{}_scale_tril".format(self.prefix), + jnp.identity(self.latent_dim) * self._init_scale, + constraint=constraints.scaled_unit_lower_cholesky, + ) + base_z_dist = dist.MultivariateNormal(init_z_loc, scale_tril=scale_tril) + + z_0 = numpyro.sample( + "{}_z_0".format(self.prefix), base_z_dist, infer={"is_auxiliary": True} + ) + + base_z_dist_log_prob = base_z_dist.log_prob + + momentum_dist = dist.Normal(0, mass_matrix).to_event() + eps = numpyro.sample( + "{}_momentum".format(self.prefix), + momentum_dist.expand((self.K,)).to_event().mask(False), + infer={"is_auxiliary": True}, + ) + + def scan_body(carry, eps_beta): + eps, beta = eps_beta + eta = eta0 + eta_coeff * beta + eta = jnp.clip(eta, a_min=0.0, a_max=self.eta_max) + z_prev, v_prev, log_factor = carry + z_half = z_prev + v_prev * eta * inv_mass_matrix + q_grad = (1.0 - beta) * grad(base_z_dist_log_prob)(z_half) + p_grad = beta * grad(blocked_surrogate_model)(z_half) + v_hat = v_prev + eta * (q_grad + p_grad) + z = z_half + v_hat * eta * inv_mass_matrix + v = gamma * v_hat + jnp.sqrt(1 - gamma**2) * eps + delta_ke = momentum_dist.log_prob(v_prev) - momentum_dist.log_prob(v_hat) + log_factor = log_factor + delta_ke + return (z, v, log_factor), None + + v_0 = eps[-1] # note the return value of scan doesn't depend on eps[-1] + (z, _, log_factor), _ = jax.lax.scan(scan_body, (z_0, v_0, 0.0), (eps, betas)) + + numpyro.factor("{}_factor".format(self.prefix), log_factor) + + return z + + def _subsample_model(model, *args, **kwargs): data = kwargs.pop("_subsample_idx", {}) with handlers.substitute(data=data): diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index 7d647c8ef..0ae4a48f4 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -41,6 +41,7 @@ AutoMultivariateNormal, AutoNormal, AutoSemiDAIS, + AutoSurrogateLikelihoodDAIS, ) from numpyro.infer.initialization import ( init_to_feasible, @@ -859,3 +860,73 @@ def model2(): svi = SVI(model2, guide, optim.Adam(0.01), Trace_ELBO()) with pytest.raises(RuntimeError, match="are no local variables"): svi.run(random.PRNGKey(0), 10) + + +def test_autosldais(N=64, D=3, num_steps=45000, num_samples=2000): + def _model(X, Y): + theta = numpyro.sample( + "theta", dist.Normal(jnp.zeros(D), jnp.ones(D)).to_event(1) + ) + with numpyro.plate("N", N, subsample_size=2 * N // 3): + X_batch = numpyro.subsample(X, event_dim=1) + Y_batch = numpyro.subsample(Y, event_dim=0) + numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_batch.T), obs=Y_batch) + + def _surrogate_model(X, Y): + theta = numpyro.sample( + "theta", dist.Normal(jnp.zeros(D), jnp.ones(D)).to_event(1) + ) + omegas = numpyro.param( + "omegas", 2.0 * jnp.ones(N // 2), constraint=dist.constraints.positive + ) + + with numpyro.plate("N", N // 2), numpyro.handlers.scale(scale=omegas): + X_batch = numpyro.subsample(X, event_dim=1) + Y_batch = numpyro.subsample(Y, event_dim=0) + numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_batch.T), obs=Y_batch) + + X = RandomState(0).randn(N, D) + X[:, 2] = X[:, 0] + X[:, 1] + logits = X[:, 0] - 0.5 * X[:, 1] + Y = dist.Bernoulli(logits=logits).sample(random.PRNGKey(0)) + + model = partial(_model, X, Y) + surrogate_model = partial(_surrogate_model, X[::2], Y[::2]) + + def _get_optim(): + scheduler = piecewise_constant_schedule( + 1.0e-3, {15 * 1000: 1.0e-4, 30 * 1000: 1.0e-5} + ) + return optax.chain( + optax.scale_by_adam(), optax.scale_by_schedule(scheduler), optax.scale(-1.0) + ) + + guide = AutoSurrogateLikelihoodDAIS( + model, surrogate_model, K=3, eta_max=0.25, eta_init=0.005 + ) + svi_result = SVI(model, guide, _get_optim(), Trace_ELBO()).run( + random.PRNGKey(1), num_steps + ) + + samples = guide.sample_posterior(random.PRNGKey(2), svi_result.params) + assert samples["theta"].shape == (D,) + + dais_elbo = Trace_ELBO(num_particles=num_samples).loss( + random.PRNGKey(0), svi_result.params, model, guide + ) + dais_elbo = -dais_elbo.item() + + def create_plates(): + return numpyro.plate("N", N, subsample_size=2 * N // 3) + + mf_guide = AutoNormal(model, create_plates=create_plates) + mf_svi_result = SVI(model, mf_guide, _get_optim(), Trace_ELBO()).run( + random.PRNGKey(0), num_steps + ) + + mf_elbo = Trace_ELBO(num_particles=num_samples).loss( + random.PRNGKey(0), mf_svi_result.params, model, mf_guide + ) + mf_elbo = -mf_elbo.item() + + assert dais_elbo > mf_elbo + 0.1 From b51b22f5bcac53ed4beb0fe820e3d29d73261abc Mon Sep 17 00:00:00 2001 From: martin jankowiak Date: Sun, 19 Jun 2022 16:45:22 -0400 Subject: [PATCH 02/11] add usage example --- numpyro/infer/autoguide.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index b23def813..81904a00c 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -886,10 +886,40 @@ class AutoSurrogateLikelihoodDAIS(AutoDAIS): Usage:: + # logistic regression model for data {X, Y} + def model(X, Y): + theta = numpyro.sample( + "theta", dist.Normal(jnp.zeros(2), jnp.ones(2)).to_event(1) + ) + with numpyro.plate("N", 100, subsample_size=10): + X_batch = numpyro.subsample(X, event_dim=1) + Y_batch = numpyro.subsample(Y, event_dim=0) + numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_batch.T), obs=Y_batch) + + # surrogate model defined by prior and surrogate likelihood. + # the latter is specified by computing the likelihood on the data subset + # {X_surr, Y_surr} of size 20. + def surrogate_model(X_surr, Y_surr): + theta = numpyro.sample( + "theta", dist.Normal(jnp.zeros(2), jnp.ones(2)).to_event(1) + ) + omegas = numpyro.param( + "omegas", 5.0 * jnp.ones(20), constraint=dist.constraints.positive + ) + + with numpyro.plate("N", 20), numpyro.handlers.scale(scale=omegas): + numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_surr.T), obs=Y_surr) + guide = AutoSurrogateLikelihoodDAIS(model, surrogate_model) svi = SVI(model, guide, ...) :param callable model: A NumPyro model. + :param callable surrogate_model: A NumPyro model that is used as a surrogate model + for guiding the HMC dynamics that define the variational distribution. In particular + `surrogate_model` should contain the same prior as `model` but should contain a + cheap-to-evaluate parametric ansatz for the likelihood. A simple ansatz for the latter + involves computing the likelihood for a fixed subset of the data and scaling the resulting + log likelihood by a learnable vector of positive weights. See the usage example above.f :param str prefix: A prefix that will be prefixed to all param internal sites. :param int K: A positive integer that controls the number of HMC steps used. Defaults to 4. From 0e6f3d20703523e168f3150b1cd4a9a6ddbe06b6 Mon Sep 17 00:00:00 2001 From: martin jankowiak Date: Sun, 19 Jun 2022 16:48:55 -0400 Subject: [PATCH 03/11] more docs --- README.md | 1 + docs/source/autoguide.rst | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/README.md b/README.md index ef4e273d4..f6bce70d9 100644 --- a/README.md +++ b/README.md @@ -211,6 +211,7 @@ Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete l - [AutoDelta](https://num.pyro.ai/en/latest/autoguide.html#autodelta) is used for computing point estimates via MAP (maximum a posteriori estimation). See [here](https://github.com/pyro-ppl/numpyro/blob/bbe1f879eede79eebfdd16dfc49c77c4d1fc727c/examples/zero_inflated_poisson.py#L101) for example usage. - [AutoBNAFNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoBNAFNormal) and [AutoIAFNormal](https://num.pyro.ai/en/latest/autoguide.html#autoiafnormal) offer flexible variational distributions parameterized by normalizing flows. - [AutoDAIS](https://num.pyro.ai/en/latest/autoguide.html#autodais) is a powerful variational inference algorithm that leverages HMC. It can be a good choice for dealing with highly correlated posteriors but may be computationally expensive depending on the nature of the model. + - [AutoSurrogateLikelihoodDAIS](https://num.pyro.ai/en/latest/autoguide.html#autosurrogatelikelihooddais) is a powerful variational inference algorithm that leverages HMC and that supports data subsampling. - [AutoSemiDAIS](https://num.pyro.ai/en/latest/autoguide.html#autosemidais) constructs a posterior approximation like [AutoDAIS](https://num.pyro.ai/en/latest/autoguide.html#autodais) for local latent variables but provides support for data subsampling during ELBO training by utilizing a parametric guide for global latent variables. - [AutoLaplaceApproximation](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoLaplaceApproximation) can be used to compute a Laplace approximation. diff --git a/docs/source/autoguide.rst b/docs/source/autoguide.rst index 0383ba881..fbb289f0e 100644 --- a/docs/source/autoguide.rst +++ b/docs/source/autoguide.rst @@ -8,6 +8,7 @@ We provide a brief overview of the automatically generated guides available in N * `AutoDelta `_ is used for computing point estimates via MAP (maximum a posteriori estimation). See `here `_ for example usage. * `AutoBNAFNormal `_ and `AutoIAFNormal `_ offer flexible variational distributions parameterized by normalizing flows. * `AutoDAIS `_ is a powerful variational inference algorithm that leverages HMC. It can be a good choice for dealing with highly correlated posteriors but may be computationally expensive depending on the nature of the model. +* `AutoSurrogateLikelihoodDAIS `_ is a powerful variational inference algorithm that leverages HMC and that supports data subsampling. * `AutoSemiDAIS `_ constructs a posterior approximation like `AutoDAIS `_ for local latent variables but provides support for data subsampling during ELBO training by utilizing a parametric guide for global latent variables. * `AutoLaplaceApproximation `_ can be used to compute a Laplace approximation. @@ -108,3 +109,11 @@ AutoSemiDAIS :undoc-members: :show-inheritance: :member-order: bysource + +AutoSurrogateLikelihoodDAIS +--------------------------- +.. autoclass:: numpyro.infer.autoguide.AutoSurrogateLikelihoodDAIS + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource From b9ae072bf940138adf91905cda15130cba58e0c2 Mon Sep 17 00:00:00 2001 From: martin jankowiak Date: Sun, 19 Jun 2022 16:51:01 -0400 Subject: [PATCH 04/11] tweak --- numpyro/infer/autoguide.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 81904a00c..b7644bbd6 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -897,8 +897,8 @@ def model(X, Y): numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_batch.T), obs=Y_batch) # surrogate model defined by prior and surrogate likelihood. - # the latter is specified by computing the likelihood on the data subset - # {X_surr, Y_surr} of size 20. + # a convenient choice for specifying the latter is to computing the likelihood on + # a randomly chosen data subset, here {X_surr, Y_surr} of size 20. def surrogate_model(X_surr, Y_surr): theta = numpyro.sample( "theta", dist.Normal(jnp.zeros(2), jnp.ones(2)).to_event(1) From d0b4da379bc28bb551add8df0d3f859c983a5961 Mon Sep 17 00:00:00 2001 From: martin jankowiak Date: Sun, 19 Jun 2022 16:54:03 -0400 Subject: [PATCH 05/11] fix typo --- numpyro/infer/autoguide.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index b7644bbd6..1ae5c919e 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -919,7 +919,7 @@ def surrogate_model(X_surr, Y_surr): `surrogate_model` should contain the same prior as `model` but should contain a cheap-to-evaluate parametric ansatz for the likelihood. A simple ansatz for the latter involves computing the likelihood for a fixed subset of the data and scaling the resulting - log likelihood by a learnable vector of positive weights. See the usage example above.f + log likelihood by a learnable vector of positive weights. See the usage example above. :param str prefix: A prefix that will be prefixed to all param internal sites. :param int K: A positive integer that controls the number of HMC steps used. Defaults to 4. From 65f6f168caa45586d05b6af25dbcd08f1e67e89b Mon Sep 17 00:00:00 2001 From: martin jankowiak Date: Sun, 19 Jun 2022 18:09:48 -0400 Subject: [PATCH 06/11] tweak docstring --- numpyro/infer/autoguide.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 1ae5c919e..2bad97805 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -897,8 +897,9 @@ def model(X, Y): numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_batch.T), obs=Y_batch) # surrogate model defined by prior and surrogate likelihood. - # a convenient choice for specifying the latter is to computing the likelihood on - # a randomly chosen data subset, here {X_surr, Y_surr} of size 20. + # a convenient choice for specifying the latter is to compute the likelihood on + # a randomly chosen data subset (here {X_surr, Y_surr} of size 20) and then use + # handlers.scale to scale the log likelihood by a vector of learnable weights. def surrogate_model(X_surr, Y_surr): theta = numpyro.sample( "theta", dist.Normal(jnp.zeros(2), jnp.ones(2)).to_event(1) @@ -906,7 +907,6 @@ def surrogate_model(X_surr, Y_surr): omegas = numpyro.param( "omegas", 5.0 * jnp.ones(20), constraint=dist.constraints.positive ) - with numpyro.plate("N", 20), numpyro.handlers.scale(scale=omegas): numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_surr.T), obs=Y_surr) From a726080329c22ba9277e334c841cda0753a43902 Mon Sep 17 00:00:00 2001 From: martin jankowiak Date: Mon, 20 Jun 2022 11:06:59 -0400 Subject: [PATCH 07/11] improve docstring --- numpyro/infer/autoguide.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 2bad97805..986291fd6 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -870,10 +870,10 @@ class AutoSurrogateLikelihoodDAIS(AutoDAIS): """ This implementation of :class:`AutoSurrogateLikelihoodDAIS` provides a mini-batchable family of variational distributions as described in [1]. - It combines a surrogate likelihood with Differentiable Annealed - Importance Sampling (DAIS) [1, 2]. It is not applicable to models with - local latent variables. The surrogate likelihood is provided by the user. - Unlike :class:`AutoDAIS`, it *can* be used in conjunction with data subsampling. + It combines a user-provided surrogate likelihood with Differentiable Annealed + Importance Sampling (DAIS) [2, 3]. It is not applicable to models with local + latent variables (see :class:`AutoSemiDAIS`), but unlike :class:`AutoDAIS`, it + *can* be used in conjunction with data subsampling. **Reference:** @@ -944,7 +944,7 @@ def __init__( model, surrogate_model, *, - K=8, + K=4, eta_init=0.01, eta_max=0.1, gamma_init=0.9, From a488923bfa9a831fb8781e706c4191655c96b3dd Mon Sep 17 00:00:00 2001 From: martin jankowiak Date: Tue, 21 Jun 2022 09:08:43 -0400 Subject: [PATCH 08/11] clean-up test and add expose_types to block --- numpyro/handlers.py | 5 ++++- test/infer/test_autoguide.py | 24 +++++++++++++----------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 53ca80f4e..6a2270dc0 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -235,6 +235,7 @@ class block(Messenger): :param callable hide_fn: function which when given a dictionary containing site-level metadata returns whether it should be blocked. :param list hide: list of site names to hide. + :param list expose_types: list of site types to expose, e.g. `['param']`. **Example:** @@ -259,11 +260,13 @@ class block(Messenger): >>> assert 'b' in trace_block_a """ - def __init__(self, fn=None, hide_fn=None, hide=None): + def __init__(self, fn=None, hide_fn=None, hide=None, expose_types=None): if hide_fn is not None: self.hide_fn = hide_fn elif hide is not None: self.hide_fn = lambda msg: msg.get("name") in hide + elif expose_types is not None: + self.hide_fn = lambda msg: msg.get("type") not in expose_types else: self.hide_fn = lambda msg: True super(block, self).__init__(fn) diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index 0ae4a48f4..27fd78a9b 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -862,28 +862,30 @@ def model2(): svi.run(random.PRNGKey(0), 10) -def test_autosldais(N=64, D=3, num_steps=45000, num_samples=2000): +def test_autosldais( + N=64, subsample_size=48, num_surrogate=32, D=3, num_steps=45000, num_samples=2000 +): def _model(X, Y): theta = numpyro.sample( "theta", dist.Normal(jnp.zeros(D), jnp.ones(D)).to_event(1) ) - with numpyro.plate("N", N, subsample_size=2 * N // 3): + with numpyro.plate("N", N, subsample_size=subsample_size): X_batch = numpyro.subsample(X, event_dim=1) Y_batch = numpyro.subsample(Y, event_dim=0) numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_batch.T), obs=Y_batch) - def _surrogate_model(X, Y): + def _surrogate_model(X_surr, Y_surr): theta = numpyro.sample( "theta", dist.Normal(jnp.zeros(D), jnp.ones(D)).to_event(1) ) omegas = numpyro.param( - "omegas", 2.0 * jnp.ones(N // 2), constraint=dist.constraints.positive + "omegas", + 2.0 * jnp.ones(num_surrogate), + constraint=dist.constraints.positive, ) - with numpyro.plate("N", N // 2), numpyro.handlers.scale(scale=omegas): - X_batch = numpyro.subsample(X, event_dim=1) - Y_batch = numpyro.subsample(Y, event_dim=0) - numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_batch.T), obs=Y_batch) + with numpyro.plate("N", num_surrogate), numpyro.handlers.scale(scale=omegas): + numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_surr.T), obs=Y_surr) X = RandomState(0).randn(N, D) X[:, 2] = X[:, 0] + X[:, 1] @@ -891,7 +893,7 @@ def _surrogate_model(X, Y): Y = dist.Bernoulli(logits=logits).sample(random.PRNGKey(0)) model = partial(_model, X, Y) - surrogate_model = partial(_surrogate_model, X[::2], Y[::2]) + surrogate_model = partial(_surrogate_model, X[:num_surrogate], Y[:num_surrogate]) def _get_optim(): scheduler = piecewise_constant_schedule( @@ -917,7 +919,7 @@ def _get_optim(): dais_elbo = -dais_elbo.item() def create_plates(): - return numpyro.plate("N", N, subsample_size=2 * N // 3) + return numpyro.plate("N", N, subsample_size=subsample_size) mf_guide = AutoNormal(model, create_plates=create_plates) mf_svi_result = SVI(model, mf_guide, _get_optim(), Trace_ELBO()).run( @@ -925,7 +927,7 @@ def create_plates(): ) mf_elbo = Trace_ELBO(num_particles=num_samples).loss( - random.PRNGKey(0), mf_svi_result.params, model, mf_guide + random.PRNGKey(1), mf_svi_result.params, model, mf_guide ) mf_elbo = -mf_elbo.item() From b8d4faded485a9ea9dd276bb50a0d752718bc919 Mon Sep 17 00:00:00 2001 From: martin jankowiak Date: Tue, 21 Jun 2022 09:10:21 -0400 Subject: [PATCH 09/11] use expose_types --- numpyro/infer/autoguide.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 986291fd6..fac158997 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -985,7 +985,7 @@ def _setup_prototype(self, *args, **kwargs): def _sample_latent(self, *args, **kwargs): def blocked_surrogate_model(x): x_unpack = self._unpack_latent(x) - with numpyro.handlers.block(hide_fn=lambda site: site["type"] != "param"): + with numpyro.handlers.block(expose_types=['param']): return -self._surrogate_potential_fn(x_unpack) eta0 = numpyro.param( From 4fc30c6965a1a1e59faa2e6980df819b58c1e406 Mon Sep 17 00:00:00 2001 From: martin jankowiak Date: Tue, 21 Jun 2022 09:11:56 -0400 Subject: [PATCH 10/11] fewer_steps --- test/infer/test_autoguide.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index 27fd78a9b..ede9cade6 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -863,7 +863,7 @@ def model2(): def test_autosldais( - N=64, subsample_size=48, num_surrogate=32, D=3, num_steps=45000, num_samples=2000 + N=64, subsample_size=48, num_surrogate=32, D=3, num_steps=40000, num_samples=2000 ): def _model(X, Y): theta = numpyro.sample( From b368a90e4403cc7d177c7ea922216897f557e0ca Mon Sep 17 00:00:00 2001 From: martin jankowiak Date: Tue, 21 Jun 2022 09:25:59 -0400 Subject: [PATCH 11/11] lint --- numpyro/infer/autoguide.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index fac158997..578f87b1b 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -985,7 +985,7 @@ def _setup_prototype(self, *args, **kwargs): def _sample_latent(self, *args, **kwargs): def blocked_surrogate_model(x): x_unpack = self._unpack_latent(x) - with numpyro.handlers.block(expose_types=['param']): + with numpyro.handlers.block(expose_types=["param"]): return -self._surrogate_potential_fn(x_unpack) eta0 = numpyro.param(