Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TraceGraph_ELBO implementation using provenance tracking #1412

Merged
merged 9 commits into from
May 31, 2022

Conversation

ordabayevy
Copy link
Member

@ordabayevy ordabayevy commented May 18, 2022

Porting pyro-ppl/pyro#3081 TraceGraph_ELBO implementation from Pyro that uses provenance tracking.

numpyro/infer/elbo.py Outdated Show resolved Hide resolved
"""
Returns log probabilities at each site for the guide and the model that is run against it.
"""
model_tr, guide_tr = get_importance_trace(model, guide, args, kwargs, params)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed since yesterday: Instead of manually tracing guide and then replaying/tracing model I changed this to just reuse the get_importance_trace function.

Yerdos Ordabayev added 3 commits May 28, 2022 09:22
@ordabayevy
Copy link
Member Author

@fehiepsi per our discussion added the tests back and fixed the stop_gradient issue. It should be ready for a review.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending the verification for the AIR example in Pyro.

and (not msg["fn"].has_rsample)
):
new_provenance = frozenset({msg["name"]})
old_provenance = msg["value"].aval.named_shape.get(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment that this is intended to be used in eval_provenance

surrogate = jnp.sum(
guide_site["log_prob"] * stop_gradient(downstream_cost)
)
elbo += surrogate - stop_gradient(surrogate)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: in jax it is clearer to use elbo = elbo + (...)

@ordabayevy ordabayevy requested a review from fehiepsi May 31, 2022 13:36
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Yerdos.

@ordabayevy
Copy link
Member Author

Thanks Du for reviewing my first PR to NumPyro!

@fehiepsi fehiepsi merged commit 28e38d8 into master May 31, 2022
@fehiepsi fehiepsi deleted the provenance-tracegraph-elbo branch May 31, 2022 23:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants