From bdd7f8ee051e16458b829945043b3d73ea1a2566 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Sat, 13 Apr 2024 11:04:02 -0400 Subject: [PATCH] Implement `infer_shapes` for `Wishart` and `WishartCholesky`. --- numpyro/distributions/continuous.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 0e1d5f515..a9cf6613e 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2616,6 +2616,14 @@ def variance(self): self.scale_matrix**2 + diag[..., :, None] * diag[..., None, :] ) + @staticmethod + def infer_shapes( + concentration=(), scale_matrix=None, rate_matrix=None, scale_tril=None + ): + return WishartCholesky.infer_shapes( + concentration, scale_matrix, rate_matrix, scale_tril + ) + class WishartCholesky(Distribution): """ @@ -2764,3 +2772,18 @@ def variance(self): jnp.ones_like(k, shape=k.shape + (k.shape[-1],)).at[..., i, i].set(k) ) return jnp.square(self.scale_tril) @ latent - jnp.square(self.mean) + + @staticmethod + def infer_shapes( + concentration=(), scale_matrix=None, rate_matrix=None, scale_tril=None + ): + assert_one_of( + scale_matrix=scale_matrix, + rate_matrix=rate_matrix, + scale_tril=scale_tril, + ) + for matrix in [scale_matrix, rate_matrix, scale_tril]: + if matrix is not None: + batch_shape = lax.broadcast_shapes(concentration, matrix[:-2]) + event_shape = matrix[-2:] + return batch_shape, event_shape