From e6cb26b1638de174964610ec3bc80cdf193db12e Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Thu, 6 Jul 2023 06:56:17 -0700 Subject: [PATCH] prior-guided acquisition function (#1920) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1920 This adds an acquisition function wrapper for prior-guided AFs. Differential Revision: D47248296 fbshipit-source-id: 78f0f8253a25eaf1c6cfb5d9f36bdfc7642e3e82 --- botorch/acquisition/__init__.py | 2 + botorch/acquisition/prior_guided.py | 79 +++++++++++++++++++++++++++ sphinx/source/acquisition.rst | 9 ++- test/acquisition/test_prior_guided.py | 68 +++++++++++++++++++++++ 4 files changed, 156 insertions(+), 2 deletions(-) create mode 100644 botorch/acquisition/prior_guided.py create mode 100644 test/acquisition/test_prior_guided.py diff --git a/botorch/acquisition/__init__.py b/botorch/acquisition/__init__.py index e3184b39f9..7ff09b30c5 100644 --- a/botorch/acquisition/__init__.py +++ b/botorch/acquisition/__init__.py @@ -61,6 +61,7 @@ AnalyticExpectedUtilityOfBestOption, PairwiseBayesianActiveLearningByDisagreement, ) +from botorch.acquisition.prior_guided import PriorGuidedAcquisitionFunction from botorch.acquisition.proximal import ProximalAcquisitionFunction from botorch.acquisition.utils import get_acquisition_function @@ -78,6 +79,7 @@ "PairwiseBayesianActiveLearningByDisagreement", "PairwiseMCPosteriorVariance", "PosteriorMean", + "PriorGuidedAcquisitionFunction", "ProbabilityOfImprovement", "ProximalAcquisitionFunction", "UpperConfidenceBound", diff --git a/botorch/acquisition/prior_guided.py b/botorch/acquisition/prior_guided.py new file mode 100644 index 0000000000..72d3407631 --- /dev/null +++ b/botorch/acquisition/prior_guided.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +""" +Prior-Guided Acquisition Functions + +References + +.. [Hvarfner2022] + C. Hvarfner, D. Stoll, A. Souza, M. Lindauer, F. Hutter, L. Nardi. PiBO: + Augmenting Acquisition Functions with User Beliefs for Bayesian Optimization. + ICLR 2022. +""" +from __future__ import annotations + +from typing import Optional + +from botorch.acquisition.acquisition import AcquisitionFunction +from torch import Tensor + +from torch.nn import Module + + +class PriorGuidedAcquisitionFunction(AcquisitionFunction): + r"""Class for weighting acquisition functions by a prior distribution. + + See [Hvarfner2022]_ for details. + """ + + def __init__( + self, + acq_function: AcquisitionFunction, + prior_module: Module, + log: bool = False, + prior_exponent: float = 1.0, + ) -> None: + r"""Initialize the prior-guided acquisition function. + + Args: + acq_function: The base acquisition function. + prior_module: A Module that computes the probability + (or log probability) for the provided inputs. + log: A boolean that should be true if the acquisition function emits a + log-transformed value and the prior module emits a log probability. + prior_exponent: The exponent applied to the prior. This can be used + for example to decay the effect the prior over time as in + [Hvarfner2022]_. + """ + Module.__init__(self) + self.acq_func = acq_function + self.prior_module = prior_module + self._log = log + self._prior_exponent = prior_exponent + + @property + def X_pending(self): + r"""Return the `X_pending` of the base acquisition function.""" + try: + return self.acq_func.X_pending + except (ValueError, AttributeError): + raise ValueError( + f"Base acquisition function {type(self.acq_func).__name__} " + "does not have an `X_pending` attribute." + ) + + @X_pending.setter + def X_pending(self, X_pending: Optional[Tensor]): + r"""Sets the `X_pending` of the base acquisition function.""" + self.acq_func.X_pending = X_pending + + def forward(self, X: Tensor) -> Tensor: + r"""Compute the acquisition function weighted by the prior.""" + if self._log: + return self.acq_func(X) + self.prior_module(X) * self._prior_exponent + return self.acq_func(X) * self.prior_module(X).pow(self._prior_exponent) diff --git a/sphinx/source/acquisition.rst b/sphinx/source/acquisition.rst index 79f529826a..3aae7e277b 100644 --- a/sphinx/source/acquisition.rst +++ b/sphinx/source/acquisition.rst @@ -65,7 +65,7 @@ Multi-Objective Analytic Acquisition Functions .. automodule:: botorch.acquisition.multi_objective.analytic :members: :exclude-members: MultiObjectiveAnalyticAcquisitionFunction - + Multi-Objective Joint Entropy Search Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.multi_objective.joint_entropy_search @@ -86,7 +86,7 @@ Multi-Objective Multi-Fidelity Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.multi_objective.multi_fidelity :members: - + Multi-Objective Predictive Entropy Search Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.multi_objective.predictive_entropy_search @@ -175,6 +175,11 @@ Penalized Acquisition Function Wrapper .. automodule:: botorch.acquisition.penalized :members: +Prior-Guided Acquisition Function Wrapper +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.acquisition.prior_guided + :members: + Proximal Acquisition Function Wrapper ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.proximal diff --git a/test/acquisition/test_prior_guided.py b/test/acquisition/test_prior_guided.py new file mode 100644 index 0000000000..fe48449d68 --- /dev/null +++ b/test/acquisition/test_prior_guided.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import product + +import torch +from botorch.acquisition.analytic import ExpectedImprovement +from botorch.acquisition.monte_carlo import qExpectedImprovement +from botorch.acquisition.prior_guided import PriorGuidedAcquisitionFunction +from botorch.models import SingleTaskGP +from botorch.utils.testing import BotorchTestCase +from torch.nn import Module + + +class DummyPrior(Module): + def forward(self, X): + p = torch.distributions.Normal(0, 1) + # sum over d and q dimensions + return p.log_prob(X).sum(dim=-1).sum(dim=-1).exp() + + +class TestPriorGuidedAcquisitionFunction(BotorchTestCase): + def test_prior_guided_acquisition_function(self): + prior = DummyPrior() + for dtype in (torch.float, torch.double): + train_X = torch.rand(5, 3, dtype=dtype, device=self.device) + train_Y = train_X.norm(dim=-1, keepdim=True) + model = SingleTaskGP(train_X, train_Y).eval() + qEI = qExpectedImprovement(model, best_f=0.0) + for batch_shape, q, use_log, exponent in product( + ([], [2]), (1, 2), (False, True), (1.0, 2.0) + ): + af = PriorGuidedAcquisitionFunction( + acq_function=qEI, + prior_module=prior, + log=use_log, + prior_exponent=exponent, + ) + test_X = torch.rand(*batch_shape, q, 3, dtype=dtype, device=self.device) + with torch.no_grad(): + val = af(test_X) + prob = prior(test_X) + ei = qEI(test_X) + if use_log: + expected_val = prob * exponent + ei + else: + expected_val = prob.pow(exponent) * ei + self.assertTrue(torch.equal(val, expected_val)) + # test set_X_pending + X_pending = torch.rand(2, 3, dtype=dtype, device=self.device) + af.X_pending = X_pending + self.assertTrue(torch.equal(X_pending, af.acq_func.X_pending)) + self.assertTrue(torch.equal(X_pending, af.X_pending)) + # test exception when base AF does not support X_pending + ei = ExpectedImprovement(model, best_f=0.0) + af = PriorGuidedAcquisitionFunction( + acq_function=ei, + prior_module=prior, + ) + msg = ( + "Base acquisition function ExpectedImprovement " + "does not have an `X_pending` attribute." + ) + with self.assertRaisesRegex(ValueError, msg): + af.X_pending