-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
possible to use numpyro as a modeling language only? #546
Comments
I think all of those ( |
Thank you! Indeed, that seems to do it! In case this is of interest to anyone else, here's what I think is a minimal working example of this: import numpyro
import numpyro.distributions as dist
import jax
import jax.random
def model(c):
a = numpyro.sample('a', dist.Normal(0.0, 1.0))
b = numpyro.sample('b', dist.Gamma(1.0))
numpyro.sample('c', dist.Normal(a, b), obs=c)
c = 1.0
rng_key = jax.random.PRNGKey(0)
init_params, potential_fn, constrain_fn = numpyro.infer.util.initialize_model(rng_key, model, model_args=(c,))
print('init_params: ', init_params)
print('constrain_fn: ', constrain_fn(init_params))
print('potential_fn: ', potential_fn(init_params))
print('grad: ', jax.grad(potential_fn)(init_params)) Which leads to the results of
|
Actually, could I be even more greedy? Is there any way to create a function that's like As an example of something that doesn't work, I tried the following: import numpyro
import numpyro.distributions as dist
import jax
import jax.random
def model(c):
a = numpyro.sample('a', dist.Normal(0.0, 1.0))
b = numpyro.sample('b', dist.Gamma(1.0))
with numpyro.plate('N',2,subsample_size=1):
numpyro.sample('c', dist.Normal(a, b), obs=c)
c = jax.numpy.array([1.0,2.0])
rng_key = jax.random.PRNGKey(0)
init_params, potential_fn, constrain_fn = numpyro.infer.util.initialize_model(rng_key, model, model_args=(c,))
print('init_params: ', init_params)
print('constrain_fn: ', constrain_fn(init_params))
print('potential_fn: ', potential_fn(init_params))
print('grad: ', jax.grad(potential_fn)(init_params)) My hope was that a random |
Looking like get_potential_fn with potential_fn_gen, _ = get_potential_fn(PRNGKey(0), model, dynamic_args=True, ...)
def potential_fn(rng, params):
# use rng to get subsample of c
return potential_fn_gen(c_subsample)(params) |
Thanks so much for your help! Just in case anyone else would like to see a full example of this, here's a version that uses import numpyro
import numpyro.distributions as dist
import jax
import jax.random
def model(c):
a = numpyro.sample('a', dist.Normal(0.0, 1.0))
b = numpyro.sample('b', dist.Gamma(1.0))
with numpyro.plate('N',2):
numpyro.sample('c', dist.Normal(a, b), obs=c)
c = jax.numpy.array([1.0,2.0])
rng_key = jax.random.PRNGKey(0)
init_params, potential_fn, constrain_fn = numpyro.infer.util.initialize_model(rng_key, model, model_args=(c,))
print('init_params: ', init_params)
print('constrain_fn: ', constrain_fn(init_params))
print('potential_fn: ', potential_fn(init_params))
print('grad: ', jax.grad(potential_fn)(init_params)) The results are:
Now, here's a version that uses import numpyro
import numpyro.distributions as dist
import jax
import jax.random
def model(c):
a = numpyro.sample('a', dist.Normal(0.0, 1.0))
b = numpyro.sample('b', dist.Gamma(1.0))
with numpyro.plate('N',2,subsample_size=1):
numpyro.sample('c', dist.Normal(a, b), obs=c)
c = jax.numpy.array([1.0,2.0])
rng_key = jax.random.PRNGKey(1)
potential_fn, constrain_fn = numpyro.infer.util.get_potential_fn(rng_key, model, dynamic_args=True, model_args=(c,))
init_params = {'a':jax.numpy.array(1.9340096), 'b':jax.numpy.array(0.8950882)}
print('init_params: ', init_params)
print('constrain_fn: ', constrain_fn(c[0])(init_params))
print('potential_fn: ', potential_fn(c[0])(init_params))
print('potential_fn: ', potential_fn(c[1])(init_params))
print('grad[0]: ', jax.grad(lambda params : potential_fn(c[0])(params))(init_params))
print('grad[1]: ', jax.grad(lambda params : potential_fn(c[1])(params))(init_params)) With the results of
Since |
@justindomke We have a forum where those tips would be more accessible to users. Thanks for your clear report! We didn't focus much on SVI, hence subsampling, right now so it is awesome that the MCMC utilities work out of the box for you. :D |
Hi! I'm very interested in this project. However, my potential use is a bit different from what seems typical with numpyro, and I was wondering if there was a "recommended" way to go about it.
In short, what I'd like to do is use numpyro as a modeling language only, and implement my own inference algorithms. Essentially, I'd like to define a model in numpyro, then specify values for some subset of the variables in the model, then get a
log_prob(*vars)
function that would evaluate the log probability, given a configuration of all of the other variables. That's it! I'd like to do everything else myself.I'm aware of the
log_density
function in the utilities section, but it is challenging to figure out how this might be used. If there were any examples, that would be helpful.If I'm being greedy, even better would be a function
log_prob(*unconstrained_vars)
that evaluated the probability after all the variables were transformed to an unconstrained space, along with aconstrain(*unconstrained_vars)
function that would transform them back.I appreciate any help!
The text was updated successfully, but these errors were encountered: