Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inf's with TruncatedNormal #1492

Closed
adrn opened this issue Oct 26, 2022 · 4 comments · Fixed by #1581
Closed

inf's with TruncatedNormal #1492

adrn opened this issue Oct 26, 2022 · 4 comments · Fixed by #1581
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@adrn
Copy link
Contributor

adrn commented Oct 26, 2022

I've seen the discussion in #1184 and #1185, but I'm still seeing this issue with numpyro v0.10.1. Here's a MWE, comparing to scipy's scipy.stats.truncnorm implementation:

import numpyro
numpyro.enable_x64()
import numpyro.distributions as dist
import jax.numpy as jnp
from scipy.stats import truncnorm

loc = 1.35
scale = jnp.geomspace(0.01, 1, 10)
low, high = (-20, -1.0)
a, b = (low - loc) / scale, (high - loc) / scale
x = -15.

scipy_val = truncnorm.logpdf(x, loc=loc, scale=scale, a=a, b=b)
numpyro_val = dist.TruncatedNormal(loc, scale, low=low, high=high).log_prob(x)

(arbitrary values chosen to get into the tail)

Comparing the output values:

>>> scipy_val
array([-1.30898994e+06, -4.70421167e+05, -1.69055833e+05, -6.07514028e+04,
       -2.18294637e+04, -7.84229794e+03, -2.81622225e+03, -1.01058785e+03,
       -3.62302368e+02, -1.29911728e+02])
>>> numpyro_val
DeviceArray([            inf,             inf,             inf,
                         inf, -21829.46367826,  -7842.29793866,
              -2816.22224529,  -1010.58784742,   -362.30236837,
               -129.91172764], dtype=float64)

It's possible to avoid this by special-casing the truncated normal distribution, as I recently implemented in Jax -- it would be great to have this in numpyro as well.

from jax.scipy.stats import truncnorm as jax_truncnorm
jax_val = jax_truncnorm.logpdf(x, loc=loc, scale=scale, a=a, b=b)
print(jax_val)

DeviceArray([-1.30898994e+06, -4.70421167e+05, -1.69055833e+05,
             -6.07514028e+04, -2.18294637e+04, -7.84229794e+03,
             -2.81622225e+03, -1.01058785e+03, -3.62302368e+02,
             -1.29911728e+02], dtype=float64)

Would you consider a PR to special-case TruncatedNormal? I'm not familiar with the numpyro codebase but have just started using it and am loving it - thanks for the work and maintenance on this project!

@fehiepsi fehiepsi added the enhancement New feature or request label Oct 29, 2022
@fehiepsi
Copy link
Member

fehiepsi commented Oct 29, 2022

consider a PR to special-case TruncatedNormal?

That would be awesome, @adrn!!! Regarding numpyro codebase, I think you can add a method log_cdf in Normal and use it to compute sort of _log_diff_tail_probs which computes log of self._tail_prob_at_high - self._tail_prob_at_low. You can add a method _log_diff_tail_probs to TwoSidedTruncatedDistribution class that uses more numerical stable formula if log_cdf method is available in base_dist. Please let me know if this is not clear.

By the way, do you have any idea to address similar issue for sample method? Maybe it is enough to clip values there?

@saraelshawa
Copy link

Hi,
I'm wondering if there are any updates on this?
I'm currently running into inf's when trying to sample from dist.TruncatedNormal().

Thanks in advance!

@fehiepsi fehiepsi added the help wanted Extra attention is needed label Feb 21, 2023
@nikmich1
Copy link
Contributor

I also had similar issues with inf's and implemented the steps as suggested by fehiepsi.

@fehiepsi is this potentially interesting for a PR? In case it is, would it be ok to simply call jax.scipy.stats.norm.logcdf for the new log_cdf method in Normal, or does it need to be implemented from scratch?

@fehiepsi
Copy link
Member

would it be ok to simply call jax.scipy.stats.norm.logcdf

Yes, that would be nice. Thanks for working on this issue, your solution looks great to me. Could you also address the issue at sample method by clip the prob value to the open (0, 1) interval. You can use clamp_probs utility for it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants