-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TraceGraph_ELBO implementation using provenance tracking #1412
Conversation
""" | ||
Returns log probabilities at each site for the guide and the model that is run against it. | ||
""" | ||
model_tr, guide_tr = get_importance_trace(model, guide, args, kwargs, params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed since yesterday: Instead of manually tracing guide and then replaying/tracing model I changed this to just reuse the get_importance_trace
function.
@fehiepsi per our discussion added the tests back and fixed the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM pending the verification for the AIR example in Pyro.
and (not msg["fn"].has_rsample) | ||
): | ||
new_provenance = frozenset({msg["name"]}) | ||
old_provenance = msg["value"].aval.named_shape.get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment that this is intended to be used in eval_provenance
numpyro/infer/elbo.py
Outdated
surrogate = jnp.sum( | ||
guide_site["log_prob"] * stop_gradient(downstream_cost) | ||
) | ||
elbo += surrogate - stop_gradient(surrogate) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: in jax it is clearer to use elbo = elbo + (...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Yerdos.
Thanks Du for reviewing my first PR to NumPyro! |
Porting pyro-ppl/pyro#3081
TraceGraph_ELBO
implementation from Pyro that uses provenance tracking.