-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement sparse Poisson.log_prob() #1003
Merged
+61
−5
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
b9b6777
Add is_sparse flag to Poisson distribution
fritzo fbc708b
Add failing tests
fritzo 83d6c53
lint
fritzo ac962a2
Attempt to fix errors
fritzo 3a4fab6
Fix bugs, add a jit log likelihood test
fritzo 4e7ae4a
Merge branch 'master' into sparse-poisson
fritzo 1a9944e
Address review comments
fritzo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,9 +47,7 @@ def _identity(x): | |
|
||
class T(namedtuple("TestCase", ["jax_dist", "sp_dist", "params"])): | ||
def __new__(cls, jax_dist, *params): | ||
sp_dist = None | ||
if jax_dist in _DIST_MAP: | ||
sp_dist = _DIST_MAP[jax_dist] | ||
sp_dist = get_sp_dist(jax_dist) | ||
return super(cls, T).__new__(cls, jax_dist, sp_dist, params) | ||
|
||
|
||
|
@@ -98,6 +96,11 @@ def sample(self, key, sample_shape=()): | |
return transform(unconstrained_samples) | ||
|
||
|
||
class SparsePoisson(dist.Poisson): | ||
def __init__(self, rate, *, validate_args=None): | ||
super().__init__(rate, is_sparse=True, validate_args=validate_args) | ||
|
||
|
||
_DIST_MAP = { | ||
dist.BernoulliProbs: lambda probs: osp.bernoulli(p=probs), | ||
dist.BernoulliLogits: lambda logits: osp.bernoulli(p=_to_probs_bernoulli(logits)), | ||
|
@@ -141,6 +144,14 @@ def sample(self, key, sample_shape=()): | |
_TruncatedNormal: _truncnorm_to_scipy, | ||
} | ||
|
||
|
||
def get_sp_dist(jax_dist): | ||
classes = jax_dist.mro() if isinstance(jax_dist, type) else [jax_dist] | ||
for cls in classes: | ||
if cls in _DIST_MAP: | ||
return _DIST_MAP[cls] | ||
|
||
|
||
CONTINUOUS = [ | ||
T(dist.Beta, 0.2, 1.1), | ||
T(dist.Beta, 1.0, jnp.array([2.0, 2.0])), | ||
|
@@ -331,6 +342,8 @@ def sample(self, key, sample_shape=()): | |
T(dist.OrderedLogistic, jnp.array([-4, 3, 4, 5]), jnp.array([-1.5])), | ||
T(dist.Poisson, 2.0), | ||
T(dist.Poisson, jnp.array([2.0, 3.0, 5.0])), | ||
T(SparsePoisson, 2.0), | ||
T(SparsePoisson, jnp.array([2.0, 3.0, 5.0])), | ||
T(dist.ZeroInflatedPoisson, 0.6, 2.0), | ||
T(dist.ZeroInflatedPoisson, jnp.array([0.2, 0.7, 0.3]), jnp.array([2.0, 3.0, 5.0])), | ||
] | ||
|
@@ -652,6 +665,29 @@ def test_pathwise_gradient(jax_dist, sp_dist, params): | |
assert_allclose(actual_grad[i], expected_grad, rtol=0.005) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL | ||
) | ||
def test_jit_log_likelihood(jax_dist, sp_dist, params): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whereas jitting in |
||
if jax_dist.__name__ in ( | ||
"GaussianRandomWalk", | ||
"_ImproperWrapper", | ||
"LKJ", | ||
"LKJCholesky", | ||
): | ||
pytest.xfail(reason="non-jittable params") | ||
|
||
rng_key = random.PRNGKey(0) | ||
samples = jax_dist(*params).sample(key=rng_key, sample_shape=(2, 3)) | ||
|
||
def log_likelihood(*params): | ||
return jax_dist(*params).log_prob(samples) | ||
|
||
expected = log_likelihood(*params) | ||
actual = jax.jit(log_likelihood)(*params) | ||
assert_allclose(actual, expected, atol=1e-5) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL | ||
) | ||
|
@@ -688,7 +724,7 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit): | |
# old api | ||
low, loc, scale = params | ||
high = jnp.inf | ||
sp_dist = _DIST_MAP[type(jax_dist.base_dist)](loc, scale) | ||
sp_dist = get_sp_dist(type(jax_dist.base_dist))(loc, scale) | ||
expected = sp_dist.logpdf(samples) - jnp.log( | ||
sp_dist.cdf(high) - sp_dist.cdf(low) | ||
) | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I guess we can swap the order of those two lines:
Btw, do we need to call
.nonzero()
here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. I think
.nonzero()
was required for tests to pass in an earlier commit, but now things pass without it, so I've removed.