Skip to content

Commit

Permalink
fix(contrib.hsgp): fix incorrect dim arg to partial spectral density …
Browse files Browse the repository at this point in the history
…fns (#1808)
  • Loading branch information
brendancooley authored Jun 1, 2024
1 parent e8216d7 commit 401e364
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions numpyro/contrib/hsgp/spectral_densities.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def diag_spectral_density_squared_exponential(

def _spectral_density(w):
return spectral_density_squared_exponential(
dim=1, w=w, alpha=alpha, length=length
dim=dim, w=w, alpha=alpha, length=length
)

sqrt_eigenvalues_ = sqrt_eigenvalues(ell=ell, m=m, dim=dim) # dim x m
Expand Down Expand Up @@ -151,7 +151,7 @@ def diag_spectral_density_matern(
"""

def _spectral_density(w):
return spectral_density_matern(dim=1, nu=nu, w=w, alpha=alpha, length=length)
return spectral_density_matern(dim=dim, nu=nu, w=w, alpha=alpha, length=length)

sqrt_eigenvalues_ = sqrt_eigenvalues(ell=ell, m=m, dim=dim)
return vmap(_spectral_density, in_axes=-1)(sqrt_eigenvalues_)
Expand Down

0 comments on commit 401e364

Please sign in to comment.