From f0cbaa90f7e1d445d36f90fa8a6b3317a97a2a8d Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Mon, 24 Jun 2024 08:56:09 -0500 Subject: [PATCH] Use broadcast_shapes to align params --- numpyro/contrib/hsgp/spectral_densities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/contrib/hsgp/spectral_densities.py b/numpyro/contrib/hsgp/spectral_densities.py index d1787a722..4762d5340 100644 --- a/numpyro/contrib/hsgp/spectral_densities.py +++ b/numpyro/contrib/hsgp/spectral_densities.py @@ -17,7 +17,7 @@ def align_param(dim, param): - return jnp.broadcast_arrays(param, jnp.zeros(dim))[0] + return jnp.broadcast_to(param, jnp.broadcast_shapes(jnp.shape(param), (dim,))) def spectral_density_squared_exponential(