Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise warning for the automatic enumerate behavior #1244

Merged
merged 6 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions examples/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def multinomial(annotations):
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

with numpyro.plate("item", num_items, dim=-2):
c = numpyro.sample("c", dist.Categorical(pi))
c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})

with numpyro.plate("position", num_positions):
numpyro.sample("y", dist.Categorical(zeta[c]), obs=annotations)
Expand All @@ -144,7 +144,7 @@ def dawid_skene(positions, annotations):
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

with numpyro.plate("item", num_items, dim=-2):
c = numpyro.sample("c", dist.Categorical(pi))
c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})

# here we use Vindex to allow broadcasting for the second index `c`
# ref: http://num.pyro.ai/en/latest/utilities.html#numpyro.contrib.indexing.vindex
Expand All @@ -167,10 +167,18 @@ def mace(positions, annotations):
theta = numpyro.sample("theta", dist.Beta(0.5, 0.5))

with numpyro.plate("item", num_items, dim=-2):
c = numpyro.sample("c", dist.DiscreteUniform(0, num_classes - 1))
c = numpyro.sample(
"c",
dist.DiscreteUniform(0, num_classes - 1),
infer={"enumerate": "parallel"},
)

with numpyro.plate("position", num_positions):
s = numpyro.sample("s", dist.Bernoulli(1 - theta[positions]))
s = numpyro.sample(
"s",
dist.Bernoulli(1 - theta[positions]),
infer={"enumerate": "parallel"},
)
probs = jnp.where(
s[..., None] == 0, nn.one_hot(c, num_classes), epsilon[positions]
)
Expand Down Expand Up @@ -207,7 +215,7 @@ def hierarchical_dawid_skene(positions, annotations):
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

with numpyro.plate("item", num_items, dim=-2):
c = numpyro.sample("c", dist.Categorical(pi))
c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})

with numpyro.plate("position", num_positions):
logits = Vindex(beta)[positions, c, :]
Expand All @@ -232,7 +240,7 @@ def item_difficulty(annotations):
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

with numpyro.plate("item", num_items, dim=-2):
c = numpyro.sample("c", dist.Categorical(pi))
c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})

with handlers.reparam(config={"theta": LocScaleReparam(0)}):
theta = numpyro.sample("theta", dist.Normal(eta[c], chi[c]).to_event(1))
Expand Down Expand Up @@ -270,7 +278,7 @@ def logistic_random_effects(positions, annotations):
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

with numpyro.plate("item", num_items, dim=-2):
c = numpyro.sample("c", dist.Categorical(pi))
c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})

with handlers.reparam(config={"theta": LocScaleReparam(0)}):
theta = numpyro.sample("theta", dist.Normal(0, chi[c]).to_event(1))
Expand Down
12 changes: 10 additions & 2 deletions notebooks/source/discrete_imputation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,11 @@
"\n",
" # sample imputation values for A\n",
" # mask out to not add log_prob to total likelihood right now\n",
" Aimp = sample(\"A\", dist.Bernoulli(logits=eta_A).mask(False))\n",
" Aimp = sample(\n",
" \"A\",\n",
" dist.Bernoulli(logits=eta_A).mask(False),\n",
" infer={\"enumerate\": \"parallel\"},\n",
" )\n",
"\n",
" # 'manually' calculate the log_prob\n",
" log_prob = dist.Bernoulli(logits=eta_A).log_prob(Aimp)\n",
Expand Down Expand Up @@ -712,7 +716,11 @@
"\n",
" # sample imputation values for A\n",
" # mask out to not add log_prob to total likelihood right now\n",
" Aimp = sample(\"A\", dist.Bernoulli(logits=eta_A).mask(False))\n",
" Aimp = sample(\n",
" \"A\",\n",
" dist.Bernoulli(logits=eta_A).mask(False),\n",
" infer={\"enumerate\": \"parallel\"},\n",
" )\n",
"\n",
" # 'manually' calculate the log_prob\n",
" log_prob = dist.Bernoulli(logits=eta_A).log_prob(Aimp)\n",
Expand Down
6 changes: 2 additions & 4 deletions notebooks/source/model_rendering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,7 @@
" theta = numpyro.sample(\"theta\", dist.Beta(0.5, 0.5))\n",
"\n",
" with numpyro.plate(\"item\", num_items, dim=-2):\n",
" # NB: using constant logits for discrete uniform prior\n",
" # (NumPyro does not have DiscreteUniform distribution yet)\n",
" c = numpyro.sample(\"c\", dist.Categorical(logits=jnp.zeros(num_classes)))\n",
" c = numpyro.sample(\"c\", dist.DiscreteUniform(0, num_classes - 1))\n",
"\n",
" with numpyro.plate(\"position\", num_positions):\n",
" s = numpyro.sample(\"s\", dist.Bernoulli(1 - theta[positions]))\n",
Expand Down Expand Up @@ -568,7 +566,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.8.8"
}
},
"nbformat": 4,
Expand Down
21 changes: 19 additions & 2 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,11 +402,28 @@ def _get_model_transforms(model, model_args=(), model_kwargs=None):
for k, v in model_trace.items():
if v["type"] == "sample" and not v["is_observed"]:
if v["fn"].support.is_discrete:
enum_type = v["infer"].get("enumerate")
if enum_type is not None and (enum_type != "parallel"):
raise RuntimeError(
"This algorithm might only work for discrete sites with"
f" enumerate marked 'parallel'. But the site {k} is marked"
f" as '{enum_type}'."
)
has_enumerate_support = True
if not v["fn"].has_enumerate_support:
dist_name = type(v["fn"]).__name__
raise RuntimeError(
"MCMC only supports continuous sites or discrete sites "
f"with enumerate support, but got {type(v['fn']).__name__}."
"This algorithm might only work for discrete sites with"
f" enumerate support. But the {dist_name} distribution at"
f" site {k} does not have enumerate support."
)
if enum_type is None:
warnings.warn(
"Some algorithms will automatically enumerate the discrete"
f" latent site {k} of your model. In the future,"
" enumerated sites need to be marked with"
" `infer={'enumerate': 'parallel'}`.",
FutureWarning,
)
else:
support = v["fn"].support
Expand Down
90 changes: 12 additions & 78 deletions test/contrib/test_funsor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
def test_gaussian_mixture_model():
K, N = 3, 1000

@config_enumerate
def gmm(data):
mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K)))
with numpyro.plate("num_clusters", K, dim=-1):
Expand Down Expand Up @@ -60,6 +61,7 @@ def gmm(data):


def test_bernoulli_latent_model():
@config_enumerate
def model(data):
y_prob = numpyro.sample("y_prob", dist.Beta(1.0, 1.0))
with numpyro.plate("data", data.shape[0]):
Expand All @@ -81,6 +83,7 @@ def model(data):


def test_change_point():
@config_enumerate
def model(count_data):
n_count_data = count_data.shape[0]
alpha = 1 / jnp.mean(count_data.astype(np.float32))
Expand All @@ -93,84 +96,13 @@ def model(count_data):
with numpyro.plate("data", n_count_data):
numpyro.sample("obs", dist.Poisson(lambda_), obs=count_data)

count_data = jnp.array(
[
13,
24,
8,
24,
7,
35,
14,
11,
15,
11,
22,
22,
11,
57,
11,
19,
29,
6,
19,
12,
22,
12,
18,
72,
32,
9,
7,
13,
19,
23,
27,
20,
6,
17,
13,
10,
14,
6,
16,
15,
7,
2,
15,
15,
19,
70,
49,
7,
53,
22,
21,
31,
19,
11,
1,
20,
12,
35,
17,
23,
17,
4,
2,
31,
30,
13,
27,
0,
39,
37,
5,
14,
13,
22,
]
)
# fmt: off
count_data = jnp.array([
13, 24, 8, 24, 7, 35, 14, 11, 15, 11, 22, 22, 11, 57, 11, 19, 29, 6, 19, 12, 22,
12, 18, 72, 32, 9, 7, 13, 19, 23, 27, 20, 6, 17, 13, 10, 14, 6, 16, 15, 7, 2,
15, 15, 19, 70, 49, 7, 53, 22, 21, 31, 19, 11, 1, 20, 12, 35, 17, 23, 17, 4, 2,
31, 30, 13, 27, 0, 39, 37, 5, 14, 13, 22])
# fmt: on

kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=500)
Expand All @@ -184,6 +116,7 @@ def test_gaussian_hmm():
dim = 4
num_steps = 10

@config_enumerate
def model(data):
with numpyro.plate("states", dim):
transition = numpyro.sample("transition", dist.Dirichlet(jnp.ones(dim)))
Expand Down Expand Up @@ -586,6 +519,7 @@ def transition_fn(carry, y):
def test_missing_plate(monkeypatch):
K, N = 3, 1000

@config_enumerate
def gmm(data):
mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K)))
# plate/to_event is missing here
Expand Down
2 changes: 1 addition & 1 deletion test/contrib/test_infer_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def model2():
@pytest.mark.parametrize("model", [model_zzxx, model2])
@pytest.mark.parametrize("temperature", [0, 1])
def test_mcmc_model_side_enumeration(model, temperature):
mcmc = infer.MCMC(infer.NUTS(model), num_warmup=0, num_samples=1)
mcmc = infer.MCMC(infer.NUTS(config_enumerate(model)), num_warmup=0, num_samples=1)
mcmc.run(random.PRNGKey(0))
mcmc_data = {
k: v[0] for k, v in mcmc.get_samples().items() if k in ["loc", "scale"]
Expand Down
1 change: 1 addition & 0 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ def model():
init_to_uniform,
],
)
@pytest.mark.filterwarnings("ignore:.*enumerate.*:FutureWarning")
def test_discrete_helpful_error(auto_class, init_loc_fn):
def model():
p = numpyro.sample("p", dist.Beta(2.0, 2.0))
Expand Down
2 changes: 1 addition & 1 deletion test/infer/test_hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def test_block_update_partitioning(num_blocks):

def test_enum_subsample_smoke():
def model(data):
x = numpyro.sample("x", dist.Bernoulli(0.5))
x = numpyro.sample("x", dist.Bernoulli(0.5), infer={"enumerate": "parallel"})
with numpyro.plate("N", data.shape[0], subsample_size=100, dim=-1):
batch = numpyro.subsample(data, event_dim=0)
numpyro.sample("obs", dist.Normal(x, 1), obs=batch)
Expand Down
9 changes: 9 additions & 0 deletions test/infer/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,3 +1060,12 @@ def model():
mcmc = MCMC(NUTS(subs_model), num_warmup=10, num_samples=10)
with pytest.warns(UserWarning, match="skipping initialization"):
mcmc.run(random.PRNGKey(1))


def test_discrete_site_without_infer_enumerate():
def model():
numpyro.sample("x", dist.Bernoulli(0.5))

mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
with pytest.warns(FutureWarning, match="enumerated sites"):
mcmc.run(random.PRNGKey(0))