diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 278387578..c4a9f7235 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -674,6 +674,7 @@ class seed(Messenger): :param fn: Python callable with NumPyro primitives. :param rng_seed: a random number generator seed. :type rng_seed: int, jnp.ndarray scalar, or jax.random.PRNGKey + :param list hide_types: an optional list of side types to skip seeding, e.g. ['plate']. .. note:: @@ -703,7 +704,7 @@ class seed(Messenger): >>> assert x == y """ - def __init__(self, fn=None, rng_seed=None): + def __init__(self, fn=None, rng_seed=None, hide_types=None): if isinstance(rng_seed, int) or ( isinstance(rng_seed, (np.ndarray, jnp.ndarray)) and not jnp.shape(rng_seed) ): @@ -715,19 +716,19 @@ def __init__(self, fn=None, rng_seed=None): ): raise TypeError("Incorrect type for rng_seed: {}".format(type(rng_seed))) self.rng_key = rng_seed + self.hide_types = [] if hide_types is None else hide_types super(seed, self).__init__(fn) def process_message(self, msg): - if ( - msg["type"] == "sample" - and not msg["is_observed"] - and msg["kwargs"]["rng_key"] is None - ) or msg["type"] in ["prng_key", "plate", "control_flow"]: - if msg["value"] is not None: - # no need to create a new key when value is available - return - self.rng_key, rng_key_sample = random.split(self.rng_key) - msg["kwargs"]["rng_key"] = rng_key_sample + if msg["type"] in self.hide_types: + return + if msg["type"] not in ["sample", "prng_key", "plate", "control_flow"]: + return + if (msg["kwargs"]["rng_key"] is not None) or (msg["value"] is not None): + # no need to create a new key when value is available + return + self.rng_key, rng_key_sample = random.split(self.rng_key) + msg["kwargs"]["rng_key"] = rng_key_sample class substitute(Messenger): diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index 3c4bbb8cf..b0a39a4bc 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -312,6 +312,27 @@ class RenyiELBO(ELBO): :param num_particles: The number of particles/samples used to form the objective (gradient) estimator. Default is 2. + Example:: + + def model(data): + with numpyro.plate("batch", 10000, subsample_size=100): + latent = numpyro.sample("latent", dist.Normal(0, 1)) + batch = numpyro.subsample(data, event_dim=0) + numpyro.sample("data", dist.Bernoulli(logits=latent), obs=batch) + + def guide(data): + w_loc = numpyro.param("w_loc", 1.) + w_scale = numpyro.param("w_scale", 1.) + with numpyro.plate("batch", 10000, subsample_size=100): + batch = numpyro.subsample(data, event_dim=0) + loc = w_loc * batch + scale = jnp.exp(w_scale * batch) + numpyro.sample("latent", dist.Normal(loc, scale)) + + elbo = RenyiELBO(num_particles=10) + svi = SVI(model, guide, optax.adam(0.1), elbo) + + **References:** 1. *Renyi Divergence Variational Inference*, Yingzhen Li, Richard E. Turner @@ -327,37 +348,99 @@ def __init__(self, alpha=0, num_particles=2): self.alpha = alpha super().__init__(num_particles=num_particles) - def loss(self, rng_key, param_map, model, guide, *args, **kwargs): - def single_particle_elbo(rng_key): - model_seed, guide_seed = random.split(rng_key) - seeded_model = seed(model, model_seed) - seeded_guide = seed(guide, guide_seed) - guide_log_density, guide_trace = log_density( - seeded_guide, args, kwargs, param_map - ) - # NB: we only want to substitute params not available in guide_trace - model_param_map = { - k: v for k, v in param_map.items() if k not in guide_trace - } - seeded_model = replay(seeded_model, guide_trace) - model_log_density, model_trace = log_density( - seeded_model, args, kwargs, model_param_map - ) - check_model_guide_match(model_trace, guide_trace) - _validate_model(model_trace, plate_warning="loose") + def _single_particle_elbo(self, model, guide, param_map, args, kwargs, rng_key): + model_seed, guide_seed = random.split(rng_key) + seeded_model = seed(model, model_seed) + seeded_guide = seed(guide, guide_seed) + model_trace, guide_trace = get_importance_trace( + seeded_model, seeded_guide, args, kwargs, param_map + ) + check_model_guide_match(model_trace, guide_trace) + _validate_model(model_trace, plate_warning="loose") + + site_plates = { + name: {frame for frame in site["cond_indep_stack"]} + for name, site in model_trace.items() + if site["type"] == "sample" + } + # We will compute Renyi elbos separately across dimensions + # defined in indep_plates. Then the final elbo is the sum + # of those independent elbos. + if site_plates: + indep_plates = set.intersection(*site_plates.values()) + else: + indep_plates = set() + for frame in set.union(*site_plates.values()): + if frame not in indep_plates: + subsample_size = frame.size + size = model_trace[frame.name]["args"][0] + if size > subsample_size: + raise ValueError( + "RenyiELBO only supports subsampling in plates that are common" + " to all sample sites, e.g. a data plate that encloses the" + " entire model." + ) - # log p(z) - log q(z) - elbo = model_log_density - guide_log_density - return elbo + indep_plate_scale = 1.0 + for frame in indep_plates: + subsample_size = frame.size + size = model_trace[frame.name]["args"][0] + if size > subsample_size: + indep_plate_scale = indep_plate_scale * size / subsample_size + indep_plate_dims = {frame.dim for frame in indep_plates} + + log_densities = {} + for trace_type, tr in {"guide": guide_trace, "model": model_trace}.items(): + log_densities[trace_type] = 0.0 + for site in tr.values(): + if site["type"] != "sample": + continue + log_prob = site["log_prob"] + squeeze_axes = () + for dim in range(log_prob.ndim): + neg_dim = dim - log_prob.ndim + if neg_dim in indep_plate_dims: + continue + log_prob = jnp.sum(log_prob, axis=dim, keepdims=True) + squeeze_axes = squeeze_axes + (dim,) + log_prob = jnp.squeeze(log_prob, squeeze_axes) + log_densities[trace_type] = log_densities[trace_type] + log_prob + + # log p(z) - log q(z) + elbo = log_densities["model"] - log_densities["guide"] + # Log probabilities at indep_plates dimensions are scaled to MC approximate + # the "full size" log probabilities. Because we want to compute Renyi elbos + # separately across indep_plates dimensions, we will remove such scale now. + # We will apply such scale after getting those Renyi elbos. + return elbo / indep_plate_scale, indep_plate_scale + + def loss(self, rng_key, param_map, model, guide, *args, **kwargs): + plate_key, rng_key = random.split(rng_key) + model = seed( + model, plate_key, hide_types=["sample", "prng_key", "control_flow"] + ) + guide = seed( + guide, plate_key, hide_types=["sample", "prng_key", "control_flow"] + ) + single_particle_elbo = partial( + self._single_particle_elbo, model, guide, param_map, args, kwargs + ) rng_keys = random.split(rng_key, self.num_particles) - elbos = vmap(single_particle_elbo)(rng_keys) + elbos, common_plate_scale = vmap(single_particle_elbo)(rng_keys) + assert common_plate_scale.shape == (self.num_particles,) + assert elbos.shape[0] == self.num_particles scaled_elbos = (1.0 - self.alpha) * elbos - avg_log_exp = logsumexp(scaled_elbos) - jnp.log(self.num_particles) + avg_log_exp = logsumexp(scaled_elbos, axis=0) - jnp.log(self.num_particles) + assert avg_log_exp.shape == elbos.shape[1:] weights = jnp.exp(scaled_elbos - avg_log_exp) renyi_elbo = avg_log_exp / (1.0 - self.alpha) - weighted_elbo = jnp.dot(stop_gradient(weights), elbos) / self.num_particles - return -(stop_gradient(renyi_elbo - weighted_elbo) + weighted_elbo) + weighted_elbo = (stop_gradient(weights) * elbos).mean(0) + assert renyi_elbo.shape == elbos.shape[1:] + assert weighted_elbo.shape == elbos.shape[1:] + loss = -(stop_gradient(renyi_elbo - weighted_elbo) + weighted_elbo) + # common_plate_scale should be the same across particles. + return loss.sum() * common_plate_scale[0] def _get_plate_stacks(trace): @@ -994,12 +1077,12 @@ def single_particle_elbo(rng_key): for key in deps: site = guide_trace[key] if site["infer"].get("enumerate") == "parallel": - for plate in ( + for p in ( frozenset(site["log_measure"].inputs) & elim_plates ): raise ValueError( "Expected model enumeration to be no more global than guide enumeration, but found " - f"model enumeration sites upstream of guide site '{key}' in plate('{plate}')." + f"model enumeration sites upstream of guide site '{key}' in plate('{p}')." "Try converting some model enumeration sites to guide enumeration sites." ) cost_terms.append((cost, scale, deps)) diff --git a/test/infer/test_svi.py b/test/infer/test_svi.py index 929f923b3..fd6f8ab75 100644 --- a/test/infer/test_svi.py +++ b/test/infer/test_svi.py @@ -62,8 +62,116 @@ def renyi_loss_fn(x): assert_allclose(elbo_grad, renyi_grad, rtol=1e-6) +def test_renyi_local(): + def model(subsample_size=None): + with numpyro.plate("N", 100, subsample_size=subsample_size): + numpyro.sample("x", dist.Normal(0, 1)) + numpyro.sample("obs", dist.Bernoulli(0.6), obs=1) + + def guide(subsample_size=None): + with numpyro.plate("N", 100, subsample_size=subsample_size): + numpyro.sample("x", dist.Normal(0, 1)) + + def renyi_loss_fn(subsample_size=None): + return RenyiELBO(num_particles=10).loss( + random.PRNGKey(0), {}, model, guide, subsample_size + ) + + # Test that the scales are applied correctly. + # Here for each particle, log_p - log_q = log(0.6) + full_loss = renyi_loss_fn() + subsample_loss = renyi_loss_fn(subsample_size=2) + assert_allclose(full_loss, -jnp.log(0.6) * 100, rtol=1e-6) + assert_allclose(subsample_loss, full_loss, rtol=1e-6) + + +def test_renyi_nonnested_plates(): + def model(): + with numpyro.plate("N", 10): + numpyro.sample("x", dist.Normal(0, 1)) + + with numpyro.plate("M", 10): + numpyro.sample("y", dist.Normal(0, 1)) + + def get_elbo(): + renyi_elbo = RenyiELBO(num_particles=10) + return renyi_elbo._single_particle_elbo( + model, + model, + {}, + (), + {}, + random.PRNGKey(0), + ) + + elbo, _ = get_elbo() + assert elbo.shape == () + + +@pytest.mark.parametrize( + "n,k", + [(3, 5), (2, 5), (3, 3), (2, 3)], + ids=str, +) +def test_renyi_create_plates(n, k): + P = 10 + N, M, K = 3, 4, 5 + data = jnp.linspace(0.1, 0.9, N * M * K).reshape((N, M, K)) + + def model(data, n=N, k=K, fix_indices=True): + with numpyro.plate("N", N, subsample_size=n, dim=-3): + with numpyro.plate("M", M, dim=-2): + with numpyro.plate("K", K, subsample_size=k, dim=-1): + if fix_indices: + batch = data[:n, :, :k] + else: + batch = numpyro.subsample(data, event_dim=0) + numpyro.sample("data", dist.Bernoulli(batch), obs=1) + + def guide(data, n=N, k=K, fix_indices=True): + pass + + def get_elbo(n=N, k=K, fix_indices=True): + renyi_elbo = RenyiELBO(num_particles=P) + return renyi_elbo._single_particle_elbo( + model, + guide, + {}, + (data,), + dict(n=n, k=k, fix_indices=fix_indices), + random.PRNGKey(0), + ) + + def get_renyi(n=N, k=K, fix_indices=True): + renyi_elbo = RenyiELBO(num_particles=P) + return -renyi_elbo.loss( + random.PRNGKey(0), {}, model, guide, data, n=n, k=k, fix_indices=fix_indices + ) + + elbo, scale = get_elbo(n=n, k=k) + expected_shape = (n, M, k) + expected_scale = N * K / n / k + expected_elbo = jnp.log(data)[:n, :, :k] + assert elbo.shape == expected_shape + assert_allclose(scale, expected_scale, rtol=1e-6) + assert_allclose(elbo, expected_elbo, rtol=1e-6) + + renyi = get_renyi(n=n, k=k) + assert_allclose(renyi, elbo.sum() * scale, rtol=1e-6) + + if (n, k) == (2, 5): + renyi_random = get_renyi(n=2, fix_indices=False) + renyi_idx01 = jnp.log(data)[jnp.array([0, 1])].sum() * N / 2 + renyi_idx02 = jnp.log(data)[jnp.array([0, 2])].sum() * N / 2 + renyi_idx12 = jnp.log(data)[jnp.array([1, 2])].sum() * N / 2 + atol = jnp.min( + jnp.abs(jnp.stack([renyi_idx01, renyi_idx02, renyi_idx12]) - renyi_random) + ) + assert_allclose(atol, 0.0, atol=1e-5) + + @pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)]) -@pytest.mark.parametrize("optimizer", [optim.Adam(0.05), optimizers.adam(0.05)]) +@pytest.mark.parametrize("optimizer", [optim.Adam(0.01), optimizers.adam(0.01)]) def test_beta_bernoulli(elbo, optimizer): data = jnp.array([1.0] * 8 + [0.0] * 2) @@ -85,13 +193,14 @@ def body_fn(i, val): svi_state, _ = svi.update(val, data) return svi_state - svi_state = fori_loop(0, 2000, body_fn, svi_state) + svi_state = fori_loop(0, 10000, body_fn, svi_state) params = svi.get_params(svi_state) + actual_posterior_mean = (data.sum() + 1) / (data.shape[0] + 2) assert_allclose( params["alpha_q"] / (params["alpha_q"] + params["beta_q"]), - 0.8, - atol=0.05, - rtol=0.05, + actual_posterior_mean, + atol=0.03, + rtol=0.03, )