From 1a57d314195ab40517f4d6835e71e34bece05fa6 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 2 May 2022 08:45:12 -0400 Subject: [PATCH] Numerically stabilize ProjectedNormal.log_prob() via erfc (#3071) * Numerically stabilize ProjectedNormal.log_prob() via logaddexp * Fix conceptual error, now this NANs * Switch to erfc * Strengthen tests * lint * Strengthen test, clamp harder --- pyro/distributions/projected_normal.py | 30 +++++++++++--------- pyro/ops/special.py | 2 +- tests/common.py | 15 ++++++++++ tests/distributions/conftest.py | 5 +++- tests/distributions/test_distributions.py | 3 +- tests/distributions/test_projected_normal.py | 24 ++++++++++++++++ 6 files changed, 63 insertions(+), 16 deletions(-) create mode 100644 tests/distributions/test_projected_normal.py diff --git a/pyro/distributions/projected_normal.py b/pyro/distributions/projected_normal.py index 49856f8d49..fcd4c29212 100644 --- a/pyro/distributions/projected_normal.py +++ b/pyro/distributions/projected_normal.py @@ -126,6 +126,10 @@ def _dot(x, y): return (x[..., None, :] @ y[..., None])[..., 0, 0] +def _safe_log(x): + return x.clamp(min=torch.finfo(x.dtype).eps).log() + + @ProjectedNormal._register_log_prob(dim=2) def _log_prob_2(concentration, value): # We integrate along a ray, factorizing the integrand as a product of: @@ -139,13 +143,11 @@ def _log_prob_2(concentration, value): # This is the log of a definite integral, computed by mathematica: # Integrate[x/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}] # = (t + Sqrt[2/Pi]/E^(t^2/2) + t Erf[t/Sqrt[2]])/2 - para_part = ( - ( - t2.mul(-0.5).exp().mul((2 / math.pi) ** 0.5) - + t * (1 + (t * 0.5**0.5).erf()) - ) - .mul(0.5) - .log() + # = (Sqrt[2/Pi]/E^(t^2/2) + t (1 + Erf[t/Sqrt[2]]))/2 + # = (Sqrt[2/Pi]/E^(t^2/2) + t Erfc[-t/Sqrt[2]])/2 + para_part = _safe_log( + (t2.mul(-0.5).exp().mul((2 / math.pi) ** 0.5) + t * (t * -(0.5**0.5)).erfc()) + / 2 ) return para_part + perp_part @@ -164,10 +166,11 @@ def _log_prob_3(concentration, value): # This is the log of a definite integral, computed by mathematica: # Integrate[x^2/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}] # = t/(E^(t^2/2) Sqrt[2 Pi]) + ((1 + t^2) (1 + Erf[t/Sqrt[2]]))/2 - para_part = ( + # = t/(E^(t^2/2) Sqrt[2 Pi]) + ((1 + t^2) Erfc[-t/Sqrt[2]])/2 + para_part = _safe_log( t * t2.mul(-0.5).exp() / (2 * math.pi) ** 0.5 - + (1 + t2) * (1 + (t * 0.5**0.5).erf()) / 2 - ).log() + + (1 + t2) * (t * -(0.5**0.5)).erfc() / 2 + ) return para_part + perp_part @@ -185,9 +188,10 @@ def _log_prob_4(concentration, value): # This is the log of a definite integral, computed by mathematica: # Integrate[x^3/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}] # = (2 + t^2)/(E^(t^2/2) Sqrt[2 Pi]) + (t (3 + t^2) (1 + Erf[t/Sqrt[2]]))/2 - para_part = ( + # = (2 + t^2)/(E^(t^2/2) Sqrt[2 Pi]) + (t (3 + t^2) Erfc[-t/Sqrt[2]])/2 + para_part = _safe_log( (2 + t2) * t2.mul(-0.5).exp() / (2 * math.pi) ** 0.5 - + t * (3 + t2) * (1 + (t * 0.5**0.5).erf()) / 2 - ).log() + + t * (3 + t2) * (t * -(0.5**0.5)).erfc() / 2 + ) return para_part + perp_part diff --git a/pyro/ops/special.py b/pyro/ops/special.py index 7e38896df7..abd16df47a 100644 --- a/pyro/ops/special.py +++ b/pyro/ops/special.py @@ -213,6 +213,6 @@ def _log_factorial_sum(x: torch.Tensor) -> torch.Tensor: return (x + 1).lgamma().sum() key = id(x) if key not in _log_factorial_cache: - weakref.finalize(x, _log_factorial_cache.pop, key, None) + weakref.finalize(x, _log_factorial_cache.pop, key, None) # type: ignore _log_factorial_cache[key] = (x + 1).lgamma().sum() return _log_factorial_cache[key] diff --git a/tests/common.py b/tests/common.py index 28708ba8b4..9968baf136 100644 --- a/tests/common.py +++ b/tests/common.py @@ -102,6 +102,21 @@ def tensors_default_to(host): torch.set_default_tensor_type("{}.{}".format(old_module, name)) +@contextlib.contextmanager +def default_dtype(dtype): + """ + Context manager to temporarily set PyTorch default dtype. + + :param str host: Either "cuda" or "cpu". + """ + old = torch.get_default_dtype() + try: + torch.set_default_dtype(dtype) + yield + finally: + torch.set_default_dtype(old) + + @contextlib.contextmanager def freeze_rng_state(): rng_state = torch.get_rng_state() diff --git a/tests/distributions/conftest.py b/tests/distributions/conftest.py index 85b8aba808..1c0b314f43 100644 --- a/tests/distributions/conftest.py +++ b/tests/distributions/conftest.py @@ -526,10 +526,13 @@ def __init__(self, von_loc, von_conc, skewness): pyro_dist=dist.ProjectedNormal, examples=[ {"concentration": [0.0, 0.0], "test_data": [1.0, 0.0]}, + {"concentration": [0.2, 0.1], "test_data": [1.0, 0.0]}, {"concentration": [2.0, 3.0], "test_data": [0.0, 1.0]}, - {"concentration": [0.0, 0.0, 0.0], "test_data": [1.0, 0.0, 0.0]}, + {"concentration": [0.1, 0.0, 0.0], "test_data": [1.0, 0.0, 0.0]}, + {"concentration": [0.3, 0.2, 0.1], "test_data": [1.0, 0.0, 0.0]}, {"concentration": [-1.0, 2.0, 3.0], "test_data": [0.0, 0.0, 1.0]}, {"concentration": [0.0, 0.0, 0.0, 0.0], "test_data": [1.0, 0.0, 0.0, 0.0]}, + {"concentration": [0.4, 0.3, 0.2, 0.1], "test_data": [1.0, 0.0, 0.0, 0.0]}, { "concentration": [-1.0, 2.0, 0.5, -0.5], "test_data": [0.0, 1.0, 0.0, 0.0], diff --git a/tests/distributions/test_distributions.py b/tests/distributions/test_distributions.py index 3a2b4ca224..1ec7d2ae02 100644 --- a/tests/distributions/test_distributions.py +++ b/tests/distributions/test_distributions.py @@ -143,7 +143,8 @@ def test_gof(continuous_dist): num_samples = 50000 for i in range(continuous_dist.get_num_test_data()): d = Dist(**continuous_dist.get_dist_params(i)) - samples = d.sample(torch.Size([num_samples])) + with torch.random.fork_rng(): + samples = d.sample(torch.Size([num_samples])) with xfail_if_not_implemented(): probs = d.log_prob(samples).exp() diff --git a/tests/distributions/test_projected_normal.py b/tests/distributions/test_projected_normal.py new file mode 100644 index 0000000000..7827e7ce58 --- /dev/null +++ b/tests/distributions/test_projected_normal.py @@ -0,0 +1,24 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +import pyro.distributions as dist +from tests.common import default_dtype + + +@pytest.mark.parametrize("strength", [0, 1, 10, 100, 1000]) +@pytest.mark.parametrize("dim", [2, 3, 4]) +@pytest.mark.parametrize("dtype", [torch.float, torch.double], ids=str) +def test_log_prob(dtype, dim, strength): + with default_dtype(dtype): + concentration = torch.full((dim,), float(strength), requires_grad=True) + value = dist.ProjectedNormal(torch.zeros_like(concentration)).sample([10000]) + d = dist.ProjectedNormal(concentration) + + logp = d.log_prob(value) + assert logp.max().lt(1 + dim * strength).all() + + logp.sum().backward() + assert not torch.isnan(concentration.grad).any()