forked from pytorch/botorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Orthogonal Additive Kernels (pytorch#1869)
Summary: Pull Request resolved: pytorch#1869 OAKs were introduced in [Additive Gaussian Processes Revisited](https://arxiv.org/pdf/2206.09861.pdf) but were limited to Gaussian kernels with Gaussian data densities (which required the application of normalizing flows to the actual input data to make it look Gaussian). This commit introduces a generalization of OAKs that works with arbitrary base kernels by leveraging Gauss-Legendre quadrature rules for the associated one-dimensional integrals. OAKs could be more sample-efficient than canonical kernels in higher dimensions, and allow for more efficient relevance determination, because dimensions or interactions of dimensions can be pruned by setting their assoicated coefficients -- not just their lengthscales -- to zero. Reviewed By: Balandat Differential Revision: D45217852 fbshipit-source-id: 182ebab917a541029bd1f514044404147b84ed81
- Loading branch information
1 parent
7c0cdf6
commit 6d329a8
Showing
3 changed files
with
419 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,296 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import math | ||
from typing import List, Optional, Tuple | ||
|
||
import numpy | ||
import torch | ||
from botorch.exceptions.errors import UnsupportedError | ||
from gpytorch.constraints import Interval, Positive | ||
from gpytorch.kernels import Kernel | ||
|
||
from torch import nn, Tensor | ||
|
||
_positivity_constraint = Positive() | ||
|
||
|
||
class OrthogonalAdditiveKernel(Kernel): | ||
r"""Orthogonal Additive Kernels (OAKs) were introduced in [Lu2022additive]_, though | ||
only for the case of Gaussian base kernels with a Gaussian input data distribution. | ||
The implementation here generalizes OAKs to arbitrary base kernels by using a | ||
Gauss-Legendre quadrature approximation to the required one-dimensional integrals | ||
involving the base kernels. | ||
.. [Lu2022additive] | ||
X. Lu, A. Boukouvalas, and J. Hensman. Additive Gaussian processes revisited. | ||
Proceedings of the 39th International Conference on Machine Learning. Jul 2022. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
base_kernel: Kernel, | ||
dim: int, | ||
quad_deg: int = 32, | ||
second_order: bool = False, | ||
batch_shape: Optional[torch.Size] = None, | ||
dtype: Optional[torch.dtype] = None, | ||
device: Optional[torch.device] = None, | ||
coeff_constraint: Interval = _positivity_constraint, | ||
): | ||
""" | ||
Args: | ||
base_kernel: The kernel which to orthogonalize and evaluate in `forward`. | ||
dim: Input dimensionality of the kernel. | ||
quad_deg: Number of integration nodes for orthogonalization. | ||
second_order: Toggles second order interactions. If true, both the time and | ||
space complexity of evaluating the kernel are quadratic in `dim`. | ||
batch_shape: Optional batch shape for the kernel and its parameters. | ||
dtype: Initialization dtype for required Tensors. | ||
device: Initialization device for required Tensors. | ||
coeff_constraint: Constraint on the coefficients of the additive kernel. | ||
""" | ||
super().__init__(batch_shape=batch_shape) | ||
self.base_kernel = base_kernel | ||
# integration nodes, weights for [0, 1] | ||
tkwargs = {"dtype": dtype, "device": device} | ||
z, w = leggauss(deg=quad_deg, a=0, b=1, **tkwargs) | ||
self.z = z.unsqueeze(-1).expand(quad_deg, dim) # deg x dim | ||
self.w = w.unsqueeze(-1) | ||
self.register_parameter( | ||
name="raw_offset", | ||
parameter=nn.Parameter(torch.zeros(self.batch_shape, **tkwargs)), | ||
) | ||
log_d = math.log(dim) | ||
self.register_parameter( | ||
name="raw_coeffs_1", | ||
parameter=nn.Parameter( | ||
torch.zeros(*self.batch_shape, dim, **tkwargs) - log_d | ||
), | ||
) | ||
self.register_parameter( | ||
name="raw_coeffs_2", | ||
parameter=nn.Parameter( | ||
torch.zeros(*self.batch_shape, int(dim * (dim - 1) / 2), **tkwargs) | ||
- 2 * log_d | ||
) | ||
if second_order | ||
else None, | ||
) | ||
if second_order: | ||
self._rev_triu_indices = torch.tensor( | ||
_reverse_triu_indices(dim), | ||
device=device, | ||
dtype=int, | ||
) | ||
# zero tensor for construction of upper-triangular coefficient matrix | ||
self._quad_zero = torch.zeros( | ||
tuple(1 for _ in range(len(batch_shape) + 1)), **tkwargs | ||
).expand(*batch_shape, 1) | ||
self.coeff_constraint = coeff_constraint | ||
self.dim = dim | ||
|
||
def k(self, x1, x2) -> Tensor: | ||
"""Evaluates the kernel matrix base_kernel(x1, x2) on each input dimension | ||
independently. | ||
Args: | ||
x1: `batch_shape x n1 x d`-dim Tensor in [0, 1]^dim. | ||
x2: `batch_shape x n2 x d`-dim Tensor in [0, 1]^dim. | ||
Returns: | ||
A `batch_shape x d x n1 x n2`-dim Tensor of kernel matrices. | ||
""" | ||
return self.base_kernel(x1, x2, last_dim_is_batch=True).to_dense() | ||
|
||
@property | ||
def offset(self) -> Tensor: | ||
"""Returns the `batch_shape`-dim Tensor of zeroth-order coefficients.""" | ||
return self.coeff_constraint.transform(self.raw_offset) | ||
|
||
@property | ||
def coeffs_1(self) -> Tensor: | ||
"""Returns the `batch_shape x d`-dim Tensor of first-order coefficients.""" | ||
return self.coeff_constraint.transform(self.raw_coeffs_1) | ||
|
||
@property | ||
def coeffs_2(self) -> Optional[Tensor]: | ||
"""Returns the upper-triangular tensor of second-order coefficients. | ||
NOTE: We only keep track of the upper triangular part of raw second order | ||
coefficients since the effect of the lower triangular part is identical and | ||
exclude the diagonal, since it is associated with first-order effects only. | ||
While we could further exploit this structure in the forward pass, the | ||
associated indexing and temporary allocations make it significantly less | ||
efficient than the einsum-based implementation below. | ||
Returns: | ||
`batch_shape x d x d`-dim Tensor of second-order coefficients. | ||
""" | ||
if self.raw_coeffs_2 is not None: | ||
C2 = self.coeff_constraint.transform(self.raw_coeffs_2) | ||
C2 = torch.cat((C2, self._quad_zero), dim=-1) # batch_shape x (d(d-1)/2+1) | ||
C2 = C2.index_select(-1, self._rev_triu_indices) | ||
return C2.reshape(*self.batch_shape, self.dim, self.dim) | ||
else: | ||
return None | ||
|
||
def forward( | ||
self, | ||
x1: Tensor, | ||
x2: Tensor, | ||
diag: bool = False, | ||
last_dim_is_batch: bool = False, | ||
) -> Tensor: | ||
"""Computes the kernel matrix k(x1, x2). | ||
Args: | ||
x1: `batch_shape x n1 x d`-dim Tensor in [0, 1]^dim. | ||
x2: `batch_shape x n2 x d`-dim Tensor in [0, 1]^dim. | ||
diag: If True, only returns the diagonal of the kernel matrix. | ||
last_dim_is_batch: Not supported by this kernel. | ||
Returns: | ||
A `batch_shape x n1 x n2`-dim Tensor of kernel matrices. | ||
""" | ||
if last_dim_is_batch: | ||
raise UnsupportedError( | ||
"OrthogonalAdditiveKernel does not support `last_dim_is_batch`." | ||
) | ||
K_ortho = self._orthogonal_base_kernels(x1, x2) # batch_shape x d x n1 x n2 | ||
|
||
# contracting over d, leading to `batch_shape x n x n`-dim tensor, i.e.: | ||
# K1 = torch.sum(self.coeffs_1[..., None, None] * K_ortho, dim=-3) | ||
K1 = torch.einsum(self.coeffs_1, [..., 0], K_ortho, [..., 0, 1, 2], [..., 1, 2]) | ||
# adding the non-batch dimensions to offset | ||
K = K1 + self.offset[..., None, None] | ||
if self.coeffs_2 is not None: | ||
# Computing the tensor of second order interactions K2. | ||
# NOTE: K2 here is equivalent to: | ||
# K2 = K_ortho.unsqueeze(-4) * K_ortho.unsqueeze(-3) # d x d x n x n | ||
# K2 = (self.coeffs_2[..., None, None] * K2).sum(dim=(-4, -3)) | ||
# but avoids forming the `batch_shape x d x d x n x n`-dim tensor in memory. | ||
# Reducing over the dimensions with the O(d^2) quadratic terms: | ||
K2 = torch.einsum( | ||
K_ortho, | ||
[..., 0, 2, 3], | ||
K_ortho, | ||
[..., 1, 2, 3], | ||
self.coeffs_2, | ||
[..., 0, 1], | ||
[..., 2, 3], # i.e. contracting over the first two non-batch dims | ||
) | ||
K = K + K2 | ||
|
||
return K if not diag else K.diag() # poor man's diag (TODO) | ||
|
||
def _orthogonal_base_kernels(self, x1: Tensor, x2: Tensor) -> Tensor: | ||
"""Evaluates the set of `d` orthogonalized base kernels on (x1, x2). | ||
Note that even if the base kernel is positive, the orthogonalized versions | ||
can - and usually do - take negative values. | ||
Args: | ||
x1: `batch_shape x n1 x d`-dim inputs to the kernel. | ||
x2: `batch_shape x n2 x d`-dim inputs to the kernel. | ||
Returns: | ||
A `batch_shape x d x n1 x n2`-dim Tensor. | ||
""" | ||
_check_hypercube(x1, "x1") | ||
if x1 is not x2: | ||
_check_hypercube(x2, "x2") | ||
Kx1x2 = self.k(x1, x2) # d x n x n | ||
# Overwriting allocated quadrature tensors with fitting dtype and device | ||
# self.z, self.w = self.z.to(x1), self.w.to(x1) | ||
# include normalization constant in weights | ||
w = self.w / self.normalizer().sqrt() | ||
Skx1 = self.k(x1, self.z) @ w # batch_shape x d x n | ||
Skx2 = Skx1 if (x1 is x2) else self.k(x2, self.z) @ w # d x n | ||
# this is a tensor of kernel matrices of orthogonal 1d kernels | ||
K_ortho = (Kx1x2 - Skx1 @ Skx2.transpose(-2, -1)).to_dense() # d x n x n | ||
return K_ortho | ||
|
||
def normalizer(self, eps: float = 1e-6) -> Tensor: | ||
"""Integrates the `d` orthogonalized base kernels over `[0, 1] x [0, 1]`. | ||
NOTE: If the module is in train mode, this needs to re-compute the normalizer | ||
each time because the underlying parameters might have changed. | ||
Args: | ||
eps: Minimum value constraint on the normalizers. Avoids division by zero. | ||
Returns: | ||
A `d`-dim tensor of normalization constants. | ||
""" | ||
if self.train() or getattr(self, "_normalizer", None) is None: | ||
self._normalizer = (self.w.T @ self.k(self.z, self.z) @ self.w).clamp(eps) | ||
return self._normalizer | ||
|
||
|
||
def leggauss( | ||
deg: int, | ||
a: float = -1.0, | ||
b: float = 1.0, | ||
dtype: Optional[torch.dtype] = None, | ||
device: Optional[torch.device] = None, | ||
) -> Tuple[Tensor, Tensor]: | ||
"""Computes Gauss-Legendre quadrature nodes and weights. Wraps | ||
`numpy.polynomial.legendre.leggauss` and returns Torch Tensors. | ||
Args: | ||
deg: Number of sample points and weights. Integrates poynomials of degree | ||
`2 * deg + 1` exactly. | ||
a, b: Lower and upper bound of integration domain. | ||
dtype: Desired floating point type of the return Tensors. | ||
device: Desired device type of the return Tensors. | ||
Returns: | ||
A tuple of Gauss-Legendre quadrature nodes and weights of length deg. | ||
""" | ||
dtype = dtype if dtype is not None else torch.get_default_dtype() | ||
x, w = numpy.polynomial.legendre.leggauss(deg=deg) | ||
x = torch.as_tensor(x, dtype=dtype, device=device) | ||
w = torch.as_tensor(w, dtype=dtype, device=device) | ||
if not (a == -1 and b == 1): # need to normalize for different domain | ||
x = (b - a) * (x + 1) / 2 + a | ||
w = w * ((b - a) / 2) | ||
return x, w | ||
|
||
|
||
def _check_hypercube(x: Tensor, name: str) -> None: | ||
"""Raises a `ValueError` if an element `x` is not in [0, 1]. | ||
Args: | ||
x: Tensor to be checked. | ||
name: Name of the Tensor for the error message. | ||
""" | ||
if (x < 0).any() or (x > 1).any(): | ||
raise ValueError(name + " is not in hypercube [0, 1]^d.") | ||
|
||
|
||
def _reverse_triu_indices(d: int) -> List[int]: | ||
"""Computes a list of indices which, upon indexing a `d * (d - 1) / 2 + 1`-dim | ||
Tensor whose last element is zero, will lead to a vectorized representation of | ||
an upper-triangular matrix, whose diagonal is set to zero and whose super-diagonal | ||
elements are set to the `d * (d - 1) / 2` values in the original tensor. | ||
NOTE: This is a helper function for Orthogonal Additive Kernels, and allows the | ||
implementation to only register `d * (d - 1) / 2` parameters to model the second | ||
order interactions, instead of the full d^2 redundant terms. | ||
Args: | ||
d: Dimensionality that gives rise to the `d * (d - 1) / 2` quadratic terms. | ||
Returns: | ||
A list of integer indices in `[0, d * (d - 1) / 2]`. See above for details. | ||
""" | ||
indices = [] | ||
j = 0 | ||
d2 = int(d * (d - 1) / 2) | ||
for i in range(d): | ||
indices.extend(d2 for _ in range(i + 1)) # indexing zero (sub-diagonal) | ||
indices.extend(range(j, j + d - i - 1)) # indexing coeffs (super-diagonal) | ||
j += d - i - 1 | ||
return indices |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.