From f812861cf847185d178fe47f21c7962d4d61b307 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev <yordabay@broadinstitute.org> Date: Fri, 3 Feb 2023 23:22:45 +0000 Subject: [PATCH 1/7] fix subsample scaling --- numpyro/infer/elbo.py | 43 ++++++----- test/contrib/test_enum_elbo.py | 127 +++++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+), 17 deletions(-) diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index c5f91cec5..c05705801 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -839,7 +839,7 @@ def guess_max_plate_nesting(model, guide, args, kwargs, param_map): class TraceEnum_ELBO(ELBO): """ - A TraceEnum implementation of ELBO-based SVI. The gradient estimator + (EXPERIMENTAL) A TraceEnum implementation of ELBO-based SVI. The gradient estimator is constructed along the lines of reference [1] specialized to the case of the ELBO. It supports arbitrary dependency structure for the model and guide. @@ -952,32 +952,41 @@ def single_particle_elbo(rng_key): *(frozenset(f.inputs) & group_plates for f in group_factors) ) elim_plates = group_plates - outermost_plates - cost = funsor.sum_product.sum_product( - funsor.ops.logaddexp, - funsor.ops.add, - group_factors, - plates=group_plates, - eliminate=group_sum_vars | elim_plates, - ) - # incorporate the effects of subsampling and handlers.scale through a common scale factor - group_scales = {} + # incorporate the effects of subsampling and handlers.scale + plate_to_scale = {} for name in group_names: for plate, value in ( model_trace[name].get("plate_to_scale", {}).items() ): - if plate in group_scales: - if value != group_scales[plate]: + if plate in plate_to_scale: + if value != plate_to_scale[plate]: raise ValueError( "Expected all enumerated sample sites to share a common scale factor, " f"but found different scales at plate('{plate}')." ) else: - group_scales[plate] = value - scale = ( - reduce(lambda a, b: a * b, group_scales.values()) - if group_scales - else None + plate_to_scale[plate] = value + group_scales = tuple( + [ + value + for plate, value in plate_to_scale.items() + if (plate in f.inputs) or (plate is None) + ] + for f in group_factors + ) + scaled_group_factors = tuple( + reduce(lambda a, b: a * b, scales, factor) + for scales, factor in zip(group_scales, group_factors) + ) + + cost = funsor.sum_product.sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + scaled_group_factors, + plates=group_plates, + eliminate=group_sum_vars | elim_plates, ) + scale = None # combine deps deps = frozenset().union( *[model_deps[name] for name in group_names] diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py index 8874b2cbe..5ccc1a8ed 100644 --- a/test/contrib/test_enum_elbo.py +++ b/test/contrib/test_enum_elbo.py @@ -2255,3 +2255,130 @@ def actual_loss_fn(params): assert_equal(expected_loss, actual_loss, prec=1e-5) assert_equal(expected_grad, actual_grad, prec=1e-5) + + +def test_model_enum_subsample_1(): + data = jnp.ones(4) + + @config_enumerate + def model(data, params): + locs = pyro.param("locs", params["locs"]) + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("b_axis", len(data)): + pyro.sample("b", dist.Normal(locs[a], 1.0), obs=data) + + @config_enumerate + def model_subsample(data, params): + locs = pyro.param("locs", params["locs"]) + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("b_axis", len(data), subsample_size=2) as ind: + pyro.sample("b", dist.Normal(locs[a], 1.0), obs=data[ind]) + + def guide(data, params): + pass + + params = { + "locs": jnp.array([0.0, 1.0]), + "probs_a": jnp.array([0.4, 0.6]), + } + transform = dist.biject_to(dist.constraints.simplex) + params_raw = {"locs": params["locs"], "probs_a": transform.inv(params["probs_a"])} + + elbo = infer.TraceEnum_ELBO() + + # Expected grads w/o subsampling + def expected_loss_fn(params_raw): + params = { + "locs": params_raw["locs"], + "probs_a": transform(params_raw["probs_a"]), + } + return elbo.loss(random.PRNGKey(0), {}, model, guide, data, params) + + expected_loss, expected_grads = jax.value_and_grad(expected_loss_fn)(params_raw) + + # Actual grads w/ subsampling + def actual_loss_fn(params_raw): + params = { + "locs": params_raw["locs"], + "probs_a": transform(params_raw["probs_a"]), + } + return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, data, params) + + actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) + + assert_equal(actual_loss, expected_loss, prec=1e-5) + assert_equal(actual_grads, expected_grads, prec=1e-5) + + +def test_model_enum_subsample_2(): + data_b = jnp.zeros(6) + data_c = jnp.ones(4) + + @config_enumerate + def model(data_b, data_c, params): + locs = pyro.param("locs", params["locs"]) + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("b_axis", len(data_b)): + pyro.sample("b", dist.Normal(locs[a], 1.0), obs=data_b) + + with pyro.plate("c_axis", len(data_c)): + pyro.sample("c", dist.Normal(locs[a], 1.0), obs=data_c) + + @config_enumerate + def model_subsample(data_b, data_c, params): + locs = pyro.param("locs", params["locs"]) + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("b_axis", len(data_b), subsample_size=2) as ind: + pyro.sample("b", dist.Normal(locs[a], 1.0), obs=data_b[ind]) + + with pyro.plate("c_axis", len(data_c), subsample_size=2) as ind: + pyro.sample("c", dist.Normal(locs[a], 1.0), obs=data_c[ind]) + + def guide(data_b, data_c, params): + pass + + params = { + "locs": jnp.array([0.0, 1.0]), + "probs_a": jnp.array([0.4, 0.6]), + } + transform = dist.biject_to(dist.constraints.simplex) + params_raw = {"locs": params["locs"], "probs_a": transform.inv(params["probs_a"])} + + elbo = infer.TraceEnum_ELBO() + + # Expected grads w/o subsampling + def expected_loss_fn(params_raw): + params = { + "locs": params_raw["locs"], + "probs_a": transform(params_raw["probs_a"]), + } + return elbo.loss(random.PRNGKey(0), {}, model, guide, data_b, data_c, params) + + expected_loss, expected_grads = jax.value_and_grad(expected_loss_fn)(params_raw) + + # Actual grads w/ subsampling + def actual_loss_fn(params_raw): + params = { + "locs": params_raw["locs"], + "probs_a": transform(params_raw["probs_a"]), + } + return elbo.loss( + random.PRNGKey(0), {}, model_subsample, guide, data_b, data_c, params + ) + + actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) + + assert_equal(actual_loss, expected_loss, prec=1e-5) + assert_equal(actual_grads, expected_grads, prec=1e-5) From 9c5196f2aba55b9a6a64a72ce2f2f74779e975e4 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev <yordabay@broadinstitute.org> Date: Sat, 4 Feb 2023 17:32:50 +0000 Subject: [PATCH 2/7] add more tests --- numpyro/infer/elbo.py | 4 +- test/contrib/test_enum_elbo.py | 127 +++++++++++++++++++++++++-------- 2 files changed, 98 insertions(+), 33 deletions(-) diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index c05705801..8e3b419a8 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -970,7 +970,7 @@ def single_particle_elbo(rng_key): [ value for plate, value in plate_to_scale.items() - if (plate in f.inputs) or (plate is None) + if plate in f.inputs ] for f in group_factors ) @@ -986,7 +986,7 @@ def single_particle_elbo(rng_key): plates=group_plates, eliminate=group_sum_vars | elim_plates, ) - scale = None + scale = plate_to_scale.get(None, None) # combine deps deps = frozenset().union( *[model_deps[name] for name in group_names] diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py index 5ccc1a8ed..b16dffa89 100644 --- a/test/contrib/test_enum_elbo.py +++ b/test/contrib/test_enum_elbo.py @@ -2257,30 +2257,31 @@ def actual_loss_fn(params): assert_equal(expected_grad, actual_grad, prec=1e-5) -def test_model_enum_subsample_1(): - data = jnp.ones(4) - +@pytest.mark.parametrize("scale", [1, 10]) +def test_model_enum_subsample_1(scale): @config_enumerate - def model(data, params): + @handlers.scale(scale=scale) + def model(params): locs = pyro.param("locs", params["locs"]) probs_a = pyro.param( "probs_a", params["probs_a"], constraint=constraints.simplex ) a = pyro.sample("a", dist.Categorical(probs_a)) - with pyro.plate("b_axis", len(data)): - pyro.sample("b", dist.Normal(locs[a], 1.0), obs=data) + with pyro.plate("b_axis", size=3): + pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0) @config_enumerate - def model_subsample(data, params): + @handlers.scale(scale=scale) + def model_subsample(params): locs = pyro.param("locs", params["locs"]) probs_a = pyro.param( "probs_a", params["probs_a"], constraint=constraints.simplex ) a = pyro.sample("a", dist.Categorical(probs_a)) - with pyro.plate("b_axis", len(data), subsample_size=2) as ind: - pyro.sample("b", dist.Normal(locs[a], 1.0), obs=data[ind]) + with pyro.plate("b_axis", size=3, subsample_size=2): + pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0) - def guide(data, params): + def guide(params): pass params = { @@ -2298,7 +2299,7 @@ def expected_loss_fn(params_raw): "locs": params_raw["locs"], "probs_a": transform(params_raw["probs_a"]), } - return elbo.loss(random.PRNGKey(0), {}, model, guide, data, params) + return elbo.loss(random.PRNGKey(0), {}, model, guide, params) expected_loss, expected_grads = jax.value_and_grad(expected_loss_fn)(params_raw) @@ -2308,7 +2309,7 @@ def actual_loss_fn(params_raw): "locs": params_raw["locs"], "probs_a": transform(params_raw["probs_a"]), } - return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, data, params) + return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params) actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) @@ -2316,37 +2317,37 @@ def actual_loss_fn(params_raw): assert_equal(actual_grads, expected_grads, prec=1e-5) -def test_model_enum_subsample_2(): - data_b = jnp.zeros(6) - data_c = jnp.ones(4) - +@pytest.mark.parametrize("scale", [1, 10]) +def test_model_enum_subsample_2(scale): @config_enumerate - def model(data_b, data_c, params): + @handlers.scale(scale=scale) + def model(params): locs = pyro.param("locs", params["locs"]) probs_a = pyro.param( "probs_a", params["probs_a"], constraint=constraints.simplex ) a = pyro.sample("a", dist.Categorical(probs_a)) - with pyro.plate("b_axis", len(data_b)): - pyro.sample("b", dist.Normal(locs[a], 1.0), obs=data_b) + with pyro.plate("b_axis", size=3): + pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0) - with pyro.plate("c_axis", len(data_c)): - pyro.sample("c", dist.Normal(locs[a], 1.0), obs=data_c) + with pyro.plate("c_axis", size=6): + pyro.sample("c", dist.Normal(locs[a], 1.0), obs=1) @config_enumerate - def model_subsample(data_b, data_c, params): + @handlers.scale(scale=scale) + def model_subsample(params): locs = pyro.param("locs", params["locs"]) probs_a = pyro.param( "probs_a", params["probs_a"], constraint=constraints.simplex ) a = pyro.sample("a", dist.Categorical(probs_a)) - with pyro.plate("b_axis", len(data_b), subsample_size=2) as ind: - pyro.sample("b", dist.Normal(locs[a], 1.0), obs=data_b[ind]) + with pyro.plate("b_axis", size=3, subsample_size=2): + pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0) - with pyro.plate("c_axis", len(data_c), subsample_size=2) as ind: - pyro.sample("c", dist.Normal(locs[a], 1.0), obs=data_c[ind]) + with pyro.plate("c_axis", size=6, subsample_size=3): + pyro.sample("c", dist.Normal(locs[a], 1.0), obs=1) - def guide(data_b, data_c, params): + def guide(params): pass params = { @@ -2364,7 +2365,7 @@ def expected_loss_fn(params_raw): "locs": params_raw["locs"], "probs_a": transform(params_raw["probs_a"]), } - return elbo.loss(random.PRNGKey(0), {}, model, guide, data_b, data_c, params) + return elbo.loss(random.PRNGKey(0), {}, model, guide, params) expected_loss, expected_grads = jax.value_and_grad(expected_loss_fn)(params_raw) @@ -2374,11 +2375,75 @@ def actual_loss_fn(params_raw): "locs": params_raw["locs"], "probs_a": transform(params_raw["probs_a"]), } - return elbo.loss( - random.PRNGKey(0), {}, model_subsample, guide, data_b, data_c, params - ) + return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params) actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) assert_equal(actual_loss, expected_loss, prec=1e-5) assert_equal(actual_grads, expected_grads, prec=1e-5) + + +@pytest.mark.parametrize("scale", [1, 10]) +def test_model_enum_subsample_3(scale): + @config_enumerate + @handlers.scale(scale=scale) + def model(params): + locs = pyro.param("locs", params["locs"]) + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + with pyro.plate("a_axis", size=3): + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("b_axis", size=6): + pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0) + with pyro.plate("c_axis", size=9): + pyro.sample("c", dist.Normal(locs[a], 1.0), obs=1) + + @config_enumerate + @handlers.scale(scale=scale) + def model_subsample(params): + locs = pyro.param("locs", params["locs"]) + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + with pyro.plate("a_axis", size=3, subsample_size=2): + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("b_axis", size=6, subsample_size=3): + pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0) + with pyro.plate("c_axis", size=9, subsample_size=4): + pyro.sample("c", dist.Normal(locs[a], 1.0), obs=1) + + def guide(params): + pass + + params = { + "locs": jnp.array([0.0, 1.0]), + "probs_a": jnp.array([0.4, 0.6]), + } + transform = dist.biject_to(dist.constraints.simplex) + params_raw = {"locs": params["locs"], "probs_a": transform.inv(params["probs_a"])} + + elbo = infer.TraceEnum_ELBO() + + # Expected grads w/o subsampling + def expected_loss_fn(params_raw): + params = { + "locs": params_raw["locs"], + "probs_a": transform(params_raw["probs_a"]), + } + return elbo.loss(random.PRNGKey(0), {}, model, guide, params) + + expected_loss, expected_grads = jax.value_and_grad(expected_loss_fn)(params_raw) + + # Actual grads w/ subsampling + def actual_loss_fn(params_raw): + params = { + "locs": params_raw["locs"], + "probs_a": transform(params_raw["probs_a"]), + } + return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params) + + actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) + + assert_equal(actual_loss, expected_loss, prec=1e-3) + assert_equal(actual_grads, expected_grads, prec=1e-5) From 4ef88cc6adf8f7375bea4154b561d79c255bf514 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev <yordabay@broadinstitute.org> Date: Sat, 4 Feb 2023 21:30:56 +0000 Subject: [PATCH 3/7] pass tests --- numpyro/infer/elbo.py | 27 ++++++++++++--------------- test/contrib/test_enum_elbo.py | 6 ++++++ 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index 8e3b419a8..0deff5d46 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -965,28 +965,25 @@ def single_particle_elbo(rng_key): f"but found different scales at plate('{plate}')." ) else: - plate_to_scale[plate] = value - group_scales = tuple( - [ - value - for plate, value in plate_to_scale.items() - if plate in f.inputs - ] - for f in group_factors - ) - scaled_group_factors = tuple( - reduce(lambda a, b: a * b, scales, factor) - for scales, factor in zip(group_scales, group_factors) - ) + plate_to_scale[plate] = to_funsor(value) cost = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, - scaled_group_factors, + group_factors, plates=group_plates, eliminate=group_sum_vars | elim_plates, + scales=plate_to_scale, + ) + scale = reduce( + funsor.ops.mul, + [ + value + for plate, value in plate_to_scale.items() + if plate not in elim_plates + ], + funsor.Number(1.0), ) - scale = plate_to_scale.get(None, None) # combine deps deps = frozenset().union( *[model_deps[name] for name in group_names] diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py index b16dffa89..cf108c2d0 100644 --- a/test/contrib/test_enum_elbo.py +++ b/test/contrib/test_enum_elbo.py @@ -2385,6 +2385,12 @@ def actual_loss_fn(params_raw): @pytest.mark.parametrize("scale", [1, 10]) def test_model_enum_subsample_3(scale): + # +--------------------+ + # | +----------+ | + # a ----> b ----> c | | + # | | N=2 | | + # | M=2 +----------+ | + # +--------------------+ @config_enumerate @handlers.scale(scale=scale) def model(params): From aecebe4068a11f3e693982c151e22633e4629bbf Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev <yordabay@broadinstitute.org> Date: Sat, 4 Feb 2023 21:58:00 +0000 Subject: [PATCH 4/7] pytest.raises --- numpyro/infer/elbo.py | 34 ++++++++----------------- test/contrib/test_enum_elbo.py | 46 +++++++++++++++++++++++----------- 2 files changed, 42 insertions(+), 38 deletions(-) diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index 0deff5d46..aab355b73 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -952,38 +952,26 @@ def single_particle_elbo(rng_key): *(frozenset(f.inputs) & group_plates for f in group_factors) ) elim_plates = group_plates - outermost_plates - # incorporate the effects of subsampling and handlers.scale - plate_to_scale = {} - for name in group_names: - for plate, value in ( - model_trace[name].get("plate_to_scale", {}).items() - ): - if plate in plate_to_scale: - if value != plate_to_scale[plate]: - raise ValueError( - "Expected all enumerated sample sites to share a common scale factor, " - f"but found different scales at plate('{plate}')." - ) - else: - plate_to_scale[plate] = to_funsor(value) - cost = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, group_factors, plates=group_plates, eliminate=group_sum_vars | elim_plates, - scales=plate_to_scale, ) - scale = reduce( - funsor.ops.mul, + # incorporate the effects of subsampling and handlers.scale through a common scale factor + scales_set = set( [ - value - for plate, value in plate_to_scale.items() - if plate not in elim_plates - ], - funsor.Number(1.0), + model_trace[name]["scale"] + for name in (group_names | group_sum_vars) + ] ) + if len(scales_set) != 1: + raise ValueError( + "Expected all enumerated sample sites to share a common scale, " + f"but found {len(scales_set)} different scales." + ) + scale = next(iter(scales_set)) # combine deps deps = frozenset().union( *[model_deps[name] for name in group_names] diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py index cf108c2d0..237fa9cf2 100644 --- a/test/contrib/test_enum_elbo.py +++ b/test/contrib/test_enum_elbo.py @@ -2259,6 +2259,8 @@ def actual_loss_fn(params): @pytest.mark.parametrize("scale", [1, 10]) def test_model_enum_subsample_1(scale): + # Model: enumerate a + # a - [-> b ] @config_enumerate @handlers.scale(scale=scale) def model(params): @@ -2311,14 +2313,22 @@ def actual_loss_fn(params_raw): } return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params) - actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) + with pytest.raises( + ValueError, match="Expected all enumerated sample sites to share a common scale" + ): + # This never gets run because we don't support this yet. + actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) - assert_equal(actual_loss, expected_loss, prec=1e-5) - assert_equal(actual_grads, expected_grads, prec=1e-5) + assert_equal(actual_loss, expected_loss, prec=1e-5) + assert_equal(actual_grads, expected_grads, prec=1e-5) @pytest.mark.parametrize("scale", [1, 10]) def test_model_enum_subsample_2(scale): + # Model: enumerate a + # a - [-> b ] + # \ + # - [-> c ] @config_enumerate @handlers.scale(scale=scale) def model(params): @@ -2377,20 +2387,22 @@ def actual_loss_fn(params_raw): } return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params) - actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) + with pytest.raises( + ValueError, match="Expected all enumerated sample sites to share a common scale" + ): + # This never gets run because we don't support this yet. + actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) - assert_equal(actual_loss, expected_loss, prec=1e-5) - assert_equal(actual_grads, expected_grads, prec=1e-5) + assert_equal(actual_loss, expected_loss, prec=1e-5) + assert_equal(actual_grads, expected_grads, prec=1e-5) @pytest.mark.parametrize("scale", [1, 10]) def test_model_enum_subsample_3(scale): - # +--------------------+ - # | +----------+ | - # a ----> b ----> c | | - # | | N=2 | | - # | M=2 +----------+ | - # +--------------------+ + # Model: enumerate a + # [ a - [----> b ] + # [ \ [ ] + # [ - [- [-> c ] ] @config_enumerate @handlers.scale(scale=scale) def model(params): @@ -2449,7 +2461,11 @@ def actual_loss_fn(params_raw): } return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params) - actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) + with pytest.raises( + ValueError, match="Expected all enumerated sample sites to share a common scale" + ): + # This never gets run because we don't support this yet. + actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) - assert_equal(actual_loss, expected_loss, prec=1e-3) - assert_equal(actual_grads, expected_grads, prec=1e-5) + assert_equal(actual_loss, expected_loss, prec=1e-3) + assert_equal(actual_grads, expected_grads, prec=1e-5) From 578664a5a30260d3c9afd31561a13c76405a973d Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev <yordabay@broadinstitute.org> Date: Sat, 4 Feb 2023 22:01:13 +0000 Subject: [PATCH 5/7] improve comments --- test/contrib/test_enum_elbo.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py index 237fa9cf2..7a760ddbe 100644 --- a/test/contrib/test_enum_elbo.py +++ b/test/contrib/test_enum_elbo.py @@ -2259,7 +2259,8 @@ def actual_loss_fn(params): @pytest.mark.parametrize("scale", [1, 10]) def test_model_enum_subsample_1(scale): - # Model: enumerate a + # Enumerate: a + # Subsample: b # a - [-> b ] @config_enumerate @handlers.scale(scale=scale) @@ -2325,7 +2326,8 @@ def actual_loss_fn(params_raw): @pytest.mark.parametrize("scale", [1, 10]) def test_model_enum_subsample_2(scale): - # Model: enumerate a + # Enumerate: a + # Subsample: b, c # a - [-> b ] # \ # - [-> c ] @@ -2399,7 +2401,8 @@ def actual_loss_fn(params_raw): @pytest.mark.parametrize("scale", [1, 10]) def test_model_enum_subsample_3(scale): - # Model: enumerate a + # Enumerate: a + # Subsample: a, b, c # [ a - [----> b ] # [ \ [ ] # [ - [- [-> c ] ] From ead0446ada31f69b4ad39b5bf7f1d3c72d7f254c Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev <yordabay@broadinstitute.org> Date: Sun, 5 Feb 2023 23:45:00 +0000 Subject: [PATCH 6/7] convert scale to float --- numpyro/infer/elbo.py | 16 ++++++++++------ test/contrib/test_enum_elbo.py | 2 +- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index aab355b73..847b235fe 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -960,12 +960,16 @@ def single_particle_elbo(rng_key): eliminate=group_sum_vars | elim_plates, ) # incorporate the effects of subsampling and handlers.scale through a common scale factor - scales_set = set( - [ - model_trace[name]["scale"] - for name in (group_names | group_sum_vars) - ] - ) + scales_set = set() + for name in group_names | group_sum_vars: + site_scale = model_trace[name]["scale"] + if site_scale is None: + site_scale = 1.0 + if isinstance(site_scale, jnp.ndarray) and site_scale.ndim: + raise ValueError( + "enumeration only supports scalar handlers.scale" + ) + scales_set.add(float(site_scale)) if len(scales_set) != 1: raise ValueError( "Expected all enumerated sample sites to share a common scale, " diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py index 7a760ddbe..b4d1e9f5c 100644 --- a/test/contrib/test_enum_elbo.py +++ b/test/contrib/test_enum_elbo.py @@ -139,7 +139,7 @@ def hand_loss_fn(params_raw): assert_equal(auto_grad, hand_grad, prec=1e-5) -@pytest.mark.parametrize("scale", [1, 10]) +@pytest.mark.parametrize("scale", [1, 10, jnp.array(10)]) def test_elbo_enumerate_2(scale): params = {} params["guide_probs_x"] = jnp.array([0.1, 0.9]) From 9cd99c6a5869913a43340a29e8afdad2fa9c026a Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev <yordabay@broadinstitute.org> Date: Mon, 6 Feb 2023 01:09:21 +0000 Subject: [PATCH 7/7] raise error if jnp.array --- numpyro/infer/elbo.py | 4 ++-- test/contrib/test_enum_elbo.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index 847b235fe..f8c9281f6 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -965,9 +965,9 @@ def single_particle_elbo(rng_key): site_scale = model_trace[name]["scale"] if site_scale is None: site_scale = 1.0 - if isinstance(site_scale, jnp.ndarray) and site_scale.ndim: + if isinstance(site_scale, jnp.ndarray): raise ValueError( - "enumeration only supports scalar handlers.scale" + "Enumeration only supports scalar handlers.scale" ) scales_set.add(float(site_scale)) if len(scales_set) != 1: diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py index b4d1e9f5c..7a760ddbe 100644 --- a/test/contrib/test_enum_elbo.py +++ b/test/contrib/test_enum_elbo.py @@ -139,7 +139,7 @@ def hand_loss_fn(params_raw): assert_equal(auto_grad, hand_grad, prec=1e-5) -@pytest.mark.parametrize("scale", [1, 10, jnp.array(10)]) +@pytest.mark.parametrize("scale", [1, 10]) def test_elbo_enumerate_2(scale): params = {} params["guide_probs_x"] = jnp.array([0.1, 0.9])