Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for local variables in RenyiELBO #1608

Merged
merged 19 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ class seed(Messenger):
:param fn: Python callable with NumPyro primitives.
:param rng_seed: a random number generator seed.
:type rng_seed: int, jnp.ndarray scalar, or jax.random.PRNGKey
:param list hide_types: an optional list of side types to skip seeding, e.g. ['plate'].

.. note::

Expand Down Expand Up @@ -703,7 +704,7 @@ class seed(Messenger):
>>> assert x == y
"""

def __init__(self, fn=None, rng_seed=None):
def __init__(self, fn=None, rng_seed=None, hide_types=None):
if isinstance(rng_seed, int) or (
isinstance(rng_seed, (np.ndarray, jnp.ndarray)) and not jnp.shape(rng_seed)
):
Expand All @@ -715,19 +716,19 @@ def __init__(self, fn=None, rng_seed=None):
):
raise TypeError("Incorrect type for rng_seed: {}".format(type(rng_seed)))
self.rng_key = rng_seed
self.hide_types = [] if hide_types is None else hide_types
super(seed, self).__init__(fn)

def process_message(self, msg):
if (
msg["type"] == "sample"
and not msg["is_observed"]
and msg["kwargs"]["rng_key"] is None
) or msg["type"] in ["prng_key", "plate", "control_flow"]:
if msg["value"] is not None:
# no need to create a new key when value is available
return
self.rng_key, rng_key_sample = random.split(self.rng_key)
msg["kwargs"]["rng_key"] = rng_key_sample
if msg["type"] in self.hide_types:
return
if msg["type"] not in ["sample", "prng_key", "plate", "control_flow"]:
return
if (msg["kwargs"]["rng_key"] is not None) or (msg["value"] is not None):
# no need to create a new key when value is available
return
self.rng_key, rng_key_sample = random.split(self.rng_key)
msg["kwargs"]["rng_key"] = rng_key_sample


class substitute(Messenger):
Expand Down
126 changes: 99 additions & 27 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,27 @@ class RenyiELBO(ELBO):
:param num_particles: The number of particles/samples
used to form the objective (gradient) estimator. Default is 2.

Example::

def model(data):
with numpyro.plate("batch", 10000, subsample_size=100):
latent = numpyro.sample("latent", dist.Normal(0, 1))
batch = numpyro.subsample(data, event_dim=0)
numpyro.sample("data", dist.Bernoulli(logits=latent), obs=batch)

def guide(data):
w_loc = numpyro.param("w_loc", 1.)
w_scale = numpyro.param("w_scale", 1.)
with numpyro.plate("batch", 10000, subsample_size=100):
batch = numpyro.subsample(data, event_dim=0)
loc = w_loc * batch
scale = jnp.exp(w_scale * batch)
numpyro.sample("latent", dist.Normal(loc, scale))

elbo = RenyiELBO(num_particles=10)
svi = SVI(model, guide, optax.adam(0.1), elbo)


**References:**

1. *Renyi Divergence Variational Inference*, Yingzhen Li, Richard E. Turner
Expand All @@ -327,37 +348,88 @@ def __init__(self, alpha=0, num_particles=2):
self.alpha = alpha
super().__init__(num_particles=num_particles)

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
def single_particle_elbo(rng_key):
model_seed, guide_seed = random.split(rng_key)
seeded_model = seed(model, model_seed)
seeded_guide = seed(guide, guide_seed)
guide_log_density, guide_trace = log_density(
seeded_guide, args, kwargs, param_map
)
# NB: we only want to substitute params not available in guide_trace
model_param_map = {
k: v for k, v in param_map.items() if k not in guide_trace
}
seeded_model = replay(seeded_model, guide_trace)
model_log_density, model_trace = log_density(
seeded_model, args, kwargs, model_param_map
)
check_model_guide_match(model_trace, guide_trace)
_validate_model(model_trace, plate_warning="loose")
def _single_particle_elbo(self, model, guide, param_map, args, kwargs, rng_key):
model_seed, guide_seed = random.split(rng_key)
seeded_model = seed(model, model_seed)
seeded_guide = seed(guide, guide_seed)
model_trace, guide_trace = get_importance_trace(
seeded_model, seeded_guide, args, kwargs, param_map
)
check_model_guide_match(model_trace, guide_trace)
_validate_model(model_trace, plate_warning="loose")

site_plates = {
name: {frame for frame in site["cond_indep_stack"]}
for name, site in model_trace.items()
if site["type"] == "sample"
}
# We will compute Renyi elbos separately across dimensions
# defined in indep_plates. Then the final elbo is the sum
# of those independent elbos.
if site_plates:
indep_plates = set.intersection(*site_plates.values())
else:
indep_plates = set()
indep_plate_scale = 1.0
for frame in indep_plates:
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
subsample_size = frame.size
size = model_trace[frame.name]["args"][0]
if size > subsample_size:
indep_plate_scale = indep_plate_scale * size / subsample_size
indep_plate_dims = {frame.dim for frame in indep_plates}

log_densities = {}
for trace_type, tr in {"guide": guide_trace, "model": model_trace}.items():
log_densities[trace_type] = 1.0
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
for site in tr.values():
if site["type"] != "sample":
continue
log_prob = site["log_prob"]
squeeze_axes = ()
for dim in range(log_prob.ndim):
neg_dim = dim - log_prob.ndim
if neg_dim in indep_plate_dims:
continue
log_prob = jnp.sum(log_prob, axis=dim, keepdims=True)
squeeze_axes = squeeze_axes + (dim,)
log_prob = jnp.squeeze(log_prob, squeeze_axes)
log_densities[trace_type] = log_densities[trace_type] + log_prob

# log p(z) - log q(z)
elbo = log_densities["model"] - log_densities["guide"]
# Log probabilities at indep_plates dimensions are scaled to MC approximate
# the "full size" log probabilities. Because we want to compute Renyi elbos
# separately across indep_plates dimensions, we will remove such scale now.
# We will apply such scale after getting those Renyi elbos.
return elbo / indep_plate_scale, indep_plate_scale

# log p(z) - log q(z)
elbo = model_log_density - guide_log_density
return elbo
def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
plate_key, rng_key = random.split(rng_key)
model = seed(
model, plate_key, hide_types=["sample", "prng_key", "control_flow"]
)
guide = seed(
guide, plate_key, hide_types=["sample", "prng_key", "control_flow"]
)
single_particle_elbo = partial(
self._single_particle_elbo, model, guide, param_map, args, kwargs
)

rng_keys = random.split(rng_key, self.num_particles)
elbos = vmap(single_particle_elbo)(rng_keys)
elbos, common_plate_scale = vmap(single_particle_elbo)(rng_keys)
assert common_plate_scale.shape == (self.num_particles,)
assert elbos.shape[0] == self.num_particles
scaled_elbos = (1.0 - self.alpha) * elbos
avg_log_exp = logsumexp(scaled_elbos) - jnp.log(self.num_particles)
avg_log_exp = logsumexp(scaled_elbos, axis=0) - jnp.log(self.num_particles)
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
assert avg_log_exp.shape == elbos.shape[1:]
weights = jnp.exp(scaled_elbos - avg_log_exp)
renyi_elbo = avg_log_exp / (1.0 - self.alpha)
weighted_elbo = jnp.dot(stop_gradient(weights), elbos) / self.num_particles
return -(stop_gradient(renyi_elbo - weighted_elbo) + weighted_elbo)
weighted_elbo = (stop_gradient(weights) * elbos).mean(0)
assert renyi_elbo.shape == elbos.shape[1:]
assert weighted_elbo.shape == elbos.shape[1:]
loss = -(stop_gradient(renyi_elbo - weighted_elbo) + weighted_elbo)
# common_plate_scale should be the same across particles.
return loss.sum() * common_plate_scale[0]


def _get_plate_stacks(trace):
Expand Down Expand Up @@ -994,12 +1066,12 @@ def single_particle_elbo(rng_key):
for key in deps:
site = guide_trace[key]
if site["infer"].get("enumerate") == "parallel":
for plate in (
for p in (
frozenset(site["log_measure"].inputs) & elim_plates
):
raise ValueError(
"Expected model enumeration to be no more global than guide enumeration, but found "
f"model enumeration sites upstream of guide site '{key}' in plate('{plate}')."
f"model enumeration sites upstream of guide site '{key}' in plate('{p}')."
"Try converting some model enumeration sites to guide enumeration sites."
)
cost_terms.append((cost, scale, deps))
Expand Down
113 changes: 111 additions & 2 deletions test/infer/test_svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,114 @@ def renyi_loss_fn(x):
assert_allclose(elbo_grad, renyi_grad, rtol=1e-6)


def test_renyi_local():
def model(subsample_size=None):
with numpyro.plate("N", 100, subsample_size=subsample_size):
numpyro.sample("x", dist.Normal(0, 1))
numpyro.sample("obs", dist.Bernoulli(0.6), obs=1)

def guide(subsample_size=None):
with numpyro.plate("N", 100, subsample_size=subsample_size):
numpyro.sample("x", dist.Normal(0, 1))

def renyi_loss_fn(subsample_size=None):
return RenyiELBO(num_particles=10).loss(
random.PRNGKey(0), {}, model, guide, subsample_size
)

# Test that the scales are applied correctly.
# Here for each particle, log_p - log_q = log(0.6)
full_loss = renyi_loss_fn()
subsample_loss = renyi_loss_fn(subsample_size=2)
assert_allclose(full_loss, -jnp.log(0.6) * 100, rtol=1e-6)
assert_allclose(subsample_loss, full_loss, rtol=1e-6)


def test_renyi_nonnested_plates():
def model():
with numpyro.plate("N", 10, subsample_size=2):
numpyro.sample("x", dist.Normal(0, 1))

with numpyro.plate("M", 10, subsample_size=2):
numpyro.sample("y", dist.Normal(0, 1))

def get_elbo():
renyi_elbo = RenyiELBO(num_particles=10)
return renyi_elbo._single_particle_elbo(
model,
model,
{},
(),
{},
random.PRNGKey(0),
)

elbo, _ = get_elbo()
assert elbo.shape == ()


@pytest.mark.parametrize(
"n,k",
[(3, 5), (2, 5), (3, 3), (2, 3)],
ids=str,
)
def test_renyi_create_plates(n, k):
P = 10
N, M, K = 3, 4, 5
data = jnp.linspace(0.1, 0.9, N * M * K).reshape((N, M, K))

def model(data, n=N, k=K, fix_indices=True):
with numpyro.plate("N", N, subsample_size=n, dim=-3):
with numpyro.plate("M", M, dim=-2):
with numpyro.plate("K", K, subsample_size=k, dim=-1):
if fix_indices:
batch = data[:n, :, :k]
else:
batch = numpyro.subsample(data, event_dim=0)
numpyro.sample("data", dist.Bernoulli(batch), obs=1)

def guide(data, n=N, k=K, fix_indices=True):
pass

def get_elbo(n=N, k=K, fix_indices=True):
renyi_elbo = RenyiELBO(num_particles=P)
return renyi_elbo._single_particle_elbo(
model,
guide,
{},
(data,),
dict(n=n, k=k, fix_indices=fix_indices),
random.PRNGKey(0),
)

def get_renyi(n=N, k=K, fix_indices=True):
renyi_elbo = RenyiELBO(num_particles=P)
return -renyi_elbo.loss(
random.PRNGKey(0), {}, model, guide, data, n=n, k=k, fix_indices=fix_indices
)

elbo, scale = get_elbo(n=n, k=k)
expected_shape = (n, M, k)
expected_scale = N * K / n / k
expected_elbo = jnp.log(data)[:n, :, :k]
assert elbo.shape == expected_shape
assert_allclose(scale, expected_scale, rtol=1e-6)
assert_allclose(elbo, expected_elbo, rtol=1e-6)

renyi = get_renyi(n=n, k=k)
assert_allclose(renyi, elbo.sum() * scale, rtol=1e-6)

if (n, k) == (2, 5):
renyi_random = get_renyi(n=2, fix_indices=False)
renyi_idx01 = jnp.log(data)[jnp.array([0, 1])].sum() * N / 2
renyi_idx02 = jnp.log(data)[jnp.array([0, 2])].sum() * N / 2
renyi_idx12 = jnp.log(data)[jnp.array([1, 2])].sum() * N / 2
atol = jnp.min(
jnp.abs(jnp.stack([renyi_idx01, renyi_idx02, renyi_idx12]) - renyi_random)
)
assert_allclose(atol, 0.0, atol=1e-5)


@pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)])
@pytest.mark.parametrize("optimizer", [optim.Adam(0.05), optimizers.adam(0.05)])
def test_beta_bernoulli(elbo, optimizer):
Expand All @@ -85,11 +193,12 @@ def body_fn(i, val):
svi_state, _ = svi.update(val, data)
return svi_state

svi_state = fori_loop(0, 2000, body_fn, svi_state)
svi_state = fori_loop(0, 4000, body_fn, svi_state)
params = svi.get_params(svi_state)
actual_posterior_mean = 0.75 # (8 + 1) / (8 + 1 + 2 + 1)
assert_allclose(
params["alpha_q"] / (params["alpha_q"] + params["beta_q"]),
0.8,
actual_posterior_mean,
martinjankowiak marked this conversation as resolved.
Show resolved Hide resolved
atol=0.05,
rtol=0.05,
)
Expand Down