Skip to content

Commit

Permalink
Raise warning for the automatic enumerate behavior (#1244)
Browse files Browse the repository at this point in the history
* Raise Future Warning for sites without enumerated support

* fix typo on enumerate key

* add test to capture the warning

* Run black

* Fix all failing test due to the new future warning
  • Loading branch information
fehiepsi authored Dec 6, 2021
1 parent b149769 commit 9987eb7
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 95 deletions.
22 changes: 15 additions & 7 deletions examples/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def multinomial(annotations):
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

with numpyro.plate("item", num_items, dim=-2):
c = numpyro.sample("c", dist.Categorical(pi))
c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})

with numpyro.plate("position", num_positions):
numpyro.sample("y", dist.Categorical(zeta[c]), obs=annotations)
Expand All @@ -144,7 +144,7 @@ def dawid_skene(positions, annotations):
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

with numpyro.plate("item", num_items, dim=-2):
c = numpyro.sample("c", dist.Categorical(pi))
c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})

# here we use Vindex to allow broadcasting for the second index `c`
# ref: http://num.pyro.ai/en/latest/utilities.html#numpyro.contrib.indexing.vindex
Expand All @@ -167,10 +167,18 @@ def mace(positions, annotations):
theta = numpyro.sample("theta", dist.Beta(0.5, 0.5))

with numpyro.plate("item", num_items, dim=-2):
c = numpyro.sample("c", dist.DiscreteUniform(0, num_classes - 1))
c = numpyro.sample(
"c",
dist.DiscreteUniform(0, num_classes - 1),
infer={"enumerate": "parallel"},
)

with numpyro.plate("position", num_positions):
s = numpyro.sample("s", dist.Bernoulli(1 - theta[positions]))
s = numpyro.sample(
"s",
dist.Bernoulli(1 - theta[positions]),
infer={"enumerate": "parallel"},
)
probs = jnp.where(
s[..., None] == 0, nn.one_hot(c, num_classes), epsilon[positions]
)
Expand Down Expand Up @@ -207,7 +215,7 @@ def hierarchical_dawid_skene(positions, annotations):
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

with numpyro.plate("item", num_items, dim=-2):
c = numpyro.sample("c", dist.Categorical(pi))
c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})

with numpyro.plate("position", num_positions):
logits = Vindex(beta)[positions, c, :]
Expand All @@ -232,7 +240,7 @@ def item_difficulty(annotations):
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

with numpyro.plate("item", num_items, dim=-2):
c = numpyro.sample("c", dist.Categorical(pi))
c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})

with handlers.reparam(config={"theta": LocScaleReparam(0)}):
theta = numpyro.sample("theta", dist.Normal(eta[c], chi[c]).to_event(1))
Expand Down Expand Up @@ -270,7 +278,7 @@ def logistic_random_effects(positions, annotations):
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

with numpyro.plate("item", num_items, dim=-2):
c = numpyro.sample("c", dist.Categorical(pi))
c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})

with handlers.reparam(config={"theta": LocScaleReparam(0)}):
theta = numpyro.sample("theta", dist.Normal(0, chi[c]).to_event(1))
Expand Down
12 changes: 10 additions & 2 deletions notebooks/source/discrete_imputation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,11 @@
"\n",
" # sample imputation values for A\n",
" # mask out to not add log_prob to total likelihood right now\n",
" Aimp = sample(\"A\", dist.Bernoulli(logits=eta_A).mask(False))\n",
" Aimp = sample(\n",
" \"A\",\n",
" dist.Bernoulli(logits=eta_A).mask(False),\n",
" infer={\"enumerate\": \"parallel\"},\n",
" )\n",
"\n",
" # 'manually' calculate the log_prob\n",
" log_prob = dist.Bernoulli(logits=eta_A).log_prob(Aimp)\n",
Expand Down Expand Up @@ -712,7 +716,11 @@
"\n",
" # sample imputation values for A\n",
" # mask out to not add log_prob to total likelihood right now\n",
" Aimp = sample(\"A\", dist.Bernoulli(logits=eta_A).mask(False))\n",
" Aimp = sample(\n",
" \"A\",\n",
" dist.Bernoulli(logits=eta_A).mask(False),\n",
" infer={\"enumerate\": \"parallel\"},\n",
" )\n",
"\n",
" # 'manually' calculate the log_prob\n",
" log_prob = dist.Bernoulli(logits=eta_A).log_prob(Aimp)\n",
Expand Down
6 changes: 2 additions & 4 deletions notebooks/source/model_rendering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,7 @@
" theta = numpyro.sample(\"theta\", dist.Beta(0.5, 0.5))\n",
"\n",
" with numpyro.plate(\"item\", num_items, dim=-2):\n",
" # NB: using constant logits for discrete uniform prior\n",
" # (NumPyro does not have DiscreteUniform distribution yet)\n",
" c = numpyro.sample(\"c\", dist.Categorical(logits=jnp.zeros(num_classes)))\n",
" c = numpyro.sample(\"c\", dist.DiscreteUniform(0, num_classes - 1))\n",
"\n",
" with numpyro.plate(\"position\", num_positions):\n",
" s = numpyro.sample(\"s\", dist.Bernoulli(1 - theta[positions]))\n",
Expand Down Expand Up @@ -568,7 +566,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.8.8"
}
},
"nbformat": 4,
Expand Down
21 changes: 19 additions & 2 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,11 +402,28 @@ def _get_model_transforms(model, model_args=(), model_kwargs=None):
for k, v in model_trace.items():
if v["type"] == "sample" and not v["is_observed"]:
if v["fn"].support.is_discrete:
enum_type = v["infer"].get("enumerate")
if enum_type is not None and (enum_type != "parallel"):
raise RuntimeError(
"This algorithm might only work for discrete sites with"
f" enumerate marked 'parallel'. But the site {k} is marked"
f" as '{enum_type}'."
)
has_enumerate_support = True
if not v["fn"].has_enumerate_support:
dist_name = type(v["fn"]).__name__
raise RuntimeError(
"MCMC only supports continuous sites or discrete sites "
f"with enumerate support, but got {type(v['fn']).__name__}."
"This algorithm might only work for discrete sites with"
f" enumerate support. But the {dist_name} distribution at"
f" site {k} does not have enumerate support."
)
if enum_type is None:
warnings.warn(
"Some algorithms will automatically enumerate the discrete"
f" latent site {k} of your model. In the future,"
" enumerated sites need to be marked with"
" `infer={'enumerate': 'parallel'}`.",
FutureWarning,
)
else:
support = v["fn"].support
Expand Down
90 changes: 12 additions & 78 deletions test/contrib/test_funsor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
def test_gaussian_mixture_model():
K, N = 3, 1000

@config_enumerate
def gmm(data):
mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K)))
with numpyro.plate("num_clusters", K, dim=-1):
Expand Down Expand Up @@ -60,6 +61,7 @@ def gmm(data):


def test_bernoulli_latent_model():
@config_enumerate
def model(data):
y_prob = numpyro.sample("y_prob", dist.Beta(1.0, 1.0))
with numpyro.plate("data", data.shape[0]):
Expand All @@ -81,6 +83,7 @@ def model(data):


def test_change_point():
@config_enumerate
def model(count_data):
n_count_data = count_data.shape[0]
alpha = 1 / jnp.mean(count_data.astype(np.float32))
Expand All @@ -93,84 +96,13 @@ def model(count_data):
with numpyro.plate("data", n_count_data):
numpyro.sample("obs", dist.Poisson(lambda_), obs=count_data)

count_data = jnp.array(
[
13,
24,
8,
24,
7,
35,
14,
11,
15,
11,
22,
22,
11,
57,
11,
19,
29,
6,
19,
12,
22,
12,
18,
72,
32,
9,
7,
13,
19,
23,
27,
20,
6,
17,
13,
10,
14,
6,
16,
15,
7,
2,
15,
15,
19,
70,
49,
7,
53,
22,
21,
31,
19,
11,
1,
20,
12,
35,
17,
23,
17,
4,
2,
31,
30,
13,
27,
0,
39,
37,
5,
14,
13,
22,
]
)
# fmt: off
count_data = jnp.array([
13, 24, 8, 24, 7, 35, 14, 11, 15, 11, 22, 22, 11, 57, 11, 19, 29, 6, 19, 12, 22,
12, 18, 72, 32, 9, 7, 13, 19, 23, 27, 20, 6, 17, 13, 10, 14, 6, 16, 15, 7, 2,
15, 15, 19, 70, 49, 7, 53, 22, 21, 31, 19, 11, 1, 20, 12, 35, 17, 23, 17, 4, 2,
31, 30, 13, 27, 0, 39, 37, 5, 14, 13, 22])
# fmt: on

kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=500)
Expand All @@ -184,6 +116,7 @@ def test_gaussian_hmm():
dim = 4
num_steps = 10

@config_enumerate
def model(data):
with numpyro.plate("states", dim):
transition = numpyro.sample("transition", dist.Dirichlet(jnp.ones(dim)))
Expand Down Expand Up @@ -586,6 +519,7 @@ def transition_fn(carry, y):
def test_missing_plate(monkeypatch):
K, N = 3, 1000

@config_enumerate
def gmm(data):
mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K)))
# plate/to_event is missing here
Expand Down
2 changes: 1 addition & 1 deletion test/contrib/test_infer_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def model2():
@pytest.mark.parametrize("model", [model_zzxx, model2])
@pytest.mark.parametrize("temperature", [0, 1])
def test_mcmc_model_side_enumeration(model, temperature):
mcmc = infer.MCMC(infer.NUTS(model), num_warmup=0, num_samples=1)
mcmc = infer.MCMC(infer.NUTS(config_enumerate(model)), num_warmup=0, num_samples=1)
mcmc.run(random.PRNGKey(0))
mcmc_data = {
k: v[0] for k, v in mcmc.get_samples().items() if k in ["loc", "scale"]
Expand Down
1 change: 1 addition & 0 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ def model():
init_to_uniform,
],
)
@pytest.mark.filterwarnings("ignore:.*enumerate.*:FutureWarning")
def test_discrete_helpful_error(auto_class, init_loc_fn):
def model():
p = numpyro.sample("p", dist.Beta(2.0, 2.0))
Expand Down
2 changes: 1 addition & 1 deletion test/infer/test_hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def test_block_update_partitioning(num_blocks):

def test_enum_subsample_smoke():
def model(data):
x = numpyro.sample("x", dist.Bernoulli(0.5))
x = numpyro.sample("x", dist.Bernoulli(0.5), infer={"enumerate": "parallel"})
with numpyro.plate("N", data.shape[0], subsample_size=100, dim=-1):
batch = numpyro.subsample(data, event_dim=0)
numpyro.sample("obs", dist.Normal(x, 1), obs=batch)
Expand Down
9 changes: 9 additions & 0 deletions test/infer/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,3 +1060,12 @@ def model():
mcmc = MCMC(NUTS(subs_model), num_warmup=10, num_samples=10)
with pytest.warns(UserWarning, match="skipping initialization"):
mcmc.run(random.PRNGKey(1))


def test_discrete_site_without_infer_enumerate():
def model():
numpyro.sample("x", dist.Bernoulli(0.5))

mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
with pytest.warns(FutureWarning, match="enumerated sites"):
mcmc.run(random.PRNGKey(0))

0 comments on commit 9987eb7

Please sign in to comment.