Skip to content

Commit

Permalink
infinite width bnn kernel (pytorch#2366)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2366

Add infinite-width BNN kernel to BoTorch

Formulas from [Lee et al, 2017 "Deep Neural Networks as Gaussian Processes"](https://arxiv.org/abs/1711.00165)

Notebook with tutorial in next diff!

Reviewed By: sdaulton

Differential Revision: D58208973
  • Loading branch information
Lily Li authored and facebook-github-bot committed Jun 12, 2024
1 parent 0bdd4b2 commit 1143db8
Show file tree
Hide file tree
Showing 4 changed files with 357 additions and 1 deletion.
2 changes: 2 additions & 0 deletions botorch/models/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from botorch.models.kernels.categorical import CategoricalKernel
from botorch.models.kernels.downsampling import DownsamplingKernel
from botorch.models.kernels.exponential_decay import ExponentialDecayKernel
from botorch.models.kernels.infinite_width_bnn import InfiniteWidthBNNKernel
from botorch.models.kernels.linear_truncated_fidelity import (
LinearTruncatedFidelityKernel,
)
Expand All @@ -16,5 +17,6 @@
"CategoricalKernel",
"DownsamplingKernel",
"ExponentialDecayKernel",
"InfiniteWidthBNNKernel",
"LinearTruncatedFidelityKernel",
]
180 changes: 180 additions & 0 deletions botorch/models/kernels/infinite_width_bnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import Optional, Tuple

import torch
from gpytorch.constraints import Positive
from gpytorch.kernels import Kernel
from torch import Tensor


class InfiniteWidthBNNKernel(Kernel):
r"""Infinite-width BNN kernel.
Defines the GP kernel which is equivalent to performing exact Bayesian
inference on a fully-connected deep neural network with ReLU activations
and i.i.d. priors in the infinite-width limit.
See [Cho2009kernel]_ and [Lee2018deep]_ for details.
.. [Cho2009kernel]
Y. Cho, and L. Saul. Kernel methods for deep learning.
Advances in Neural Information Processing Systems 22. 2009.
.. [Lee2018deep]
J. Lee, Y. Bahri, R. Novak, S. Schoenholz, J. Pennington, and J. Dickstein.
Deep Neural Networks as Gaussian Processes.
International Conference on Learning Representations. 2018.
"""

has_lengthscale = False

def __init__(
self,
depth: int = 3,
batch_shape: Optional[torch.Size] = None,
active_dims: Optional[Tuple[int, ...]] = None,
acos_eps: float = 1e-7,
device: Optional[torch.device] = None,
) -> None:
r"""
Args:
depth: Depth of neural network.
batch_shape: This will set a separate weight/bias var for each batch.
It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf` is
a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor.
param active_dims: Compute the covariance of only a few input dimensions.
The ints corresponds to the indices of the dimensions.
param acos_eps: A small positive value to restrict acos inputs to
:math`[-1 + \epsilon, 1 - \epsilon]`
param device: Device for parameters.
"""
super().__init__(batch_shape=batch_shape, active_dims=active_dims)
self.depth = depth
self.acos_eps = acos_eps

self.register_parameter(
"raw_weight_var",
torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, 1, device=device)),
)
self.register_constraint("raw_weight_var", Positive())

self.register_parameter(
"raw_bias_var",
torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, 1, device=device)),
)
self.register_constraint("raw_bias_var", Positive())

@property
def weight_var(self) -> Tensor:
return self.raw_weight_var_constraint.transform(self.raw_weight_var)

@weight_var.setter
def weight_var(self, value) -> None:
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_weight_var)
self.initialize(
raw_weight_var=self.raw_weight_var_constraint.inverse_transform(value)
)

@property
def bias_var(self) -> Tensor:
return self.raw_bias_var_constraint.transform(self.raw_bias_var)

@bias_var.setter
def bias_var(self, value) -> None:
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_bias_var)
self.initialize(
raw_bias_var=self.raw_bias_var_constraint.inverse_transform(value)
)

def _initialize_var(self, x: Tensor) -> Tensor:
"""Computes the initial variance of x for layer 0"""
return (
self.weight_var * torch.sum(x * x, dim=-1, keepdim=True) / x.shape[-1]
+ self.bias_var
)

def _update_var(self, K: Tensor, x: Tensor) -> Tensor:
"""Computes the updated variance of x for next layer"""
return self.weight_var * K / 2 + self.bias_var

def k(self, x1: Tensor, x2: Tensor) -> Tensor:
r"""
For single-layer infinite-width neural networks with i.i.d. priors,
the covariance between outputs can be computed by
:math:`K^0(x, x')=\sigma_b^2+\sigma_w^2\frac{x \cdot x'}{d_\text{input}}`.
For deeper networks, we can recursively define the covariance as
:math:`K^l(x, x')=\sigma_b^2+\sigma_w^2
F_\phi(K^{l-1}(x, x'), K^{l-1}(x, x), K^{l-1}(x', x'))`
where :math:`F_\phi` is a deterministic function based on the
activation function :math:`\phi`.
For ReLU activations, this yields the arc-cosine kernel, which can be computed
analytically.
Args:
x1: `batch_shape x n1 x d`-dim Tensor
x2: `batch_shape x n2 x d`-dim Tensor
"""
K_12 = (
self.weight_var * (x1.matmul(x2.transpose(-2, -1)) / x1.shape[-1])
+ self.bias_var
)

for layer in range(self.depth):
if layer == 0:
K_11 = self._initialize_var(x1)
K_22 = self._initialize_var(x2)
else:
K_11 = self._update_var(K_11, x1)
K_22 = self._update_var(K_22, x2)

sqrt_term = torch.sqrt(K_11.matmul(K_22.transpose(-2, -1)))

fraction = K_12 / sqrt_term
fraction = torch.clamp(
fraction, min=-1 + self.acos_eps, max=1 - self.acos_eps
)

theta = torch.acos(fraction)
theta_term = torch.sin(theta) + (torch.pi - theta) * fraction

K_12 = (
self.weight_var / (2 * torch.pi) * sqrt_term * theta_term
+ self.bias_var
)

return K_12

def forward(
self,
x1: Tensor,
x2: Tensor,
diag: Optional[bool] = False,
last_dim_is_batch: Optional[bool] = False,
**params,
) -> Tensor:
"""
Args:
x1: `batch_shape x n1 x d`-dim Tensor
x2: `batch_shape x n2 x d`-dim Tensor
diag: If True, only returns the diagonal of the kernel matrix.
last_dim_is_batch: Not supported by this kernel.
"""
if last_dim_is_batch:
raise RuntimeError("last_dim_is_batch not supported by this kernel.")

if diag:
K = self._initialize_var(x1)
for _ in range(self.depth):
K = self._update_var(K, x1)
return K.squeeze(-1)
else:
return self.k(x1, x2)
5 changes: 4 additions & 1 deletion sphinx/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ Kernels
.. automodule:: botorch.models.kernels.exponential_decay
.. autoclass:: ExponentialDecayKernel

.. automodule:: botorch.models.kernels.infinite_width_bnn
.. autoclass:: InfiniteWidthBNNKernel

.. automodule:: botorch.models.kernels.linear_truncated_fidelity
.. autoclass:: LinearTruncatedFidelityKernel

Expand Down Expand Up @@ -177,4 +180,4 @@ Inducing Point Allocators
Other Utilties
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.models.utils.assorted
:members:
:members:
171 changes: 171 additions & 0 deletions test/models/kernels/test_infinite_width_bnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from botorch.models.kernels.infinite_width_bnn import InfiniteWidthBNNKernel
from botorch.utils.testing import BotorchTestCase
from gpytorch.test.base_kernel_test_case import BaseKernelTestCase


class TestInfiniteWidthBNNKernel(BotorchTestCase, BaseKernelTestCase):
def create_kernel_no_ard(self, **kwargs):
return InfiniteWidthBNNKernel(**kwargs)

def test_properties(self):
with self.subTest():
kernel = InfiniteWidthBNNKernel(3)
bias_var_init = torch.tensor(0.2)
kernel.initialize(bias_var=bias_var_init)
actual_value = bias_var_init.view_as(kernel.bias_var)
self.assertLess(torch.linalg.norm(kernel.bias_var - actual_value), 1e-5)
with self.subTest():
kernel = InfiniteWidthBNNKernel(3)
weight_var_init = torch.tensor(0.2)
kernel.initialize(weight_var=weight_var_init)
actual_value = weight_var_init.view_as(kernel.weight_var)
self.assertLess(torch.linalg.norm(kernel.weight_var - actual_value), 1e-5)
with self.subTest():
kernel = InfiniteWidthBNNKernel(5, batch_shape=torch.Size([2]))
bias_var_init = torch.tensor([0.2, 0.01])
kernel.initialize(bias_var=bias_var_init)
actual_value = bias_var_init.view_as(kernel.bias_var)
self.assertLess(torch.linalg.norm(kernel.bias_var - actual_value), 1e-5)
with self.subTest():
kernel = InfiniteWidthBNNKernel(3, batch_shape=torch.Size([2]))
weight_var_init = torch.tensor([1.0, 2.0])
kernel.initialize(weight_var=weight_var_init)
actual_value = weight_var_init.view_as(kernel.weight_var)
self.assertLess(torch.linalg.norm(kernel.weight_var - actual_value), 1e-5)
with self.subTest():
kernel = InfiniteWidthBNNKernel(3, batch_shape=torch.Size([2]))
x = torch.randn(3, 2)
with self.assertRaises(RuntimeError):
kernel(x, x, last_dim_is_batch=True).to_dense()

def test_forward_0(self):
for dtype in (torch.float, torch.double):
tkwargs = {"device": self.device, "dtype": dtype}
x1 = torch.tensor([[0.1, 0.2], [1.2, 0.4], [2.4, 0.3]]).to(**tkwargs)
x2 = torch.tensor([[4.1, 2.3], [3.9, 0.0]]).to(**tkwargs)
weight_var = 1.0
bias_var = 0.1
kernel = InfiniteWidthBNNKernel(0, device=self.device).initialize(
weight_var=weight_var, bias_var=bias_var
)
kernel.eval()
expected = (
weight_var * (x1.matmul(x2.transpose(-2, -1)) / x1.shape[-1]) + bias_var
).to(**tkwargs)
res = kernel(x1, x2).to_dense()
self.assertAllClose(res, expected)

def test_forward_0_batch(self):
for dtype in (torch.float, torch.double):
tkwargs = {"device": self.device, "dtype": dtype}
x1 = torch.tensor(
[
[
[0.4960, 0.7680, 0.0880],
[0.1320, 0.3070, 0.6340],
[0.4900, 0.8960, 0.4550],
[0.6320, 0.3480, 0.4010],
[0.0220, 0.1680, 0.2930],
],
[
[0.5180, 0.6970, 0.8000],
[0.1610, 0.2820, 0.6810],
[0.9150, 0.3970, 0.8740],
[0.4190, 0.5520, 0.9520],
[0.0360, 0.1850, 0.3730],
],
]
).to(**tkwargs)
x2 = torch.tensor(
[
[[0.3050, 0.9320, 0.1750], [0.2690, 0.1500, 0.0310]],
[[0.2080, 0.9290, 0.7230], [0.7420, 0.5260, 0.2430]],
]
).to(**tkwargs)
weight_var = torch.tensor([1.0, 2.0]).to(**tkwargs)
bias_var = torch.tensor([0.1, 0.5]).to(**tkwargs)
kernel = InfiniteWidthBNNKernel(
0, batch_shape=[2], device=self.device
).initialize(weight_var=weight_var, bias_var=bias_var)
kernel.eval()
expected = torch.tensor(
[
[
[0.3942, 0.1838],
[0.2458, 0.1337],
[0.4547, 0.1934],
[0.2958, 0.1782],
[0.1715, 0.1134],
],
[
[1.3891, 1.1303],
[1.0252, 0.7889],
[1.2940, 1.2334],
[1.3588, 1.0551],
[0.7994, 0.6431],
],
]
).to(**tkwargs)
res = kernel(x1, x2).to_dense()
self.assertAllClose(res, expected, 0.0001, 0.0001)

def test_forward_2(self):
for dtype in (torch.float, torch.double):
tkwargs = {"device": self.device, "dtype": dtype}
x1 = torch.tensor(
[
[
[0.4960, 0.7680, 0.0880],
[0.1320, 0.3070, 0.6340],
[0.4900, 0.8960, 0.4550],
[0.6320, 0.3480, 0.4010],
[0.0220, 0.1680, 0.2930],
],
[
[0.5180, 0.6970, 0.8000],
[0.1610, 0.2820, 0.6810],
[0.9150, 0.3970, 0.8740],
[0.4190, 0.5520, 0.9520],
[0.0360, 0.1850, 0.3730],
],
]
).to(**tkwargs)
x2 = torch.tensor(
[
[[0.3050, 0.9320, 0.1750], [0.2690, 0.1500, 0.0310]],
[[0.2080, 0.9290, 0.7230], [0.7420, 0.5260, 0.2430]],
]
).to(**tkwargs)
weight_var = 1.0
bias_var = 0.1
kernel = InfiniteWidthBNNKernel(2, device=self.device).initialize(
weight_var=weight_var, bias_var=bias_var
)
kernel.eval()
expected = torch.tensor(
[
[
[0.2488, 0.1985],
[0.2178, 0.1872],
[0.2641, 0.2036],
[0.2286, 0.1962],
[0.1983, 0.1793],
],
[
[0.2869, 0.2564],
[0.2429, 0.2172],
[0.2820, 0.2691],
[0.2837, 0.2498],
[0.2160, 0.1986],
],
]
).to(**tkwargs)
res = kernel(x1, x2).to_dense()
self.assertAllClose(res, expected, 0.0001, 0.0001)

0 comments on commit 1143db8

Please sign in to comment.