Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

numpyro models run in numpyro but not using bayeux #51

Open
theorashid opened this issue Jun 11, 2024 · 0 comments
Open

numpyro models run in numpyro but not using bayeux #51

theorashid opened this issue Jun 11, 2024 · 0 comments

Comments

@theorashid
Copy link
Contributor

theorashid commented Jun 11, 2024

A couple of examples of models that run in numpyro but not in bayeux. First example runs but does not produce the correct answer. Second example does not run and has shape errors associated with the number of chains.

numpyro==0.15.0
bayeux-ml==0.1.12
import jax.numpy as jnp
from jax import random

import arviz as az
import bayeux as bx
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC

N = 100
true_alpha = 1.1
true_sigma = 0.1

key = random.PRNGKey(0)
data = true_alpha + true_sigma * random.normal(key=key, shape=(N,))

def model():
	alpha = numpyro.sample("alpha", dist.Normal(0, 3))
	sigma = numpyro.sample("sigma", dist.HalfNormal(1))
	numpyro.sample("y", dist.Normal(alpha, sigma), obs=data)


# this runs fine, samples only from alpha and sigma and recovers the parameters
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=500)
mcmc.run(random.key(0))
mcmc.print_summary()

# this does not work and seems to sample from the observed sites
bx_model = bx.Model.from_numpyro(model)
idata = bx_model.mcmc.numpyro_nuts(seed=random.key(0))

# it would also be nice to write the numpyro model as def model(data=None)
# and call bayeux as bx.Model.from_numpyro(model, data=data)
import jax.numpy as jnp
from jax import random

import arviz as az
import bayeux as bx
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC

N = 100
true_alpha = 1.1
true_sigma = 0.1
true_beta = 0.8

key = random.PRNGKey(0)
x = jnp.linspace(0, 1, N)
data = true_alpha + true_sigma * random.normal(key=key, shape=(N,)) + true_beta * x

def model():
	alpha = numpyro.sample("alpha", dist.Normal(0, 3))
	sigma = numpyro.sample("sigma", dist.HalfNormal(1))
	beta = numpyro.sample("beta", dist.Normal(0, 3))
	mu = alpha + beta * x
	numpyro.sample("y", dist.Normal(mu, sigma), obs=data)

# this runs fine and recovers the parameters
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=2)
mcmc.run(random.key(0))
mcmc.print_summary()

# this does not work
bx_model = bx.Model.from_numpyro(model)
idata = bx_model.mcmc.numpyro_nuts(seed=random.key(0), num_chains=2)

# mul got incompatible shapes for broadcasting: (2,), (100,).
# issue with multiple chains
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant