diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 06b51c929..f568d458d 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -380,6 +380,22 @@ Weibull :show-inheritance: :member-order: bysource +Wishart +^^^^^^^ +.. autoclass:: numpyro.distributions.continuous.Wishart + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + +WishartCholesky +^^^^^^^^^^^^^^^ +.. autoclass:: numpyro.distributions.continuous.WishartCholesky + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + ZeroSumNormal ^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.continuous.ZeroSumNormal diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index d05376573..412dd1976 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -47,6 +47,8 @@ StudentT, Uniform, Weibull, + Wishart, + WishartCholesky, ZeroSumNormal, ) from numpyro.distributions.copula import GaussianCopula, GaussianCopulaBeta @@ -194,6 +196,8 @@ "Unit", "VonMises", "Weibull", + "Wishart", + "WishartCholesky", "ZeroInflatedDistribution", "ZeroInflatedPoisson", "ZeroInflatedNegativeBinomial2", diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 273fa4a43..8280482c0 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -55,6 +55,7 @@ from numpyro.distributions.distribution import Distribution, TransformedDistribution from numpyro.distributions.transforms import ( AffineTransform, + CholeskyTransform, CorrMatrixCholeskyTransform, ExpTransform, PowerTransform, @@ -63,14 +64,17 @@ ) from numpyro.distributions.util import ( add_diag, + assert_one_of, betainc, betaincinv, cholesky_of_inverse, gammaincinv, lazy_property, matrix_to_tril_vec, + multidigamma, promote_shapes, signed_stick_breaking_tril, + tri_logabsdet, validate_sample, vec_to_tril_matrix, ) @@ -1397,12 +1401,8 @@ def sample(self, key, sample_shape=()): def log_prob(self, values): n, p = self.event_shape - row_log_det = jnp.log( - jnp.diagonal(self.scale_tril_row, axis1=-2, axis2=-1) - ).sum(-1) - col_log_det = jnp.log( - jnp.diagonal(self.scale_tril_column, axis1=-2, axis2=-1) - ).sum(-1) + row_log_det = tri_logabsdet(self.scale_tril_row) + col_log_det = tri_logabsdet(self.scale_tril_column) log_det_term = ( p * row_log_det + n * col_log_det + 0.5 * n * p * jnp.log(2 * jnp.pi) ) @@ -1489,6 +1489,11 @@ def __init__( scale_tril=None, validate_args=None, ): + assert_one_of( + covariance_matrix=covariance_matrix, + precision_matrix=precision_matrix, + scale_tril=scale_tril, + ) if jnp.ndim(loc) == 0: (loc,) = promote_shapes(loc, shape=(1,)) # temporary append a new axis to loc @@ -1501,11 +1506,6 @@ def __init__( self.scale_tril = cholesky_of_inverse(self.precision_matrix) elif scale_tril is not None: loc, self.scale_tril = promote_shapes(loc, scale_tril) - else: - raise ValueError( - "One of `covariance_matrix`, `precision_matrix`, `scale_tril`" - " must be specified." - ) batch_shape = lax.broadcast_shapes( jnp.shape(loc)[:-2], jnp.shape(self.scale_tril)[:-2] ) @@ -1529,9 +1529,7 @@ def sample(self, key, sample_shape=()): @validate_sample def log_prob(self, value): M = _batch_mahalanobis(self.scale_tril, value - self.loc) - half_log_det = jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum( - -1 - ) + half_log_det = tri_logabsdet(self.scale_tril) normalize_term = half_log_det + 0.5 * self.scale_tril.shape[-1] * jnp.log( 2 * jnp.pi ) @@ -1562,18 +1560,21 @@ def variance(self): def infer_shapes( loc=(), covariance_matrix=None, precision_matrix=None, scale_tril=None ): + assert_one_of( + covariance_matrix=covariance_matrix, + precision_matrix=precision_matrix, + scale_tril=scale_tril, + ) batch_shape, event_shape = loc[:-1], loc[-1:] for matrix in [covariance_matrix, precision_matrix, scale_tril]: if matrix is not None: batch_shape = lax.broadcast_shapes(batch_shape, matrix[:-2]) event_shape = lax.broadcast_shapes(event_shape, matrix[-1:]) - return batch_shape, event_shape + return batch_shape, event_shape def entropy(self): (n,) = self.event_shape - half_log_det = jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum( - -1 - ) + half_log_det = tri_logabsdet(self.scale_tril) return n * (jnp.log(2 * np.pi) + 1) / 2 + half_log_det @@ -1849,7 +1850,7 @@ def sample(self, key, sample_shape=()): def log_prob(self, value): n = self.scale_tril.shape[-1] Z = ( - jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1) + tri_logabsdet(self.scale_tril) + 0.5 * n * jnp.log(self.df) + 0.5 * n * jnp.log(jnp.pi) + gammaln(0.5 * self.df) @@ -1924,9 +1925,7 @@ def _batch_lowrank_logdet(W, D, capacitance_tril): where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the log determinant. """ - return 2 * jnp.sum( - jnp.log(jnp.diagonal(capacitance_tril, axis1=-2, axis2=-1)), axis=-1 - ) + jnp.log(D).sum(-1) + return 2 * tri_logabsdet(capacitance_tril) + jnp.log(D).sum(-1) def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril): @@ -2615,3 +2614,260 @@ def variance(self): theoretical_var *= 1 - 1 / self.event_shape[axis] return jnp.broadcast_to(theoretical_var, self.batch_shape + self.event_shape) + + +class Wishart(TransformedDistribution): + """ + Wishart distribution for covariance matrices. + + :param concentration: Positive concentration parameter analogous to the + concentration of a :class:`Gamma` distribution. The concentration must be larger + than the dimensionality of the scale matrix. + :param scale_matrix: Scale matrix analogous to the inverse rate of a :class:`Gamma` + distribution. + :param rate_matrix: Rate matrix anaologous to the rate of a :class:`Gamma` + distribution. + :param scale_tril: Cholesky decomposition of the :code:`scale_matrix`. + """ + + arg_constraints = { + "concentration": constraints.dependent(is_discrete=False), + "scale_matrix": constraints.positive_definite, + "rate_matrix": constraints.positive_definite, + "scale_tril": constraints.lower_cholesky, + } + support = constraints.positive_definite + reparametrized_params = [ + "scale_matrix", + "rate_matrix", + "scale_tril", + ] + + def __init__( + self, + concentration, + scale_matrix=None, + rate_matrix=None, + scale_tril=None, + *, + validate_args=None, + ): + base_dist = WishartCholesky( + concentration, + scale_matrix, + rate_matrix, + scale_tril, + validate_args=validate_args, + ) + super().__init__( + base_dist, CholeskyTransform().inv, validate_args=validate_args + ) + + @lazy_property + def concentration(self): + return self.base_dist.concentration + + @lazy_property + def scale_matrix(self): + return self.base_dist.scale_matrix + + @lazy_property + def rate_matrix(self): + return self.base_dist.rate_matrix + + @lazy_property + def scale_tril(self): + return self.base_dist.scale_tril + + @lazy_property + def mean(self): + return self.concentration[..., None, None] * self.scale_matrix + + @lazy_property + def variance(self): + diag = jnp.diagonal(self.scale_matrix, axis1=-1, axis2=-2) + return self.concentration[..., None, None] * ( + self.scale_matrix**2 + diag[..., :, None] * diag[..., None, :] + ) + + @staticmethod + def infer_shapes( + concentration=(), scale_matrix=None, rate_matrix=None, scale_tril=None + ): + return WishartCholesky.infer_shapes( + concentration, scale_matrix, rate_matrix, scale_tril + ) + + def entropy(self): + p = self.event_shape[-1] + return ( + (p + 1) * tri_logabsdet(self.scale_tril) + + p * (p + 1) / 2 * jnp.log(2) + + multigammaln(self.concentration / 2, p) + - (self.concentration - p - 1) / 2 * multidigamma(self.concentration / 2, p) + + self.concentration * p / 2 + ) + + +class WishartCholesky(Distribution): + """ + Cholesky factor of a Wishart distribution for covariance matrices. + + :param concentration: Positive concentration parameter analogous to the + concentration of a :class:`Gamma` distribution. The concentration must be larger + than the dimensionality of the scale matrix. + :param scale_matrix: Scale matrix analogous to the inverse rate of a :class:`Gamma` + distribution. + :param rate_matrix: Rate matrix anaologous to the rate of a :class:`Gamma` + distribution. + :param scale_tril: Cholesky decomposition of the :code:`scale_matrix`. + """ + + arg_constraints = { + "concentration": constraints.dependent(is_discrete=False), + "scale_matrix": constraints.positive_definite, + "rate_matrix": constraints.positive_definite, + "scale_tril": constraints.lower_cholesky, + } + support = constraints.lower_cholesky + reparametrized_params = [ + "scale_matrix", + "rate_matrix", + "scale_tril", + ] + + def __init__( + self, + concentration, + scale_matrix=None, + rate_matrix=None, + scale_tril=None, + *, + validate_args=None, + ): + assert_one_of( + scale_matrix=scale_matrix, + rate_matrix=rate_matrix, + scale_tril=scale_tril, + ) + concentration = jnp.asarray(concentration)[..., None, None] + if scale_matrix is not None: + concentration, self.scale_matrix = promote_shapes( + concentration, scale_matrix + ) + self.scale_tril = jnp.linalg.cholesky(self.scale_matrix) + elif rate_matrix is not None: + concentration, self.rate_matrix = promote_shapes(concentration, rate_matrix) + self.scale_tril = cholesky_of_inverse(self.rate_matrix) + elif scale_tril is not None: + concentration, self.scale_tril = promote_shapes( + concentration, jnp.asarray(scale_tril) + ) + batch_shape = lax.broadcast_shapes( + jnp.shape(concentration)[:-2], jnp.shape(self.scale_tril)[:-2] + ) + event_shape = jnp.shape(self.scale_tril)[-2:] + self.concentration = concentration[..., 0, 0] + super().__init__( + batch_shape=batch_shape, + event_shape=event_shape, + validate_args=validate_args, + ) + + @validate_sample + def log_prob(self, value): + # The log density of the Wishart distribution includes a term + # t = trace(rate_matrix @ cov). Here, value = cholesky(cov) such that + # t = trace(value.T @ rate_matrix @ value) by the cyclical property of the + # trace. The rate matrix is the inverse scale matrix with Cholesky decomposition + # scale_tril. Thus, + # t = trace(value.T @ inv(scale_tril).T @ inv(scale_tril) @ value), and we can + # rewrite as t = trace(x.T @ x) for x = inv(scale_tril) @ value which we can + # obtain easily by solving a triangular system. x is again triangular such that + # trace(x @ x.T) is equal to the sum of squares of elements. + x = solve_triangular(*jnp.broadcast_arrays(self.scale_tril, value), lower=True) + trace = jnp.square(x).sum(axis=(-1, -2)) + p = value.shape[-1] + return ( + (self.concentration - p - 1) * tri_logabsdet(value) + - trace / 2 + + p * (1 - self.concentration / 2) * jnp.log(2) + - multigammaln(self.concentration / 2, p) + - self.concentration * tri_logabsdet(self.scale_tril) + # Part of the Jacobian of the Cholesky transformation. + + jnp.sum( + jnp.arange(p, 0, -1) * jnp.log(jnp.diagonal(value, axis1=-2, axis2=-1)), + axis=-1, + ) + ) + + @lazy_property + def scale_matrix(self): + return jnp.matmul(self.scale_tril, self.scale_tril.mT) + + @lazy_property + def rate_matrix(self): + identity = jnp.broadcast_to( + jnp.eye(self.scale_tril.shape[-1]), self.scale_tril.shape + ) + return cho_solve((self.scale_tril, True), identity) + + def sample(self, key, sample_shape=()): + assert is_prng_key(key) + # Sample using the Bartlett decomposition + # (https://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition). + rng_diag, rng_offdiag = random.split(key) + latent = jnp.zeros(sample_shape + self.batch_shape + self.event_shape) + p = self.event_shape[-1] + i = jnp.arange(p) + latent = latent.at[..., i, i].set( + jnp.sqrt( + random.chisquare( + rng_diag, self.concentration[..., None] - i, latent.shape[:-1] + ) + ) + ) + i, j = jnp.tril_indices(p, -1) + assert i.size == p * (p - 1) // 2 + latent = latent.at[..., i, j].set( + random.normal(rng_offdiag, latent.shape[:-2] + (i.size,)) + ) + return jnp.matmul(*jnp.broadcast_arrays(self.scale_tril, latent)) + + @lazy_property + def mean(self): + # The mean follows from the Bartlett decomposition sampling. All off-diagonal + # elements of the latent variable have zero expectation. The diagonal are the + # expected square roots of chi^2 variables which can be expressed in terms of + # gamma functions (see + # https://en.wikipedia.org/wiki/Chi-squared_distribution#Noncentral_moments). + k = self.concentration[..., None] - jnp.arange(self.scale_tril.shape[-1]) + sqrtchi2 = jnp.sqrt(2) * jnp.exp(gammaln((k + 1) / 2) - gammaln(k / 2)) + return self.scale_tril * sqrtchi2[..., None, :] + + @lazy_property + def variance(self): + # We have the same as for the mean except now the lower off-diagonals are one + # due to the standard normal noise, and the diagonals are equal to the dof of + # the chi^2 variables. + i = jnp.arange(self.scale_tril.shape[-1]) + k = self.concentration[..., None] - i + latent = jnp.tril( + jnp.ones_like(k, shape=k.shape + (k.shape[-1],)).at[..., i, i].set(k) + ) + return jnp.square(self.scale_tril) @ latent - jnp.square(self.mean) + + @staticmethod + def infer_shapes( + concentration=(), scale_matrix=None, rate_matrix=None, scale_tril=None + ): + assert_one_of( + scale_matrix=scale_matrix, + rate_matrix=rate_matrix, + scale_tril=scale_tril, + ) + for matrix in [scale_matrix, rate_matrix, scale_tril]: + if matrix is not None: + batch_shape = lax.broadcast_shapes(concentration, matrix[:-2]) + event_shape = matrix[-2:] + return batch_shape, event_shape diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index 841dd2194..1add0fcba 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -16,6 +16,7 @@ from numpyro.distributions import constraints from numpyro.distributions.distribution import Distribution from numpyro.distributions.util import ( + assert_one_of, lazy_property, promote_shapes, safe_normalize, @@ -349,7 +350,9 @@ def __init__( weighted_correlation=None, validate_args=None, ): - assert (correlation is None) != (weighted_correlation is None) + assert_one_of( + correlation=correlation, weighted_correlation=weighted_correlation + ) if weighted_correlation is not None: correlation = weighted_correlation * jnp.sqrt( diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index fc0c81dd9..0ee140620 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -37,6 +37,7 @@ from numpyro.distributions import constraints, transforms from numpyro.distributions.distribution import Distribution from numpyro.distributions.util import ( + assert_one_of, binary_cross_entropy_with_logits, binomial, categorical, @@ -160,12 +161,11 @@ def entropy(self): def Bernoulli(probs=None, logits=None, *, validate_args=None): + assert_one_of(probs=probs, logits=logits) if probs is not None: return BernoulliProbs(probs, validate_args=validate_args) elif logits is not None: return BernoulliLogits(logits, validate_args=validate_args) - else: - raise ValueError("One of `probs` or `logits` must be specified.") class BinomialProbs(Distribution): @@ -293,12 +293,11 @@ def support(self): def Binomial(total_count=1, probs=None, logits=None, *, validate_args=None): + assert_one_of(probs=probs, logits=logits) if probs is not None: return BinomialProbs(probs, total_count, validate_args=validate_args) elif logits is not None: return BinomialLogits(logits, total_count, validate_args=validate_args) - else: - raise ValueError("One of `probs` or `logits` must be specified.") class CategoricalProbs(Distribution): @@ -411,12 +410,11 @@ def entropy(self): def Categorical(probs=None, logits=None, *, validate_args=None): + assert_one_of(probs=probs, logits=logits) if probs is not None: return CategoricalProbs(probs, validate_args=validate_args) elif logits is not None: return CategoricalLogits(logits, validate_args=validate_args) - else: - raise ValueError("One of `probs` or `logits` must be specified.") class DiscreteUniform(Distribution): @@ -670,6 +668,7 @@ def Multinomial( :param int total_count_max: the maximum number of trials, i.e. `max(total_count)` """ + assert_one_of(probs=probs, logits=logits) if probs is not None: return MultinomialProbs( probs, @@ -684,8 +683,6 @@ def Multinomial( total_count_max=total_count_max, validate_args=validate_args, ) - else: - raise ValueError("One of `probs` or `logits` must be specified.") class Poisson(Distribution): @@ -837,10 +834,7 @@ def ZeroInflatedDistribution( :param numpy.ndarray gate: probability of extra zeros given via a Bernoulli distribution. :param numpy.ndarray gate_logits: logits of extra zeros given via a Bernoulli distribution. """ - if (gate is None) == (gate_logits is None): - raise ValueError( - "Either `gate` or `gate_logits` must be specified, but not both." - ) + assert_one_of(gate=gate, gate_logits=gate_logits) if gate is not None: return ZeroInflatedProbs(base_dist, gate, validate_args=validate_args) else: @@ -947,9 +941,8 @@ def entropy(self): def Geometric(probs=None, logits=None, *, validate_args=None): + assert_one_of(probs=probs, logits=logits) if probs is not None: return GeometricProbs(probs, validate_args=validate_args) elif logits is not None: return GeometricLogits(logits, validate_args=validate_args) - else: - raise ValueError("One of `probs` or `logits` must be specified.") diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index fa9e4613c..6eec9f4d3 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -509,8 +509,9 @@ def infer_shapes(cls, *args, **kwargs): # Assumes distribution is univariate. batch_shapes = [] for name, shape in kwargs.items(): - event_dim = cls.arg_constraints.get(name, constraints.real).event_dim - batch_shapes.append(shape[: len(shape) - event_dim]) + if shape is not None: + event_dim = cls.arg_constraints.get(name, constraints.real).event_dim + batch_shapes.append(shape[: len(shape) - event_dim]) batch_shape = lax.broadcast_shapes(*batch_shapes) if batch_shapes else () event_shape = () return batch_shape, event_shape diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index c2e7aafc2..aca32b1f7 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -12,6 +12,7 @@ from jax import jit, lax, random, vmap import jax.numpy as jnp from jax.scipy.linalg import solve_triangular +from jax.scipy.special import digamma # Parameters for Transformed Rejection with Squeeze (TRS) algorithm - page 3. _tr_params = namedtuple( @@ -623,6 +624,31 @@ def is_prng_key(key): return False +def assert_one_of(**kwargs): + """ + Assert that exactly one of the keyword arguments is not None. + """ + specified = [key for key, value in kwargs.items() if value is not None] + if len(specified) != 1: + raise ValueError( + f"Exactly one of {list(kwargs)} must be specified; got {specified}." + ) + + +def multidigamma(a: jnp.ndarray, d: jnp.ndarray) -> jnp.ndarray: + """ + Derivative of the log of multivariate gamma. + """ + return digamma(a[..., None] - 0.5 * jnp.arange(d)).sum(axis=-1) + + +def tri_logabsdet(a: jnp.ndarray) -> jnp.ndarray: + """ + Evaluate the `logabsdet` of a triangular positive-definite matrix. + """ + return jnp.log(jnp.diagonal(a, axis1=-1, axis2=-2)).sum(axis=-1) + + # The is sourced from: torch.distributions.util.py # # Copyright (c) 2016- Facebook, Inc (Adam Paszke) diff --git a/test/test_distributions.py b/test/test_distributions.py index 43360b74b..3487ab745 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -131,6 +131,14 @@ def _truncnorm_to_scipy(loc, scale, low, high): return osp.truncnorm(a, b, loc=loc, scale=scale) +def _wishart_to_scipy(conc, scale, rate, tril): + jax_dist = dist.Wishart(conc, scale, rate, tril) + if not jnp.isscalar(jax_dist.concentration): + pytest.skip("scipy Wishart only supports a single scalar concentration") + # Cast to float explicitly because np.isscalar returns False on scalar jax arrays. + return osp.wishart(float(jax_dist.concentration), jax_dist.scale_matrix) + + def _TruncatedNormal(loc, scale, low, high): return dist.TruncatedNormal(loc=loc, scale=scale, low=low, high=high) @@ -444,6 +452,7 @@ def __init__( c=conc, scale=scale, ), + dist.Wishart: _wishart_to_scipy, _TruncatedNormal: _truncnorm_to_scipy, } @@ -775,6 +784,78 @@ def get_sp_dist(jax_dist): T(dist.Weibull, 0.2, 1.1), T(dist.Weibull, 2.8, np.array([2.0, 2.0])), T(dist.Weibull, 1.8, np.array([[1.0, 1.0], [2.0, 2.0]])), + T(dist.Wishart, 3, 2 * np.eye(2) + 0.1, None, None), + T( + dist.Wishart, + 3.0, + None, + np.array([[1.0, 0.5], [0.5, 1.0]]), + None, + ), + T( + dist.Wishart, + np.array([4.0, 5.0]), + None, + np.array([[[1.0, 0.5], [0.5, 1.0]]]), + None, + ), + T( + dist.Wishart, + np.array([3.0]), + None, + None, + np.array([[1.0, 0.0], [0.5, 1.0]]), + ), + T( + dist.Wishart, + np.arange(3, 9, dtype=np.float32).reshape((3, 2)), + None, + None, + np.array([[1.0, 0.0], [0.0, 1.0]]), + ), + T( + dist.Wishart, + 9.0, + None, + np.broadcast_to(np.identity(3), (2, 3, 3)), + None, + ), + T(dist.WishartCholesky, 3, 2 * np.eye(2) + 0.1, None, None), + T( + dist.WishartCholesky, + 3.0, + None, + np.array([[1.0, 0.5], [0.5, 1.0]]), + None, + ), + T( + dist.WishartCholesky, + np.array([4.0, 5.0]), + None, + np.array([[[1.0, 0.5], [0.5, 1.0]]]), + None, + ), + T( + dist.WishartCholesky, + np.array([3.0]), + None, + None, + np.array([[1.0, 0.0], [0.5, 1.0]]), + ), + T( + dist.WishartCholesky, + np.arange(3, 9, dtype=np.float32).reshape((3, 2)), + None, + None, + np.array([[1.0, 0.0], [0.0, 1.0]]), + ), + T( + dist.WishartCholesky, + 9.0, + None, + np.broadcast_to(np.identity(3), (2, 3, 3)), + None, + ), T(dist.ZeroSumNormal, 1.0, (5,)), T(dist.ZeroSumNormal, np.array([2.0]), (5,)), T(dist.ZeroSumNormal, 1.0, (4, 5)), @@ -1120,7 +1201,13 @@ def test_dist_shape(jax_dist_cls, sp_dist, params, prepend_shape): and not isinstance(jax_dist, dist.MultivariateStudentT) ): sp_dist = sp_dist(*params) - sp_samples = sp_dist.rvs(size=prepend_shape + jax_dist.batch_shape) + size = prepend_shape + jax_dist.batch_shape + # The scipy implementation of the Wishart distribution cannot handle an empty + # tuple as the sample size so we replace it by `1` which generates a single + # sample without any sample shape. + if isinstance(jax_dist, dist.Wishart): + size = size or 1 + sp_samples = sp_dist.rvs(size=size) assert jnp.shape(sp_samples) == expected_shape elif ( sp_dist @@ -1148,8 +1235,15 @@ def test_dist_shape(jax_dist_cls, sp_dist, params, prepend_shape): "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL ) def test_infer_shapes(jax_dist, sp_dist, params): - shapes = tuple(getattr(p, "shape", ()) for p in params) - shapes = tuple(x() if callable(x) else x for x in shapes) + shapes = [] + for param in params: + if param is None: + shapes.append(None) + continue + shape = getattr(param, "shape", ()) + if callable(shape): + shape = shape() + shapes.append(shape) jax_dist = jax_dist(*params) try: expected_batch_shape, expected_event_shape = type(jax_dist).infer_shapes( @@ -1394,21 +1488,43 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit): @pytest.mark.parametrize( "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL ) -def test_entropy(jax_dist, sp_dist, params): +def test_entropy_scipy(jax_dist, sp_dist, params): jax_dist = jax_dist(*params) + try: + actual = jax_dist.entropy() + except NotImplementedError: + pytest.skip(reason="distribution does not implement `entropy`") if _is_batched_multivariate(jax_dist): pytest.skip("batching not allowed in multivariate distns.") if sp_dist is None: pytest.skip(reason="no corresponding scipy distribution") + + sp_dist = sp_dist(*params) + expected = sp_dist.entropy() + assert_allclose(actual, expected, atol=1e-5) + + +@pytest.mark.parametrize( + "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL +) +def test_entropy_samples(jax_dist, sp_dist, params): + jax_dist = jax_dist(*params) + try: actual = jax_dist.entropy() except NotImplementedError: pytest.skip(reason="distribution does not implement `entropy`") - sp_dist = sp_dist(*params) - expected = sp_dist.entropy() - assert_allclose(actual, expected, atol=1e-5) + samples = jax_dist.sample(jax.random.key(8), (1000,)) + neg_log_probs = -jax_dist.log_prob(samples) + mean = neg_log_probs.mean(axis=0) + stderr = neg_log_probs.std(axis=0) / jnp.sqrt(neg_log_probs.shape[-1] - 1) + z = (actual - mean) / stderr + + # Check the z-score is small or that all values are close. This happens, for + # example, for uniform distributions with constant log prob and hence zero stderr. + assert (jnp.abs(z) < 5).all() or jnp.allclose(actual, neg_log_probs, atol=1e-5) def test_entropy_categorical(): @@ -1481,7 +1597,7 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params): def test_gof(jax_dist, sp_dist, params): if "Improper" in jax_dist.__name__: pytest.skip("distribution has improper .log_prob()") - if "LKJ" in jax_dist.__name__: + if "LKJ" in jax_dist.__name__ or "Wishart" in jax_dist.__name__: pytest.xfail("incorrect submanifold scaling") if jax_dist is dist.EulerMaruyama: d = jax_dist(*params) diff --git a/test/test_transforms.py b/test/test_transforms.py index f35345193..e3ada6fd9 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -360,13 +360,13 @@ def test_batched_recursive_linear_transform(): (constraints.circular, (3,)), (constraints.complex, (3,)), (constraints.corr_cholesky, (10, 10)), - (constraints.corr_matrix, (21,)), + (constraints.corr_matrix, (15,)), (constraints.greater_than(3), ()), (constraints.greater_than_eq(3), ()), (constraints.interval(8, 13), (17,)), (constraints.l1_ball, (4,)), (constraints.less_than(-1), ()), - (constraints.lower_cholesky, (21,)), + (constraints.lower_cholesky, (15,)), (constraints.open_interval(3, 4), ()), (constraints.ordered_vector, (5,)), (constraints.positive_definite, (6,)), @@ -376,9 +376,9 @@ def test_batched_recursive_linear_transform(): (constraints.real_matrix, (17,)), (constraints.real_vector, (18,)), (constraints.real, (3,)), - (constraints.scaled_unit_lower_cholesky, (21,)), + (constraints.scaled_unit_lower_cholesky, (15,)), (constraints.simplex, (3,)), - (constraints.softplus_lower_cholesky, (21,)), + (constraints.softplus_lower_cholesky, (15,)), (constraints.softplus_positive, (2,)), (constraints.unit_interval, (4,)), (constraints.nonnegative, (7,)),