Skip to content

Commit

Permalink
wip(hsgp_nd): support/test n-d approximations
Browse files Browse the repository at this point in the history
  • Loading branch information
brendancooley committed May 21, 2024
1 parent 4ea24f0 commit 812520f
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 76 deletions.
12 changes: 8 additions & 4 deletions numpyro/contrib/hsgp/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def hsgp_squared_exponential(
ell: float,
m: int,
non_centered: bool = True,
dim: int = 1, # TODO infer from data?
) -> ArrayImpl:
"""
Hilbert space Gaussian process approximation using the squared exponential kernel.
Expand All @@ -92,13 +91,16 @@ def hsgp_squared_exponential(
:return: the low-rank approximation linear model
:rtype: ArrayImpl
"""
dim = x.shape[-1] if x.ndim > 1 else 1
phi = eigenfunctions(x=x, ell=ell, m=m)
spd = jnp.sqrt(
diag_spectral_density_squared_exponential(
alpha=alpha, length=length, ell=ell, m=m, dim=dim
)
)
return linear_approximation(phi=phi, spd=spd, m=m, non_centered=non_centered)
return linear_approximation(
phi=phi, spd=spd, m=phi.shape[-1], non_centered=non_centered
)


def hsgp_matern(
Expand All @@ -109,7 +111,6 @@ def hsgp_matern(
ell: float,
m: int,
non_centered: bool = True,
dim: int = 1, # TODO infer from data?
):
"""
Hilbert space Gaussian process approximation using the Matérn kernel.
Expand Down Expand Up @@ -137,13 +138,16 @@ def hsgp_matern(
:return: the low-rank approximation linear model
:rtype: ArrayImpl
"""
dim = x.shape[-1] if x.ndim > 1 else 1
phi = eigenfunctions(x=x, ell=ell, m=m)
spd = jnp.sqrt(
diag_spectral_density_matern(
nu=nu, alpha=alpha, length=length, ell=ell, m=m, dim=dim
)
)
return linear_approximation(phi=phi, spd=spd, m=m, non_centered=non_centered)
return linear_approximation(
phi=phi, spd=spd, m=phi.shape[-1], non_centered=non_centered
)


def hsgp_periodic_non_centered(
Expand Down
66 changes: 29 additions & 37 deletions numpyro/contrib/hsgp/laplacian.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
This module contains functions for computing eigenvalues and eigenfunctions of the laplace operator.
"""


from jaxlib.xla_extension import ArrayImpl

import jax
import jax.numpy as jnp


def eigen_indices(m: list[int] | int, dim: int) -> ArrayImpl:
"""Returns the indices of the first `m_star x D` eigenvalues of the laplacian operator in `[-ell, ell]`.
"""Returns the indices of the first `m_star x D` eigenvalues of the laplacian operator.
.. math::
Expand All @@ -23,7 +22,7 @@ def eigen_indices(m: list[int] | int, dim: int) -> ArrayImpl:
1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).
:param Sequence[int] | int m: The number of eigenvalues to compute in each dimension.
:param Sequence[int] | int m: The number of desired eigenvalue indices in each dimension.
If an integer, the same number of eigenvalues is computed in each dimension.
:param int dim: The dimension of the space.
Expand All @@ -41,38 +40,20 @@ def eigen_indices(m: list[int] | int, dim: int) -> ArrayImpl:
>>> m = 10
>>> S = eigen_indices(m, 1)
>>> assert S.shape == (1, m)
>>> S.T
Array([[ 1],
[ 2],
[ 3],
[ 4],
[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
[10]], dtype=int32)
>>> S
Array([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)
>>> m = 10
>>> S = eigen_indices(m, 2)
>>> assert S.shape == (2, 100)
>>> m = [2, 2, 3]
>>> m = [2, 2, 3] # Ruitort-Mayol et al eq (10)
>>> S = eigen_indices(m, 3)
>>> assert S.shape == (3, 12)
>>> S.T
Array([[1, 1, 1],
[1, 1, 2],
[1, 1, 3],
[1, 2, 1],
[1, 2, 2],
[1, 2, 3],
[2, 1, 1],
[2, 1, 2],
[2, 1, 3],
[2, 2, 1],
[2, 2, 2],
[2, 2, 3]], dtype=int32)
>>> S
Array([[1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2],
[1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2],
[1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]], dtype=int32)
"""
if isinstance(m, int):
Expand All @@ -92,7 +73,8 @@ def sqrt_eigenvalues(
ell: float | list[float], m: list[int] | int, dim: int
) -> ArrayImpl:
"""
The first `m` square root of eigenvalues of the laplacian operator in `[-ell, ell]`. See Eq. (56) in [1].
The first `dim x m_star` square root of eigenvalues of the laplacian operator in
`[-ell_1, ell_1] x ... x [-ell_D, ell_D]`. See Eq. (56) in [1].
**References:**
Expand All @@ -101,7 +83,8 @@ def sqrt_eigenvalues(
:param Sequence[float] | float ell: The length of the interval in each dimension divided by 2.
If a float, the same length is used in each dimension.
:param int m: The number of eigenvalues to compute.
:param list[int] | int m: The number of eigenvalues to compute in each dimension.
If an integer, the same number of eigenvalues is computed in each dimension.
:param int dim: The dimension of the space.
:returns: An array of the first `m` square root of eigenvalues.
Expand All @@ -112,10 +95,12 @@ def sqrt_eigenvalues(
return S * jnp.pi / 2 / ell_ # dim x prod(m) array of eigenvalues


def eigenfunctions(x: ArrayImpl, ell: float | list[float], m: int) -> ArrayImpl:
def eigenfunctions(
x: ArrayImpl, ell: float | list[float], m: int | list[int]
) -> ArrayImpl:
"""
The first `m` eigenfunctions of the laplacian operator in `[-ell, ell]`
evaluated at `x`. See Eq. (56) in [1].
The first `m_star` eigenfunctions of the laplacian operator in `[-ell_1, ell_1] x ... x [-ell_D, ell_D]`
evaluated at values of `x`. See Eq. (56) in [1].
**Example:**
Expand All @@ -134,17 +119,24 @@ def eigenfunctions(x: ArrayImpl, ell: float | list[float], m: int) -> ArrayImpl:
>>> assert basis.shape == (n, m)
# TODO add batched test
>>> x = jnp.ones((n, 3)) # 2d input
>>> basis = eigenfunctions(x=x, ell=1.2, m=[2, 2, 3])
>>> assert basis.shape == (n, 12)
**References:**
1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression.
Stat Comput 30, 419-446 (2020)
:param ArrayImpl x: The points at which to evaluate the eigenfunctions.
:param float ell: The length of the interval divided by 2.
:param int m: The number of eigenfunctions to compute.
:returns: An array of the first `m` eigenfunctions evaluated at `x`.
If x is 1D the problem is assumed unidimensional.
Otherwise, the dimension of the input space is inferred as the last dimension of x.
Other dimensions are treated as batch dimensions.
:param float | list[float] ell: The length of the interval in each dimension divided by 2.
If a float, the same length is used in each dimension.
:param int | list[int] m: The number of eigenvalues to compute in each dimension.
If an integer, the same number of eigenvalues is computed in each dimension.
:returns: An array of the first `m_star` eigenfunctions evaluated at `x`.
:rtype: ArrayImpl
"""
if x.ndim == 1:
Expand Down
Loading

0 comments on commit 812520f

Please sign in to comment.