Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement sparse Poisson.log_prob() #1003

Merged
merged 7 commits into from
Apr 13, 2021
Merged

Implement sparse Poisson.log_prob() #1003

merged 7 commits into from
Apr 13, 2021

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Apr 12, 2021

This adds an is_sparse kwarg to dist.Poisson triggering a cheaper .log_prob() computaiton, following the idiom of Pyro's DirichletMultinomial pyro-ppl/pyro#1740

If this works well, we may want to add is_sparse options to other discrete distributions used as likelihoods, e.g. GammaPoisson aka negative binomial, DirichletMutlinomial, and maybe Multinomial.

Questions for reviewers

  • Can you help me with the failing tests? I am new to JAX reshaping logic.
  • Is the is_sparse constructor arg ok? I slightly prefer this to a separate SparsePoisson subclass, since I've seen classes with multiple such algorithmic flags and it seems cleaner to avoid a giant n-cube class hierarchy of interacting variants.

Tested

  • refactored test_distributions.py and added a SparsePoisson test

cc @phylyc

@fehiepsi
Copy link
Member

Is the is_sparse constructor arg ok? I slightly prefer this to a separate SparsePoisson subclass

yup, both look good to me. Please go with the one you prefer. :)

@pytest.mark.parametrize(
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
def test_jit_log_likelihood(jax_dist, sp_dist, params):
Copy link
Member Author

@fritzo fritzo Apr 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whereas jitting in test_log_prob fixes params and varies the value, jitting in this test fixes the value and varies params. This test is thus closer to usage in an observe statement numpyro.sample(..., obs=data).

fehiepsi
fehiepsi previously approved these changes Apr 13, 2021
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for supporting this and adding more useful tests!

shape = lax.broadcast_shapes(self.batch_shape, jnp.shape(value))
rate = jnp.broadcast_to(self.rate, shape).reshape(-1)
nonzero = jnp.broadcast_to(value > 0, shape).reshape(-1).nonzero()
value = jnp.broadcast_to(value, shape).reshape(-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I guess we can swap the order of those two lines:

value = jnp.broadcast_to(value, shape).reshape(-1)
nonzero = (value > 0).nonzero()

Btw, do we need to call .nonzero() here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I think .nonzero() was required for tests to pass in an earlier commit, but now things pass without it, so I've removed.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you @fritzo!

@fehiepsi fehiepsi merged commit 3080edb into master Apr 13, 2021
@fehiepsi fehiepsi deleted the sparse-poisson branch April 13, 2021 15:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants