From 4667bb967d50094d2804d0156bdc9810c251f863 Mon Sep 17 00:00:00 2001 From: Stephen Huan Date: Wed, 27 Dec 2023 23:33:00 -0800 Subject: [PATCH 1/5] feat(gpjax/kernels/base.py): add diagonal --- gpjax/kernels/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gpjax/kernels/base.py b/gpjax/kernels/base.py index ff9e7f8b6..1f07295ed 100644 --- a/gpjax/kernels/base.py +++ b/gpjax/kernels/base.py @@ -63,6 +63,9 @@ def cross_covariance(self, x: Num[Array, "N D"], y: Num[Array, "M D"]): def gram(self, x: Num[Array, "N D"]): return self.compute_engine.gram(self, x) + def diagonal(self, x: Num[Array, "N D"]): + return self.compute_engine.diagonal(self, x) + def slice_input(self, x: Float[Array, "... D"]) -> Float[Array, "... Q"]: r"""Slice out the relevant columns of the input matrix. From de906f775587e8ce875f09633d785496fbc4f7b9 Mon Sep 17 00:00:00 2001 From: Stephen Huan Date: Tue, 2 Apr 2024 03:54:46 -0400 Subject: [PATCH 2/5] feat(gpjax/kernels/computations/basis_functions.py): add diagonal --- gpjax/kernels/computations/basis_functions.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/gpjax/kernels/computations/basis_functions.py b/gpjax/kernels/computations/basis_functions.py index e0693f129..d62d144ff 100644 --- a/gpjax/kernels/computations/basis_functions.py +++ b/gpjax/kernels/computations/basis_functions.py @@ -12,6 +12,7 @@ from cola import PSD from cola.ops import ( Dense, + Diagonal, LinearOperator, ) @@ -58,6 +59,20 @@ def gram(self, kernel: Kernel, inputs: Float[Array, "N D"]) -> LinearOperator: z1 = self.compute_features(kernel, inputs) return PSD(Dense(self.scaling(kernel) * jnp.matmul(z1, z1.T))) + def diagonal(self, kernel: Kernel, inputs: Float[Array, "N D"]) -> Diagonal: + r"""For a given kernel, compute the elementwise diagonal of the + NxN gram matrix on an input matrix of shape NxD. + + Args: + kernel (AbstractKernel): the kernel function. + inputs (Float[Array, "N D"]): The input matrix. + + Returns + ------- + Diagonal: The computed diagonal variance entries. + """ + return super().diagonal(kernel.base_kernel, inputs) + def compute_features( self, kernel: Kernel, x: Float[Array, "N D"] ) -> Float[Array, "N L"]: From 07c5561332d149c71a62b967789f128c198858b8 Mon Sep 17 00:00:00 2001 From: Stephen Huan Date: Tue, 2 Apr 2024 03:55:25 -0400 Subject: [PATCH 3/5] test(tests/test_kernels): add diagonal tests --- tests/test_kernels/test_approximations.py | 31 ++++++++++++++++++++++- tests/test_kernels/test_nonstationary.py | 20 ++++++++++++++- tests/test_kernels/test_stationary.py | 20 ++++++++++++++- 3 files changed, 68 insertions(+), 3 deletions(-) diff --git a/tests/test_kernels/test_approximations.py b/tests/test_kernels/test_approximations.py index ca71f5901..17b60d973 100644 --- a/tests/test_kernels/test_approximations.py +++ b/tests/test_kernels/test_approximations.py @@ -1,6 +1,9 @@ from typing import Tuple -from cola.ops import Dense +from cola.ops import ( + Dense, + Diagonal, +) import jax from jax import config import jax.numpy as jnp @@ -63,6 +66,32 @@ def test_gram(kernel: AbstractKernel, num_basis_fns: int, n_dims: int, n_data: i assert jnp.all(evals > 0) +@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) +@pytest.mark.parametrize("num_basis_fns", [2, 10, 20]) +@pytest.mark.parametrize("n_dims", [1, 2, 5]) +@pytest.mark.parametrize("n_data", [50, 100]) +def test_diagonal(kernel: AbstractKernel, num_basis_fns: int, n_dims: int, n_data: int): + key = jr.key(123) + x = jr.uniform(key, shape=(n_data, 1), minval=-3.0, maxval=3.0).reshape(-1, 1) + if n_dims > 1: + x = jnp.hstack([x] * n_dims) + base_kernel = kernel(active_dims=list(range(n_dims))) + approximate = RFF(base_kernel=base_kernel, num_basis_fns=num_basis_fns) + + linop = approximate.diagonal(x) + + # Check the return type + assert isinstance(linop, Diagonal) + + Kxx = linop.diag + _jitter + + # Check that the shape is correct + assert Kxx.shape == (n_data,) + + # Check that the diagonal is positive + assert jnp.all(Kxx > 0) + + @pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) @pytest.mark.parametrize("num_basis_fns", [2, 10, 20]) @pytest.mark.parametrize("n_dims", [1, 2, 5]) diff --git a/tests/test_kernels/test_nonstationary.py b/tests/test_kernels/test_nonstationary.py index 9ab84cdd1..27e4ce7c4 100644 --- a/tests/test_kernels/test_nonstationary.py +++ b/tests/test_kernels/test_nonstationary.py @@ -16,7 +16,10 @@ from itertools import product from typing import List -from cola.ops import LinearOperator +from cola.ops import ( + Diagonal, + LinearOperator, +) import jax from jax import config import jax.numpy as jnp @@ -129,6 +132,21 @@ def test_gram(self, dim: int, n: int) -> None: assert Kxx.shape == (n, n) assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * 1e-6) > 0.0) + @pytest.mark.parametrize("n", [1, 2, 5], ids=lambda x: f"n={x}") + @pytest.mark.parametrize("dim", [1, 3], ids=lambda x: f"dim={x}") + def test_diagonal(self, dim: int, n: int) -> None: + # Initialise kernel + kernel: AbstractKernel = self.kernel() + + # Inputs + x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) + + # Test diagonal + Kxx = kernel.diagonal(x) + assert isinstance(Kxx, Diagonal) + assert Kxx.shape == (n, n) + assert jnp.all(Kxx.diag + 1e-6 > 0.0) + @pytest.mark.parametrize("n_a", [1, 2, 5], ids=lambda x: f"n_a={x}") @pytest.mark.parametrize("n_b", [1, 2, 5], ids=lambda x: f"n_b={x}") @pytest.mark.parametrize("dim", [1, 2, 5], ids=lambda x: f"dim={x}") diff --git a/tests/test_kernels/test_stationary.py b/tests/test_kernels/test_stationary.py index 3e214b45b..0cb42d4a9 100644 --- a/tests/test_kernels/test_stationary.py +++ b/tests/test_kernels/test_stationary.py @@ -17,7 +17,10 @@ from dataclasses import is_dataclass from itertools import product -from cola.ops import LinearOperator +from cola.ops import ( + Diagonal, + LinearOperator, +) import jax from jax import config import jax.numpy as jnp @@ -133,6 +136,21 @@ def test_gram(self, dim: int, n: int) -> None: assert Kxx.shape == (n, n) assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * 1e-6) > 0.0) + @pytest.mark.parametrize("n", [1, 2, 5], ids=lambda x: f"n={x}") + @pytest.mark.parametrize("dim", [1, 3], ids=lambda x: f"dim={x}") + def test_diagonal(self, dim: int, n: int) -> None: + # Initialise kernel + kernel: AbstractKernel = self.kernel() + + # Inputs + x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) + + # Test diagonal + Kxx = kernel.diagonal(x) + assert isinstance(Kxx, Diagonal) + assert Kxx.shape == (n, n) + assert jnp.all(Kxx.diag + 1e-6 > 0.0) + @pytest.mark.parametrize("n_a", [1, 2, 5], ids=lambda x: f"n_a={x}") @pytest.mark.parametrize("n_b", [1, 2, 5], ids=lambda x: f"n_b={x}") @pytest.mark.parametrize("dim", [1, 2, 5], ids=lambda x: f"dim={x}") From b3e603c0b593b147cb60a3f31129606f6990041c Mon Sep 17 00:00:00 2001 From: Stephen Huan Date: Mon, 1 Jul 2024 15:45:47 -0400 Subject: [PATCH 4/5] test(tests/test_kernels): cross-consistency tests --- tests/test_kernels/test_nonstationary.py | 7 +++++++ tests/test_kernels/test_stationary.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/tests/test_kernels/test_nonstationary.py b/tests/test_kernels/test_nonstationary.py index 27e4ce7c4..a9a2f3f05 100644 --- a/tests/test_kernels/test_nonstationary.py +++ b/tests/test_kernels/test_nonstationary.py @@ -128,9 +128,11 @@ def test_gram(self, dim: int, n: int) -> None: # Test gram matrix Kxx = kernel.gram(x) + Kxx_cross = kernel.cross_covariance(x, x) assert isinstance(Kxx, LinearOperator) assert Kxx.shape == (n, n) assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * 1e-6) > 0.0) + assert jnp.allclose(Kxx_cross, Kxx.to_dense()) @pytest.mark.parametrize("n", [1, 2, 5], ids=lambda x: f"n={x}") @pytest.mark.parametrize("dim", [1, 3], ids=lambda x: f"dim={x}") @@ -143,9 +145,11 @@ def test_diagonal(self, dim: int, n: int) -> None: # Test diagonal Kxx = kernel.diagonal(x) + Kxx_gram = jnp.diagonal(kernel.gram(x).to_dense()) assert isinstance(Kxx, Diagonal) assert Kxx.shape == (n, n) assert jnp.all(Kxx.diag + 1e-6 > 0.0) + assert jnp.allclose(Kxx_gram, Kxx.diag) @pytest.mark.parametrize("n_a", [1, 2, 5], ids=lambda x: f"n_a={x}") @pytest.mark.parametrize("n_b", [1, 2, 5], ids=lambda x: f"n_b={x}") @@ -157,11 +161,14 @@ def test_cross_covariance(self, n_a: int, n_b: int, dim: int) -> None: # Inputs a = jnp.linspace(-1.0, 1.0, n_a * dim).reshape(n_a, dim) b = jnp.linspace(3.0, 4.0, n_b * dim).reshape(n_b, dim) + c = jnp.vstack((a, b)) # Test cross-covariance Kab = kernel.cross_covariance(a, b) + Kab_gram = kernel.gram(c).to_dense()[:n_a, n_a:] assert isinstance(Kab, jnp.ndarray) assert Kab.shape == (n_a, n_b) + assert jnp.allclose(Kab, Kab_gram) def prod(inp): diff --git a/tests/test_kernels/test_stationary.py b/tests/test_kernels/test_stationary.py index 0cb42d4a9..cc37429d7 100644 --- a/tests/test_kernels/test_stationary.py +++ b/tests/test_kernels/test_stationary.py @@ -132,9 +132,11 @@ def test_gram(self, dim: int, n: int) -> None: # Test gram matrix Kxx = kernel.gram(x) + Kxx_cross = kernel.cross_covariance(x, x) assert isinstance(Kxx, LinearOperator) assert Kxx.shape == (n, n) assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * 1e-6) > 0.0) + assert jnp.allclose(Kxx_cross, Kxx.to_dense()) @pytest.mark.parametrize("n", [1, 2, 5], ids=lambda x: f"n={x}") @pytest.mark.parametrize("dim", [1, 3], ids=lambda x: f"dim={x}") @@ -147,9 +149,11 @@ def test_diagonal(self, dim: int, n: int) -> None: # Test diagonal Kxx = kernel.diagonal(x) + Kxx_gram = jnp.diagonal(kernel.gram(x).to_dense()) assert isinstance(Kxx, Diagonal) assert Kxx.shape == (n, n) assert jnp.all(Kxx.diag + 1e-6 > 0.0) + assert jnp.allclose(Kxx_gram, Kxx.diag) @pytest.mark.parametrize("n_a", [1, 2, 5], ids=lambda x: f"n_a={x}") @pytest.mark.parametrize("n_b", [1, 2, 5], ids=lambda x: f"n_b={x}") @@ -161,11 +165,14 @@ def test_cross_covariance(self, n_a: int, n_b: int, dim: int) -> None: # Inputs a = jnp.linspace(-1.0, 1.0, n_a * dim).reshape(n_a, dim) b = jnp.linspace(3.0, 4.0, n_b * dim).reshape(n_b, dim) + c = jnp.vstack((a, b)) # Test cross-covariance Kab = kernel.cross_covariance(a, b) + Kab_gram = kernel.gram(c).to_dense()[:n_a, n_a:] assert isinstance(Kab, jnp.ndarray) assert Kab.shape == (n_a, n_b) + assert jnp.allclose(Kab, Kab_gram) def test_spectral_density(self): # Initialise kernel From e697365f8e924730013f207bb43b129f80232a7c Mon Sep 17 00:00:00 2001 From: Stephen Huan Date: Mon, 1 Jul 2024 15:46:41 -0400 Subject: [PATCH 5/5] test(tests/test_kernels/test_nonstationary): abs --- tests/test_kernels/test_nonstationary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_kernels/test_nonstationary.py b/tests/test_kernels/test_nonstationary.py index a9a2f3f05..b359b2026 100644 --- a/tests/test_kernels/test_nonstationary.py +++ b/tests/test_kernels/test_nonstationary.py @@ -241,4 +241,4 @@ def test_values_by_monte_carlo_in_special_case(self, order: int) -> None: integrands = H_a * H_b * (weights_a**order) * (weights_b**order) Kab_approx = 2.0 * jnp.mean(integrands) - assert jnp.max(Kab_approx - Kab_exact) < 1e-4 + assert jnp.max(jnp.abs(Kab_approx - Kab_exact)) < 1e-4