Remove inf's in TruncatedNormal log_prob & sample (#1492) #1581
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This commit fixes #1492 by removing inf return values in log_prob and sample method of the TruncatedNormal distribution, when it is truncated in the tail of the distribution.
Changes made to remove inf's in
log_prob
method of TruncatedNormal distribution:log_cdf
method to the Normal distributionlog_prob
calculation for the truncated Normallog_cdf
, if it is availablelogsumexp
function to subtract requiredlog_cdf
values more numerically stablelog_cdf
method is not available, it falls back to original implementation withcdf
Changes made to remove inf's in
sample
method of TruncatedNormal distribution:icdf
to open interval (0,1) using theclamp_probs
utilityTests added:
log_cdf
method by comparing to JAX implementation and tocdf
method (test_normal_log_cdf)log_prob
implementation by comparing to JAX (test_truncated_normal_log_prob_in_tail)sample
method in tail of distribution to check, if any inf's are returned (test_sample_truncated_normal_in_tail)