-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement sparse Poisson.log_prob() #1003
Conversation
yup, both look good to me. Please go with the one you prefer. :) |
@pytest.mark.parametrize( | ||
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL | ||
) | ||
def test_jit_log_likelihood(jax_dist, sp_dist, params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whereas jitting in test_log_prob
fixes params and varies the value
, jitting in this test fixes the value
and varies params. This test is thus closer to usage in an observe statement numpyro.sample(..., obs=data)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for supporting this and adding more useful tests!
shape = lax.broadcast_shapes(self.batch_shape, jnp.shape(value)) | ||
rate = jnp.broadcast_to(self.rate, shape).reshape(-1) | ||
nonzero = jnp.broadcast_to(value > 0, shape).reshape(-1).nonzero() | ||
value = jnp.broadcast_to(value, shape).reshape(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I guess we can swap the order of those two lines:
value = jnp.broadcast_to(value, shape).reshape(-1)
nonzero = (value > 0).nonzero()
Btw, do we need to call .nonzero()
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. I think .nonzero()
was required for tests to pass in an earlier commit, but now things pass without it, so I've removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you @fritzo!
This adds an
is_sparse
kwarg todist.Poisson
triggering a cheaper.log_prob()
computaiton, following the idiom of Pyro'sDirichletMultinomial
pyro-ppl/pyro#1740If this works well, we may want to add
is_sparse
options to other discrete distributions used as likelihoods, e.g.GammaPoisson
aka negative binomial,DirichletMutlinomial
, and maybeMultinomial
.Questions for reviewers
is_sparse
constructor arg ok? I slightly prefer this to a separateSparsePoisson
subclass, since I've seen classes with multiple such algorithmic flags and it seems cleaner to avoid a giant n-cube class hierarchy of interacting variants.Tested
SparsePoisson
testcc @phylyc