-
-
Notifications
You must be signed in to change notification settings - Fork 984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numerically stabilize ProjectedNormal.log_prob() via erfc #3071
Conversation
Hmm this seems to have a mistake since |
You can do what you're thinking in this PR if you implement your own logaddexp/logsumexp which incorporates a scaling factor. See the |
@mbrubake I wasn't sure how to get the Anyway the |
To see how to use logsumexp, consider any expression of the form
|
🤔 sounds like that would shrink the hole at -∞ by introducing a second hole at zero. I think a longer term solution is to implement the underlying special functions, I believe confluent hypergeometric functions of the first kind / Kummer's function, or maybe the iterated erfc intergrals, I'm not sure. |
@martinjankowiak would you be ok merging this? I'm sure we can do better, but this PR solves the immediate problem by eliminating NANs. |
Thanks @fritzo ! This fix pushes the numerical stability to about t=-13.5, after which the para_part underflows to 0, and the log is -inf. This puts restrictions on how tight the distribution can get (how large the magnitude of the concentration can be). For my application, I need it to be a tighter distribution. t = torch.linspace(-16,-10,1000)
t2 = t.square()
para_part_exp = ( t * t2.mul(-0.5).exp() / (2 * math.pi) ** 0.5 + (1 + t2) * (t * -(0.5**0.5)).erfc() / 2 )
pd.Series(para_part_exp.log(),index=t.numpy()).plot()
plt.xlabel('t')
plt.ylabel('para_part: unsafe log') With @ntfrgl and @mbrubake I looked into each of the terms and we have something that uses We can prepare a PR with some tests showing stability. |
@geoffwoollard that's great, |
@fritzo can you explain how the Fixture for ProjectedNormal in tests/distributions/conftest.py get tested? pyro/tests/distributions/conftest.py Line 526 in d6df231
I couldn't track down where they are read in... I'm just trying to figure out where to put in some new tests ... and what tests will need to be passed. |
@geoffwoollard each time you add a % ag -s '^def.*\bdist\b'
test_conjugate.py
31:def test_mean(dist):
55:def test_variance(dist):
71:def test_log_prob_support(dist, values):
test_sine_skewed.py
52:def test_ss_multidim_log_prob(expand_shape, dist):
69:def test_ss_mle(dim, dist):
test_cuda.py
17:def test_sample(dist):
40:def test_rsample(dist):
82:def test_log_prob(dist):
test_distributions.py
21:def _log_prob_shape(dist, x_size=torch.Size()):
32:def test_support_shape(dist):
45:def test_infer_shapes(dist):
60:def test_batch_log_prob(dist):
74:def test_batch_log_prob_shape(dist):
86:def test_batch_entropy_shape(dist):
97:def test_score_errors_event_dim_mismatch(dist):
115:def test_score_errors_non_broadcastable_data_shape(dist):
244:def test_enumerate_support_shape(dist):
307:def test_expand_by(dist, sample_shape, shape_type):
320:def test_expand_new_dim(dist, sample_shape, shape_type, default):
338:def test_expand_existing_dim(dist, shape_type, default):
366:def test_subsequent_expands_ok(dist, sample_shapes, default):
392:def test_expand_error(dist, initial_shape, proposed_shape, default): % ag '^def.*\bcontinuous_dist\b'
test_distributions.py
131:def test_support_is_not_discrete(continuous_dist):
138:def test_gof(continuous_dist):
165:def test_mean(continuous_dist):
186:def test_variance(continuous_dist):
207:def test_cdf_icdf(continuous_dist): I've also added some distribution-specific tests in tests/distributions/test_projected_normal.py, feel free to add more custom tests there |
This attempts to resolve @geoffwoollard's forum issue by numerically stabilizing
ProjectedNormal.log_prob()
usingtorch.erfc()
.Tested