From d8328684d52b31e867a89c7943de553e4e9da539 Mon Sep 17 00:00:00 2001 From: Brendan Cooley Date: Fri, 31 May 2024 23:12:49 -0400 Subject: [PATCH] fix(contrib.hsgp): fix incorrect dim arg to partial spectral density fns --- numpyro/contrib/hsgp/spectral_densities.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numpyro/contrib/hsgp/spectral_densities.py b/numpyro/contrib/hsgp/spectral_densities.py index ea5e83f69..8befff885 100644 --- a/numpyro/contrib/hsgp/spectral_densities.py +++ b/numpyro/contrib/hsgp/spectral_densities.py @@ -117,7 +117,7 @@ def diag_spectral_density_squared_exponential( def _spectral_density(w): return spectral_density_squared_exponential( - dim=1, w=w, alpha=alpha, length=length + dim=dim, w=w, alpha=alpha, length=length ) sqrt_eigenvalues_ = sqrt_eigenvalues(ell=ell, m=m, dim=dim) # dim x m @@ -151,7 +151,7 @@ def diag_spectral_density_matern( """ def _spectral_density(w): - return spectral_density_matern(dim=1, nu=nu, w=w, alpha=alpha, length=length) + return spectral_density_matern(dim=dim, nu=nu, w=w, alpha=alpha, length=length) sqrt_eigenvalues_ = sqrt_eigenvalues(ell=ell, m=m, dim=dim) return vmap(_spectral_density, in_axes=-1)(sqrt_eigenvalues_)