Skip to content

Commit

Permalink
Merge pull request #429 from stephen-huan/kernel-diagonal
Browse files Browse the repository at this point in the history
feat(gpjax/kernels/base.py): add diagonal
  • Loading branch information
thomaspinder authored Aug 9, 2024
2 parents b69be96 + e697365 commit 7ae0adf
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 4 deletions.
3 changes: 3 additions & 0 deletions gpjax/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def cross_covariance(self, x: Num[Array, "N D"], y: Num[Array, "M D"]):
def gram(self, x: Num[Array, "N D"]):
return self.compute_engine.gram(self, x)

def diagonal(self, x: Num[Array, "N D"]):
return self.compute_engine.diagonal(self, x)

def slice_input(self, x: Float[Array, "... D"]) -> Float[Array, "... Q"]:
r"""Slice out the relevant columns of the input matrix.
Expand Down
15 changes: 15 additions & 0 deletions gpjax/kernels/computations/basis_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from cola import PSD
from cola.ops import (
Dense,
Diagonal,
LinearOperator,
)

Expand Down Expand Up @@ -58,6 +59,20 @@ def gram(self, kernel: Kernel, inputs: Float[Array, "N D"]) -> LinearOperator:
z1 = self.compute_features(kernel, inputs)
return PSD(Dense(self.scaling(kernel) * jnp.matmul(z1, z1.T)))

def diagonal(self, kernel: Kernel, inputs: Float[Array, "N D"]) -> Diagonal:
r"""For a given kernel, compute the elementwise diagonal of the
NxN gram matrix on an input matrix of shape NxD.
Args:
kernel (AbstractKernel): the kernel function.
inputs (Float[Array, "N D"]): The input matrix.
Returns
-------
Diagonal: The computed diagonal variance entries.
"""
return super().diagonal(kernel.base_kernel, inputs)

def compute_features(
self, kernel: Kernel, x: Float[Array, "N D"]
) -> Float[Array, "N L"]:
Expand Down
31 changes: 30 additions & 1 deletion tests/test_kernels/test_approximations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Tuple

from cola.ops import Dense
from cola.ops import (
Dense,
Diagonal,
)
import jax
from jax import config
import jax.numpy as jnp
Expand Down Expand Up @@ -63,6 +66,32 @@ def test_gram(kernel: AbstractKernel, num_basis_fns: int, n_dims: int, n_data: i
assert jnp.all(evals > 0)


@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52])
@pytest.mark.parametrize("num_basis_fns", [2, 10, 20])
@pytest.mark.parametrize("n_dims", [1, 2, 5])
@pytest.mark.parametrize("n_data", [50, 100])
def test_diagonal(kernel: AbstractKernel, num_basis_fns: int, n_dims: int, n_data: int):
key = jr.key(123)
x = jr.uniform(key, shape=(n_data, 1), minval=-3.0, maxval=3.0).reshape(-1, 1)
if n_dims > 1:
x = jnp.hstack([x] * n_dims)
base_kernel = kernel(active_dims=list(range(n_dims)))
approximate = RFF(base_kernel=base_kernel, num_basis_fns=num_basis_fns)

linop = approximate.diagonal(x)

# Check the return type
assert isinstance(linop, Diagonal)

Kxx = linop.diag + _jitter

# Check that the shape is correct
assert Kxx.shape == (n_data,)

# Check that the diagonal is positive
assert jnp.all(Kxx > 0)


@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52])
@pytest.mark.parametrize("num_basis_fns", [2, 10, 20])
@pytest.mark.parametrize("n_dims", [1, 2, 5])
Expand Down
29 changes: 27 additions & 2 deletions tests/test_kernels/test_nonstationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from itertools import product
from typing import List

from cola.ops import LinearOperator
from cola.ops import (
Diagonal,
LinearOperator,
)
import jax
from jax import config
import jax.numpy as jnp
Expand Down Expand Up @@ -125,9 +128,28 @@ def test_gram(self, dim: int, n: int) -> None:

# Test gram matrix
Kxx = kernel.gram(x)
Kxx_cross = kernel.cross_covariance(x, x)
assert isinstance(Kxx, LinearOperator)
assert Kxx.shape == (n, n)
assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * 1e-6) > 0.0)
assert jnp.allclose(Kxx_cross, Kxx.to_dense())

@pytest.mark.parametrize("n", [1, 2, 5], ids=lambda x: f"n={x}")
@pytest.mark.parametrize("dim", [1, 3], ids=lambda x: f"dim={x}")
def test_diagonal(self, dim: int, n: int) -> None:
# Initialise kernel
kernel: AbstractKernel = self.kernel()

# Inputs
x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim)

# Test diagonal
Kxx = kernel.diagonal(x)
Kxx_gram = jnp.diagonal(kernel.gram(x).to_dense())
assert isinstance(Kxx, Diagonal)
assert Kxx.shape == (n, n)
assert jnp.all(Kxx.diag + 1e-6 > 0.0)
assert jnp.allclose(Kxx_gram, Kxx.diag)

@pytest.mark.parametrize("n_a", [1, 2, 5], ids=lambda x: f"n_a={x}")
@pytest.mark.parametrize("n_b", [1, 2, 5], ids=lambda x: f"n_b={x}")
Expand All @@ -139,11 +161,14 @@ def test_cross_covariance(self, n_a: int, n_b: int, dim: int) -> None:
# Inputs
a = jnp.linspace(-1.0, 1.0, n_a * dim).reshape(n_a, dim)
b = jnp.linspace(3.0, 4.0, n_b * dim).reshape(n_b, dim)
c = jnp.vstack((a, b))

# Test cross-covariance
Kab = kernel.cross_covariance(a, b)
Kab_gram = kernel.gram(c).to_dense()[:n_a, n_a:]
assert isinstance(Kab, jnp.ndarray)
assert Kab.shape == (n_a, n_b)
assert jnp.allclose(Kab, Kab_gram)


def prod(inp):
Expand Down Expand Up @@ -216,4 +241,4 @@ def test_values_by_monte_carlo_in_special_case(self, order: int) -> None:
integrands = H_a * H_b * (weights_a**order) * (weights_b**order)
Kab_approx = 2.0 * jnp.mean(integrands)

assert jnp.max(Kab_approx - Kab_exact) < 1e-4
assert jnp.max(jnp.abs(Kab_approx - Kab_exact)) < 1e-4
27 changes: 26 additions & 1 deletion tests/test_kernels/test_stationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from dataclasses import is_dataclass
from itertools import product

from cola.ops import LinearOperator
from cola.ops import (
Diagonal,
LinearOperator,
)
import jax
from jax import config
import jax.numpy as jnp
Expand Down Expand Up @@ -129,9 +132,28 @@ def test_gram(self, dim: int, n: int) -> None:

# Test gram matrix
Kxx = kernel.gram(x)
Kxx_cross = kernel.cross_covariance(x, x)
assert isinstance(Kxx, LinearOperator)
assert Kxx.shape == (n, n)
assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * 1e-6) > 0.0)
assert jnp.allclose(Kxx_cross, Kxx.to_dense())

@pytest.mark.parametrize("n", [1, 2, 5], ids=lambda x: f"n={x}")
@pytest.mark.parametrize("dim", [1, 3], ids=lambda x: f"dim={x}")
def test_diagonal(self, dim: int, n: int) -> None:
# Initialise kernel
kernel: AbstractKernel = self.kernel()

# Inputs
x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim)

# Test diagonal
Kxx = kernel.diagonal(x)
Kxx_gram = jnp.diagonal(kernel.gram(x).to_dense())
assert isinstance(Kxx, Diagonal)
assert Kxx.shape == (n, n)
assert jnp.all(Kxx.diag + 1e-6 > 0.0)
assert jnp.allclose(Kxx_gram, Kxx.diag)

@pytest.mark.parametrize("n_a", [1, 2, 5], ids=lambda x: f"n_a={x}")
@pytest.mark.parametrize("n_b", [1, 2, 5], ids=lambda x: f"n_b={x}")
Expand All @@ -143,11 +165,14 @@ def test_cross_covariance(self, n_a: int, n_b: int, dim: int) -> None:
# Inputs
a = jnp.linspace(-1.0, 1.0, n_a * dim).reshape(n_a, dim)
b = jnp.linspace(3.0, 4.0, n_b * dim).reshape(n_b, dim)
c = jnp.vstack((a, b))

# Test cross-covariance
Kab = kernel.cross_covariance(a, b)
Kab_gram = kernel.gram(c).to_dense()[:n_a, n_a:]
assert isinstance(Kab, jnp.ndarray)
assert Kab.shape == (n_a, n_b)
assert jnp.allclose(Kab, Kab_gram)

def test_spectral_density(self):
# Initialise kernel
Expand Down

0 comments on commit 7ae0adf

Please sign in to comment.