Skip to content

Commit

Permalink
Update sample_all_priors to support wider set of priors (#2371)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2371

Addresses #780

Previously, this would pass in `closure(module).shape` as the `sample_shape`, which only worked if the prior was a univariate distribution. `Distribution.sample` produces samples of shape `Distribution._extended_shape(sample_shape) = sample_shape + Distribution._extended_shape()`, so we can calculate the `sample_shape` required to support both univariate and multivariate / batched priors.

Reviewed By: dme65

Differential Revision: D58377495

fbshipit-source-id: 17510505012838a3fe670492656be4d13bc0db5e
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jun 11, 2024
1 parent f3dd493 commit d753706
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
7 changes: 6 additions & 1 deletion botorch/optim/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,12 @@ def sample_all_priors(model: GPyTorchModel, max_retries: int = 100) -> None:
)
for i in range(max_retries):
try:
setting_closure(module, prior.sample(closure(module).shape))
# Set sample shape, so that the prior samples have the same shape
# as `closure(module)` without having to be repeated.
closure_shape = closure(module).shape
prior_shape = prior._extended_shape()
sample_shape = closure_shape[: -len(prior_shape)]
setting_closure(module, prior.sample(sample_shape=sample_shape))
break
except NotImplementedError:
warn(
Expand Down
26 changes: 25 additions & 1 deletion test/optim/utils/test_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@
)
from botorch.utils.testing import BotorchTestCase
from gpytorch.constraints import GreaterThan
from gpytorch.kernels import RBFKernel
from gpytorch.kernels.matern_kernel import MaternKernel
from gpytorch.kernels.scale_kernel import ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.priors import UniformPrior
from gpytorch.priors.prior import Prior
from gpytorch.priors.torch_priors import GammaPrior
from gpytorch.priors.torch_priors import GammaPrior, NormalPrior


class DummyPrior(Prior):
Expand Down Expand Up @@ -244,3 +245,26 @@ def test_sample_all_priors(self):
original_state_dict = dict(deepcopy(mll.model.state_dict()))
with self.assertRaises(RuntimeError):
sample_all_priors(model)

def test_with_multivariate_prior(self) -> None:
# This is modified from https://github.com/pytorch/botorch/issues/780.
for batch in (torch.Size([]), torch.Size([3])):
model = SingleTaskGP(
train_X=torch.randn(*batch, 2, 2),
train_Y=torch.randn(*batch, 2, 1),
covar_module=RBFKernel(
ard_num_dims=2,
batch_shape=batch,
lengthscale_prior=NormalPrior(
# Make this almost singular for easy comparison below.
torch.tensor([[1.0, 1.0]]),
torch.tensor(1e-10),
),
),
)
# Check that the lengthscale is replaced with the sampled values.
original_lengthscale = model.covar_module.lengthscale
sample_all_priors(model)
new_lengthscale = model.covar_module.lengthscale
self.assertFalse(torch.allclose(original_lengthscale, new_lengthscale))
self.assertAllClose(new_lengthscale, torch.ones(*batch, 1, 2))

0 comments on commit d753706

Please sign in to comment.