From 8d6960f6f9e6716d1c5432d148b9ae3084e6163f Mon Sep 17 00:00:00 2001 From: fukushima_daisuke Date: Sun, 10 Nov 2024 11:52:41 +0900 Subject: [PATCH] Add ScaleKernel to get_covar_module_with_dim_scaled_prior --- botorch/models/utils/gpytorch_modules.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/botorch/models/utils/gpytorch_modules.py b/botorch/models/utils/gpytorch_modules.py index bfe5da8551..36e3c4f445 100644 --- a/botorch/models/utils/gpytorch_modules.py +++ b/botorch/models/utils/gpytorch_modules.py @@ -102,8 +102,8 @@ def get_covar_module_with_dim_scaled_prior( batch_shape: torch.Size | None = None, use_rbf_kernel: bool = True, active_dims: Sequence[int] | None = None, -) -> MaternKernel | RBFKernel: - """Returns an RBF or Matern kernel with priors +) -> ScaleKernel: + """Returns an ScaleKernel based on RBF or Matern kernel with priors from [Hvarfner2024vanilla]_. Args: @@ -130,4 +130,8 @@ def get_covar_module_with_dim_scaled_prior( # pyre-ignore[6] GPyTorch type is unnecessarily restrictive. active_dims=active_dims, ) - return base_kernel + return ScaleKernel( + base_kernel=base_kernel, + batch_shape=batch_shape, + outputscale_prior=GammaPrior(2.0, 0.15), + )