From d3a2bac87cc1f0fdfc1a4f3e20ee3fe676ff2c7e Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Wed, 22 Jan 2025 15:52:43 -0800 Subject: [PATCH] Support mixing SAAS & SingleTaskGP models in ModelListGP (#2693) Summary: Adds support for broadcasting MVNs produced by the underlying models to enable mixing together SAAS & SingleTaskGP models within a ModelListGP. Reviewed By: sdaulton Differential Revision: D68503063 --- botorch/models/gpytorch.py | 33 ++++++++++++++++++++ test/models/test_model_list_gp_regression.py | 29 ++++++++++++++++- 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/botorch/models/gpytorch.py b/botorch/models/gpytorch.py index 39052e600b..b6b490125f 100644 --- a/botorch/models/gpytorch.py +++ b/botorch/models/gpytorch.py @@ -719,6 +719,7 @@ def posterior( interleaved=False, ) else: + mvns = self._broadcast_mvns(mvns=mvns) mvn = ( mvns[0] if len(mvns) == 1 @@ -738,6 +739,38 @@ def posterior( def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Model: raise NotImplementedError() + def _broadcast_mvns(self, mvns: list[MultivariateNormal]) -> MultivariateNormal: + """Broadcasts the batch shapes of the given MultivariateNormals. + + The MVNs will have a batch shape of `input_batch_shape x model_batch_shape`. + If the model batch shapes are broadcastable, we will broadcast the mvns to + a batch shape of `input_batch_shape x self.batch_shape`. + + Args: + mvns: A list of MultivariateNormals. + + Returns: + A list of MultivariateNormals with broadcasted batch shapes. + """ + mvn_batch_shapes = {mvn.batch_shape for mvn in mvns} + if len(mvn_batch_shapes) == 1: + # All MVNs have the same batch shape. We can return as is. + return mvns + # This call will error out if they're not broadcastable. + # If they're broadcastable, it'll log a warning. + target_model_shape = self.batch_shape + max_batch = max(mvn_batch_shapes, key=len) + max_len = len(max_batch) + input_batch_len = max_len - len(target_model_shape) + for i in range(len(mvns)): # Loop over index since we modify contents. + while len(mvns[i].batch_shape) < max_len: + # MVN is missing batch dimensions. Unsqueeze as needed. + mvns[i] = mvns[i].unsqueeze(input_batch_len) + if mvns[i].batch_shape != max_batch: + # Expand to match the batch shapes. + mvns[i] = mvns[i].expand(max_batch) + return mvns + class MultiTaskGPyTorchModel(GPyTorchModel, ABC): r"""Abstract base class for multi-task models based on GPyTorch models. diff --git a/test/models/test_model_list_gp_regression.py b/test/models/test_model_list_gp_regression.py index 1396c47c39..ad4aeda5c0 100644 --- a/test/models/test_model_list_gp_regression.py +++ b/test/models/test_model_list_gp_regression.py @@ -12,7 +12,8 @@ from botorch.acquisition.objective import ScalarizedPosteriorTransform from botorch.exceptions.errors import BotorchTensorDimensionError, UnsupportedError from botorch.exceptions.warnings import OptimizationWarning -from botorch.fit import fit_gpytorch_mll +from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_mll +from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP from botorch.models.gp_regression import SingleTaskGP from botorch.models.model_list_gp_regression import ModelListGP from botorch.models.multitask import MultiTaskGP @@ -733,3 +734,29 @@ def test_fantasize_with_outcome_transform_fixed_noise(self) -> None: self.assertTrue( torch.equal(fm_i.train_inputs[0][0][-1], X[1 - i]) ) + + def test_with_different_batch_shapes(self) -> None: + # Tests that we can mix single task and SAAS models together. + tkwargs = {"device": self.device, "dtype": torch.double} + m1 = SaasFullyBayesianSingleTaskGP( + train_X=torch.rand(10, 2, **tkwargs), train_Y=torch.rand(10, 1, **tkwargs) + ) + fit_fully_bayesian_model_nuts(m1, warmup_steps=0, num_samples=8, thinning=1) + m2 = SingleTaskGP( + train_X=torch.rand(10, 2, **tkwargs), train_Y=torch.rand(10, 1, **tkwargs) + ) + m = ModelListGP(m1, m2) + with self.assertWarnsRegex(UserWarning, "Component models of"): + self.assertEqual(m.batch_shape, torch.Size([8])) + # Non-batched evaluation. + with self.assertWarnsRegex(UserWarning, "Component models of"): + post = m.posterior(torch.rand(1, 2)) + self.assertEqual(post.batch_shape, torch.Size([8])) + self.assertEqual(post.rsample(torch.Size([2])).shape, torch.Size([2, 8, 1, 2])) + # Batched evaluation. + with self.assertWarnsRegex(UserWarning, "Component models of"): + post = m.posterior(torch.rand(5, 1, 2)) + self.assertEqual(post.batch_shape, torch.Size([5, 8])) + self.assertEqual( + post.rsample(torch.Size([2])).shape, torch.Size([2, 5, 8, 1, 2]) + )