Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use regular inheritance instead of dispatcher to special-case PairwiseGP logic #2176

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions botorch/models/fully_bayesian_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,7 @@ def construct_inputs(
Args:
training_data: A `SupervisedDataset` or a `MultiTaskDataset`.
task_feature: Column index of embedded task indicator features. For details,
see `parse_training_data`.
task_feature: Column index of embedded task indicator features.
rank: The rank of the cross-task covariance matrix.
"""
inputs = super().construct_inputs(
Expand Down
26 changes: 22 additions & 4 deletions botorch/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from botorch.posteriors import Posterior, PosteriorList
from botorch.sampling.base import MCSampler
from botorch.sampling.list_sampler import ListSampler
from botorch.utils.containers import BotorchContainer
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.transforms import is_fully_bayesian
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
Expand Down Expand Up @@ -187,11 +188,28 @@ def construct_inputs(
cls,
training_data: SupervisedDataset,
**kwargs: Any,
) -> Dict[str, Any]:
r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`."""
from botorch.models.utils.parse_training_data import parse_training_data
) -> Dict[str, Union[BotorchContainer, Tensor]]:
"""
Construct `Model` keyword arguments from a `SupervisedDataset`.

Args:
training_data: A `SupervisedDataset`, with attributes `train_X`,
`train_Y`, and, optionally, `train_Yvar`.
kwargs: Ignored.

return parse_training_data(cls, training_data, **kwargs)
Returns:
A dict of keyword arguments that can be used to initialize a `Model`,
with keys `train_X`, `train_Y`, and, optionally, `train_Yvar`.
"""
if not isinstance(training_data, SupervisedDataset):
raise TypeError(
"Expected `training_data` to be a `SupervisedDataset`, but got "
f"{type(training_data)}."
)
parsed_data = {"train_X": training_data.X, "train_Y": training_data.Y}
if training_data.Yvar is not None:
parsed_data["train_Yvar"] = training_data.Yvar
return parsed_data

def transform_inputs(
self,
Expand Down
3 changes: 1 addition & 2 deletions botorch/models/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,7 @@ def construct_inputs(
Args:
training_data: A `SupervisedDataset` or a `MultiTaskDataset`.
task_feature: Column index of embedded task indicator features. For details,
see `parse_training_data`.
task_feature: Column index of embedded task indicator features.
output_tasks: A list of task indices for which to compute model
outputs for. If omitted, return outputs for all task indices.
task_covar_prior: A GPyTorch `Prior` object to use as prior on
Expand Down
35 changes: 34 additions & 1 deletion botorch/models/pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import torch
from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions import UnsupportedError

from botorch.exceptions.warnings import _get_single_precision_warning, InputDataWarning
from botorch.models.likelihoods.pairwise import (
PairwiseLikelihood,
Expand All @@ -39,6 +38,7 @@
from botorch.models.utils.assorted import consolidate_duplicates
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
from botorch.utils.datasets import RankingDataset, SupervisedDataset
from gpytorch import settings
from gpytorch.constraints import GreaterThan, Interval
from gpytorch.distributions.multivariate_normal import MultivariateNormal
Expand Down Expand Up @@ -784,6 +784,39 @@ def batch_shape(self) -> torch.Size:
else:
return self.datapoints.shape[:-2]

@classmethod
def construct_inputs(
cls,
training_data: SupervisedDataset,
**kwargs: Any,
) -> Dict[str, Tensor]:
r"""
Construct `Model` keyword arguments from a `RankingDataset`.
Args:
training_data: A `RankingDataset`, with attributes `train_X`,
`train_Y`, and, optionally, `train_Yvar`.
kwargs: Ignored.
Returns:
A dict of keyword arguments that can be used to initialize a
`PairwiseGP`, including `datapoints` and `comparisons`.
"""
if not isinstance(training_data, RankingDataset):
raise UnsupportedError(
"Only `RankingDataset` is supported for `PairwiseGP`. Received "
f"{type(training_data)}."
)
datapoints = training_data._X.values
comparisons = training_data._X.indices
comp_order = training_data.Y
comparisons = torch.gather(input=comparisons, dim=-1, index=comp_order)

return {
"datapoints": datapoints,
"comparisons": comparisons,
}

def set_train_data(
self,
datapoints: Optional[Tensor] = None,
Expand Down
3 changes: 0 additions & 3 deletions botorch/models/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
validate_input_scaling,
)

# # TODO: Omitted to avoid circular dependency created by `Model.construct_inputs`
# from botorch.models.utils.parse_training_data import parse_training_data


__all__ = [
"_make_X_full",
Expand Down
71 changes: 0 additions & 71 deletions botorch/models/utils/parse_training_data.py

This file was deleted.

5 changes: 0 additions & 5 deletions sphinx/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,6 @@ Transform Utilities
Utilities
-------------------------------------------

Dataset Parsing
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.models.utils.parse_training_data
:members:

GPyTorch Module Constructors
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.models.utils.gpytorch_modules
Expand Down
29 changes: 22 additions & 7 deletions test/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from unittest.mock import patch

import torch
from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions.errors import InputDataError
from botorch.models.deterministic import GenericDeterministicModel
from botorch.models.model import Model, ModelDict, ModelList
from botorch.models.utils import parse_training_data
from botorch.posteriors.deterministic import DeterministicPosterior
from botorch.posteriors.posterior_list import PosteriorList
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
from torch import rand


class NotSoAbstractBaseModel(Model):
Expand Down Expand Up @@ -55,12 +55,27 @@ def test_not_so_abstract_base_model(self):
with self.assertRaises(NotImplementedError):
model.subset_output([0])

def test_construct_inputs(self):
with patch.object(
parse_training_data, "parse_training_data", return_value={"a": 1}
def test_construct_inputs(self) -> None:
model = NotSoAbstractBaseModel()
with self.subTest("Wrong training data type"), self.assertRaisesRegex(
TypeError, "Expected `training_data` to be a `SupervisedDataset`, but got "
):
model = NotSoAbstractBaseModel()
self.assertEqual(model.construct_inputs(None), {"a": 1})
model.construct_inputs(training_data=None)

x = rand(3, 2)
y = rand(3, 1)
dataset = SupervisedDataset(
X=x, Y=y, feature_names=["a", "b"], outcome_names=["y"]
)
model_inputs = model.construct_inputs(training_data=dataset)
self.assertEqual(model_inputs, {"train_X": x, "train_Y": y})

yvar = rand(3, 1)
dataset = SupervisedDataset(
X=x, Y=y, Yvar=yvar, feature_names=["a", "b"], outcome_names=["y"]
)
model_inputs = model.construct_inputs(training_data=dataset)
self.assertEqual(model_inputs, {"train_X": x, "train_Y": y, "train_Yvar": yvar})

def test_model_list(self):
tkwargs = {"device": self.device, "dtype": torch.double}
Expand Down
29 changes: 29 additions & 0 deletions test/models/test_pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from botorch.models.transforms.input import Normalize
from botorch.posteriors import GPyTorchPosterior
from botorch.sampling.pairwise_samplers import PairwiseSobolQMCNormalSampler
from botorch.utils.containers import SliceContainer
from botorch.utils.datasets import RankingDataset, SupervisedDataset
from botorch.utils.testing import BotorchTestCase
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.kernels.linear_kernel import LinearKernel
Expand Down Expand Up @@ -75,6 +77,33 @@ def _get_model_and_data(
model = PairwiseGP(**model_kwargs)
return model, model_kwargs

def test_construct_inputs(self) -> None:
datapoints = torch.rand(3, 2)
indices = torch.tensor([[0, 1], [1, 2]], dtype=torch.long)
event_shape = torch.Size([2 * datapoints.shape[-1]])
dataset_X = SliceContainer(datapoints, indices, event_shape=event_shape)
dataset_Y = torch.tensor([[0, 1], [1, 0]]).expand(indices.shape)
dataset = RankingDataset(
X=dataset_X, Y=dataset_Y, feature_names=["a", "b"], outcome_names=["y"]
)
model_inputs = PairwiseGP.construct_inputs(dataset)
comparisons = torch.tensor([[0, 1], [2, 1]], dtype=torch.long)
self.assertSetEqual(set(model_inputs.keys()), {"datapoints", "comparisons"})
self.assertTrue(torch.equal(model_inputs["datapoints"], datapoints))
self.assertTrue(torch.equal(model_inputs["comparisons"], comparisons))

with self.subTest("Input other than RankingDataset"):
dataset = SupervisedDataset(
X=datapoints,
Y=torch.rand(3, 1),
feature_names=["a", "b"],
outcome_names=["y"],
)
with self.assertRaisesRegex(
UnsupportedError, "Only `RankingDataset` is supported"
):
PairwiseGP.construct_inputs(dataset)

def test_pairwise_gp(self) -> None:
torch.manual_seed(random.randint(0, 10))
for batch_shape, likelihood_cls in itertools.product(
Expand Down
61 changes: 0 additions & 61 deletions test/models/utils/test_parse_training_data.py

This file was deleted.