Skip to content

Commit

Permalink
Raise an error if there is no common scale when model enumerated (#1536)
Browse files Browse the repository at this point in the history
* fix subsample scaling

* add more tests

* pass tests

* pytest.raises

* improve comments

* convert scale to float

* raise error if jnp.array
  • Loading branch information
ordabayevy authored Feb 6, 2023
1 parent 5d83e65 commit 643412d
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 19 deletions.
36 changes: 17 additions & 19 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ def guess_max_plate_nesting(model, guide, args, kwargs, param_map):

class TraceEnum_ELBO(ELBO):
"""
A TraceEnum implementation of ELBO-based SVI. The gradient estimator
(EXPERIMENTAL) A TraceEnum implementation of ELBO-based SVI. The gradient estimator
is constructed along the lines of reference [1] specialized to the case
of the ELBO. It supports arbitrary dependency structure for the model
and guide.
Expand Down Expand Up @@ -960,24 +960,22 @@ def single_particle_elbo(rng_key):
eliminate=group_sum_vars | elim_plates,
)
# incorporate the effects of subsampling and handlers.scale through a common scale factor
group_scales = {}
for name in group_names:
for plate, value in (
model_trace[name].get("plate_to_scale", {}).items()
):
if plate in group_scales:
if value != group_scales[plate]:
raise ValueError(
"Expected all enumerated sample sites to share a common scale factor, "
f"but found different scales at plate('{plate}')."
)
else:
group_scales[plate] = value
scale = (
reduce(lambda a, b: a * b, group_scales.values())
if group_scales
else None
)
scales_set = set()
for name in group_names | group_sum_vars:
site_scale = model_trace[name]["scale"]
if site_scale is None:
site_scale = 1.0
if isinstance(site_scale, jnp.ndarray):
raise ValueError(
"Enumeration only supports scalar handlers.scale"
)
scales_set.add(float(site_scale))
if len(scales_set) != 1:
raise ValueError(
"Expected all enumerated sample sites to share a common scale, "
f"but found {len(scales_set)} different scales."
)
scale = next(iter(scales_set))
# combine deps
deps = frozenset().union(
*[model_deps[name] for name in group_names]
Expand Down
217 changes: 217 additions & 0 deletions test/contrib/test_enum_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2255,3 +2255,220 @@ def actual_loss_fn(params):

assert_equal(expected_loss, actual_loss, prec=1e-5)
assert_equal(expected_grad, actual_grad, prec=1e-5)


@pytest.mark.parametrize("scale", [1, 10])
def test_model_enum_subsample_1(scale):
# Enumerate: a
# Subsample: b
# a - [-> b ]
@config_enumerate
@handlers.scale(scale=scale)
def model(params):
locs = pyro.param("locs", params["locs"])
probs_a = pyro.param(
"probs_a", params["probs_a"], constraint=constraints.simplex
)
a = pyro.sample("a", dist.Categorical(probs_a))
with pyro.plate("b_axis", size=3):
pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0)

@config_enumerate
@handlers.scale(scale=scale)
def model_subsample(params):
locs = pyro.param("locs", params["locs"])
probs_a = pyro.param(
"probs_a", params["probs_a"], constraint=constraints.simplex
)
a = pyro.sample("a", dist.Categorical(probs_a))
with pyro.plate("b_axis", size=3, subsample_size=2):
pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0)

def guide(params):
pass

params = {
"locs": jnp.array([0.0, 1.0]),
"probs_a": jnp.array([0.4, 0.6]),
}
transform = dist.biject_to(dist.constraints.simplex)
params_raw = {"locs": params["locs"], "probs_a": transform.inv(params["probs_a"])}

elbo = infer.TraceEnum_ELBO()

# Expected grads w/o subsampling
def expected_loss_fn(params_raw):
params = {
"locs": params_raw["locs"],
"probs_a": transform(params_raw["probs_a"]),
}
return elbo.loss(random.PRNGKey(0), {}, model, guide, params)

expected_loss, expected_grads = jax.value_and_grad(expected_loss_fn)(params_raw)

# Actual grads w/ subsampling
def actual_loss_fn(params_raw):
params = {
"locs": params_raw["locs"],
"probs_a": transform(params_raw["probs_a"]),
}
return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params)

with pytest.raises(
ValueError, match="Expected all enumerated sample sites to share a common scale"
):
# This never gets run because we don't support this yet.
actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)

assert_equal(actual_loss, expected_loss, prec=1e-5)
assert_equal(actual_grads, expected_grads, prec=1e-5)


@pytest.mark.parametrize("scale", [1, 10])
def test_model_enum_subsample_2(scale):
# Enumerate: a
# Subsample: b, c
# a - [-> b ]
# \
# - [-> c ]
@config_enumerate
@handlers.scale(scale=scale)
def model(params):
locs = pyro.param("locs", params["locs"])
probs_a = pyro.param(
"probs_a", params["probs_a"], constraint=constraints.simplex
)
a = pyro.sample("a", dist.Categorical(probs_a))
with pyro.plate("b_axis", size=3):
pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0)

with pyro.plate("c_axis", size=6):
pyro.sample("c", dist.Normal(locs[a], 1.0), obs=1)

@config_enumerate
@handlers.scale(scale=scale)
def model_subsample(params):
locs = pyro.param("locs", params["locs"])
probs_a = pyro.param(
"probs_a", params["probs_a"], constraint=constraints.simplex
)
a = pyro.sample("a", dist.Categorical(probs_a))
with pyro.plate("b_axis", size=3, subsample_size=2):
pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0)

with pyro.plate("c_axis", size=6, subsample_size=3):
pyro.sample("c", dist.Normal(locs[a], 1.0), obs=1)

def guide(params):
pass

params = {
"locs": jnp.array([0.0, 1.0]),
"probs_a": jnp.array([0.4, 0.6]),
}
transform = dist.biject_to(dist.constraints.simplex)
params_raw = {"locs": params["locs"], "probs_a": transform.inv(params["probs_a"])}

elbo = infer.TraceEnum_ELBO()

# Expected grads w/o subsampling
def expected_loss_fn(params_raw):
params = {
"locs": params_raw["locs"],
"probs_a": transform(params_raw["probs_a"]),
}
return elbo.loss(random.PRNGKey(0), {}, model, guide, params)

expected_loss, expected_grads = jax.value_and_grad(expected_loss_fn)(params_raw)

# Actual grads w/ subsampling
def actual_loss_fn(params_raw):
params = {
"locs": params_raw["locs"],
"probs_a": transform(params_raw["probs_a"]),
}
return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params)

with pytest.raises(
ValueError, match="Expected all enumerated sample sites to share a common scale"
):
# This never gets run because we don't support this yet.
actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)

assert_equal(actual_loss, expected_loss, prec=1e-5)
assert_equal(actual_grads, expected_grads, prec=1e-5)


@pytest.mark.parametrize("scale", [1, 10])
def test_model_enum_subsample_3(scale):
# Enumerate: a
# Subsample: a, b, c
# [ a - [----> b ]
# [ \ [ ]
# [ - [- [-> c ] ]
@config_enumerate
@handlers.scale(scale=scale)
def model(params):
locs = pyro.param("locs", params["locs"])
probs_a = pyro.param(
"probs_a", params["probs_a"], constraint=constraints.simplex
)
with pyro.plate("a_axis", size=3):
a = pyro.sample("a", dist.Categorical(probs_a))
with pyro.plate("b_axis", size=6):
pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0)
with pyro.plate("c_axis", size=9):
pyro.sample("c", dist.Normal(locs[a], 1.0), obs=1)

@config_enumerate
@handlers.scale(scale=scale)
def model_subsample(params):
locs = pyro.param("locs", params["locs"])
probs_a = pyro.param(
"probs_a", params["probs_a"], constraint=constraints.simplex
)
with pyro.plate("a_axis", size=3, subsample_size=2):
a = pyro.sample("a", dist.Categorical(probs_a))
with pyro.plate("b_axis", size=6, subsample_size=3):
pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0)
with pyro.plate("c_axis", size=9, subsample_size=4):
pyro.sample("c", dist.Normal(locs[a], 1.0), obs=1)

def guide(params):
pass

params = {
"locs": jnp.array([0.0, 1.0]),
"probs_a": jnp.array([0.4, 0.6]),
}
transform = dist.biject_to(dist.constraints.simplex)
params_raw = {"locs": params["locs"], "probs_a": transform.inv(params["probs_a"])}

elbo = infer.TraceEnum_ELBO()

# Expected grads w/o subsampling
def expected_loss_fn(params_raw):
params = {
"locs": params_raw["locs"],
"probs_a": transform(params_raw["probs_a"]),
}
return elbo.loss(random.PRNGKey(0), {}, model, guide, params)

expected_loss, expected_grads = jax.value_and_grad(expected_loss_fn)(params_raw)

# Actual grads w/ subsampling
def actual_loss_fn(params_raw):
params = {
"locs": params_raw["locs"],
"probs_a": transform(params_raw["probs_a"]),
}
return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params)

with pytest.raises(
ValueError, match="Expected all enumerated sample sites to share a common scale"
):
# This never gets run because we don't support this yet.
actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)

assert_equal(actual_loss, expected_loss, prec=1e-3)
assert_equal(actual_grads, expected_grads, prec=1e-5)

0 comments on commit 643412d

Please sign in to comment.