From f812861cf847185d178fe47f21c7962d4d61b307 Mon Sep 17 00:00:00 2001
From: Yerdos Ordabayev <yordabay@broadinstitute.org>
Date: Fri, 3 Feb 2023 23:22:45 +0000
Subject: [PATCH 1/7] fix subsample scaling

---
 numpyro/infer/elbo.py          |  43 ++++++-----
 test/contrib/test_enum_elbo.py | 127 +++++++++++++++++++++++++++++++++
 2 files changed, 153 insertions(+), 17 deletions(-)

diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py
index c5f91cec5..c05705801 100644
--- a/numpyro/infer/elbo.py
+++ b/numpyro/infer/elbo.py
@@ -839,7 +839,7 @@ def guess_max_plate_nesting(model, guide, args, kwargs, param_map):
 
 class TraceEnum_ELBO(ELBO):
     """
-    A TraceEnum implementation of ELBO-based SVI. The gradient estimator
+    (EXPERIMENTAL) A TraceEnum implementation of ELBO-based SVI. The gradient estimator
     is constructed along the lines of reference [1] specialized to the case
     of the ELBO. It supports arbitrary dependency structure for the model
     and guide.
@@ -952,32 +952,41 @@ def single_particle_elbo(rng_key):
                         *(frozenset(f.inputs) & group_plates for f in group_factors)
                     )
                     elim_plates = group_plates - outermost_plates
-                    cost = funsor.sum_product.sum_product(
-                        funsor.ops.logaddexp,
-                        funsor.ops.add,
-                        group_factors,
-                        plates=group_plates,
-                        eliminate=group_sum_vars | elim_plates,
-                    )
-                    # incorporate the effects of subsampling and handlers.scale through a common scale factor
-                    group_scales = {}
+                    # incorporate the effects of subsampling and handlers.scale
+                    plate_to_scale = {}
                     for name in group_names:
                         for plate, value in (
                             model_trace[name].get("plate_to_scale", {}).items()
                         ):
-                            if plate in group_scales:
-                                if value != group_scales[plate]:
+                            if plate in plate_to_scale:
+                                if value != plate_to_scale[plate]:
                                     raise ValueError(
                                         "Expected all enumerated sample sites to share a common scale factor, "
                                         f"but found different scales at plate('{plate}')."
                                     )
                             else:
-                                group_scales[plate] = value
-                    scale = (
-                        reduce(lambda a, b: a * b, group_scales.values())
-                        if group_scales
-                        else None
+                                plate_to_scale[plate] = value
+                    group_scales = tuple(
+                        [
+                            value
+                            for plate, value in plate_to_scale.items()
+                            if (plate in f.inputs) or (plate is None)
+                        ]
+                        for f in group_factors
+                    )
+                    scaled_group_factors = tuple(
+                        reduce(lambda a, b: a * b, scales, factor)
+                        for scales, factor in zip(group_scales, group_factors)
+                    )
+
+                    cost = funsor.sum_product.sum_product(
+                        funsor.ops.logaddexp,
+                        funsor.ops.add,
+                        scaled_group_factors,
+                        plates=group_plates,
+                        eliminate=group_sum_vars | elim_plates,
                     )
+                    scale = None
                     # combine deps
                     deps = frozenset().union(
                         *[model_deps[name] for name in group_names]
diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py
index 8874b2cbe..5ccc1a8ed 100644
--- a/test/contrib/test_enum_elbo.py
+++ b/test/contrib/test_enum_elbo.py
@@ -2255,3 +2255,130 @@ def actual_loss_fn(params):
 
     assert_equal(expected_loss, actual_loss, prec=1e-5)
     assert_equal(expected_grad, actual_grad, prec=1e-5)
+
+
+def test_model_enum_subsample_1():
+    data = jnp.ones(4)
+
+    @config_enumerate
+    def model(data, params):
+        locs = pyro.param("locs", params["locs"])
+        probs_a = pyro.param(
+            "probs_a", params["probs_a"], constraint=constraints.simplex
+        )
+        a = pyro.sample("a", dist.Categorical(probs_a))
+        with pyro.plate("b_axis", len(data)):
+            pyro.sample("b", dist.Normal(locs[a], 1.0), obs=data)
+
+    @config_enumerate
+    def model_subsample(data, params):
+        locs = pyro.param("locs", params["locs"])
+        probs_a = pyro.param(
+            "probs_a", params["probs_a"], constraint=constraints.simplex
+        )
+        a = pyro.sample("a", dist.Categorical(probs_a))
+        with pyro.plate("b_axis", len(data), subsample_size=2) as ind:
+            pyro.sample("b", dist.Normal(locs[a], 1.0), obs=data[ind])
+
+    def guide(data, params):
+        pass
+
+    params = {
+        "locs": jnp.array([0.0, 1.0]),
+        "probs_a": jnp.array([0.4, 0.6]),
+    }
+    transform = dist.biject_to(dist.constraints.simplex)
+    params_raw = {"locs": params["locs"], "probs_a": transform.inv(params["probs_a"])}
+
+    elbo = infer.TraceEnum_ELBO()
+
+    # Expected grads w/o subsampling
+    def expected_loss_fn(params_raw):
+        params = {
+            "locs": params_raw["locs"],
+            "probs_a": transform(params_raw["probs_a"]),
+        }
+        return elbo.loss(random.PRNGKey(0), {}, model, guide, data, params)
+
+    expected_loss, expected_grads = jax.value_and_grad(expected_loss_fn)(params_raw)
+
+    # Actual grads w/ subsampling
+    def actual_loss_fn(params_raw):
+        params = {
+            "locs": params_raw["locs"],
+            "probs_a": transform(params_raw["probs_a"]),
+        }
+        return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, data, params)
+
+    actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)
+
+    assert_equal(actual_loss, expected_loss, prec=1e-5)
+    assert_equal(actual_grads, expected_grads, prec=1e-5)
+
+
+def test_model_enum_subsample_2():
+    data_b = jnp.zeros(6)
+    data_c = jnp.ones(4)
+
+    @config_enumerate
+    def model(data_b, data_c, params):
+        locs = pyro.param("locs", params["locs"])
+        probs_a = pyro.param(
+            "probs_a", params["probs_a"], constraint=constraints.simplex
+        )
+        a = pyro.sample("a", dist.Categorical(probs_a))
+        with pyro.plate("b_axis", len(data_b)):
+            pyro.sample("b", dist.Normal(locs[a], 1.0), obs=data_b)
+
+        with pyro.plate("c_axis", len(data_c)):
+            pyro.sample("c", dist.Normal(locs[a], 1.0), obs=data_c)
+
+    @config_enumerate
+    def model_subsample(data_b, data_c, params):
+        locs = pyro.param("locs", params["locs"])
+        probs_a = pyro.param(
+            "probs_a", params["probs_a"], constraint=constraints.simplex
+        )
+        a = pyro.sample("a", dist.Categorical(probs_a))
+        with pyro.plate("b_axis", len(data_b), subsample_size=2) as ind:
+            pyro.sample("b", dist.Normal(locs[a], 1.0), obs=data_b[ind])
+
+        with pyro.plate("c_axis", len(data_c), subsample_size=2) as ind:
+            pyro.sample("c", dist.Normal(locs[a], 1.0), obs=data_c[ind])
+
+    def guide(data_b, data_c, params):
+        pass
+
+    params = {
+        "locs": jnp.array([0.0, 1.0]),
+        "probs_a": jnp.array([0.4, 0.6]),
+    }
+    transform = dist.biject_to(dist.constraints.simplex)
+    params_raw = {"locs": params["locs"], "probs_a": transform.inv(params["probs_a"])}
+
+    elbo = infer.TraceEnum_ELBO()
+
+    # Expected grads w/o subsampling
+    def expected_loss_fn(params_raw):
+        params = {
+            "locs": params_raw["locs"],
+            "probs_a": transform(params_raw["probs_a"]),
+        }
+        return elbo.loss(random.PRNGKey(0), {}, model, guide, data_b, data_c, params)
+
+    expected_loss, expected_grads = jax.value_and_grad(expected_loss_fn)(params_raw)
+
+    # Actual grads w/ subsampling
+    def actual_loss_fn(params_raw):
+        params = {
+            "locs": params_raw["locs"],
+            "probs_a": transform(params_raw["probs_a"]),
+        }
+        return elbo.loss(
+            random.PRNGKey(0), {}, model_subsample, guide, data_b, data_c, params
+        )
+
+    actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)
+
+    assert_equal(actual_loss, expected_loss, prec=1e-5)
+    assert_equal(actual_grads, expected_grads, prec=1e-5)

From 9c5196f2aba55b9a6a64a72ce2f2f74779e975e4 Mon Sep 17 00:00:00 2001
From: Yerdos Ordabayev <yordabay@broadinstitute.org>
Date: Sat, 4 Feb 2023 17:32:50 +0000
Subject: [PATCH 2/7] add more tests

---
 numpyro/infer/elbo.py          |   4 +-
 test/contrib/test_enum_elbo.py | 127 +++++++++++++++++++++++++--------
 2 files changed, 98 insertions(+), 33 deletions(-)

diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py
index c05705801..8e3b419a8 100644
--- a/numpyro/infer/elbo.py
+++ b/numpyro/infer/elbo.py
@@ -970,7 +970,7 @@ def single_particle_elbo(rng_key):
                         [
                             value
                             for plate, value in plate_to_scale.items()
-                            if (plate in f.inputs) or (plate is None)
+                            if plate in f.inputs
                         ]
                         for f in group_factors
                     )
@@ -986,7 +986,7 @@ def single_particle_elbo(rng_key):
                         plates=group_plates,
                         eliminate=group_sum_vars | elim_plates,
                     )
-                    scale = None
+                    scale = plate_to_scale.get(None, None)
                     # combine deps
                     deps = frozenset().union(
                         *[model_deps[name] for name in group_names]
diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py
index 5ccc1a8ed..b16dffa89 100644
--- a/test/contrib/test_enum_elbo.py
+++ b/test/contrib/test_enum_elbo.py
@@ -2257,30 +2257,31 @@ def actual_loss_fn(params):
     assert_equal(expected_grad, actual_grad, prec=1e-5)
 
 
-def test_model_enum_subsample_1():
-    data = jnp.ones(4)
-
+@pytest.mark.parametrize("scale", [1, 10])
+def test_model_enum_subsample_1(scale):
     @config_enumerate
-    def model(data, params):
+    @handlers.scale(scale=scale)
+    def model(params):
         locs = pyro.param("locs", params["locs"])
         probs_a = pyro.param(
             "probs_a", params["probs_a"], constraint=constraints.simplex
         )
         a = pyro.sample("a", dist.Categorical(probs_a))
-        with pyro.plate("b_axis", len(data)):
-            pyro.sample("b", dist.Normal(locs[a], 1.0), obs=data)
+        with pyro.plate("b_axis", size=3):
+            pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0)
 
     @config_enumerate
-    def model_subsample(data, params):
+    @handlers.scale(scale=scale)
+    def model_subsample(params):
         locs = pyro.param("locs", params["locs"])
         probs_a = pyro.param(
             "probs_a", params["probs_a"], constraint=constraints.simplex
         )
         a = pyro.sample("a", dist.Categorical(probs_a))
-        with pyro.plate("b_axis", len(data), subsample_size=2) as ind:
-            pyro.sample("b", dist.Normal(locs[a], 1.0), obs=data[ind])
+        with pyro.plate("b_axis", size=3, subsample_size=2):
+            pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0)
 
-    def guide(data, params):
+    def guide(params):
         pass
 
     params = {
@@ -2298,7 +2299,7 @@ def expected_loss_fn(params_raw):
             "locs": params_raw["locs"],
             "probs_a": transform(params_raw["probs_a"]),
         }
-        return elbo.loss(random.PRNGKey(0), {}, model, guide, data, params)
+        return elbo.loss(random.PRNGKey(0), {}, model, guide, params)
 
     expected_loss, expected_grads = jax.value_and_grad(expected_loss_fn)(params_raw)
 
@@ -2308,7 +2309,7 @@ def actual_loss_fn(params_raw):
             "locs": params_raw["locs"],
             "probs_a": transform(params_raw["probs_a"]),
         }
-        return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, data, params)
+        return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params)
 
     actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)
 
@@ -2316,37 +2317,37 @@ def actual_loss_fn(params_raw):
     assert_equal(actual_grads, expected_grads, prec=1e-5)
 
 
-def test_model_enum_subsample_2():
-    data_b = jnp.zeros(6)
-    data_c = jnp.ones(4)
-
+@pytest.mark.parametrize("scale", [1, 10])
+def test_model_enum_subsample_2(scale):
     @config_enumerate
-    def model(data_b, data_c, params):
+    @handlers.scale(scale=scale)
+    def model(params):
         locs = pyro.param("locs", params["locs"])
         probs_a = pyro.param(
             "probs_a", params["probs_a"], constraint=constraints.simplex
         )
         a = pyro.sample("a", dist.Categorical(probs_a))
-        with pyro.plate("b_axis", len(data_b)):
-            pyro.sample("b", dist.Normal(locs[a], 1.0), obs=data_b)
+        with pyro.plate("b_axis", size=3):
+            pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0)
 
-        with pyro.plate("c_axis", len(data_c)):
-            pyro.sample("c", dist.Normal(locs[a], 1.0), obs=data_c)
+        with pyro.plate("c_axis", size=6):
+            pyro.sample("c", dist.Normal(locs[a], 1.0), obs=1)
 
     @config_enumerate
-    def model_subsample(data_b, data_c, params):
+    @handlers.scale(scale=scale)
+    def model_subsample(params):
         locs = pyro.param("locs", params["locs"])
         probs_a = pyro.param(
             "probs_a", params["probs_a"], constraint=constraints.simplex
         )
         a = pyro.sample("a", dist.Categorical(probs_a))
-        with pyro.plate("b_axis", len(data_b), subsample_size=2) as ind:
-            pyro.sample("b", dist.Normal(locs[a], 1.0), obs=data_b[ind])
+        with pyro.plate("b_axis", size=3, subsample_size=2):
+            pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0)
 
-        with pyro.plate("c_axis", len(data_c), subsample_size=2) as ind:
-            pyro.sample("c", dist.Normal(locs[a], 1.0), obs=data_c[ind])
+        with pyro.plate("c_axis", size=6, subsample_size=3):
+            pyro.sample("c", dist.Normal(locs[a], 1.0), obs=1)
 
-    def guide(data_b, data_c, params):
+    def guide(params):
         pass
 
     params = {
@@ -2364,7 +2365,7 @@ def expected_loss_fn(params_raw):
             "locs": params_raw["locs"],
             "probs_a": transform(params_raw["probs_a"]),
         }
-        return elbo.loss(random.PRNGKey(0), {}, model, guide, data_b, data_c, params)
+        return elbo.loss(random.PRNGKey(0), {}, model, guide, params)
 
     expected_loss, expected_grads = jax.value_and_grad(expected_loss_fn)(params_raw)
 
@@ -2374,11 +2375,75 @@ def actual_loss_fn(params_raw):
             "locs": params_raw["locs"],
             "probs_a": transform(params_raw["probs_a"]),
         }
-        return elbo.loss(
-            random.PRNGKey(0), {}, model_subsample, guide, data_b, data_c, params
-        )
+        return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params)
 
     actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)
 
     assert_equal(actual_loss, expected_loss, prec=1e-5)
     assert_equal(actual_grads, expected_grads, prec=1e-5)
+
+
+@pytest.mark.parametrize("scale", [1, 10])
+def test_model_enum_subsample_3(scale):
+    @config_enumerate
+    @handlers.scale(scale=scale)
+    def model(params):
+        locs = pyro.param("locs", params["locs"])
+        probs_a = pyro.param(
+            "probs_a", params["probs_a"], constraint=constraints.simplex
+        )
+        with pyro.plate("a_axis", size=3):
+            a = pyro.sample("a", dist.Categorical(probs_a))
+            with pyro.plate("b_axis", size=6):
+                pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0)
+                with pyro.plate("c_axis", size=9):
+                    pyro.sample("c", dist.Normal(locs[a], 1.0), obs=1)
+
+    @config_enumerate
+    @handlers.scale(scale=scale)
+    def model_subsample(params):
+        locs = pyro.param("locs", params["locs"])
+        probs_a = pyro.param(
+            "probs_a", params["probs_a"], constraint=constraints.simplex
+        )
+        with pyro.plate("a_axis", size=3, subsample_size=2):
+            a = pyro.sample("a", dist.Categorical(probs_a))
+            with pyro.plate("b_axis", size=6, subsample_size=3):
+                pyro.sample("b", dist.Normal(locs[a], 1.0), obs=0)
+                with pyro.plate("c_axis", size=9, subsample_size=4):
+                    pyro.sample("c", dist.Normal(locs[a], 1.0), obs=1)
+
+    def guide(params):
+        pass
+
+    params = {
+        "locs": jnp.array([0.0, 1.0]),
+        "probs_a": jnp.array([0.4, 0.6]),
+    }
+    transform = dist.biject_to(dist.constraints.simplex)
+    params_raw = {"locs": params["locs"], "probs_a": transform.inv(params["probs_a"])}
+
+    elbo = infer.TraceEnum_ELBO()
+
+    # Expected grads w/o subsampling
+    def expected_loss_fn(params_raw):
+        params = {
+            "locs": params_raw["locs"],
+            "probs_a": transform(params_raw["probs_a"]),
+        }
+        return elbo.loss(random.PRNGKey(0), {}, model, guide, params)
+
+    expected_loss, expected_grads = jax.value_and_grad(expected_loss_fn)(params_raw)
+
+    # Actual grads w/ subsampling
+    def actual_loss_fn(params_raw):
+        params = {
+            "locs": params_raw["locs"],
+            "probs_a": transform(params_raw["probs_a"]),
+        }
+        return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params)
+
+    actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)
+
+    assert_equal(actual_loss, expected_loss, prec=1e-3)
+    assert_equal(actual_grads, expected_grads, prec=1e-5)

From 4ef88cc6adf8f7375bea4154b561d79c255bf514 Mon Sep 17 00:00:00 2001
From: Yerdos Ordabayev <yordabay@broadinstitute.org>
Date: Sat, 4 Feb 2023 21:30:56 +0000
Subject: [PATCH 3/7] pass tests

---
 numpyro/infer/elbo.py          | 27 ++++++++++++---------------
 test/contrib/test_enum_elbo.py |  6 ++++++
 2 files changed, 18 insertions(+), 15 deletions(-)

diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py
index 8e3b419a8..0deff5d46 100644
--- a/numpyro/infer/elbo.py
+++ b/numpyro/infer/elbo.py
@@ -965,28 +965,25 @@ def single_particle_elbo(rng_key):
                                         f"but found different scales at plate('{plate}')."
                                     )
                             else:
-                                plate_to_scale[plate] = value
-                    group_scales = tuple(
-                        [
-                            value
-                            for plate, value in plate_to_scale.items()
-                            if plate in f.inputs
-                        ]
-                        for f in group_factors
-                    )
-                    scaled_group_factors = tuple(
-                        reduce(lambda a, b: a * b, scales, factor)
-                        for scales, factor in zip(group_scales, group_factors)
-                    )
+                                plate_to_scale[plate] = to_funsor(value)
 
                     cost = funsor.sum_product.sum_product(
                         funsor.ops.logaddexp,
                         funsor.ops.add,
-                        scaled_group_factors,
+                        group_factors,
                         plates=group_plates,
                         eliminate=group_sum_vars | elim_plates,
+                        scales=plate_to_scale,
+                    )
+                    scale = reduce(
+                        funsor.ops.mul,
+                        [
+                            value
+                            for plate, value in plate_to_scale.items()
+                            if plate not in elim_plates
+                        ],
+                        funsor.Number(1.0),
                     )
-                    scale = plate_to_scale.get(None, None)
                     # combine deps
                     deps = frozenset().union(
                         *[model_deps[name] for name in group_names]
diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py
index b16dffa89..cf108c2d0 100644
--- a/test/contrib/test_enum_elbo.py
+++ b/test/contrib/test_enum_elbo.py
@@ -2385,6 +2385,12 @@ def actual_loss_fn(params_raw):
 
 @pytest.mark.parametrize("scale", [1, 10])
 def test_model_enum_subsample_3(scale):
+    #      +--------------------+
+    #      |       +----------+ |
+    #  a ----> b ----> c      | |
+    #      |       |      N=2 | |
+    #      | M=2   +----------+ |
+    #      +--------------------+
     @config_enumerate
     @handlers.scale(scale=scale)
     def model(params):

From aecebe4068a11f3e693982c151e22633e4629bbf Mon Sep 17 00:00:00 2001
From: Yerdos Ordabayev <yordabay@broadinstitute.org>
Date: Sat, 4 Feb 2023 21:58:00 +0000
Subject: [PATCH 4/7] pytest.raises

---
 numpyro/infer/elbo.py          | 34 ++++++++-----------------
 test/contrib/test_enum_elbo.py | 46 +++++++++++++++++++++++-----------
 2 files changed, 42 insertions(+), 38 deletions(-)

diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py
index 0deff5d46..aab355b73 100644
--- a/numpyro/infer/elbo.py
+++ b/numpyro/infer/elbo.py
@@ -952,38 +952,26 @@ def single_particle_elbo(rng_key):
                         *(frozenset(f.inputs) & group_plates for f in group_factors)
                     )
                     elim_plates = group_plates - outermost_plates
-                    # incorporate the effects of subsampling and handlers.scale
-                    plate_to_scale = {}
-                    for name in group_names:
-                        for plate, value in (
-                            model_trace[name].get("plate_to_scale", {}).items()
-                        ):
-                            if plate in plate_to_scale:
-                                if value != plate_to_scale[plate]:
-                                    raise ValueError(
-                                        "Expected all enumerated sample sites to share a common scale factor, "
-                                        f"but found different scales at plate('{plate}')."
-                                    )
-                            else:
-                                plate_to_scale[plate] = to_funsor(value)
-
                     cost = funsor.sum_product.sum_product(
                         funsor.ops.logaddexp,
                         funsor.ops.add,
                         group_factors,
                         plates=group_plates,
                         eliminate=group_sum_vars | elim_plates,
-                        scales=plate_to_scale,
                     )
-                    scale = reduce(
-                        funsor.ops.mul,
+                    # incorporate the effects of subsampling and handlers.scale through a common scale factor
+                    scales_set = set(
                         [
-                            value
-                            for plate, value in plate_to_scale.items()
-                            if plate not in elim_plates
-                        ],
-                        funsor.Number(1.0),
+                            model_trace[name]["scale"]
+                            for name in (group_names | group_sum_vars)
+                        ]
                     )
+                    if len(scales_set) != 1:
+                        raise ValueError(
+                            "Expected all enumerated sample sites to share a common scale, "
+                            f"but found {len(scales_set)} different scales."
+                        )
+                    scale = next(iter(scales_set))
                     # combine deps
                     deps = frozenset().union(
                         *[model_deps[name] for name in group_names]
diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py
index cf108c2d0..237fa9cf2 100644
--- a/test/contrib/test_enum_elbo.py
+++ b/test/contrib/test_enum_elbo.py
@@ -2259,6 +2259,8 @@ def actual_loss_fn(params):
 
 @pytest.mark.parametrize("scale", [1, 10])
 def test_model_enum_subsample_1(scale):
+    # Model: enumerate a
+    #  a - [-> b  ]
     @config_enumerate
     @handlers.scale(scale=scale)
     def model(params):
@@ -2311,14 +2313,22 @@ def actual_loss_fn(params_raw):
         }
         return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params)
 
-    actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)
+    with pytest.raises(
+        ValueError, match="Expected all enumerated sample sites to share a common scale"
+    ):
+        # This never gets run because we don't support this yet.
+        actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)
 
-    assert_equal(actual_loss, expected_loss, prec=1e-5)
-    assert_equal(actual_grads, expected_grads, prec=1e-5)
+        assert_equal(actual_loss, expected_loss, prec=1e-5)
+        assert_equal(actual_grads, expected_grads, prec=1e-5)
 
 
 @pytest.mark.parametrize("scale", [1, 10])
 def test_model_enum_subsample_2(scale):
+    # Model: enumerate a
+    #  a - [-> b  ]
+    #   \
+    #    - [-> c  ]
     @config_enumerate
     @handlers.scale(scale=scale)
     def model(params):
@@ -2377,20 +2387,22 @@ def actual_loss_fn(params_raw):
         }
         return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params)
 
-    actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)
+    with pytest.raises(
+        ValueError, match="Expected all enumerated sample sites to share a common scale"
+    ):
+        # This never gets run because we don't support this yet.
+        actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)
 
-    assert_equal(actual_loss, expected_loss, prec=1e-5)
-    assert_equal(actual_grads, expected_grads, prec=1e-5)
+        assert_equal(actual_loss, expected_loss, prec=1e-5)
+        assert_equal(actual_grads, expected_grads, prec=1e-5)
 
 
 @pytest.mark.parametrize("scale", [1, 10])
 def test_model_enum_subsample_3(scale):
-    #      +--------------------+
-    #      |       +----------+ |
-    #  a ----> b ----> c      | |
-    #      |       |      N=2 | |
-    #      | M=2   +----------+ |
-    #      +--------------------+
+    # Model: enumerate a
+    # [ a - [----> b    ]
+    # [  \  [           ]
+    # [   - [- [-> c  ] ]
     @config_enumerate
     @handlers.scale(scale=scale)
     def model(params):
@@ -2449,7 +2461,11 @@ def actual_loss_fn(params_raw):
         }
         return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params)
 
-    actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)
+    with pytest.raises(
+        ValueError, match="Expected all enumerated sample sites to share a common scale"
+    ):
+        # This never gets run because we don't support this yet.
+        actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)
 
-    assert_equal(actual_loss, expected_loss, prec=1e-3)
-    assert_equal(actual_grads, expected_grads, prec=1e-5)
+        assert_equal(actual_loss, expected_loss, prec=1e-3)
+        assert_equal(actual_grads, expected_grads, prec=1e-5)

From 578664a5a30260d3c9afd31561a13c76405a973d Mon Sep 17 00:00:00 2001
From: Yerdos Ordabayev <yordabay@broadinstitute.org>
Date: Sat, 4 Feb 2023 22:01:13 +0000
Subject: [PATCH 5/7] improve comments

---
 test/contrib/test_enum_elbo.py | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py
index 237fa9cf2..7a760ddbe 100644
--- a/test/contrib/test_enum_elbo.py
+++ b/test/contrib/test_enum_elbo.py
@@ -2259,7 +2259,8 @@ def actual_loss_fn(params):
 
 @pytest.mark.parametrize("scale", [1, 10])
 def test_model_enum_subsample_1(scale):
-    # Model: enumerate a
+    # Enumerate: a
+    # Subsample: b
     #  a - [-> b  ]
     @config_enumerate
     @handlers.scale(scale=scale)
@@ -2325,7 +2326,8 @@ def actual_loss_fn(params_raw):
 
 @pytest.mark.parametrize("scale", [1, 10])
 def test_model_enum_subsample_2(scale):
-    # Model: enumerate a
+    # Enumerate: a
+    # Subsample: b, c
     #  a - [-> b  ]
     #   \
     #    - [-> c  ]
@@ -2399,7 +2401,8 @@ def actual_loss_fn(params_raw):
 
 @pytest.mark.parametrize("scale", [1, 10])
 def test_model_enum_subsample_3(scale):
-    # Model: enumerate a
+    # Enumerate: a
+    # Subsample: a, b, c
     # [ a - [----> b    ]
     # [  \  [           ]
     # [   - [- [-> c  ] ]

From ead0446ada31f69b4ad39b5bf7f1d3c72d7f254c Mon Sep 17 00:00:00 2001
From: Yerdos Ordabayev <yordabay@broadinstitute.org>
Date: Sun, 5 Feb 2023 23:45:00 +0000
Subject: [PATCH 6/7] convert scale to float

---
 numpyro/infer/elbo.py          | 16 ++++++++++------
 test/contrib/test_enum_elbo.py |  2 +-
 2 files changed, 11 insertions(+), 7 deletions(-)

diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py
index aab355b73..847b235fe 100644
--- a/numpyro/infer/elbo.py
+++ b/numpyro/infer/elbo.py
@@ -960,12 +960,16 @@ def single_particle_elbo(rng_key):
                         eliminate=group_sum_vars | elim_plates,
                     )
                     # incorporate the effects of subsampling and handlers.scale through a common scale factor
-                    scales_set = set(
-                        [
-                            model_trace[name]["scale"]
-                            for name in (group_names | group_sum_vars)
-                        ]
-                    )
+                    scales_set = set()
+                    for name in group_names | group_sum_vars:
+                        site_scale = model_trace[name]["scale"]
+                        if site_scale is None:
+                            site_scale = 1.0
+                        if isinstance(site_scale, jnp.ndarray) and site_scale.ndim:
+                            raise ValueError(
+                                "enumeration only supports scalar handlers.scale"
+                            )
+                        scales_set.add(float(site_scale))
                     if len(scales_set) != 1:
                         raise ValueError(
                             "Expected all enumerated sample sites to share a common scale, "
diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py
index 7a760ddbe..b4d1e9f5c 100644
--- a/test/contrib/test_enum_elbo.py
+++ b/test/contrib/test_enum_elbo.py
@@ -139,7 +139,7 @@ def hand_loss_fn(params_raw):
     assert_equal(auto_grad, hand_grad, prec=1e-5)
 
 
-@pytest.mark.parametrize("scale", [1, 10])
+@pytest.mark.parametrize("scale", [1, 10, jnp.array(10)])
 def test_elbo_enumerate_2(scale):
     params = {}
     params["guide_probs_x"] = jnp.array([0.1, 0.9])

From 9cd99c6a5869913a43340a29e8afdad2fa9c026a Mon Sep 17 00:00:00 2001
From: Yerdos Ordabayev <yordabay@broadinstitute.org>
Date: Mon, 6 Feb 2023 01:09:21 +0000
Subject: [PATCH 7/7] raise error if jnp.array

---
 numpyro/infer/elbo.py          | 4 ++--
 test/contrib/test_enum_elbo.py | 2 +-
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py
index 847b235fe..f8c9281f6 100644
--- a/numpyro/infer/elbo.py
+++ b/numpyro/infer/elbo.py
@@ -965,9 +965,9 @@ def single_particle_elbo(rng_key):
                         site_scale = model_trace[name]["scale"]
                         if site_scale is None:
                             site_scale = 1.0
-                        if isinstance(site_scale, jnp.ndarray) and site_scale.ndim:
+                        if isinstance(site_scale, jnp.ndarray):
                             raise ValueError(
-                                "enumeration only supports scalar handlers.scale"
+                                "Enumeration only supports scalar handlers.scale"
                             )
                         scales_set.add(float(site_scale))
                     if len(scales_set) != 1:
diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py
index b4d1e9f5c..7a760ddbe 100644
--- a/test/contrib/test_enum_elbo.py
+++ b/test/contrib/test_enum_elbo.py
@@ -139,7 +139,7 @@ def hand_loss_fn(params_raw):
     assert_equal(auto_grad, hand_grad, prec=1e-5)
 
 
-@pytest.mark.parametrize("scale", [1, 10, jnp.array(10)])
+@pytest.mark.parametrize("scale", [1, 10])
 def test_elbo_enumerate_2(scale):
     params = {}
     params["guide_probs_x"] = jnp.array([0.1, 0.9])