diff --git a/botorch/models/fully_bayesian_multitask.py b/botorch/models/fully_bayesian_multitask.py index b621f65d27..c262183624 100644 --- a/botorch/models/fully_bayesian_multitask.py +++ b/botorch/models/fully_bayesian_multitask.py @@ -395,8 +395,7 @@ def construct_inputs( Args: training_data: A `SupervisedDataset` or a `MultiTaskDataset`. - task_feature: Column index of embedded task indicator features. For details, - see `parse_training_data`. + task_feature: Column index of embedded task indicator features. rank: The rank of the cross-task covariance matrix. """ inputs = super().construct_inputs( diff --git a/botorch/models/model.py b/botorch/models/model.py index 8c2be562dd..acff089e2f 100644 --- a/botorch/models/model.py +++ b/botorch/models/model.py @@ -42,6 +42,7 @@ from botorch.posteriors import Posterior, PosteriorList from botorch.sampling.base import MCSampler from botorch.sampling.list_sampler import ListSampler +from botorch.utils.containers import BotorchContainer from botorch.utils.datasets import SupervisedDataset from botorch.utils.transforms import is_fully_bayesian from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood @@ -187,11 +188,28 @@ def construct_inputs( cls, training_data: SupervisedDataset, **kwargs: Any, - ) -> Dict[str, Any]: - r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`.""" - from botorch.models.utils.parse_training_data import parse_training_data + ) -> Dict[str, Union[BotorchContainer, Tensor]]: + """ + Construct `Model` keyword arguments from a `SupervisedDataset`. + + Args: + training_data: A `SupervisedDataset`, with attributes `train_X`, + `train_Y`, and, optionally, `train_Yvar`. + kwargs: Ignored. - return parse_training_data(cls, training_data, **kwargs) + Returns: + A dict of keyword arguments that can be used to initialize a `Model`, + with keys `train_X`, `train_Y`, and, optionally, `train_Yvar`. + """ + if not isinstance(training_data, SupervisedDataset): + raise TypeError( + "Expected `training_data` to be a `SupervisedDataset`, but got " + f"{type(training_data)}." + ) + parsed_data = {"train_X": training_data.X, "train_Y": training_data.Y} + if training_data.Yvar is not None: + parsed_data["train_Yvar"] = training_data.Yvar + return parsed_data def transform_inputs( self, diff --git a/botorch/models/multitask.py b/botorch/models/multitask.py index f722f73f1b..606cdec7f7 100644 --- a/botorch/models/multitask.py +++ b/botorch/models/multitask.py @@ -283,8 +283,7 @@ def construct_inputs( Args: training_data: A `SupervisedDataset` or a `MultiTaskDataset`. - task_feature: Column index of embedded task indicator features. For details, - see `parse_training_data`. + task_feature: Column index of embedded task indicator features. output_tasks: A list of task indices for which to compute model outputs for. If omitted, return outputs for all task indices. task_covar_prior: A GPyTorch `Prior` object to use as prior on diff --git a/botorch/models/pairwise_gp.py b/botorch/models/pairwise_gp.py index 8fbd8c5e14..0913f2bfe3 100644 --- a/botorch/models/pairwise_gp.py +++ b/botorch/models/pairwise_gp.py @@ -28,7 +28,6 @@ import torch from botorch.acquisition.objective import PosteriorTransform from botorch.exceptions import UnsupportedError - from botorch.exceptions.warnings import _get_single_precision_warning, InputDataWarning from botorch.models.likelihoods.pairwise import ( PairwiseLikelihood, @@ -39,6 +38,7 @@ from botorch.models.utils.assorted import consolidate_duplicates from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.posteriors.posterior import Posterior +from botorch.utils.datasets import RankingDataset, SupervisedDataset from gpytorch import settings from gpytorch.constraints import GreaterThan, Interval from gpytorch.distributions.multivariate_normal import MultivariateNormal @@ -784,6 +784,39 @@ def batch_shape(self) -> torch.Size: else: return self.datapoints.shape[:-2] + @classmethod + def construct_inputs( + cls, + training_data: SupervisedDataset, + **kwargs: Any, + ) -> Dict[str, Tensor]: + r""" + Construct `Model` keyword arguments from a `RankingDataset`. + + Args: + training_data: A `RankingDataset`, with attributes `train_X`, + `train_Y`, and, optionally, `train_Yvar`. + kwargs: Ignored. + + Returns: + A dict of keyword arguments that can be used to initialize a + `PairwiseGP`, including `datapoints` and `comparisons`. + """ + if not isinstance(training_data, RankingDataset): + raise UnsupportedError( + "Only `RankingDataset` is supported for `PairwiseGP`. Received " + f"{type(training_data)}." + ) + datapoints = training_data._X.values + comparisons = training_data._X.indices + comp_order = training_data.Y + comparisons = torch.gather(input=comparisons, dim=-1, index=comp_order) + + return { + "datapoints": datapoints, + "comparisons": comparisons, + } + def set_train_data( self, datapoints: Optional[Tensor] = None, diff --git a/botorch/models/utils/__init__.py b/botorch/models/utils/__init__.py index 746e8ac216..97e65194e3 100644 --- a/botorch/models/utils/__init__.py +++ b/botorch/models/utils/__init__.py @@ -19,9 +19,6 @@ validate_input_scaling, ) -# # TODO: Omitted to avoid circular dependency created by `Model.construct_inputs` -# from botorch.models.utils.parse_training_data import parse_training_data - __all__ = [ "_make_X_full", diff --git a/botorch/models/utils/parse_training_data.py b/botorch/models/utils/parse_training_data.py deleted file mode 100644 index 349f6d0a79..0000000000 --- a/botorch/models/utils/parse_training_data.py +++ /dev/null @@ -1,71 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -r"""Parsing rules for BoTorch datasets.""" - - -from __future__ import annotations - -from typing import Any, Dict, Type - -import torch -from botorch.models.model import Model -from botorch.models.pairwise_gp import PairwiseGP -from botorch.utils.datasets import RankingDataset, SupervisedDataset -from botorch.utils.dispatcher import Dispatcher -from torch import Tensor - - -def _encoder(arg: Any) -> Type: - # Allow type variables to be passed as arguments at runtime - return arg if isinstance(arg, type) else type(arg) - - -dispatcher = Dispatcher("parse_training_data", encoder=_encoder) - - -def parse_training_data( - consumer: Any, - training_data: SupervisedDataset, - **kwargs: Any, -) -> Dict[str, Tensor]: - r"""Prepares a dataset for consumption by a given object. - - Args: - training_datas: A SupervisedDataset. - consumer: The object that will consume the parsed data, or type thereof. - - Returns: - A dictionary containing the extracted information. - """ - return dispatcher(consumer, training_data, **kwargs) - - -@dispatcher.register(Model, SupervisedDataset) -def _parse_model_supervised( - consumer: Model, dataset: SupervisedDataset, **ignore: Any -) -> Dict[str, Tensor]: - parsed_data = {"train_X": dataset.X, "train_Y": dataset.Y} - if dataset.Yvar is not None: - parsed_data["train_Yvar"] = dataset.Yvar - return parsed_data - - -@dispatcher.register(PairwiseGP, RankingDataset) -def _parse_pairwiseGP_ranking( - consumer: PairwiseGP, dataset: RankingDataset, **ignore: Any -) -> Dict[str, Tensor]: - # TODO: [T163045056] Not sure what the point of the special container is if we have - # to further process it here. We should move this logic into RankingDataset. - datapoints = dataset._X.values - comparisons = dataset._X.indices - comp_order = dataset.Y - comparisons = torch.gather(input=comparisons, dim=-1, index=comp_order) - - return { - "datapoints": datapoints, - "comparisons": comparisons, - } diff --git a/sphinx/source/models.rst b/sphinx/source/models.rst index c8a6f8fe55..20d340a050 100644 --- a/sphinx/source/models.rst +++ b/sphinx/source/models.rst @@ -158,11 +158,6 @@ Transform Utilities Utilities ------------------------------------------- -Dataset Parsing -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: botorch.models.utils.parse_training_data - :members: - GPyTorch Module Constructors ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.models.utils.gpytorch_modules diff --git a/test/models/test_model.py b/test/models/test_model.py index 0ef53b9a13..dd30c106ea 100644 --- a/test/models/test_model.py +++ b/test/models/test_model.py @@ -4,17 +4,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from unittest.mock import patch import torch from botorch.acquisition.objective import PosteriorTransform from botorch.exceptions.errors import InputDataError from botorch.models.deterministic import GenericDeterministicModel from botorch.models.model import Model, ModelDict, ModelList -from botorch.models.utils import parse_training_data from botorch.posteriors.deterministic import DeterministicPosterior from botorch.posteriors.posterior_list import PosteriorList +from botorch.utils.datasets import SupervisedDataset from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior +from torch import rand class NotSoAbstractBaseModel(Model): @@ -55,12 +55,27 @@ def test_not_so_abstract_base_model(self): with self.assertRaises(NotImplementedError): model.subset_output([0]) - def test_construct_inputs(self): - with patch.object( - parse_training_data, "parse_training_data", return_value={"a": 1} + def test_construct_inputs(self) -> None: + model = NotSoAbstractBaseModel() + with self.subTest("Wrong training data type"), self.assertRaisesRegex( + TypeError, "Expected `training_data` to be a `SupervisedDataset`, but got " ): - model = NotSoAbstractBaseModel() - self.assertEqual(model.construct_inputs(None), {"a": 1}) + model.construct_inputs(training_data=None) + + x = rand(3, 2) + y = rand(3, 1) + dataset = SupervisedDataset( + X=x, Y=y, feature_names=["a", "b"], outcome_names=["y"] + ) + model_inputs = model.construct_inputs(training_data=dataset) + self.assertEqual(model_inputs, {"train_X": x, "train_Y": y}) + + yvar = rand(3, 1) + dataset = SupervisedDataset( + X=x, Y=y, Yvar=yvar, feature_names=["a", "b"], outcome_names=["y"] + ) + model_inputs = model.construct_inputs(training_data=dataset) + self.assertEqual(model_inputs, {"train_X": x, "train_Y": y, "train_Yvar": yvar}) def test_model_list(self): tkwargs = {"device": self.device, "dtype": torch.double} diff --git a/test/models/test_pairwise_gp.py b/test/models/test_pairwise_gp.py index d0e0b3dac6..3f1f2deed0 100644 --- a/test/models/test_pairwise_gp.py +++ b/test/models/test_pairwise_gp.py @@ -28,6 +28,8 @@ from botorch.models.transforms.input import Normalize from botorch.posteriors import GPyTorchPosterior from botorch.sampling.pairwise_samplers import PairwiseSobolQMCNormalSampler +from botorch.utils.containers import SliceContainer +from botorch.utils.datasets import RankingDataset, SupervisedDataset from botorch.utils.testing import BotorchTestCase from gpytorch.kernels import RBFKernel, ScaleKernel from gpytorch.kernels.linear_kernel import LinearKernel @@ -75,6 +77,33 @@ def _get_model_and_data( model = PairwiseGP(**model_kwargs) return model, model_kwargs + def test_construct_inputs(self) -> None: + datapoints = torch.rand(3, 2) + indices = torch.tensor([[0, 1], [1, 2]], dtype=torch.long) + event_shape = torch.Size([2 * datapoints.shape[-1]]) + dataset_X = SliceContainer(datapoints, indices, event_shape=event_shape) + dataset_Y = torch.tensor([[0, 1], [1, 0]]).expand(indices.shape) + dataset = RankingDataset( + X=dataset_X, Y=dataset_Y, feature_names=["a", "b"], outcome_names=["y"] + ) + model_inputs = PairwiseGP.construct_inputs(dataset) + comparisons = torch.tensor([[0, 1], [2, 1]], dtype=torch.long) + self.assertSetEqual(set(model_inputs.keys()), {"datapoints", "comparisons"}) + self.assertTrue(torch.equal(model_inputs["datapoints"], datapoints)) + self.assertTrue(torch.equal(model_inputs["comparisons"], comparisons)) + + with self.subTest("Input other than RankingDataset"): + dataset = SupervisedDataset( + X=datapoints, + Y=torch.rand(3, 1), + feature_names=["a", "b"], + outcome_names=["y"], + ) + with self.assertRaisesRegex( + UnsupportedError, "Only `RankingDataset` is supported" + ): + PairwiseGP.construct_inputs(dataset) + def test_pairwise_gp(self) -> None: torch.manual_seed(random.randint(0, 10)) for batch_shape, likelihood_cls in itertools.product( diff --git a/test/models/utils/test_parse_training_data.py b/test/models/utils/test_parse_training_data.py deleted file mode 100644 index f7fa593861..0000000000 --- a/test/models/utils/test_parse_training_data.py +++ /dev/null @@ -1,61 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from botorch.models.model import Model -from botorch.models.pairwise_gp import PairwiseGP -from botorch.models.utils.parse_training_data import parse_training_data -from botorch.utils.containers import SliceContainer -from botorch.utils.datasets import RankingDataset, SupervisedDataset -from botorch.utils.testing import BotorchTestCase -from torch import long, rand, Size, tensor - - -class TestParseTrainingData(BotorchTestCase): - def test_supervised(self): - with self.assertRaisesRegex(NotImplementedError, "Could not find signature"): - parse_training_data(Model, None) - - dataset = SupervisedDataset( - X=rand(3, 2), Y=rand(3, 1), feature_names=["a", "b"], outcome_names=["y"] - ) - with self.assertRaisesRegex(NotImplementedError, "Could not find signature"): - parse_training_data(None, dataset) - - parse = parse_training_data(Model, dataset) - self.assertIsInstance(parse, dict) - self.assertTrue(torch.equal(dataset.X, parse["train_X"])) - self.assertTrue(torch.equal(dataset.Y, parse["train_Y"])) - self.assertTrue("train_Yvar" not in parse) - - # Test with noise - dataset = SupervisedDataset( - X=rand(3, 2), - Y=rand(3, 1), - Yvar=rand(3, 1), - feature_names=["a", "b"], - outcome_names=["y"], - ) - parse = parse_training_data(Model, dataset) - self.assertTrue(torch.equal(dataset.X, parse["train_X"])) - self.assertTrue(torch.equal(dataset.Y, parse["train_Y"])) - self.assertTrue(torch.equal(dataset.Yvar, parse["train_Yvar"])) - - def test_pairwiseGP_ranking(self): - # Test parsing Ranking Dataset for PairwiseGP - datapoints = rand(3, 2) - indices = tensor([[0, 1], [1, 2]], dtype=long) - event_shape = Size([2 * datapoints.shape[-1]]) - dataset_X = SliceContainer(datapoints, indices, event_shape=event_shape) - dataset_Y = tensor([[0, 1], [1, 0]]).expand(indices.shape) - dataset = RankingDataset( - X=dataset_X, Y=dataset_Y, feature_names=["a", "b"], outcome_names=["y"] - ) - parse = parse_training_data(PairwiseGP, dataset) - self.assertTrue(dataset._X.values.equal(parse["datapoints"])) - - comparisons = tensor([[0, 1], [2, 1]], dtype=long) - self.assertTrue(comparisons.equal(parse["comparisons"]))