Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise an error if there is no common scale when model enumerated #1536

Merged
merged 7 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 12 additions & 18 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ def guess_max_plate_nesting(model, guide, args, kwargs, param_map):

class TraceEnum_ELBO(ELBO):
"""
A TraceEnum implementation of ELBO-based SVI. The gradient estimator
(EXPERIMENTAL) A TraceEnum implementation of ELBO-based SVI. The gradient estimator
is constructed along the lines of reference [1] specialized to the case
of the ELBO. It supports arbitrary dependency structure for the model
and guide.
Expand Down Expand Up @@ -960,24 +960,18 @@ def single_particle_elbo(rng_key):
eliminate=group_sum_vars | elim_plates,
)
# incorporate the effects of subsampling and handlers.scale through a common scale factor
group_scales = {}
for name in group_names:
for plate, value in (
model_trace[name].get("plate_to_scale", {}).items()
):
if plate in group_scales:
if value != group_scales[plate]:
raise ValueError(
"Expected all enumerated sample sites to share a common scale factor, "
f"but found different scales at plate('{plate}')."
)
else:
group_scales[plate] = value
scale = (
reduce(lambda a, b: a * b, group_scales.values())
if group_scales
else None
scales_set = set(
[
model_trace[name]["scale"]
Copy link
Member

@fehiepsi fehiepsi Feb 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how to check the uniqueness of jax array scales because jax arrays is not hashable. Could you replace this by something like

scale = None
for name in (group_names | group_sum_vars):
    site_scale = model_trace[name]["scale"]
    if isinstance(scale, (int, float)) and isinstance(site_scale, (int, float, type(None))) and (site_scale != scale):
        raise ValueError(...)
    scale = site_scale

Btw, does this mean that we don't support enumeration for models with both global and local variables?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied the implementation from Pyro which first checks that array is scalar, converts it to scalar, and then compares them. Does it look alright to you?

https://github.com/pyro-ppl/pyro/blob/dev/pyro/infer/traceenum_elbo.py#L37-L41

Btw, does this mean that we don't support enumeration for models with both global and local variables?

Yeah, if you enumerate a global variable than you cannot subsample a local variable that depends on it.
I'm working on it but that will require first changing funsor.sum_product.sum_product to handle plate-wise scaling there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we can't convert a tracer to float, float(site_scale) will fail when scale is a tracer. Maybe we need to raise error if it is a jnp.array?

Copy link
Member Author

@ordabayevy ordabayevy Feb 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Changed it so that it raises an error if scale is jnp.array.

for name in (group_names | group_sum_vars)
]
)
if len(scales_set) != 1:
raise ValueError(
"Expected all enumerated sample sites to share a common scale, "
f"but found {len(scales_set)} different scales."
)
scale = next(iter(scales_set))
# combine deps
deps = frozenset().union(
*[model_deps[name] for name in group_names]
Expand Down
217 changes: 217 additions & 0 deletions test/contrib/test_enum_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2255,3 +2255,220 @@ def actual_loss_fn(params):

assert_equal(expected_loss, actual_loss, prec=1e-5)
assert_equal(expected_grad, actual_grad, prec=1e-5)


@pytest.mark.parametrize("scale", [1, 10])
def test_model_enum_subsample_1(scale):
# Enumerate: a
# Subsample: b
# a - [-> b ]
@config_enumerate
@handlers.scale(scale=scale)
def model(params):
locs = pyro.param("locs", params["locs"])
probs_a = pyro.param(
"probs_a", params["probs_a"], constraint=constraints.simplex
)
a = pyro.sample("a", dist.Categorical(probs_a))
with pyro.plate("b_axis", size=3):
pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0)

@config_enumerate
@handlers.scale(scale=scale)
def model_subsample(params):
locs = pyro.param("locs", params["locs"])
probs_a = pyro.param(
"probs_a", params["probs_a"], constraint=constraints.simplex
)
a = pyro.sample("a", dist.Categorical(probs_a))
with pyro.plate("b_axis", size=3, subsample_size=2):
pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0)

def guide(params):
pass

params = {
"locs": jnp.array([0.0, 1.0]),
"probs_a": jnp.array([0.4, 0.6]),
}
transform = dist.biject_to(dist.constraints.simplex)
params_raw = {"locs": params["locs"], "probs_a": transform.inv(params["probs_a"])}

elbo = infer.TraceEnum_ELBO()

# Expected grads w/o subsampling
def expected_loss_fn(params_raw):
params = {
"locs": params_raw["locs"],
"probs_a": transform(params_raw["probs_a"]),
}
return elbo.loss(random.PRNGKey(0), {}, model, guide, params)

expected_loss, expected_grads = jax.value_and_grad(expected_loss_fn)(params_raw)

# Actual grads w/ subsampling
def actual_loss_fn(params_raw):
params = {
"locs": params_raw["locs"],
"probs_a": transform(params_raw["probs_a"]),
}
return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params)

with pytest.raises(
ValueError, match="Expected all enumerated sample sites to share a common scale"
):
# This never gets run because we don't support this yet.
actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)

assert_equal(actual_loss, expected_loss, prec=1e-5)
assert_equal(actual_grads, expected_grads, prec=1e-5)


@pytest.mark.parametrize("scale", [1, 10])
def test_model_enum_subsample_2(scale):
# Enumerate: a
# Subsample: b, c
# a - [-> b ]
# \
# - [-> c ]
@config_enumerate
@handlers.scale(scale=scale)
def model(params):
locs = pyro.param("locs", params["locs"])
probs_a = pyro.param(
"probs_a", params["probs_a"], constraint=constraints.simplex
)
a = pyro.sample("a", dist.Categorical(probs_a))
with pyro.plate("b_axis", size=3):
pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0)

with pyro.plate("c_axis", size=6):
pyro.sample("c", dist.Normal(locs[a], 1.0), obs=1)

@config_enumerate
@handlers.scale(scale=scale)
def model_subsample(params):
locs = pyro.param("locs", params["locs"])
probs_a = pyro.param(
"probs_a", params["probs_a"], constraint=constraints.simplex
)
a = pyro.sample("a", dist.Categorical(probs_a))
with pyro.plate("b_axis", size=3, subsample_size=2):
pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0)

with pyro.plate("c_axis", size=6, subsample_size=3):
pyro.sample("c", dist.Normal(locs[a], 1.0), obs=1)

def guide(params):
pass

params = {
"locs": jnp.array([0.0, 1.0]),
"probs_a": jnp.array([0.4, 0.6]),
}
transform = dist.biject_to(dist.constraints.simplex)
params_raw = {"locs": params["locs"], "probs_a": transform.inv(params["probs_a"])}

elbo = infer.TraceEnum_ELBO()

# Expected grads w/o subsampling
def expected_loss_fn(params_raw):
params = {
"locs": params_raw["locs"],
"probs_a": transform(params_raw["probs_a"]),
}
return elbo.loss(random.PRNGKey(0), {}, model, guide, params)

expected_loss, expected_grads = jax.value_and_grad(expected_loss_fn)(params_raw)

# Actual grads w/ subsampling
def actual_loss_fn(params_raw):
params = {
"locs": params_raw["locs"],
"probs_a": transform(params_raw["probs_a"]),
}
return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params)

with pytest.raises(
ValueError, match="Expected all enumerated sample sites to share a common scale"
):
# This never gets run because we don't support this yet.
actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)

assert_equal(actual_loss, expected_loss, prec=1e-5)
assert_equal(actual_grads, expected_grads, prec=1e-5)


@pytest.mark.parametrize("scale", [1, 10])
def test_model_enum_subsample_3(scale):
# Enumerate: a
# Subsample: a, b, c
# [ a - [----> b ]
# [ \ [ ]
# [ - [- [-> c ] ]
@config_enumerate
@handlers.scale(scale=scale)
def model(params):
locs = pyro.param("locs", params["locs"])
probs_a = pyro.param(
"probs_a", params["probs_a"], constraint=constraints.simplex
)
with pyro.plate("a_axis", size=3):
a = pyro.sample("a", dist.Categorical(probs_a))
with pyro.plate("b_axis", size=6):
pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0)
with pyro.plate("c_axis", size=9):
pyro.sample("c", dist.Normal(locs[a], 1.0), obs=1)

@config_enumerate
@handlers.scale(scale=scale)
def model_subsample(params):
locs = pyro.param("locs", params["locs"])
probs_a = pyro.param(
"probs_a", params["probs_a"], constraint=constraints.simplex
)
with pyro.plate("a_axis", size=3, subsample_size=2):
a = pyro.sample("a", dist.Categorical(probs_a))
with pyro.plate("b_axis", size=6, subsample_size=3):
pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0)
with pyro.plate("c_axis", size=9, subsample_size=4):
pyro.sample("c", dist.Normal(locs[a], 1.0), obs=1)

def guide(params):
pass

params = {
"locs": jnp.array([0.0, 1.0]),
"probs_a": jnp.array([0.4, 0.6]),
}
transform = dist.biject_to(dist.constraints.simplex)
params_raw = {"locs": params["locs"], "probs_a": transform.inv(params["probs_a"])}

elbo = infer.TraceEnum_ELBO()

# Expected grads w/o subsampling
def expected_loss_fn(params_raw):
params = {
"locs": params_raw["locs"],
"probs_a": transform(params_raw["probs_a"]),
}
return elbo.loss(random.PRNGKey(0), {}, model, guide, params)

expected_loss, expected_grads = jax.value_and_grad(expected_loss_fn)(params_raw)

# Actual grads w/ subsampling
def actual_loss_fn(params_raw):
params = {
"locs": params_raw["locs"],
"probs_a": transform(params_raw["probs_a"]),
}
return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params)

with pytest.raises(
ValueError, match="Expected all enumerated sample sites to share a common scale"
):
# This never gets run because we don't support this yet.
actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)

assert_equal(actual_loss, expected_loss, prec=1e-3)
assert_equal(actual_grads, expected_grads, prec=1e-5)