-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
inf's with TruncatedNormal #1492
Comments
That would be awesome, @adrn!!! Regarding numpyro codebase, I think you can add a method By the way, do you have any idea to address similar issue for sample method? Maybe it is enough to clip values there? |
Hi, Thanks in advance! |
I also had similar issues with inf's and implemented the steps as suggested by fehiepsi. @fehiepsi is this potentially interesting for a PR? In case it is, would it be ok to simply call jax.scipy.stats.norm.logcdf for the new log_cdf method in Normal, or does it need to be implemented from scratch? |
Yes, that would be nice. Thanks for working on this issue, your solution looks great to me. Could you also address the issue at |
Co-authored-by: Niklas Michel <[email protected]>
I've seen the discussion in #1184 and #1185, but I'm still seeing this issue with numpyro v0.10.1. Here's a MWE, comparing to scipy's
scipy.stats.truncnorm
implementation:(arbitrary values chosen to get into the tail)
Comparing the output values:
It's possible to avoid this by special-casing the truncated normal distribution, as I recently implemented in Jax -- it would be great to have this in numpyro as well.
Would you consider a PR to special-case
TruncatedNormal
? I'm not familiar with the numpyro codebase but have just started using it and am loving it - thanks for the work and maintenance on this project!The text was updated successfully, but these errors were encountered: