Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support total_count=0 in multinomial #1000

Merged
merged 1 commit into from
Apr 8, 2021

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Apr 8, 2021

Fixes #983. And oh yeah, this is PR #1000. 💯

assert not isinstance(
n, jax.core.Tracer
), "The total count parameter `n` should not be a jax abstract array."
n_max = int(np.max(jax.device_get(n)))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This covers more usage cases while raising an early error if n is an abstract array.

@fritzo fritzo merged commit b73e85f into pyro-ppl:master Apr 8, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

dist.DirichletMultinomial can't handle zero_total count
2 participants