Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add numpyro.collapse #773

Merged
merged 17 commits into from
Oct 7, 2020
Merged

Add numpyro.collapse #773

merged 17 commits into from
Oct 7, 2020

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Oct 3, 2020

Resolves #756.

TODO:

  • debug why the pattern is not recognized
  • add beta bernoulli xfail test

Blocked by pyro-ppl/funsor#377

numpyro/contrib/tfp/distributions.py Show resolved Hide resolved
numpyro/distributions/distribution.py Outdated Show resolved Hide resolved
numpyro/handlers.py Outdated Show resolved Hide resolved

if msg["type"] == "sample":
if msg["value"] is None:
msg["value"] = msg["name"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there no need for msg["done"] = True as in Pyro?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have msg["done"] in NumPyro. What does it do?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eb8680 can you explain msg["done"] vs msg["stop"]?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

msg["done"] indicates that msg["value"] is final and should not be changed by subsequent handlers. In this case it would equivalent to msg["value"] is not None, although it can occasionally be useful to distinguish between these situations.

Comment on lines +577 to +583
def model():
c = numpyro.sample("c", dist.Gamma(1, 1))
with handlers.collapse():
probs = numpyro.sample("probs", dist.Beta(c, 2))
with numpyro.plate("plate", len(data)):
numpyro.sample("obs", dist.Binomial(10, probs),
obs=data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @neerajprad I believe this is basically the mobb model 😄

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@eb8680 eb8680 merged commit 78716a7 into pyro-ppl:master Oct 7, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Port poutine.collapse to NumPyro
3 participants