From ef1c155714e0d64a99e7bfcc919b9659c5f1692a Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 1 Apr 2021 18:14:10 -0400 Subject: [PATCH 1/4] Add a SmoothLaplace distribution --- docs/source/distributions.rst | 7 ++++ pyro/distributions/__init__.py | 2 + pyro/distributions/smoothlaplace.py | 64 +++++++++++++++++++++++++++++ tests/distributions/conftest.py | 11 +++++ 4 files changed, 84 insertions(+) create mode 100644 pyro/distributions/smoothlaplace.py diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index fa3394a813..724397f918 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -316,6 +316,13 @@ Rejector :undoc-members: :show-inheritance: +SmoothLaplace +------------- +.. autoclass:: pyro.distributions.SmoothLaplace + :members: + :undoc-members: + :show-inheritance: + SpanningTree ------------ .. autoclass:: pyro.distributions.SpanningTree diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 8c3ee624cd..0110e72952 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -58,6 +58,7 @@ RelaxedBernoulliStraightThrough, RelaxedOneHotCategoricalStraightThrough, ) +from pyro.distributions.smoothlaplace import SmoothLaplace from pyro.distributions.spanning_tree import SpanningTree from pyro.distributions.stable import Stable from pyro.distributions.torch import __all__ as torch_dists @@ -127,6 +128,7 @@ "Rejector", "RelaxedBernoulliStraightThrough", "RelaxedOneHotCategoricalStraightThrough", + "SmoothLaplace", "SpanningTree", "Stable", "TorchDistribution", diff --git a/pyro/distributions/smoothlaplace.py b/pyro/distributions/smoothlaplace.py new file mode 100644 index 0000000000..e26464d9a9 --- /dev/null +++ b/pyro/distributions/smoothlaplace.py @@ -0,0 +1,64 @@ + +import math + +import torch +from torch.distributions import constraints +from torch.distributions.utils import broadcast_all + +from .torch_distribution import TorchDistribution + + +class SmoothLaplace(TorchDistribution): + """ + Smooth distribution with Laplace-like tail behavior. + + This distribution corresponds to the log-convex density:: + + z = (value - loc) / scale + log_prob = log(2 / pi) - log(scale) - logaddexp(z, -z) + + Like the Laplace density, this density has the heaviest possible tails + (asymptotically) while still being log-convex. Unlike the Laplace + distribution, this distribution is infinitely differentiable everywhere, + and is thus suitable for constructing Laplace approximations. + + :param loc: Location parameter. + :param scale: Scale parameter. + """ + + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + support = constraints.real + has_rsample = True + + def __init__(self, loc, scale, *, validate_args=None): + self.loc, self.scale = broadcast_all(loc, scale) + super().__init__(self.loc.shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(SmoothLaplace, _instance) + batch_shape = torch.Size(batch_shape) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + super(SmoothLaplace, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + z = (value - self.loc) / self.scale + return math.log(2 / math.pi) - self.scale.log() - torch.logaddexp(z, -z) + + def rsample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + u = self.loc.new_empty(shape).uniform_() + return self.icdf(u) + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + z = (value - self.loc) / self.scale + return z.exp().atan().mul(2 / math.pi) + + def icdf(self, value): + return value.mul(math.pi / 2).tan().log().mul(self.scale).add(self.loc) diff --git a/tests/distributions/conftest.py b/tests/distributions/conftest.py index 0ba138f1aa..d48c158ac6 100644 --- a/tests/distributions/conftest.py +++ b/tests/distributions/conftest.py @@ -313,6 +313,17 @@ def scale(self): {'concentration': [0., 0., 0.], 'test_data': [1., 0., 0.]}, {'concentration': [-1., 2., 3.], 'test_data': [0., 0., 1.]}, ]), + Fixture(pyro_dist=dist.SmoothLaplace, + examples=[ + {'loc': [2.0], 'scale': [4.0], + 'test_data': [2.0]}, + {'loc': [[2.0]], 'scale': [[4.0]], + 'test_data': [[2.0]]}, + {'loc': [[[2.0]]], 'scale': [[[4.0]]], + 'test_data': [[[2.0]]]}, + {'loc': [2.0, 50.0], 'scale': [4.0, 100.0], + 'test_data': [[2.0, 50.0], [2.0, 50.0]]}, + ]), ] discrete_dists = [ From 8ee9d2903066605d4f084d336aced27a9d3d8094 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 1 Apr 2021 18:22:55 -0400 Subject: [PATCH 2/4] Rename to SoftLaplace --- pyro/distributions/{smoothlaplace.py => softlaplace.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename pyro/distributions/{smoothlaplace.py => softlaplace.py} (100%) diff --git a/pyro/distributions/smoothlaplace.py b/pyro/distributions/softlaplace.py similarity index 100% rename from pyro/distributions/smoothlaplace.py rename to pyro/distributions/softlaplace.py From d140a85498ee83100b743c1dfdfe78034c932a11 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 1 Apr 2021 18:23:21 -0400 Subject: [PATCH 3/4] Rename to SoftLaplace --- docs/source/distributions.rst | 4 ++-- pyro/distributions/__init__.py | 4 ++-- pyro/distributions/softlaplace.py | 8 +++++--- tests/distributions/conftest.py | 2 +- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 724397f918..03f91b9054 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -316,9 +316,9 @@ Rejector :undoc-members: :show-inheritance: -SmoothLaplace +SoftLaplace ------------- -.. autoclass:: pyro.distributions.SmoothLaplace +.. autoclass:: pyro.distributions.SoftLaplace :members: :undoc-members: :show-inheritance: diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 0110e72952..2f179928de 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -58,7 +58,7 @@ RelaxedBernoulliStraightThrough, RelaxedOneHotCategoricalStraightThrough, ) -from pyro.distributions.smoothlaplace import SmoothLaplace +from pyro.distributions.softlaplace import SoftLaplace from pyro.distributions.spanning_tree import SpanningTree from pyro.distributions.stable import Stable from pyro.distributions.torch import __all__ as torch_dists @@ -128,7 +128,7 @@ "Rejector", "RelaxedBernoulliStraightThrough", "RelaxedOneHotCategoricalStraightThrough", - "SmoothLaplace", + "SoftLaplace", "SpanningTree", "Stable", "TorchDistribution", diff --git a/pyro/distributions/softlaplace.py b/pyro/distributions/softlaplace.py index e26464d9a9..ac14696669 100644 --- a/pyro/distributions/softlaplace.py +++ b/pyro/distributions/softlaplace.py @@ -1,3 +1,5 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 import math @@ -8,7 +10,7 @@ from .torch_distribution import TorchDistribution -class SmoothLaplace(TorchDistribution): +class SoftLaplace(TorchDistribution): """ Smooth distribution with Laplace-like tail behavior. @@ -35,11 +37,11 @@ def __init__(self, loc, scale, *, validate_args=None): super().__init__(self.loc.shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): - new = self._get_checked_instance(SmoothLaplace, _instance) + new = self._get_checked_instance(SoftLaplace, _instance) batch_shape = torch.Size(batch_shape) new.loc = self.loc.expand(batch_shape) new.scale = self.scale.expand(batch_shape) - super(SmoothLaplace, new).__init__(batch_shape, validate_args=False) + super(SoftLaplace, new).__init__(batch_shape, validate_args=False) new._validate_args = self._validate_args return new diff --git a/tests/distributions/conftest.py b/tests/distributions/conftest.py index d48c158ac6..0c6797aceb 100644 --- a/tests/distributions/conftest.py +++ b/tests/distributions/conftest.py @@ -313,7 +313,7 @@ def scale(self): {'concentration': [0., 0., 0.], 'test_data': [1., 0., 0.]}, {'concentration': [-1., 2., 3.], 'test_data': [0., 0., 1.]}, ]), - Fixture(pyro_dist=dist.SmoothLaplace, + Fixture(pyro_dist=dist.SoftLaplace, examples=[ {'loc': [2.0], 'scale': [4.0], 'test_data': [2.0]}, From 0388c6a97935eac91a0f37b0ec0dc6b9bd145d9b Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 2 Apr 2021 12:17:43 -0400 Subject: [PATCH 4/4] Add test for agreement between .cdf() and .icdf() --- pyro/distributions/multivariate_studentt.py | 2 +- tests/distributions/test_distributions.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/pyro/distributions/multivariate_studentt.py b/pyro/distributions/multivariate_studentt.py index 1e641e64df..8b6a2d77c4 100644 --- a/pyro/distributions/multivariate_studentt.py +++ b/pyro/distributions/multivariate_studentt.py @@ -34,7 +34,7 @@ def __init__(self, df, loc, scale_tril, validate_args=None): if not isinstance(df, torch.Tensor): df = loc.new_tensor(df) batch_shape = broadcast_shape(df.shape, loc.shape[:-1], scale_tril.shape[:-2]) - event_shape = (dim,) + event_shape = torch.Size((dim,)) self.df = df.expand(batch_shape) self.loc = loc.expand(batch_shape + event_shape) self._unbroadcasted_scale_tril = scale_tril diff --git a/tests/distributions/test_distributions.py b/tests/distributions/test_distributions.py index 29ae117a61..9e53619c10 100644 --- a/tests/distributions/test_distributions.py +++ b/tests/distributions/test_distributions.py @@ -145,6 +145,19 @@ def test_gof(continuous_dist): assert gof > TEST_FAILURE_RATE +def test_cdf_icdf(continuous_dist): + Dist = continuous_dist.pyro_dist + for i in range(continuous_dist.get_num_test_data()): + d = Dist(**continuous_dist.get_dist_params(i)) + if d.event_shape.numel() != 1: + continue # only valid for univariate distributions + u = torch.empty((100,) + d.shape()).uniform_() + with xfail_if_not_implemented(): + x = d.icdf(u) + u2 = d.cdf(x) + assert_equal(u, u2) + + # Distributions tests - discrete distributions def test_support_is_discrete(discrete_dist):