Skip to content

Commit

Permalink
Numerically stabilize ProjectedNormal.log_prob() via erfc (#3071)
Browse files Browse the repository at this point in the history
* Numerically stabilize ProjectedNormal.log_prob() via logaddexp

* Fix conceptual error, now this NANs

* Switch to erfc

* Strengthen tests

* lint

* Strengthen test, clamp harder
  • Loading branch information
fritzo authored May 2, 2022
1 parent 611dda1 commit 1a57d31
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 16 deletions.
30 changes: 17 additions & 13 deletions pyro/distributions/projected_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def _dot(x, y):
return (x[..., None, :] @ y[..., None])[..., 0, 0]


def _safe_log(x):
return x.clamp(min=torch.finfo(x.dtype).eps).log()


@ProjectedNormal._register_log_prob(dim=2)
def _log_prob_2(concentration, value):
# We integrate along a ray, factorizing the integrand as a product of:
Expand All @@ -139,13 +143,11 @@ def _log_prob_2(concentration, value):
# This is the log of a definite integral, computed by mathematica:
# Integrate[x/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}]
# = (t + Sqrt[2/Pi]/E^(t^2/2) + t Erf[t/Sqrt[2]])/2
para_part = (
(
t2.mul(-0.5).exp().mul((2 / math.pi) ** 0.5)
+ t * (1 + (t * 0.5**0.5).erf())
)
.mul(0.5)
.log()
# = (Sqrt[2/Pi]/E^(t^2/2) + t (1 + Erf[t/Sqrt[2]]))/2
# = (Sqrt[2/Pi]/E^(t^2/2) + t Erfc[-t/Sqrt[2]])/2
para_part = _safe_log(
(t2.mul(-0.5).exp().mul((2 / math.pi) ** 0.5) + t * (t * -(0.5**0.5)).erfc())
/ 2
)

return para_part + perp_part
Expand All @@ -164,10 +166,11 @@ def _log_prob_3(concentration, value):
# This is the log of a definite integral, computed by mathematica:
# Integrate[x^2/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}]
# = t/(E^(t^2/2) Sqrt[2 Pi]) + ((1 + t^2) (1 + Erf[t/Sqrt[2]]))/2
para_part = (
# = t/(E^(t^2/2) Sqrt[2 Pi]) + ((1 + t^2) Erfc[-t/Sqrt[2]])/2
para_part = _safe_log(
t * t2.mul(-0.5).exp() / (2 * math.pi) ** 0.5
+ (1 + t2) * (1 + (t * 0.5**0.5).erf()) / 2
).log()
+ (1 + t2) * (t * -(0.5**0.5)).erfc() / 2
)

return para_part + perp_part

Expand All @@ -185,9 +188,10 @@ def _log_prob_4(concentration, value):
# This is the log of a definite integral, computed by mathematica:
# Integrate[x^3/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}]
# = (2 + t^2)/(E^(t^2/2) Sqrt[2 Pi]) + (t (3 + t^2) (1 + Erf[t/Sqrt[2]]))/2
para_part = (
# = (2 + t^2)/(E^(t^2/2) Sqrt[2 Pi]) + (t (3 + t^2) Erfc[-t/Sqrt[2]])/2
para_part = _safe_log(
(2 + t2) * t2.mul(-0.5).exp() / (2 * math.pi) ** 0.5
+ t * (3 + t2) * (1 + (t * 0.5**0.5).erf()) / 2
).log()
+ t * (3 + t2) * (t * -(0.5**0.5)).erfc() / 2
)

return para_part + perp_part
2 changes: 1 addition & 1 deletion pyro/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,6 @@ def _log_factorial_sum(x: torch.Tensor) -> torch.Tensor:
return (x + 1).lgamma().sum()
key = id(x)
if key not in _log_factorial_cache:
weakref.finalize(x, _log_factorial_cache.pop, key, None)
weakref.finalize(x, _log_factorial_cache.pop, key, None) # type: ignore
_log_factorial_cache[key] = (x + 1).lgamma().sum()
return _log_factorial_cache[key]
15 changes: 15 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,21 @@ def tensors_default_to(host):
torch.set_default_tensor_type("{}.{}".format(old_module, name))


@contextlib.contextmanager
def default_dtype(dtype):
"""
Context manager to temporarily set PyTorch default dtype.
:param str host: Either "cuda" or "cpu".
"""
old = torch.get_default_dtype()
try:
torch.set_default_dtype(dtype)
yield
finally:
torch.set_default_dtype(old)


@contextlib.contextmanager
def freeze_rng_state():
rng_state = torch.get_rng_state()
Expand Down
5 changes: 4 additions & 1 deletion tests/distributions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,10 +526,13 @@ def __init__(self, von_loc, von_conc, skewness):
pyro_dist=dist.ProjectedNormal,
examples=[
{"concentration": [0.0, 0.0], "test_data": [1.0, 0.0]},
{"concentration": [0.2, 0.1], "test_data": [1.0, 0.0]},
{"concentration": [2.0, 3.0], "test_data": [0.0, 1.0]},
{"concentration": [0.0, 0.0, 0.0], "test_data": [1.0, 0.0, 0.0]},
{"concentration": [0.1, 0.0, 0.0], "test_data": [1.0, 0.0, 0.0]},
{"concentration": [0.3, 0.2, 0.1], "test_data": [1.0, 0.0, 0.0]},
{"concentration": [-1.0, 2.0, 3.0], "test_data": [0.0, 0.0, 1.0]},
{"concentration": [0.0, 0.0, 0.0, 0.0], "test_data": [1.0, 0.0, 0.0, 0.0]},
{"concentration": [0.4, 0.3, 0.2, 0.1], "test_data": [1.0, 0.0, 0.0, 0.0]},
{
"concentration": [-1.0, 2.0, 0.5, -0.5],
"test_data": [0.0, 1.0, 0.0, 0.0],
Expand Down
3 changes: 2 additions & 1 deletion tests/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def test_gof(continuous_dist):
num_samples = 50000
for i in range(continuous_dist.get_num_test_data()):
d = Dist(**continuous_dist.get_dist_params(i))
samples = d.sample(torch.Size([num_samples]))
with torch.random.fork_rng():
samples = d.sample(torch.Size([num_samples]))
with xfail_if_not_implemented():
probs = d.log_prob(samples).exp()

Expand Down
24 changes: 24 additions & 0 deletions tests/distributions/test_projected_normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

import pyro.distributions as dist
from tests.common import default_dtype


@pytest.mark.parametrize("strength", [0, 1, 10, 100, 1000])
@pytest.mark.parametrize("dim", [2, 3, 4])
@pytest.mark.parametrize("dtype", [torch.float, torch.double], ids=str)
def test_log_prob(dtype, dim, strength):
with default_dtype(dtype):
concentration = torch.full((dim,), float(strength), requires_grad=True)
value = dist.ProjectedNormal(torch.zeros_like(concentration)).sample([10000])
d = dist.ProjectedNormal(concentration)

logp = d.log_prob(value)
assert logp.max().lt(1 + dim * strength).all()

logp.sum().backward()
assert not torch.isnan(concentration.grad).any()

0 comments on commit 1a57d31

Please sign in to comment.