forked from pytorch/botorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
infinite width bnn kernel (pytorch#2366)
Summary: Pull Request resolved: pytorch#2366 Add infinite-width BNN kernel to BoTorch Formulas from [Lee et al, 2017 "Deep Neural Networks as Gaussian Processes"](https://arxiv.org/abs/1711.00165) Notebook with tutorial in next diff! Reviewed By: sdaulton Differential Revision: D58208973
- Loading branch information
1 parent
0bdd4b2
commit 1143db8
Showing
4 changed files
with
357 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Optional, Tuple | ||
|
||
import torch | ||
from gpytorch.constraints import Positive | ||
from gpytorch.kernels import Kernel | ||
from torch import Tensor | ||
|
||
|
||
class InfiniteWidthBNNKernel(Kernel): | ||
r"""Infinite-width BNN kernel. | ||
Defines the GP kernel which is equivalent to performing exact Bayesian | ||
inference on a fully-connected deep neural network with ReLU activations | ||
and i.i.d. priors in the infinite-width limit. | ||
See [Cho2009kernel]_ and [Lee2018deep]_ for details. | ||
.. [Cho2009kernel] | ||
Y. Cho, and L. Saul. Kernel methods for deep learning. | ||
Advances in Neural Information Processing Systems 22. 2009. | ||
.. [Lee2018deep] | ||
J. Lee, Y. Bahri, R. Novak, S. Schoenholz, J. Pennington, and J. Dickstein. | ||
Deep Neural Networks as Gaussian Processes. | ||
International Conference on Learning Representations. 2018. | ||
""" | ||
|
||
has_lengthscale = False | ||
|
||
def __init__( | ||
self, | ||
depth: int = 3, | ||
batch_shape: Optional[torch.Size] = None, | ||
active_dims: Optional[Tuple[int, ...]] = None, | ||
acos_eps: float = 1e-7, | ||
device: Optional[torch.device] = None, | ||
) -> None: | ||
r""" | ||
Args: | ||
depth: Depth of neural network. | ||
batch_shape: This will set a separate weight/bias var for each batch. | ||
It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf` is | ||
a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor. | ||
param active_dims: Compute the covariance of only a few input dimensions. | ||
The ints corresponds to the indices of the dimensions. | ||
param acos_eps: A small positive value to restrict acos inputs to | ||
:math`[-1 + \epsilon, 1 - \epsilon]` | ||
param device: Device for parameters. | ||
""" | ||
super().__init__(batch_shape=batch_shape, active_dims=active_dims) | ||
self.depth = depth | ||
self.acos_eps = acos_eps | ||
|
||
self.register_parameter( | ||
"raw_weight_var", | ||
torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, 1, device=device)), | ||
) | ||
self.register_constraint("raw_weight_var", Positive()) | ||
|
||
self.register_parameter( | ||
"raw_bias_var", | ||
torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, 1, device=device)), | ||
) | ||
self.register_constraint("raw_bias_var", Positive()) | ||
|
||
@property | ||
def weight_var(self) -> Tensor: | ||
return self.raw_weight_var_constraint.transform(self.raw_weight_var) | ||
|
||
@weight_var.setter | ||
def weight_var(self, value) -> None: | ||
if not torch.is_tensor(value): | ||
value = torch.as_tensor(value).to(self.raw_weight_var) | ||
self.initialize( | ||
raw_weight_var=self.raw_weight_var_constraint.inverse_transform(value) | ||
) | ||
|
||
@property | ||
def bias_var(self) -> Tensor: | ||
return self.raw_bias_var_constraint.transform(self.raw_bias_var) | ||
|
||
@bias_var.setter | ||
def bias_var(self, value) -> None: | ||
if not torch.is_tensor(value): | ||
value = torch.as_tensor(value).to(self.raw_bias_var) | ||
self.initialize( | ||
raw_bias_var=self.raw_bias_var_constraint.inverse_transform(value) | ||
) | ||
|
||
def _initialize_var(self, x: Tensor) -> Tensor: | ||
"""Computes the initial variance of x for layer 0""" | ||
return ( | ||
self.weight_var * torch.sum(x * x, dim=-1, keepdim=True) / x.shape[-1] | ||
+ self.bias_var | ||
) | ||
|
||
def _update_var(self, K: Tensor, x: Tensor) -> Tensor: | ||
"""Computes the updated variance of x for next layer""" | ||
return self.weight_var * K / 2 + self.bias_var | ||
|
||
def k(self, x1: Tensor, x2: Tensor) -> Tensor: | ||
r""" | ||
For single-layer infinite-width neural networks with i.i.d. priors, | ||
the covariance between outputs can be computed by | ||
:math:`K^0(x, x')=\sigma_b^2+\sigma_w^2\frac{x \cdot x'}{d_\text{input}}`. | ||
For deeper networks, we can recursively define the covariance as | ||
:math:`K^l(x, x')=\sigma_b^2+\sigma_w^2 | ||
F_\phi(K^{l-1}(x, x'), K^{l-1}(x, x), K^{l-1}(x', x'))` | ||
where :math:`F_\phi` is a deterministic function based on the | ||
activation function :math:`\phi`. | ||
For ReLU activations, this yields the arc-cosine kernel, which can be computed | ||
analytically. | ||
Args: | ||
x1: `batch_shape x n1 x d`-dim Tensor | ||
x2: `batch_shape x n2 x d`-dim Tensor | ||
""" | ||
K_12 = ( | ||
self.weight_var * (x1.matmul(x2.transpose(-2, -1)) / x1.shape[-1]) | ||
+ self.bias_var | ||
) | ||
|
||
for layer in range(self.depth): | ||
if layer == 0: | ||
K_11 = self._initialize_var(x1) | ||
K_22 = self._initialize_var(x2) | ||
else: | ||
K_11 = self._update_var(K_11, x1) | ||
K_22 = self._update_var(K_22, x2) | ||
|
||
sqrt_term = torch.sqrt(K_11.matmul(K_22.transpose(-2, -1))) | ||
|
||
fraction = K_12 / sqrt_term | ||
fraction = torch.clamp( | ||
fraction, min=-1 + self.acos_eps, max=1 - self.acos_eps | ||
) | ||
|
||
theta = torch.acos(fraction) | ||
theta_term = torch.sin(theta) + (torch.pi - theta) * fraction | ||
|
||
K_12 = ( | ||
self.weight_var / (2 * torch.pi) * sqrt_term * theta_term | ||
+ self.bias_var | ||
) | ||
|
||
return K_12 | ||
|
||
def forward( | ||
self, | ||
x1: Tensor, | ||
x2: Tensor, | ||
diag: Optional[bool] = False, | ||
last_dim_is_batch: Optional[bool] = False, | ||
**params, | ||
) -> Tensor: | ||
""" | ||
Args: | ||
x1: `batch_shape x n1 x d`-dim Tensor | ||
x2: `batch_shape x n2 x d`-dim Tensor | ||
diag: If True, only returns the diagonal of the kernel matrix. | ||
last_dim_is_batch: Not supported by this kernel. | ||
""" | ||
if last_dim_is_batch: | ||
raise RuntimeError("last_dim_is_batch not supported by this kernel.") | ||
|
||
if diag: | ||
K = self._initialize_var(x1) | ||
for _ in range(self.depth): | ||
K = self._update_var(K, x1) | ||
return K.squeeze(-1) | ||
else: | ||
return self.k(x1, x2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from botorch.models.kernels.infinite_width_bnn import InfiniteWidthBNNKernel | ||
from botorch.utils.testing import BotorchTestCase | ||
from gpytorch.test.base_kernel_test_case import BaseKernelTestCase | ||
|
||
|
||
class TestInfiniteWidthBNNKernel(BotorchTestCase, BaseKernelTestCase): | ||
def create_kernel_no_ard(self, **kwargs): | ||
return InfiniteWidthBNNKernel(**kwargs) | ||
|
||
def test_properties(self): | ||
with self.subTest(): | ||
kernel = InfiniteWidthBNNKernel(3) | ||
bias_var_init = torch.tensor(0.2) | ||
kernel.initialize(bias_var=bias_var_init) | ||
actual_value = bias_var_init.view_as(kernel.bias_var) | ||
self.assertLess(torch.linalg.norm(kernel.bias_var - actual_value), 1e-5) | ||
with self.subTest(): | ||
kernel = InfiniteWidthBNNKernel(3) | ||
weight_var_init = torch.tensor(0.2) | ||
kernel.initialize(weight_var=weight_var_init) | ||
actual_value = weight_var_init.view_as(kernel.weight_var) | ||
self.assertLess(torch.linalg.norm(kernel.weight_var - actual_value), 1e-5) | ||
with self.subTest(): | ||
kernel = InfiniteWidthBNNKernel(5, batch_shape=torch.Size([2])) | ||
bias_var_init = torch.tensor([0.2, 0.01]) | ||
kernel.initialize(bias_var=bias_var_init) | ||
actual_value = bias_var_init.view_as(kernel.bias_var) | ||
self.assertLess(torch.linalg.norm(kernel.bias_var - actual_value), 1e-5) | ||
with self.subTest(): | ||
kernel = InfiniteWidthBNNKernel(3, batch_shape=torch.Size([2])) | ||
weight_var_init = torch.tensor([1.0, 2.0]) | ||
kernel.initialize(weight_var=weight_var_init) | ||
actual_value = weight_var_init.view_as(kernel.weight_var) | ||
self.assertLess(torch.linalg.norm(kernel.weight_var - actual_value), 1e-5) | ||
with self.subTest(): | ||
kernel = InfiniteWidthBNNKernel(3, batch_shape=torch.Size([2])) | ||
x = torch.randn(3, 2) | ||
with self.assertRaises(RuntimeError): | ||
kernel(x, x, last_dim_is_batch=True).to_dense() | ||
|
||
def test_forward_0(self): | ||
for dtype in (torch.float, torch.double): | ||
tkwargs = {"device": self.device, "dtype": dtype} | ||
x1 = torch.tensor([[0.1, 0.2], [1.2, 0.4], [2.4, 0.3]]).to(**tkwargs) | ||
x2 = torch.tensor([[4.1, 2.3], [3.9, 0.0]]).to(**tkwargs) | ||
weight_var = 1.0 | ||
bias_var = 0.1 | ||
kernel = InfiniteWidthBNNKernel(0, device=self.device).initialize( | ||
weight_var=weight_var, bias_var=bias_var | ||
) | ||
kernel.eval() | ||
expected = ( | ||
weight_var * (x1.matmul(x2.transpose(-2, -1)) / x1.shape[-1]) + bias_var | ||
).to(**tkwargs) | ||
res = kernel(x1, x2).to_dense() | ||
self.assertAllClose(res, expected) | ||
|
||
def test_forward_0_batch(self): | ||
for dtype in (torch.float, torch.double): | ||
tkwargs = {"device": self.device, "dtype": dtype} | ||
x1 = torch.tensor( | ||
[ | ||
[ | ||
[0.4960, 0.7680, 0.0880], | ||
[0.1320, 0.3070, 0.6340], | ||
[0.4900, 0.8960, 0.4550], | ||
[0.6320, 0.3480, 0.4010], | ||
[0.0220, 0.1680, 0.2930], | ||
], | ||
[ | ||
[0.5180, 0.6970, 0.8000], | ||
[0.1610, 0.2820, 0.6810], | ||
[0.9150, 0.3970, 0.8740], | ||
[0.4190, 0.5520, 0.9520], | ||
[0.0360, 0.1850, 0.3730], | ||
], | ||
] | ||
).to(**tkwargs) | ||
x2 = torch.tensor( | ||
[ | ||
[[0.3050, 0.9320, 0.1750], [0.2690, 0.1500, 0.0310]], | ||
[[0.2080, 0.9290, 0.7230], [0.7420, 0.5260, 0.2430]], | ||
] | ||
).to(**tkwargs) | ||
weight_var = torch.tensor([1.0, 2.0]).to(**tkwargs) | ||
bias_var = torch.tensor([0.1, 0.5]).to(**tkwargs) | ||
kernel = InfiniteWidthBNNKernel( | ||
0, batch_shape=[2], device=self.device | ||
).initialize(weight_var=weight_var, bias_var=bias_var) | ||
kernel.eval() | ||
expected = torch.tensor( | ||
[ | ||
[ | ||
[0.3942, 0.1838], | ||
[0.2458, 0.1337], | ||
[0.4547, 0.1934], | ||
[0.2958, 0.1782], | ||
[0.1715, 0.1134], | ||
], | ||
[ | ||
[1.3891, 1.1303], | ||
[1.0252, 0.7889], | ||
[1.2940, 1.2334], | ||
[1.3588, 1.0551], | ||
[0.7994, 0.6431], | ||
], | ||
] | ||
).to(**tkwargs) | ||
res = kernel(x1, x2).to_dense() | ||
self.assertAllClose(res, expected, 0.0001, 0.0001) | ||
|
||
def test_forward_2(self): | ||
for dtype in (torch.float, torch.double): | ||
tkwargs = {"device": self.device, "dtype": dtype} | ||
x1 = torch.tensor( | ||
[ | ||
[ | ||
[0.4960, 0.7680, 0.0880], | ||
[0.1320, 0.3070, 0.6340], | ||
[0.4900, 0.8960, 0.4550], | ||
[0.6320, 0.3480, 0.4010], | ||
[0.0220, 0.1680, 0.2930], | ||
], | ||
[ | ||
[0.5180, 0.6970, 0.8000], | ||
[0.1610, 0.2820, 0.6810], | ||
[0.9150, 0.3970, 0.8740], | ||
[0.4190, 0.5520, 0.9520], | ||
[0.0360, 0.1850, 0.3730], | ||
], | ||
] | ||
).to(**tkwargs) | ||
x2 = torch.tensor( | ||
[ | ||
[[0.3050, 0.9320, 0.1750], [0.2690, 0.1500, 0.0310]], | ||
[[0.2080, 0.9290, 0.7230], [0.7420, 0.5260, 0.2430]], | ||
] | ||
).to(**tkwargs) | ||
weight_var = 1.0 | ||
bias_var = 0.1 | ||
kernel = InfiniteWidthBNNKernel(2, device=self.device).initialize( | ||
weight_var=weight_var, bias_var=bias_var | ||
) | ||
kernel.eval() | ||
expected = torch.tensor( | ||
[ | ||
[ | ||
[0.2488, 0.1985], | ||
[0.2178, 0.1872], | ||
[0.2641, 0.2036], | ||
[0.2286, 0.1962], | ||
[0.1983, 0.1793], | ||
], | ||
[ | ||
[0.2869, 0.2564], | ||
[0.2429, 0.2172], | ||
[0.2820, 0.2691], | ||
[0.2837, 0.2498], | ||
[0.2160, 0.1986], | ||
], | ||
] | ||
).to(**tkwargs) | ||
res = kernel(x1, x2).to_dense() | ||
self.assertAllClose(res, expected, 0.0001, 0.0001) |