From 90d828688fc254d90591104d4163d485923d3af8 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Wed, 21 Aug 2024 18:58:49 -0400 Subject: [PATCH] Add `sign` for bijective scalar transforms and generic `cdf`/`icdf` implementation for `TransformedDistribution`s. --- numpyro/distributions/continuous.py | 15 ------------ numpyro/distributions/distribution.py | 18 ++++++++++++++ numpyro/distributions/transforms.py | 34 +++++++++++++++++++++++++++ test/test_distributions.py | 7 +++--- test/test_transforms.py | 3 +++ 5 files changed, 59 insertions(+), 18 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 6529417465..1298894346 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -698,9 +698,6 @@ def variance(self): a = (self.rate / (self.concentration - 1)) ** 2 / (self.concentration - 2) return jnp.where(self.concentration <= 2, jnp.inf, a) - def cdf(self, x): - return 1 - self.base_dist.cdf(1 / x) - def entropy(self): return ( self.concentration @@ -1205,9 +1202,6 @@ def mean(self): def variance(self): return (jnp.exp(self.scale**2) - 1) * jnp.exp(2 * self.loc + self.scale**2) - def cdf(self, x): - return self.base_dist.cdf(jnp.log(x)) - def entropy(self): return (1 + jnp.log(2 * jnp.pi)) / 2 + self.loc + jnp.log(self.scale) @@ -1283,9 +1277,6 @@ def variance(self): - self.mean**2 ) - def cdf(self, x): - return self.base_dist.cdf(jnp.log(x)) - def entropy(self): log_low = jnp.log(self.low) log_high = jnp.log(self.high) @@ -2162,12 +2153,6 @@ def variance(self): def support(self): return constraints.greater_than(self.scale) - def cdf(self, value): - return 1 - jnp.power(self.scale / value, self.alpha) - - def icdf(self, q): - return self.scale / jnp.power(1 - q, 1 / self.alpha) - def entropy(self): return jnp.log(self.scale / self.alpha) + 1 + 1 / self.alpha diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 04b213a3e9..cabf67c878 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -1105,6 +1105,24 @@ def mean(self): def variance(self): raise NotImplementedError + def cdf(self, value): + sign = 1 + for transform in reversed(self.transforms): + sign *= transform.sign + value = transform.inv(value) + q = self.base_dist.cdf(value) + return jnp.where(sign < 0, 1 - q, q) + + def icdf(self, q): + sign = 1 + for transform in self.transforms: + sign *= transform.sign + q = jnp.where(sign < 0, 1 - q, q) + value = self.base_dist.icdf(q) + for transform in self.transforms: + value = transform(value) + return value + class FoldedDistribution(TransformedDistribution): """ diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 2e24881811..0aa14e04af 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -104,6 +104,15 @@ def inverse_shape(self, shape): """ return shape + @property + def sign(self): + """ + Sign of the derivative of the transform if it is bijective. + """ + raise NotImplementedError( + f"Transform `{self.__class__.__name__}` does not implement `sign`." + ) + # Allow for pickle serialization of transforms. def __getstate__(self): attrs = {} @@ -147,6 +156,10 @@ def domain(self): def codomain(self): return self._inv.domain + @property + def sign(self): + return self._inv.sign + @property def inv(self): return self._inv @@ -231,6 +244,10 @@ def codomain(self): else: raise NotImplementedError + @property + def sign(self): + return jnp.sign(self.scale) + def __call__(self, x): return self.loc + self.scale * x @@ -309,6 +326,13 @@ def codomain(self): self.parts[-1].codomain, output_event_dim - last_output_event_dim ) + @property + def sign(self): + sign = 1 + for transform in self.parts: + sign *= transform.sign + return sign + def __call__(self, x): for part in self.parts: x = part(x) @@ -509,6 +533,8 @@ def log_abs_det_jacobian(self, x, y, intermediates=None): class ExpTransform(Transform): + sign = 1 + # TODO: refine domain/codomain logic through setters, especially when # transforms for inverses are supported def __init__(self, domain=constraints.real): @@ -550,6 +576,8 @@ def __eq__(self, other): class IdentityTransform(ParameterFreeTransform): + sign = 1 + def __call__(self, x): return x @@ -912,9 +940,14 @@ def __eq__(self, other): return False return jnp.array_equal(self.exponent, other.exponent) + @property + def sign(self): + return jnp.sign(self.exponent) + class SigmoidTransform(ParameterFreeTransform): codomain = constraints.unit_interval + sign = 1 def __call__(self, x): return _clipped_expit(x) @@ -1006,6 +1039,7 @@ class SoftplusTransform(ParameterFreeTransform): domain = constraints.real codomain = constraints.softplus_positive + sign = 1 def __call__(self, x): return softplus(x) diff --git a/test/test_distributions.py b/test/test_distributions.py index e10fd7248b..ac89f59824 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -540,6 +540,7 @@ def get_sp_dist(jax_dist): T(dist.HalfNormal, 1.0), T(dist.HalfNormal, np.array([1.0, 2.0])), T(_ImproperWrapper, constraints.positive, (), (3,)), + T(dist.InverseGamma, np.array([3.1]), np.array([[2.0], [3.0]])), T(dist.InverseGamma, np.array([1.7]), np.array([[2.0], [3.0]])), T(dist.InverseGamma, np.array([0.5, 1.3]), np.array([[1.0], [3.0]])), T(dist.Kumaraswamy, 10.0, np.array([2.0, 3.0])), @@ -1568,7 +1569,7 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params): samples = d.sample(key=random.PRNGKey(0), sample_shape=(100,)) quantiles = random.uniform(random.PRNGKey(1), (100,) + d.shape()) try: - rtol = 2e-3 if jax_dist in (dist.Gamma, dist.StudentT) else 1e-5 + rtol = 2e-3 if jax_dist in (dist.Gamma, dist.LogNormal, dist.StudentT) else 1e-5 if d.shape() == () and not d.is_discrete: assert_allclose( jax.vmap(jax.grad(d.cdf))(samples), @@ -1585,7 +1586,7 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params): assert_allclose(d.cdf(d.icdf(quantiles)), quantiles, atol=1e-5, rtol=1e-5) assert_allclose(d.icdf(d.cdf(samples)), samples, atol=1e-5, rtol=rtol) except NotImplementedError: - pass + pytest.skip("cdf/icdf not implemented") # test against scipy if not sp_dist: @@ -1599,7 +1600,7 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params): expected_icdf = sp_dist.ppf(quantiles) assert_allclose(actual_icdf, expected_icdf, atol=1e-4, rtol=1e-4) except NotImplementedError: - pass + pytest.skip("cdf/icdf not implemented") @pytest.mark.parametrize("jax_dist, sp_dist, params", CONTINUOUS + DIRECTIONAL) diff --git a/test/test_transforms.py b/test/test_transforms.py index 9979592449..cecbf2ca51 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -341,6 +341,9 @@ def test_bijective_transforms(transform, shape): ) slogdet = jnp.linalg.slogdet(jac) assert jnp.allclose(log_abs_det_jacobian, slogdet.logabsdet, atol=atol) + assert transform.domain.event_dim or jnp.allclose( + jnp.sign(jnp.diagonal(jac, axis1=-1, axis2=-2)), transform.sign + ) def test_batched_recursive_linear_transform():