Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement sparse Poisson.log_prob() #1003

Merged
merged 7 commits into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@

import numpy as np

import jax
from jax import lax
from jax.nn import softmax, softplus
import jax.numpy as jnp
from jax.ops import index_add
import jax.random as random
from jax.scipy.special import expit, gammaln, logsumexp, xlog1py, xlogy

Expand Down Expand Up @@ -608,8 +610,9 @@ class Poisson(Distribution):
support = constraints.nonnegative_integer
is_discrete = True

def __init__(self, rate, validate_args=None):
def __init__(self, rate, *, is_sparse=False, validate_args=None):
self.rate = rate
self.is_sparse = is_sparse
super(Poisson, self).__init__(jnp.shape(rate), validate_args=validate_args)

def sample(self, key, sample_shape=()):
Expand All @@ -620,6 +623,23 @@ def sample(self, key, sample_shape=()):
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
value = jax.device_get(value)
if (
self.is_sparse
and not isinstance(value, jax.core.Tracer)
and jnp.size(value) > 1
):
shape = lax.broadcast_shapes(self.batch_shape, jnp.shape(value))
rate = jnp.broadcast_to(self.rate, shape).reshape(-1)
value = jnp.broadcast_to(value, shape).reshape(-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I guess we can swap the order of those two lines:

value = jnp.broadcast_to(value, shape).reshape(-1)
nonzero = (value > 0).nonzero()

Btw, do we need to call .nonzero() here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I think .nonzero() was required for tests to pass in an earlier commit, but now things pass without it, so I've removed.

nonzero = value > 0
sparse_value = value[nonzero]
sparse_rate = rate[nonzero]
fritzo marked this conversation as resolved.
Show resolved Hide resolved
return index_add(
-rate,
nonzero,
jnp.log(sparse_rate) * sparse_value - gammaln(sparse_value + 1),
).reshape(shape)
return (jnp.log(self.rate) * value) - gammaln(value + 1) - self.rate

@property
Expand Down
44 changes: 40 additions & 4 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ def _identity(x):

class T(namedtuple("TestCase", ["jax_dist", "sp_dist", "params"])):
def __new__(cls, jax_dist, *params):
sp_dist = None
if jax_dist in _DIST_MAP:
sp_dist = _DIST_MAP[jax_dist]
sp_dist = get_sp_dist(jax_dist)
return super(cls, T).__new__(cls, jax_dist, sp_dist, params)


Expand Down Expand Up @@ -98,6 +96,11 @@ def sample(self, key, sample_shape=()):
return transform(unconstrained_samples)


class SparsePoisson(dist.Poisson):
def __init__(self, rate, *, validate_args=None):
super().__init__(rate, is_sparse=True, validate_args=validate_args)


_DIST_MAP = {
dist.BernoulliProbs: lambda probs: osp.bernoulli(p=probs),
dist.BernoulliLogits: lambda logits: osp.bernoulli(p=_to_probs_bernoulli(logits)),
Expand Down Expand Up @@ -141,6 +144,14 @@ def sample(self, key, sample_shape=()):
_TruncatedNormal: _truncnorm_to_scipy,
}


def get_sp_dist(jax_dist):
classes = jax_dist.mro() if isinstance(jax_dist, type) else [jax_dist]
for cls in classes:
if cls in _DIST_MAP:
return _DIST_MAP[cls]


CONTINUOUS = [
T(dist.Beta, 0.2, 1.1),
T(dist.Beta, 1.0, jnp.array([2.0, 2.0])),
Expand Down Expand Up @@ -331,6 +342,8 @@ def sample(self, key, sample_shape=()):
T(dist.OrderedLogistic, jnp.array([-4, 3, 4, 5]), jnp.array([-1.5])),
T(dist.Poisson, 2.0),
T(dist.Poisson, jnp.array([2.0, 3.0, 5.0])),
T(SparsePoisson, 2.0),
T(SparsePoisson, jnp.array([2.0, 3.0, 5.0])),
T(dist.ZeroInflatedPoisson, 0.6, 2.0),
T(dist.ZeroInflatedPoisson, jnp.array([0.2, 0.7, 0.3]), jnp.array([2.0, 3.0, 5.0])),
]
Expand Down Expand Up @@ -652,6 +665,29 @@ def test_pathwise_gradient(jax_dist, sp_dist, params):
assert_allclose(actual_grad[i], expected_grad, rtol=0.005)


@pytest.mark.parametrize(
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
def test_jit_log_likelihood(jax_dist, sp_dist, params):
Copy link
Member Author

@fritzo fritzo Apr 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whereas jitting in test_log_prob fixes params and varies the value, jitting in this test fixes the value and varies params. This test is thus closer to usage in an observe statement numpyro.sample(..., obs=data).

if jax_dist.__name__ in (
"GaussianRandomWalk",
"_ImproperWrapper",
"LKJ",
"LKJCholesky",
):
pytest.xfail(reason="non-jittable params")

rng_key = random.PRNGKey(0)
samples = jax_dist(*params).sample(key=rng_key, sample_shape=(2, 3))

def log_likelihood(*params):
return jax_dist(*params).log_prob(samples)

expected = log_likelihood(*params)
actual = jax.jit(log_likelihood)(*params)
assert_allclose(actual, expected, atol=1e-5)


@pytest.mark.parametrize(
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
Expand Down Expand Up @@ -688,7 +724,7 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit):
# old api
low, loc, scale = params
high = jnp.inf
sp_dist = _DIST_MAP[type(jax_dist.base_dist)](loc, scale)
sp_dist = get_sp_dist(type(jax_dist.base_dist))(loc, scale)
expected = sp_dist.logpdf(samples) - jnp.log(
sp_dist.cdf(high) - sp_dist.cdf(low)
)
Expand Down