Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove inf's in TruncatedNormal log_prob & sample (#1492) #1581

Merged
merged 1 commit into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
xlog1py,
xlogy,
)
from jax.scipy.stats import norm as jax_norm

from numpyro.distributions import constraints
from numpyro.distributions.discrete import _to_logits_bernoulli
Expand Down Expand Up @@ -2077,6 +2078,9 @@ def cdf(self, value):
scaled = (value - self.loc) / self.scale
return ndtr(scaled)

def log_cdf(self, value):
return jax_norm.logcdf(value, loc=self.loc, scale=self.scale)

def icdf(self, q):
return self.loc + self.scale * ndtri(q)

Expand Down
25 changes: 20 additions & 5 deletions numpyro/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
clamp_probs,
is_prng_key,
lazy_property,
promote_shapes,
Expand Down Expand Up @@ -249,6 +250,23 @@ def _tail_prob_at_high(self):
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return self.base_dist.cdf(loc - sign * (loc - self.high))

@lazy_property
def _log_diff_tail_probs(self):
# use log_cdf method, if available, to avoid inf's in log_prob
# fall back to cdf, if log_cdf not available
log_cdf = getattr(self.base_dist, "log_cdf", None)
if callable(log_cdf):
return logsumexp(
a=jnp.stack([log_cdf(self.high), log_cdf(self.low)], axis=-1),
axis=-1,
b=jnp.array([1, -1]), # subtract low from high
)

else:
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return jnp.log(sign * (self._tail_prob_at_high - self._tail_prob_at_low))

def sample(self, key, sample_shape=()):
assert is_prng_key(key)
dtype = jnp.result_type(float)
Expand All @@ -266,7 +284,7 @@ def sample(self, key, sample_shape=()):
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return (1 - sign) * loc + sign * self.base_dist.icdf(
(1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high
clamp_probs((1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high)
)

@validate_sample
Expand All @@ -276,10 +294,7 @@ def log_prob(self, value):
# cdf(high) - cdf(low) = as-is
# if low > loc
# cdf(high) - cdf(low) = cdf(2 * loc - low) - cdf(2 * loc - high)
sign = jnp.where(self.base_dist.loc >= self.low, 1.0, -1.0)
return self.base_dist.log_prob(value) - jnp.log(
sign * (self._tail_prob_at_high - self._tail_prob_at_low)
)
return self.base_dist.log_prob(value) - self._log_diff_tail_probs

def tree_flatten(self):
base_flatten, base_aux = self.base_dist.tree_flatten()
Expand Down
44 changes: 44 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jax.numpy as jnp
import jax.random as random
from jax.scipy.special import expit, logsumexp
from jax.scipy.stats import norm as jax_norm, truncnorm as jax_truncnorm
from jax.tree_util import tree_map

import numpyro.distributions as dist
Expand Down Expand Up @@ -2758,3 +2759,46 @@ def f(x):
x = dist.Multinomial(10, probs).sample(key)
y = jax.jit(f)(x)
assert_allclose(x, y, rtol=1e-6)


def test_normal_log_cdf():
# test if log_cdf method agrees with jax.scipy.stats.norm.logcdf
# and if exp(log_cdf) agrees with cdf
loc = jnp.array([[0.0, -10.0, 20.0]])
scale = jnp.array([[1, 5, 7]])
values = jnp.linspace(-5, 5, 100).reshape(-1, 1)
numpyro_log_cdf = dist.Normal(loc=loc, scale=scale).log_cdf(values)
numpyro_cdf = dist.Normal(loc=loc, scale=scale).cdf(values)
jax_log_cdf = jax_norm.logcdf(loc=loc, scale=scale, x=values)
assert_allclose(numpyro_log_cdf, jax_log_cdf)
assert_allclose(jnp.exp(numpyro_log_cdf), numpyro_cdf, rtol=1e-6)


@pytest.mark.parametrize(
"value",
[
-15.0,
jnp.array([[-15.0], [-10.0], [-5.0]]),
jnp.array([[[-15.0], [-10.0], [-5.0]], [[-14.0], [-9.0], [-4.0]]]),
],
)
def test_truncated_normal_log_prob_in_tail(value):
# define set of distributions truncated in tail of distribution
loc = 1.35
scale = jnp.geomspace(0.01, 1, 10)
low, high = (-20, -1.0)
a, b = (low - loc) / scale, (high - loc) / scale # rescale for jax input

numpyro_log_prob = dist.TruncatedNormal(loc, scale, low=low, high=high).log_prob(
value
)
jax_log_prob = jax_truncnorm.logpdf(value, loc=loc, scale=scale, a=a, b=b)
assert_allclose(numpyro_log_prob, jax_log_prob, rtol=1e-06)


def test_sample_truncated_normal_in_tail():
# test, if samples from distributions truncated in
# tail of distribution returns any inf's
tail_dist = dist.TruncatedNormal(loc=0, scale=1, low=-16, high=-15)
samples = tail_dist.sample(random.PRNGKey(0), sample_shape=(10_000,))
assert ~jnp.isinf(samples).any()