From abe786ab35faa5fff9076c00948f2ae1e79ec727 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Wed, 26 Jul 2023 20:41:08 -0700 Subject: [PATCH] Decoupled Acquisition Function (#1948) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1948 Introduce an abstract class for decoupled acquisition functions. A decoupled acquisition function where one may intend to evaluate a design on only a subset of the outcomes. Typically this would be handled by fantasizing, where one would fantasize as to what the partial observation would be if one were to evaluate a design on the subset of outcomes (e.g. you only fantasize at those outcomes) Reviewed By: esantorella Differential Revision: D47710904 fbshipit-source-id: e61b3555c5fd93b53990ce3af299650bbb5341e1 --- botorch/acquisition/__init__.py | 2 + botorch/acquisition/decoupled.py | 163 +++++++++++++++++++++++++++++ sphinx/source/acquisition.rst | 5 + test/acquisition/test_decoupled.py | 135 ++++++++++++++++++++++++ 4 files changed, 305 insertions(+) create mode 100644 botorch/acquisition/decoupled.py create mode 100644 test/acquisition/test_decoupled.py diff --git a/botorch/acquisition/__init__.py b/botorch/acquisition/__init__.py index 276cbd3aa0..fd0efa77a9 100644 --- a/botorch/acquisition/__init__.py +++ b/botorch/acquisition/__init__.py @@ -28,6 +28,7 @@ GenericCostAwareUtility, InverseCostWeightedUtility, ) +from botorch.acquisition.decoupled import DecoupledAcquisitionFunction from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction from botorch.acquisition.input_constructors import get_acqf_input_constructor from botorch.acquisition.knowledge_gradient import ( @@ -78,6 +79,7 @@ "AnalyticAcquisitionFunction", "AnalyticExpectedUtilityOfBestOption", "ConstrainedExpectedImprovement", + "DecoupledAcquisitionFunction", "ExpectedImprovement", "LogExpectedImprovement", "LogNoisyExpectedImprovement", diff --git a/botorch/acquisition/decoupled.py b/botorch/acquisition/decoupled.py new file mode 100644 index 0000000000..59e79bc050 --- /dev/null +++ b/botorch/acquisition/decoupled.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r"""Abstract base module for decoupled acquisition functions.""" + +from __future__ import annotations + +import warnings +from abc import ABC +from typing import Optional + +import torch +from botorch.acquisition.acquisition import AcquisitionFunction +from botorch.exceptions import BotorchWarning +from botorch.exceptions.errors import BotorchTensorDimensionError +from botorch.logging import shape_to_str + +from botorch.models.model import ModelList +from torch import Tensor + + +class DecoupledAcquisitionFunction(AcquisitionFunction, ABC): + """ + Abstract base class for decoupled acquisition functions. + A decoupled acquisition function where one may intend to + evaluate a design on only a subset of the outcomes. + Typically this would be handled by fantasizing, where one + would fantasize as to what the partial observation would + be if one were to evaluate a design on the subset of + outcomes (e.g. you only fantasize at those outcomes). The + `X_evaluation_mask` specifies which outcomes should be + evaluated for each design. `X_evaluation_mask` is `q x m`, + where there are q design points in the batch and m outcomes. + In the asynchronous case, where there are n' pending points, + we need to track which outcomes each pending point should be + evaluated on. In this case, we concatenate + `X_pending_evaluation_mask` with `X_evaluation_mask` to obtain + the full evaluation_mask. + + + This abstract class handles generating and updating an evaluation mask, + which is a boolean tensor indicating which outcomes a given design is + being evaluated on. The evaluation mask has shape `(n' + q) x m`, where + n' is the number of pending points and the q represents the new + candidates to be generated. + + If `X(_pending)_evaluation_mas`k is None, it is assumed that `X(_pending)` + will be evaluated on all outcomes. + """ + + def __init__( + self, model: ModelList, X_evaluation_mask: Optional[Tensor] = None, **kwargs + ) -> None: + r"""Initialize. + + Args: + model: A model + X_evaluation_mask: A `q x m`-dim boolean tensor + indicating which outcomes the decoupled acquisition + function should generate new candidates for. + """ + if not isinstance(model, ModelList): + raise ValueError(f"{self.__class__.__name__} requires using a ModelList.") + super().__init__(model=model, **kwargs) + self.num_outputs = model.num_outputs + self.X_evaluation_mask = X_evaluation_mask + self.X_pending_evaluation_mask = None + self.X_pending = None + + @property + def X_evaluation_mask(self) -> Optional[Tensor]: + r"""Get the evaluation indices for the new candidate.""" + return self._X_evaluation_mask + + @X_evaluation_mask.setter + def X_evaluation_mask(self, X_evaluation_mask: Optional[Tensor] = None) -> None: + r"""Set the evaluation indices for the new candidate.""" + if X_evaluation_mask is not None: + # TODO: Add batch support + if ( + X_evaluation_mask.ndim != 2 + or X_evaluation_mask.shape[-1] != self.num_outputs + ): + raise BotorchTensorDimensionError( + "Expected X_evaluation_mask to be `q x m`, but got shape" + f" {shape_to_str(X_evaluation_mask.shape)}." + ) + self._X_evaluation_mask = X_evaluation_mask + + def set_X_pending( + self, + X_pending: Optional[Tensor] = None, + X_pending_evaluation_mask: Optional[Tensor] = None, + ) -> None: + r"""Informs the AF about pending design points for different outcomes. + + Args: + X_pending: A `n' x d` Tensor with `n'` `d`-dim design points that have + been submitted for evaluation but have not yet been evaluated. + X_pending_evaluation_mask: A `n' x m`-dim tensor of booleans indicating + for which outputs the pending point is being evaluated on. If + `X_pending_evaluation_mask` is `None`, it is assumed that + `X_pending` will be evaluated on all outcomes. + """ + if X_pending is not None: + if X_pending.requires_grad: + warnings.warn( + "Pending points require a gradient but the acquisition function" + " will not provide a gradient to these points.", + BotorchWarning, + ) + self.X_pending = X_pending.detach().clone() + if X_pending_evaluation_mask is not None: + if ( + X_pending_evaluation_mask.ndim != 2 + or X_pending_evaluation_mask.shape[0] != X_pending.shape[0] + or X_pending_evaluation_mask.shape[1] != self.num_outputs + ): + raise BotorchTensorDimensionError( + f"Expected `X_pending_evaluation_mask` of shape " + f"`{X_pending.shape[0]} x {self.num_outputs}`, but " + f"got {shape_to_str(X_pending_evaluation_mask.shape)}." + ) + self.X_pending_evaluation_mask = X_pending_evaluation_mask + elif self.X_evaluation_mask is not None: + raise ValueError( + "If `self.X_evaluation_mask` is not None, then " + "`X_pending_evaluation_mask` must be provided." + ) + + else: + self.X_pending = X_pending + self.X_pending_evaluation_mask = X_pending_evaluation_mask + + def construct_evaluation_mask(self, X: Tensor) -> Optional[Tensor]: + r"""Construct the boolean evaluation mask for X and X_pending + + Args: + X: A `batch_shape x n x d`-dim tensor of designs. + + Returns: + A `n + n' x m`-dim tensor of booleans indicating + which outputs should be evaluated. + """ + if self.X_pending_evaluation_mask is not None: + X_evaluation_mask = self.X_evaluation_mask + if X_evaluation_mask is None: + # evaluate all objectives for X + X_evaluation_mask = torch.ones( + X.shape[-2], self.num_outputs, dtype=torch.bool, device=X.device + ) + elif X_evaluation_mask.shape[0] != X.shape[-2]: + raise BotorchTensorDimensionError( + "Expected the -2 dimension of X and X_evaluation_mask to match." + ) + # construct mask for X + return torch.cat( + [X_evaluation_mask, self.X_pending_evaluation_mask], dim=-2 + ) + return self.X_evaluation_mask diff --git a/sphinx/source/acquisition.rst b/sphinx/source/acquisition.rst index c46c906a89..c5e8cd5a5c 100644 --- a/sphinx/source/acquisition.rst +++ b/sphinx/source/acquisition.rst @@ -26,6 +26,11 @@ Cached Cholesky Acquisition Function API .. automodule:: botorch.acquisition.cached_cholesky :members: +Decoupled Acquisition Function API +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.acquisition.decoupled + :members: + Monte-Carlo Acquisition Function API ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. currentmodule:: botorch.acquisition.monte_carlo diff --git a/test/acquisition/test_decoupled.py b/test/acquisition/test_decoupled.py new file mode 100644 index 0000000000..bf15ea3469 --- /dev/null +++ b/test/acquisition/test_decoupled.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +import torch +from botorch import settings +from botorch.acquisition.decoupled import DecoupledAcquisitionFunction +from botorch.exceptions import BotorchTensorDimensionError, BotorchWarning +from botorch.logging import shape_to_str +from botorch.models import ModelListGP, SingleTaskGP +from botorch.utils.testing import BotorchTestCase + + +class DummyDecoupledAcquisitionFunction(DecoupledAcquisitionFunction): + def forward(self, X): + pass + + +class TestDecoupledAcquisitionFunction(BotorchTestCase): + def test_decoupled_acquisition_function(self): + msg = ( + "Can't instantiate abstract class DecoupledAcquisitionFunction" + " with abstract method forward" + ) + with self.assertRaisesRegex(TypeError, msg): + DecoupledAcquisitionFunction() + # test raises error if model is not ModelList + msg = "DummyDecoupledAcquisitionFunction requires using a ModelList." + model = SingleTaskGP( + torch.rand(1, 3, device=self.device), torch.rand(1, 2, device=self.device) + ) + with self.assertRaisesRegex(ValueError, msg): + DummyDecoupledAcquisitionFunction(model=model) + m = SingleTaskGP( + torch.rand(1, 3, device=self.device), torch.rand(1, 1, device=self.device) + ) + model = ModelListGP(m, m) + # basic test + af = DummyDecoupledAcquisitionFunction(model=model) + self.assertIs(af.model, model) + self.assertIsNone(af.X_evaluation_mask) + self.assertIsNone(af.X_pending) + # test set X_evaluation_mask + # test wrong number of outputs + eval_mask = torch.randint(0, 2, (2, 3), device=self.device).bool() + msg = ( + "Expected X_evaluation_mask to be `q x m`, but got shape" + f" {shape_to_str(eval_mask.shape)}." + ) + with self.assertRaisesRegex(BotorchTensorDimensionError, msg): + af.X_evaluation_mask = eval_mask + # test more than 2 dimensions + eval_mask.unsqueeze_(0) + msg = ( + "Expected X_evaluation_mask to be `q x m`, but got shape" + f" {shape_to_str(eval_mask.shape)}." + ) + with self.assertRaisesRegex(BotorchTensorDimensionError, msg): + af.X_evaluation_mask = eval_mask + + # set eval_mask + eval_mask = eval_mask[0, :, :2] + af.X_evaluation_mask = eval_mask + self.assertIs(af.X_evaluation_mask, eval_mask) + + # test set_X_pending + X_pending = torch.rand(1, 1, device=self.device) + msg = ( + "If `self.X_evaluation_mask` is not None, then " + "`X_pending_evaluation_mask` must be provided." + ) + with self.assertRaisesRegex(ValueError, msg): + af.set_X_pending(X_pending=X_pending) + af.X_evaluation_mask = None + X_pending = X_pending.requires_grad_(True) + with warnings.catch_warnings(record=True) as ws, settings.debug(True): + af.set_X_pending(X_pending) + self.assertEqual(af.X_pending, X_pending) + self.assertEqual(sum(issubclass(w.category, BotorchWarning) for w in ws), 1) + self.assertIsNone(af.X_evaluation_mask) + + # test setting X_pending with X_pending_evaluation_mask + X_pending = torch.rand(3, 1, device=self.device) + # test raises exception + # wrong number of outputs, wrong number of dims, wrong number of rows + for shape in ([3, 1], [1, 3, 2], [1, 2]): + eval_mask = torch.randint(0, 2, shape, device=self.device).bool() + msg = ( + f"Expected `X_pending_evaluation_mask` of shape `{X_pending.shape[0]} " + f"x {model.num_outputs}`, but got " + f"{shape_to_str(eval_mask.shape)}." + ) + + with self.assertRaisesRegex(BotorchTensorDimensionError, msg): + af.set_X_pending( + X_pending=X_pending, X_pending_evaluation_mask=eval_mask + ) + eval_mask = torch.randint(0, 2, (3, 2), device=self.device).bool() + af.set_X_pending(X_pending=X_pending, X_pending_evaluation_mask=eval_mask) + self.assertTrue(torch.equal(af.X_pending, X_pending)) + self.assertIs(af.X_pending_evaluation_mask, eval_mask) + + # test construct_evaluation_mask + # X_evaluation_mask is None + X = torch.rand(4, 5, 2, device=self.device) + X_eval_mask = af.construct_evaluation_mask(X=X) + expected_eval_mask = torch.cat( + [torch.ones(X.shape[1:], dtype=torch.bool, device=self.device), eval_mask], + dim=0, + ) + self.assertTrue(torch.equal(X_eval_mask, expected_eval_mask)) + # test X_evaluation_mask is not None + # test wrong shape + af.X_evaluation_mask = torch.zeros(1, 2, dtype=bool, device=self.device) + msg = "Expected the -2 dimension of X and X_evaluation_mask to match." + with self.assertRaisesRegex(BotorchTensorDimensionError, msg): + af.construct_evaluation_mask(X=X) + af.X_evaluation_mask = torch.randint(0, 2, (5, 2), device=self.device).bool() + X_eval_mask = af.construct_evaluation_mask(X=X) + expected_eval_mask = torch.cat([af.X_evaluation_mask, eval_mask], dim=0) + self.assertTrue(torch.equal(X_eval_mask, expected_eval_mask)) + + # test setting X_pending as None + af.set_X_pending(X_pending=None, X_pending_evaluation_mask=None) + self.assertIsNone(af.X_pending) + self.assertIsNone(af.X_pending_evaluation_mask) + + # test construct_evaluation_mask when X_pending is None + self.assertTrue( + torch.equal(af.construct_evaluation_mask(X=X), af.X_evaluation_mask) + )