Skip to content

Commit

Permalink
Promote instead of broadcast Wishart parameters.
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed Apr 13, 2024
1 parent 0c3cd7b commit ec86dc3
Showing 1 changed file with 14 additions and 26 deletions.
40 changes: 14 additions & 26 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -2647,36 +2647,24 @@ def __init__(
*,
validate_args=None,
):
# Determine the shapes.
batch_shape = None
event_shape = None
for x in [scale_matrix, rate_matrix, scale_tril]:
if x is not None:
batch_shape = jnp.broadcast_shapes(
jnp.shape(concentration), jnp.shape(x)[:-2]
)
event_shape = jnp.shape(x)[-2:]
break
if event_shape is None:
raise ValueError(
"One of `scale_matrix`, `rate_matrix`, or `scale_tril` must be "
"specified."
)

# Coerce to scale_tril parameter.
concentration = jnp.asarray(concentration)[..., None, None]
if scale_matrix is not None:
self.scale_matrix = jnp.broadcast_to(
scale_matrix, batch_shape + event_shape
concentration, self.scale_matrix = promote_shapes(
concentration, scale_matrix
)
self.scale_tril = jnp.linalg.cholesky(scale_matrix)
self.scale_tril = jnp.linalg.cholesky(self.scale_matrix)
elif rate_matrix is not None:
self.rate_matrix = jnp.broadcast_to(rate_matrix, batch_shape + event_shape)
self.scale_tril = cholesky_of_inverse(rate_matrix)
concentration, self.rate_matrix = promote_shapes(concentration, rate_matrix)
self.scale_tril = cholesky_of_inverse(self.rate_matrix)
elif scale_tril is not None:
self.scale_tril = scale_tril

self.concentration = jnp.broadcast_to(concentration, batch_shape)
self.scale_tril = jnp.broadcast_to(self.scale_tril, batch_shape + event_shape)
concentration, self.scale_tril = promote_shapes(
concentration, jnp.asarray(scale_tril)
)
batch_shape = lax.broadcast_shapes(
jnp.shape(concentration)[:-2], jnp.shape(self.scale_tril)[:-2]
)
event_shape = jnp.shape(self.scale_tril)[-2:]
self.concentration = concentration[..., 0, 0]
super().__init__(
batch_shape=batch_shape,
event_shape=event_shape,
Expand Down

0 comments on commit ec86dc3

Please sign in to comment.