Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerically stabilize ProjectedNormal.log_prob() via erfc #3071

Merged
merged 6 commits into from
May 2, 2022

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Apr 27, 2022

This attempts to resolve @geoffwoollard's forum issue by numerically stabilizing ProjectedNormal.log_prob() using torch.erfc().

Tested

  • covered by existing unit tests
  • added new test cases
  • added new test for nans

@fritzo fritzo added the WIP label Apr 27, 2022
@fritzo
Copy link
Member Author

fritzo commented Apr 27, 2022

Hmm this seems to have a mistake since dot(concentration, value) is not necessarily positive 🤔

@fritzo fritzo closed this Apr 27, 2022
@mbrubake
Copy link

You can do what you're thinking in this PR if you implement your own logaddexp/logsumexp which incorporates a scaling factor. See the b parameter in the scipy version of logsumexp: https://github.com/scipy/scipy/blob/v1.8.0/scipy/special/_logsumexp.py#L7-L127 Not 100% certain this will fully resolve the issue but it's easy enough to do and worth a shot.

@fritzo fritzo changed the title Numerically stabilize ProjectedNormal.log_prob() via logaddexp Numerically stabilize ProjectedNormal.log_prob() via erfc Apr 27, 2022
@fritzo fritzo reopened this Apr 27, 2022
@fritzo
Copy link
Member Author

fritzo commented Apr 27, 2022

@mbrubake I wasn't sure how to get the logaddexp(-,-) thing working because one side is actually negative and can't be logged. Note that's worse than division which is supported by scipy: I believe scipy allows you to subtract two log-numbers to simulate their log-ratio, but we can't even create the logs in the first place.

Anyway the 1+erf(t) = erfc(-t) trick seems to work 🤷

@mbrubake
Copy link

mbrubake commented Apr 28, 2022

To see how to use logsumexp, consider any expression of the form log(sum(x)) where sum(x)>0. This can be equivalently written as log(sum(b*exp(a)) where a = log(abs(x)) and b = sign(x). Basically, you log the absolute value of the expression and keep track of the sign and rely on the fact that x = sign(x)*log(exp(abs(x))).

log(sum(b*exp(a)) can be more stably implemented using the standard log-sum-exp trick. That is, as c + log(sum(b*exp(a-c)) where c = max(a). It's still subject to catastrophic cancellation in some cases, e.g., if the two largest entries of a have similar values but their corresponding values of b have different signs, but it should be much better overall.

@fritzo
Copy link
Member Author

fritzo commented Apr 29, 2022

log(abs(x))

🤔 sounds like that would shrink the hole at -∞ by introducing a second hole at zero.

I think a longer term solution is to implement the underlying special functions, I believe confluent hypergeometric functions of the first kind / Kummer's function, or maybe the iterated erfc intergrals, I'm not sure.

@fritzo fritzo requested a review from martinjankowiak May 2, 2022 01:49
@fritzo
Copy link
Member Author

fritzo commented May 2, 2022

@martinjankowiak would you be ok merging this? I'm sure we can do better, but this PR solves the immediate problem by eliminating NANs.

@martinjankowiak martinjankowiak merged commit 1a57d31 into dev May 2, 2022
@martinjankowiak martinjankowiak deleted the proj-normal-logaddexp branch May 2, 2022 12:45
@geoffwoollard
Copy link

Thanks @fritzo !

This fix pushes the numerical stability to about t=-13.5, after which the para_part underflows to 0, and the log is -inf.

This puts restrictions on how tight the distribution can get (how large the magnitude of the concentration can be). For my application, I need it to be a tighter distribution.
download

t = torch.linspace(-16,-10,1000)
t2 = t.square()
para_part_exp = ( t * t2.mul(-0.5).exp() / (2 * math.pi) ** 0.5 + (1 + t2) * (t * -(0.5**0.5)).erfc() / 2 )
pd.Series(para_part_exp.log(),index=t.numpy()).plot()
plt.xlabel('t')
plt.ylabel('para_part: unsafe log')

With @ntfrgl and @mbrubake I looked into each of the terms and we have something that uses torch.special.erfcx, which is like erfc(x) offfset by exp(-x**2) to protect against underflow. It will help for 2d, 3d and 4d case... and likely for any dimension in general. Based on the para_part integral it seems there is always an exp(-t**2/2) term to pull out and avoid underflow by adding -t**2/2 to the log_prob. Note the 1/2 term in the exp is is absorbed into the t here for clarity, but we see the exp(-t**2) term is a multiplicative factor: https://www.wolframalpha.com/input?i=int+x%5En*exp%28-%28x-t%29%5E2%29+dx+from+0+to+%2Binf

We can prepare a PR with some tests showing stability.

@fritzo
Copy link
Member Author

fritzo commented May 2, 2022

@geoffwoollard that's great, erfcx sounds like a better solution than my hack. I look forward to your PR.

@geoffwoollard
Copy link

geoffwoollard commented May 3, 2022

@fritzo can you explain how the Fixture for ProjectedNormal in tests/distributions/conftest.py get tested?

pyro_dist=dist.ProjectedNormal,

I couldn't track down where they are read in... I'm just trying to figure out where to put in some new tests ... and what tests will need to be passed.

@fritzo
Copy link
Member Author

fritzo commented May 3, 2022

@geoffwoollard each time you add a ProjecctedNormal config to tests/distributions/conftest.py, it is read by all tests that use either the dist fixture or continuous_dist fixture. For convenience here are two search queries showing some of the tests using those fixtures

% ag -s '^def.*\bdist\b'
test_conjugate.py
31:def test_mean(dist):
55:def test_variance(dist):
71:def test_log_prob_support(dist, values):

test_sine_skewed.py
52:def test_ss_multidim_log_prob(expand_shape, dist):
69:def test_ss_mle(dim, dist):

test_cuda.py
17:def test_sample(dist):
40:def test_rsample(dist):
82:def test_log_prob(dist):

test_distributions.py
21:def _log_prob_shape(dist, x_size=torch.Size()):
32:def test_support_shape(dist):
45:def test_infer_shapes(dist):
60:def test_batch_log_prob(dist):
74:def test_batch_log_prob_shape(dist):
86:def test_batch_entropy_shape(dist):
97:def test_score_errors_event_dim_mismatch(dist):
115:def test_score_errors_non_broadcastable_data_shape(dist):
244:def test_enumerate_support_shape(dist):
307:def test_expand_by(dist, sample_shape, shape_type):
320:def test_expand_new_dim(dist, sample_shape, shape_type, default):
338:def test_expand_existing_dim(dist, shape_type, default):
366:def test_subsequent_expands_ok(dist, sample_shapes, default):
392:def test_expand_error(dist, initial_shape, proposed_shape, default):
% ag '^def.*\bcontinuous_dist\b'
test_distributions.py
131:def test_support_is_not_discrete(continuous_dist):
138:def test_gof(continuous_dist):
165:def test_mean(continuous_dist):
186:def test_variance(continuous_dist):
207:def test_cdf_icdf(continuous_dist):

I've also added some distribution-specific tests in tests/distributions/test_projected_normal.py, feel free to add more custom tests there

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants