You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Here's a reproducible example that's taken nearly directly from the Gaussian Mixture Model tutorial. The AutoContinuous guide seems to be the failure mode.
importjax.numpyasjnpimportjax.randomasrandomimportnumpyroimportnumpyro.distributionsasdistfromnumpyro.inferimportSVI, TraceEnum_ELBO, autoguidefromnumpyro.handlersimportblock, seeddata=jnp.array([0.0, 1.0, 10.0, 11.0, 12.0])
K=2# Fixed number of components.defmodel(data):
# Global variables.weights=numpyro.sample("weights", dist.Dirichlet(0.5*jnp.ones(K)))
scale=numpyro.sample("scale", dist.LogNormal(0.0, 2.0))
withnumpyro.plate("components", K):
locs=numpyro.sample("locs", dist.Normal(0.0, 10.0))
withnumpyro.plate("data", len(data)):
# Local variables.assignment=numpyro.sample("assignment", dist.Categorical(weights),
infer={"enumerate":"parallel"})
numpyro.sample("obs", dist.Normal(locs[assignment], scale), obs=data)
# this worksguide=autoguide.AutoNormal(block(seed(model, rng_seed=0), hide=['assignment']))
svi=SVI(model, guide, numpyro.optim.Adam(0.003), TraceEnum_ELBO())
svi_result=svi.run(random.PRNGKey(0), 100, data)
# this failsguide=autoguide.AutoDiagonalNormal(block(seed(model, rng_seed=0), hide=['assignment']))
svi=SVI(model, guide, numpyro.optim.Adam(0.003), TraceEnum_ELBO())
svi_result=svi.run(random.PRNGKey(0), 100, data)
Here's the associated stack trace.
[426](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=425) if msg["kwargs"]["dim_type"] in (DimType.GLOBAL, DimType.VISIBLE):
[427](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=426) for name in msg["args"][0].inputs:
[428](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=427) self._saved_globals += (
--> [429](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=428) (name, _DIM_STACK.global_frame.name_to_dim[name]),
[430](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=429) )
KeyError: 'components'
If I replace the components plate with locs = numpyro.sample("locs", dist.Normal(0.0, 10.0).expand((K,)).to_event(1)), I get the KeyError on the 'data' plate.
The text was updated successfully, but these errors were encountered:
Here's a reproducible example that's taken nearly directly from the Gaussian Mixture Model tutorial. The AutoContinuous guide seems to be the failure mode.
Here's the associated stack trace.
If I replace the components plate with
locs = numpyro.sample("locs", dist.Normal(0.0, 10.0).expand((K,)).to_event(1))
, I get the KeyError on the 'data' plate.The text was updated successfully, but these errors were encountered: