Skip to content

Commit

Permalink
Use regular inheritance instead of dispatcher to special-case Pairwis…
Browse files Browse the repository at this point in the history
…eGP logic (pytorch#2176)

Summary:
Pull Request resolved: pytorch#2176

* Replace dispatcher with inheritance: `Model.construct_inputs` calls a dispatcher, `parse_training_data`, that dispatches to one function if the Model is a PairwiseGP and the dataset is a RankingDataset and another for the general case of a Model and SupervisedDataset. This can be achieved with inheritance, putting the general case in Model and special case for PairwiseGP in PairwiseGP.construct_inputs. This allows the `parse_training_data` dispatcher to be removed.
* Error for using an input that is not a `RankingDataset` to `PairwiseGP.construct_inputs`: Previously, if `PairwiseGP.construct_inputs` was called with a dataset that was not a `RankingDataset`, it would fall back to the general behavior for `Model.construct_inputs`, producing inputs that are not valid for `PairwiseGP`. This now produces an error.

Reviewed By: saitcakmak

Differential Revision: D53011931

fbshipit-source-id: d697df3fffeaa5c9213a52a1dd39afd3b7a0ee25
  • Loading branch information
esantorella authored and stefanpricopie committed Feb 27, 2024
1 parent e30dfe3 commit 40151dd
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 156 deletions.
3 changes: 1 addition & 2 deletions botorch/models/fully_bayesian_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,7 @@ def construct_inputs(
Args:
training_data: A `SupervisedDataset` or a `MultiTaskDataset`.
task_feature: Column index of embedded task indicator features. For details,
see `parse_training_data`.
task_feature: Column index of embedded task indicator features.
rank: The rank of the cross-task covariance matrix.
"""
inputs = super().construct_inputs(
Expand Down
26 changes: 22 additions & 4 deletions botorch/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from botorch.posteriors import Posterior, PosteriorList
from botorch.sampling.base import MCSampler
from botorch.sampling.list_sampler import ListSampler
from botorch.utils.containers import BotorchContainer
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.transforms import is_fully_bayesian
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
Expand Down Expand Up @@ -187,11 +188,28 @@ def construct_inputs(
cls,
training_data: SupervisedDataset,
**kwargs: Any,
) -> Dict[str, Any]:
r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`."""
from botorch.models.utils.parse_training_data import parse_training_data
) -> Dict[str, Union[BotorchContainer, Tensor]]:
"""
Construct `Model` keyword arguments from a `SupervisedDataset`.
Args:
training_data: A `SupervisedDataset`, with attributes `train_X`,
`train_Y`, and, optionally, `train_Yvar`.
kwargs: Ignored.
return parse_training_data(cls, training_data, **kwargs)
Returns:
A dict of keyword arguments that can be used to initialize a `Model`,
with keys `train_X`, `train_Y`, and, optionally, `train_Yvar`.
"""
if not isinstance(training_data, SupervisedDataset):
raise TypeError(
"Expected `training_data` to be a `SupervisedDataset`, but got "
f"{type(training_data)}."
)
parsed_data = {"train_X": training_data.X, "train_Y": training_data.Y}
if training_data.Yvar is not None:
parsed_data["train_Yvar"] = training_data.Yvar
return parsed_data

def transform_inputs(
self,
Expand Down
3 changes: 1 addition & 2 deletions botorch/models/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,7 @@ def construct_inputs(
Args:
training_data: A `SupervisedDataset` or a `MultiTaskDataset`.
task_feature: Column index of embedded task indicator features. For details,
see `parse_training_data`.
task_feature: Column index of embedded task indicator features.
output_tasks: A list of task indices for which to compute model
outputs for. If omitted, return outputs for all task indices.
task_covar_prior: A GPyTorch `Prior` object to use as prior on
Expand Down
35 changes: 34 additions & 1 deletion botorch/models/pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import torch
from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions import UnsupportedError

from botorch.exceptions.warnings import _get_single_precision_warning, InputDataWarning
from botorch.models.likelihoods.pairwise import (
PairwiseLikelihood,
Expand All @@ -39,6 +38,7 @@
from botorch.models.utils.assorted import consolidate_duplicates
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
from botorch.utils.datasets import RankingDataset, SupervisedDataset
from gpytorch import settings
from gpytorch.constraints import GreaterThan, Interval
from gpytorch.distributions.multivariate_normal import MultivariateNormal
Expand Down Expand Up @@ -784,6 +784,39 @@ def batch_shape(self) -> torch.Size:
else:
return self.datapoints.shape[:-2]

@classmethod
def construct_inputs(
cls,
training_data: SupervisedDataset,
**kwargs: Any,
) -> Dict[str, Tensor]:
r"""
Construct `Model` keyword arguments from a `RankingDataset`.
Args:
training_data: A `RankingDataset`, with attributes `train_X`,
`train_Y`, and, optionally, `train_Yvar`.
kwargs: Ignored.
Returns:
A dict of keyword arguments that can be used to initialize a
`PairwiseGP`, including `datapoints` and `comparisons`.
"""
if not isinstance(training_data, RankingDataset):
raise UnsupportedError(
"Only `RankingDataset` is supported for `PairwiseGP`. Received "
f"{type(training_data)}."
)
datapoints = training_data._X.values
comparisons = training_data._X.indices
comp_order = training_data.Y
comparisons = torch.gather(input=comparisons, dim=-1, index=comp_order)

return {
"datapoints": datapoints,
"comparisons": comparisons,
}

def set_train_data(
self,
datapoints: Optional[Tensor] = None,
Expand Down
3 changes: 0 additions & 3 deletions botorch/models/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
validate_input_scaling,
)

# # TODO: Omitted to avoid circular dependency created by `Model.construct_inputs`
# from botorch.models.utils.parse_training_data import parse_training_data


__all__ = [
"_make_X_full",
Expand Down
71 changes: 0 additions & 71 deletions botorch/models/utils/parse_training_data.py

This file was deleted.

5 changes: 0 additions & 5 deletions sphinx/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,6 @@ Transform Utilities
Utilities
-------------------------------------------

Dataset Parsing
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.models.utils.parse_training_data
:members:

GPyTorch Module Constructors
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.models.utils.gpytorch_modules
Expand Down
29 changes: 22 additions & 7 deletions test/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from unittest.mock import patch

import torch
from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions.errors import InputDataError
from botorch.models.deterministic import GenericDeterministicModel
from botorch.models.model import Model, ModelDict, ModelList
from botorch.models.utils import parse_training_data
from botorch.posteriors.deterministic import DeterministicPosterior
from botorch.posteriors.posterior_list import PosteriorList
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
from torch import rand


class NotSoAbstractBaseModel(Model):
Expand Down Expand Up @@ -55,12 +55,27 @@ def test_not_so_abstract_base_model(self):
with self.assertRaises(NotImplementedError):
model.subset_output([0])

def test_construct_inputs(self):
with patch.object(
parse_training_data, "parse_training_data", return_value={"a": 1}
def test_construct_inputs(self) -> None:
model = NotSoAbstractBaseModel()
with self.subTest("Wrong training data type"), self.assertRaisesRegex(
TypeError, "Expected `training_data` to be a `SupervisedDataset`, but got "
):
model = NotSoAbstractBaseModel()
self.assertEqual(model.construct_inputs(None), {"a": 1})
model.construct_inputs(training_data=None)

x = rand(3, 2)
y = rand(3, 1)
dataset = SupervisedDataset(
X=x, Y=y, feature_names=["a", "b"], outcome_names=["y"]
)
model_inputs = model.construct_inputs(training_data=dataset)
self.assertEqual(model_inputs, {"train_X": x, "train_Y": y})

yvar = rand(3, 1)
dataset = SupervisedDataset(
X=x, Y=y, Yvar=yvar, feature_names=["a", "b"], outcome_names=["y"]
)
model_inputs = model.construct_inputs(training_data=dataset)
self.assertEqual(model_inputs, {"train_X": x, "train_Y": y, "train_Yvar": yvar})

def test_model_list(self):
tkwargs = {"device": self.device, "dtype": torch.double}
Expand Down
29 changes: 29 additions & 0 deletions test/models/test_pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from botorch.models.transforms.input import Normalize
from botorch.posteriors import GPyTorchPosterior
from botorch.sampling.pairwise_samplers import PairwiseSobolQMCNormalSampler
from botorch.utils.containers import SliceContainer
from botorch.utils.datasets import RankingDataset, SupervisedDataset
from botorch.utils.testing import BotorchTestCase
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.kernels.linear_kernel import LinearKernel
Expand Down Expand Up @@ -75,6 +77,33 @@ def _get_model_and_data(
model = PairwiseGP(**model_kwargs)
return model, model_kwargs

def test_construct_inputs(self) -> None:
datapoints = torch.rand(3, 2)
indices = torch.tensor([[0, 1], [1, 2]], dtype=torch.long)
event_shape = torch.Size([2 * datapoints.shape[-1]])
dataset_X = SliceContainer(datapoints, indices, event_shape=event_shape)
dataset_Y = torch.tensor([[0, 1], [1, 0]]).expand(indices.shape)
dataset = RankingDataset(
X=dataset_X, Y=dataset_Y, feature_names=["a", "b"], outcome_names=["y"]
)
model_inputs = PairwiseGP.construct_inputs(dataset)
comparisons = torch.tensor([[0, 1], [2, 1]], dtype=torch.long)
self.assertSetEqual(set(model_inputs.keys()), {"datapoints", "comparisons"})
self.assertTrue(torch.equal(model_inputs["datapoints"], datapoints))
self.assertTrue(torch.equal(model_inputs["comparisons"], comparisons))

with self.subTest("Input other than RankingDataset"):
dataset = SupervisedDataset(
X=datapoints,
Y=torch.rand(3, 1),
feature_names=["a", "b"],
outcome_names=["y"],
)
with self.assertRaisesRegex(
UnsupportedError, "Only `RankingDataset` is supported"
):
PairwiseGP.construct_inputs(dataset)

def test_pairwise_gp(self) -> None:
torch.manual_seed(random.randint(0, 10))
for batch_shape, likelihood_cls in itertools.product(
Expand Down
61 changes: 0 additions & 61 deletions test/models/utils/test_parse_training_data.py

This file was deleted.

0 comments on commit 40151dd

Please sign in to comment.