diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index a41702a29..bce6c692f 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -29,9 +29,11 @@ import numpy as np +import jax from jax import lax from jax.nn import softmax, softplus import jax.numpy as jnp +from jax.ops import index_add import jax.random as random from jax.scipy.special import expit, gammaln, logsumexp, xlog1py, xlogy @@ -608,8 +610,9 @@ class Poisson(Distribution): support = constraints.nonnegative_integer is_discrete = True - def __init__(self, rate, validate_args=None): + def __init__(self, rate, *, is_sparse=False, validate_args=None): self.rate = rate + self.is_sparse = is_sparse super(Poisson, self).__init__(jnp.shape(rate), validate_args=validate_args) def sample(self, key, sample_shape=()): @@ -620,6 +623,23 @@ def sample(self, key, sample_shape=()): def log_prob(self, value): if self._validate_args: self._validate_sample(value) + value = jax.device_get(value) + if ( + self.is_sparse + and not isinstance(value, jax.core.Tracer) + and jnp.size(value) > 1 + ): + shape = lax.broadcast_shapes(self.batch_shape, jnp.shape(value)) + rate = jnp.broadcast_to(self.rate, shape).reshape(-1) + value = jnp.broadcast_to(value, shape).reshape(-1) + nonzero = value > 0 + sparse_value = value[nonzero] + sparse_rate = rate[nonzero] + return index_add( + -rate, + nonzero, + jnp.log(sparse_rate) * sparse_value - gammaln(sparse_value + 1), + ).reshape(shape) return (jnp.log(self.rate) * value) - gammaln(value + 1) - self.rate @property diff --git a/test/test_distributions.py b/test/test_distributions.py index 500d62d9e..7e791c0a4 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -47,9 +47,7 @@ def _identity(x): class T(namedtuple("TestCase", ["jax_dist", "sp_dist", "params"])): def __new__(cls, jax_dist, *params): - sp_dist = None - if jax_dist in _DIST_MAP: - sp_dist = _DIST_MAP[jax_dist] + sp_dist = get_sp_dist(jax_dist) return super(cls, T).__new__(cls, jax_dist, sp_dist, params) @@ -98,6 +96,11 @@ def sample(self, key, sample_shape=()): return transform(unconstrained_samples) +class SparsePoisson(dist.Poisson): + def __init__(self, rate, *, validate_args=None): + super().__init__(rate, is_sparse=True, validate_args=validate_args) + + _DIST_MAP = { dist.BernoulliProbs: lambda probs: osp.bernoulli(p=probs), dist.BernoulliLogits: lambda logits: osp.bernoulli(p=_to_probs_bernoulli(logits)), @@ -141,6 +144,14 @@ def sample(self, key, sample_shape=()): _TruncatedNormal: _truncnorm_to_scipy, } + +def get_sp_dist(jax_dist): + classes = jax_dist.mro() if isinstance(jax_dist, type) else [jax_dist] + for cls in classes: + if cls in _DIST_MAP: + return _DIST_MAP[cls] + + CONTINUOUS = [ T(dist.Beta, 0.2, 1.1), T(dist.Beta, 1.0, jnp.array([2.0, 2.0])), @@ -331,6 +342,8 @@ def sample(self, key, sample_shape=()): T(dist.OrderedLogistic, jnp.array([-4, 3, 4, 5]), jnp.array([-1.5])), T(dist.Poisson, 2.0), T(dist.Poisson, jnp.array([2.0, 3.0, 5.0])), + T(SparsePoisson, 2.0), + T(SparsePoisson, jnp.array([2.0, 3.0, 5.0])), T(dist.ZeroInflatedPoisson, 0.6, 2.0), T(dist.ZeroInflatedPoisson, jnp.array([0.2, 0.7, 0.3]), jnp.array([2.0, 3.0, 5.0])), ] @@ -652,6 +665,29 @@ def test_pathwise_gradient(jax_dist, sp_dist, params): assert_allclose(actual_grad[i], expected_grad, rtol=0.005) +@pytest.mark.parametrize( + "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL +) +def test_jit_log_likelihood(jax_dist, sp_dist, params): + if jax_dist.__name__ in ( + "GaussianRandomWalk", + "_ImproperWrapper", + "LKJ", + "LKJCholesky", + ): + pytest.xfail(reason="non-jittable params") + + rng_key = random.PRNGKey(0) + samples = jax_dist(*params).sample(key=rng_key, sample_shape=(2, 3)) + + def log_likelihood(*params): + return jax_dist(*params).log_prob(samples) + + expected = log_likelihood(*params) + actual = jax.jit(log_likelihood)(*params) + assert_allclose(actual, expected, atol=1e-5) + + @pytest.mark.parametrize( "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL ) @@ -688,7 +724,7 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit): # old api low, loc, scale = params high = jnp.inf - sp_dist = _DIST_MAP[type(jax_dist.base_dist)](loc, scale) + sp_dist = get_sp_dist(type(jax_dist.base_dist))(loc, scale) expected = sp_dist.logpdf(samples) - jnp.log( sp_dist.cdf(high) - sp_dist.cdf(low) )