You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
A couple of examples of models that run in numpyro but not in bayeux. First example runs but does not produce the correct answer. Second example does not run and has shape errors associated with the number of chains.
numpyro==0.15.0
bayeux-ml==0.1.12
importjax.numpyasjnpfromjaximportrandomimportarvizasazimportbayeuxasbximportnumpyroimportnumpyro.distributionsasdistfromnumpyro.inferimportNUTS, MCMCN=100true_alpha=1.1true_sigma=0.1key=random.PRNGKey(0)
data=true_alpha+true_sigma*random.normal(key=key, shape=(N,))
defmodel():
alpha=numpyro.sample("alpha", dist.Normal(0, 3))
sigma=numpyro.sample("sigma", dist.HalfNormal(1))
numpyro.sample("y", dist.Normal(alpha, sigma), obs=data)
# this runs fine, samples only from alpha and sigma and recovers the parameterskernel=NUTS(model)
mcmc=MCMC(kernel, num_warmup=500, num_samples=500)
mcmc.run(random.key(0))
mcmc.print_summary()
# this does not work and seems to sample from the observed sitesbx_model=bx.Model.from_numpyro(model)
idata=bx_model.mcmc.numpyro_nuts(seed=random.key(0))
# it would also be nice to write the numpyro model as def model(data=None)# and call bayeux as bx.Model.from_numpyro(model, data=data)
importjax.numpyasjnpfromjaximportrandomimportarvizasazimportbayeuxasbximportnumpyroimportnumpyro.distributionsasdistfromnumpyro.inferimportNUTS, MCMCN=100true_alpha=1.1true_sigma=0.1true_beta=0.8key=random.PRNGKey(0)
x=jnp.linspace(0, 1, N)
data=true_alpha+true_sigma*random.normal(key=key, shape=(N,)) +true_beta*xdefmodel():
alpha=numpyro.sample("alpha", dist.Normal(0, 3))
sigma=numpyro.sample("sigma", dist.HalfNormal(1))
beta=numpyro.sample("beta", dist.Normal(0, 3))
mu=alpha+beta*xnumpyro.sample("y", dist.Normal(mu, sigma), obs=data)
# this runs fine and recovers the parameterskernel=NUTS(model)
mcmc=MCMC(kernel, num_warmup=500, num_samples=500, num_chains=2)
mcmc.run(random.key(0))
mcmc.print_summary()
# this does not workbx_model=bx.Model.from_numpyro(model)
idata=bx_model.mcmc.numpyro_nuts(seed=random.key(0), num_chains=2)
# mul got incompatible shapes for broadcasting: (2,), (100,).# issue with multiple chains
The text was updated successfully, but these errors were encountered:
A couple of examples of models that run in numpyro but not in bayeux. First example runs but does not produce the correct answer. Second example does not run and has shape errors associated with the number of chains.
The text was updated successfully, but these errors were encountered: