Skip to content

Commit

Permalink
Support vector lengthscales for RBF and Matern kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam Anklesaria committed Jun 21, 2024
1 parent 0ba1306 commit 8e80d6e
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 23 deletions.
5 changes: 2 additions & 3 deletions numpyro/contrib/hsgp/laplacian.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@

from __future__ import annotations

from typing import Union, get_args
from typing import get_args

from jaxlib.xla_extension import ArrayImpl
import numpy as np

import jax.numpy as jnp

ARRAY_TYPE = Union[ArrayImpl, np.ndarray]
from numpyro.contrib.hsgp.util import ARRAY_TYPE


def eigenindices(m: list[int] | int, dim: int) -> ArrayImpl:
Expand Down
23 changes: 15 additions & 8 deletions numpyro/contrib/hsgp/spectral_densities.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@
from jax.scipy import special

from numpyro.contrib.hsgp.laplacian import sqrt_eigenvalues
from numpyro.contrib.hsgp.util import ARRAY_TYPE


def align_param(dim, param):
return jnp.broadcast_arrays(param, jnp.zeros(dim))[0]


def spectral_density_squared_exponential(
dim: int, w: ArrayImpl, alpha: float, length: float
dim: int, w: ARRAY_TYPE, alpha: float, length: float | ARRAY_TYPE
) -> float:
"""
Spectral density of the squared exponential kernel.
Expand All @@ -44,13 +49,14 @@ def spectral_density_squared_exponential(
:return: spectral density value
:rtype: float
"""
c = alpha * (jnp.sqrt(2 * jnp.pi) * length) ** dim
e = jnp.exp(-0.5 * (length**2) * jnp.dot(w, w))
length = align_param(dim, length)
c = alpha * jnp.prod(jnp.sqrt(2 * jnp.pi) * length, axis=-1)
e = jnp.exp(-0.5 * jnp.sum(w**2 * length**2, axis=-1))
return c * e


def spectral_density_matern(
dim: int, nu: float, w: ArrayImpl, alpha: float, length: float
dim: int, nu: float, w: ArrayImpl, alpha: float, length: float | ArrayImpl
) -> float:
"""
Spectral density of the Matérn kernel.
Expand Down Expand Up @@ -79,22 +85,23 @@ def spectral_density_matern(
:return: spectral density value
:rtype: float
""" # noqa: E501
length = align_param(dim, length)
c1 = (
alpha
* (2 ** (dim))
* (jnp.pi ** (dim / 2))
* ((2 * nu) ** nu)
* special.gamma(nu + dim / 2)
)
c2 = (2 * nu / (length**2) + jnp.dot(w, w)) ** (-nu - dim / 2)
c3 = special.gamma(nu) * length ** (2 * nu)
s = jnp.sum(length**2 * w**2, axis=-1)
c2 = jnp.prod(length, axis=-1) * (2 * nu + s) ** (-nu - dim / 2)
c3 = special.gamma(nu)
return c1 * c2 / c3


# TODO support length-D kernel hyperparameters
def diag_spectral_density_squared_exponential(
alpha: float,
length: float,
length: float | list[float],
ell: float | int | list[float | int],
m: int | list[int],
dim: int,
Expand Down
10 changes: 10 additions & 0 deletions numpyro/contrib/hsgp/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from typing import Union

import numpy as np

import jax

ARRAY_TYPE = Union[jax.Array, np.ndarray] # jax.Array covers tracers
93 changes: 81 additions & 12 deletions test/contrib/hsgp/test_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,23 +74,50 @@ def synthetic_two_dim_data() -> tuple[ArrayImpl, ArrayImpl]:


@pytest.mark.parametrize(
argnames="x1, x2, length, ell",
argnames="x1, x2, length, ell, xfail",
argvalues=[
(np.array([[1.0]]), np.array([[0.0]]), np.array([1.0]), 5.0),
(np.array([[1.0]]), np.array([[0.0]]), 1.0, 5.0, False),
(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
np.array([1.0]),
1.0,
5.0,
False,
),
(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
np.array([1.0, 0.5]),
5.0,
False,
),
(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
np.array(
[[1.0, 0.5], [0.5, 1.0]]
), # different length scale for each point/dimension
5.0,
False,
),
(
np.array([[1.5, 1.25, 1.0]]),
np.array([[0.0, 0.0, 0.0]]),
np.array([[1.0, 0.5], [0.5, 1.0]]), # invalid length scale
5.0,
True,
),
],
ids=[
"1d",
"2d,1d-length",
"1d,scalar-length",
"2d,scalar-length",
"2d,vector-length",
"2d,matrix-length",
"2d,invalid-length",
],
)
def test_kernel_approx_squared_exponential(
x1: ArrayImpl, x2: ArrayImpl, length: ArrayImpl, ell: float
x1: ArrayImpl, x2: ArrayImpl, length: float | ArrayImpl, ell: float, xfail: bool
):
"""ensure that the approximation of the squared exponential kernel is accurate,
matching the exact kernel implementation from sklearn.
Expand All @@ -100,13 +127,26 @@ def test_kernel_approx_squared_exponential(
assert x1.shape == x2.shape
m = 100 # large enough to ensure the approximation is accurate
dim = x1.shape[-1]
if xfail:
with pytest.raises(ValueError):
diag_spectral_density_squared_exponential(1.0, length, ell, m, dim)
return
spd = diag_spectral_density_squared_exponential(1.0, length, ell, m, dim)

eig_f1 = eigenfunctions(x1, ell=ell, m=m)
eig_f2 = eigenfunctions(x2, ell=ell, m=m)
approx = (eig_f1 * eig_f2) @ spd
exact = RBF(length)(x1, x2)
assert jnp.isclose(approx, exact, rtol=1e-3)

def _exact_rbf(length):
return RBF(length)(x1, x2).squeeze(axis=-1)

if isinstance(length, float | int):
exact = _exact_rbf(length)
elif length.ndim == 1:
exact = _exact_rbf(length)
else:
exact = np.apply_along_axis(_exact_rbf, axis=0, arr=length)
assert jnp.isclose(approx, exact, rtol=1e-3).all()


@pytest.mark.parametrize(
Expand All @@ -118,14 +158,32 @@ def test_kernel_approx_squared_exponential(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
3 / 2,
np.array([1.0]),
np.array([0.25, 0.5]),
5.0,
),
(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
5 / 2,
np.array([1.0]),
np.array([0.25, 0.5]),
5.0,
),
(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
3 / 2,
np.array(
[[1.0, 0.5], [0.5, 1.0]]
), # different length scale for each point/dimension
5.0,
),
(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
5 / 2,
np.array(
[[1.0, 0.5], [0.5, 1.0]]
), # different length scale for each point/dimension
5.0,
),
],
Expand All @@ -134,6 +192,8 @@ def test_kernel_approx_squared_exponential(
"1d,nu=5/2",
"2d,nu=3/2,1d-length",
"2d,nu=5/2,1d-length",
"2d,nu=3/2,2d-length",
"2d,nu=5/2,2d-length",
],
)
def test_kernel_approx_squared_matern(
Expand All @@ -154,8 +214,17 @@ def test_kernel_approx_squared_matern(
eig_f1 = eigenfunctions(x1, ell=ell, m=m)
eig_f2 = eigenfunctions(x2, ell=ell, m=m)
approx = (eig_f1 * eig_f2) @ spd
exact = Matern(length_scale=length, nu=nu)(x1, x2)
assert jnp.isclose(approx, exact, rtol=1e-3)

def _exact_matern(length):
return Matern(length_scale=length, nu=nu)(x1, x2).squeeze(axis=-1)

if isinstance(length, float | int):
exact = _exact_matern(length)
elif length.ndim == 1:
exact = _exact_matern(length)
else:
exact = np.apply_along_axis(_exact_matern, axis=0, arr=length)
assert jnp.isclose(approx, exact, rtol=1e-3).all()


@pytest.mark.parametrize(
Expand Down

0 comments on commit 8e80d6e

Please sign in to comment.