Skip to content

Commit

Permalink
Remove inf's in TruncatedNormal log_prob & sample (#1492) (#1581)
Browse files Browse the repository at this point in the history
Co-authored-by: Niklas Michel <[email protected]>
  • Loading branch information
nikmich1 and Niklas Michel authored May 5, 2023
1 parent c870ce8 commit d63dae4
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 5 deletions.
4 changes: 4 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
xlog1py,
xlogy,
)
from jax.scipy.stats import norm as jax_norm

from numpyro.distributions import constraints
from numpyro.distributions.discrete import _to_logits_bernoulli
Expand Down Expand Up @@ -2077,6 +2078,9 @@ def cdf(self, value):
scaled = (value - self.loc) / self.scale
return ndtr(scaled)

def log_cdf(self, value):
return jax_norm.logcdf(value, loc=self.loc, scale=self.scale)

def icdf(self, q):
return self.loc + self.scale * ndtri(q)

Expand Down
25 changes: 20 additions & 5 deletions numpyro/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
clamp_probs,
is_prng_key,
lazy_property,
promote_shapes,
Expand Down Expand Up @@ -249,6 +250,23 @@ def _tail_prob_at_high(self):
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return self.base_dist.cdf(loc - sign * (loc - self.high))

@lazy_property
def _log_diff_tail_probs(self):
# use log_cdf method, if available, to avoid inf's in log_prob
# fall back to cdf, if log_cdf not available
log_cdf = getattr(self.base_dist, "log_cdf", None)
if callable(log_cdf):
return logsumexp(
a=jnp.stack([log_cdf(self.high), log_cdf(self.low)], axis=-1),
axis=-1,
b=jnp.array([1, -1]), # subtract low from high
)

else:
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return jnp.log(sign * (self._tail_prob_at_high - self._tail_prob_at_low))

def sample(self, key, sample_shape=()):
assert is_prng_key(key)
dtype = jnp.result_type(float)
Expand All @@ -266,7 +284,7 @@ def sample(self, key, sample_shape=()):
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return (1 - sign) * loc + sign * self.base_dist.icdf(
(1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high
clamp_probs((1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high)
)

@validate_sample
Expand All @@ -276,10 +294,7 @@ def log_prob(self, value):
# cdf(high) - cdf(low) = as-is
# if low > loc
# cdf(high) - cdf(low) = cdf(2 * loc - low) - cdf(2 * loc - high)
sign = jnp.where(self.base_dist.loc >= self.low, 1.0, -1.0)
return self.base_dist.log_prob(value) - jnp.log(
sign * (self._tail_prob_at_high - self._tail_prob_at_low)
)
return self.base_dist.log_prob(value) - self._log_diff_tail_probs

def tree_flatten(self):
base_flatten, base_aux = self.base_dist.tree_flatten()
Expand Down
44 changes: 44 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jax.numpy as jnp
import jax.random as random
from jax.scipy.special import expit, logsumexp
from jax.scipy.stats import norm as jax_norm, truncnorm as jax_truncnorm
from jax.tree_util import tree_map

import numpyro.distributions as dist
Expand Down Expand Up @@ -2758,3 +2759,46 @@ def f(x):
x = dist.Multinomial(10, probs).sample(key)
y = jax.jit(f)(x)
assert_allclose(x, y, rtol=1e-6)


def test_normal_log_cdf():
# test if log_cdf method agrees with jax.scipy.stats.norm.logcdf
# and if exp(log_cdf) agrees with cdf
loc = jnp.array([[0.0, -10.0, 20.0]])
scale = jnp.array([[1, 5, 7]])
values = jnp.linspace(-5, 5, 100).reshape(-1, 1)
numpyro_log_cdf = dist.Normal(loc=loc, scale=scale).log_cdf(values)
numpyro_cdf = dist.Normal(loc=loc, scale=scale).cdf(values)
jax_log_cdf = jax_norm.logcdf(loc=loc, scale=scale, x=values)
assert_allclose(numpyro_log_cdf, jax_log_cdf)
assert_allclose(jnp.exp(numpyro_log_cdf), numpyro_cdf, rtol=1e-6)


@pytest.mark.parametrize(
"value",
[
-15.0,
jnp.array([[-15.0], [-10.0], [-5.0]]),
jnp.array([[[-15.0], [-10.0], [-5.0]], [[-14.0], [-9.0], [-4.0]]]),
],
)
def test_truncated_normal_log_prob_in_tail(value):
# define set of distributions truncated in tail of distribution
loc = 1.35
scale = jnp.geomspace(0.01, 1, 10)
low, high = (-20, -1.0)
a, b = (low - loc) / scale, (high - loc) / scale # rescale for jax input

numpyro_log_prob = dist.TruncatedNormal(loc, scale, low=low, high=high).log_prob(
value
)
jax_log_prob = jax_truncnorm.logpdf(value, loc=loc, scale=scale, a=a, b=b)
assert_allclose(numpyro_log_prob, jax_log_prob, rtol=1e-06)


def test_sample_truncated_normal_in_tail():
# test, if samples from distributions truncated in
# tail of distribution returns any inf's
tail_dist = dist.TruncatedNormal(loc=0, scale=1, low=-16, high=-15)
samples = tail_dist.sample(random.PRNGKey(0), sample_shape=(10_000,))
assert ~jnp.isinf(samples).any()

0 comments on commit d63dae4

Please sign in to comment.