diff --git a/numpyro/contrib/hsgp/spectral_densities.py b/numpyro/contrib/hsgp/spectral_densities.py index ea5e83f69..8befff885 100644 --- a/numpyro/contrib/hsgp/spectral_densities.py +++ b/numpyro/contrib/hsgp/spectral_densities.py @@ -117,7 +117,7 @@ def diag_spectral_density_squared_exponential( def _spectral_density(w): return spectral_density_squared_exponential( - dim=1, w=w, alpha=alpha, length=length + dim=dim, w=w, alpha=alpha, length=length ) sqrt_eigenvalues_ = sqrt_eigenvalues(ell=ell, m=m, dim=dim) # dim x m @@ -151,7 +151,7 @@ def diag_spectral_density_matern( """ def _spectral_density(w): - return spectral_density_matern(dim=1, nu=nu, w=w, alpha=alpha, length=length) + return spectral_density_matern(dim=dim, nu=nu, w=w, alpha=alpha, length=length) sqrt_eigenvalues_ = sqrt_eigenvalues(ell=ell, m=m, dim=dim) return vmap(_spectral_density, in_axes=-1)(sqrt_eigenvalues_)