Skip to content

Commit

Permalink
Support mixing SAAS & SingleTaskGP models in ModelListGP
Browse files Browse the repository at this point in the history
Summary: Adds support for broadcasting MVNs produced by the underlying models to enable mixing together SAAS & SingleTaskGP models within a ModelListGP.

Differential Revision: D68503063
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jan 22, 2025
1 parent 7b803bd commit 562ce08
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
33 changes: 33 additions & 0 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,7 @@ def posterior(
interleaved=False,
)
else:
mvns = self._broadcast_mvns(mvns=mvns)
mvn = (
mvns[0]
if len(mvns) == 1
Expand All @@ -738,6 +739,38 @@ def posterior(
def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Model:
raise NotImplementedError()

def _broadcast_mvns(self, mvns: list[MultivariateNormal]) -> MultivariateNormal:
"""Broadcasts the batch shapes of the given MultivariateNormals.
The MVNs will have a batch shape of `input_batch_shape x model_batch_shape`.
If the model batch shapes are broadcastable, we will broadcast the mvns to
a batch shape of `input_batch_shape x self.batch_shape`.
Args:
mvns: A list of MultivariateNormals.
Returns:
A list of MultivariateNormals with broadcasted batch shapes.
"""
mvn_batch_shapes = {mvn.batch_shape for mvn in mvns}
if len(mvn_batch_shapes) == 1:
# All MVNs have the same batch shape. We can return as is.
return mvns
# This call will error out of they're not broadcastable.
# If they're broadcastable, it'll log a warning.
target_model_shape = self.batch_shape
max_batch = max(mvn_batch_shapes, key=len)
max_len = len(max_batch)
input_batch_len = max_len - len(target_model_shape)
for i in range(len(mvns)): # Loop over index since we modify contents.
while len(mvns[i].batch_shape) < max_len:
# MVN is missing batch dimensions. Unsqueeze as needed.
mvns[i] = mvns[i].unsqueeze(input_batch_len)
if mvns[i].batch_shape != max_batch:
# Expand to match the batch shapes.
mvns[i] = mvns[i].expand(max_batch)
return mvns


class MultiTaskGPyTorchModel(GPyTorchModel, ABC):
r"""Abstract base class for multi-task models based on GPyTorch models.
Expand Down
29 changes: 28 additions & 1 deletion test/models/test_model_list_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from botorch.acquisition.objective import ScalarizedPosteriorTransform
from botorch.exceptions.errors import BotorchTensorDimensionError, UnsupportedError
from botorch.exceptions.warnings import OptimizationWarning
from botorch.fit import fit_gpytorch_mll
from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_mll
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.multitask import MultiTaskGP
Expand Down Expand Up @@ -733,3 +734,29 @@ def test_fantasize_with_outcome_transform_fixed_noise(self) -> None:
self.assertTrue(
torch.equal(fm_i.train_inputs[0][0][-1], X[1 - i])
)

def test_with_different_batch_shapes(self) -> None:
# Tests that we can mix single task and SAAS models together.
tkwargs = {"device": self.device, "dtype": torch.double}
m1 = SaasFullyBayesianSingleTaskGP(
train_X=torch.rand(10, 2, **tkwargs), train_Y=torch.rand(10, 1, **tkwargs)
)
fit_fully_bayesian_model_nuts(m1, warmup_steps=0, num_samples=8, thinning=1)
m2 = SingleTaskGP(
train_X=torch.rand(10, 2, **tkwargs), train_Y=torch.rand(10, 1, **tkwargs)
)
m = ModelListGP(m1, m2)
with self.assertWarnsRegex(UserWarning, "Component models of"):
self.assertEqual(m.batch_shape, torch.Size([8]))
# Non-batched evaluation.
with self.assertWarnsRegex(UserWarning, "Component models of"):
post = m.posterior(torch.rand(1, 2))
self.assertEqual(post.batch_shape, torch.Size([8]))
self.assertEqual(post.rsample(torch.Size([2])).shape, torch.Size([2, 8, 1, 2]))
# Batched evaluation.
with self.assertWarnsRegex(UserWarning, "Component models of"):
post = m.posterior(torch.rand(5, 1, 2))
self.assertEqual(post.batch_shape, torch.Size([5, 8]))
self.assertEqual(
post.rsample(torch.Size([2])).shape, torch.Size([2, 5, 8, 1, 2])
)

0 comments on commit 562ce08

Please sign in to comment.