Skip to content

Commit

Permalink
Implement infer_shapes for Wishart and WishartCholesky.
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed Apr 21, 2024
1 parent a431bd9 commit d7fe610
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -2616,6 +2616,14 @@ def variance(self):
self.scale_matrix**2 + diag[..., :, None] * diag[..., None, :]
)

@staticmethod
def infer_shapes(
concentration=(), scale_matrix=None, rate_matrix=None, scale_tril=None
):
return WishartCholesky.infer_shapes(
concentration, scale_matrix, rate_matrix, scale_tril
)


class WishartCholesky(Distribution):
"""
Expand Down Expand Up @@ -2764,3 +2772,18 @@ def variance(self):
jnp.ones_like(k, shape=k.shape + (k.shape[-1],)).at[..., i, i].set(k)
)
return jnp.square(self.scale_tril) @ latent - jnp.square(self.mean)

@staticmethod
def infer_shapes(
concentration=(), scale_matrix=None, rate_matrix=None, scale_tril=None
):
assert_one_of(
scale_matrix=scale_matrix,
rate_matrix=rate_matrix,
scale_tril=scale_tril,
)
for matrix in [scale_matrix, rate_matrix, scale_tril]:
if matrix is not None:
batch_shape = lax.broadcast_shapes(concentration, matrix[:-2])
event_shape = matrix[-2:]
return batch_shape, event_shape

0 comments on commit d7fe610

Please sign in to comment.