From 67c4f40e0994b87ce1fec14d1ffecd1e2ee6650c Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Mon, 23 Jan 2023 21:33:11 -0800 Subject: [PATCH] Introduce a `DeterministicSampler` (#1641) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1641 ... to sanction its use with `optimize_acqf`. Reviewed By: Balandat Differential Revision: D42700043 fbshipit-source-id: 0d85ff5741167801ea11c942dade653b61877289 --- botorch/sampling/__init__.py | 2 ++ botorch/sampling/deterministic.py | 34 +++++++++++++++++++++++ botorch/sampling/get_sampler.py | 4 +-- botorch/sampling/stochastic_samplers.py | 16 ----------- sphinx/source/sampling.rst | 5 ++++ test/sampling/test_deterministic.py | 24 ++++++++++++++++ test/sampling/test_stochastic_samplers.py | 5 ---- 7 files changed, 67 insertions(+), 23 deletions(-) create mode 100644 botorch/sampling/deterministic.py create mode 100644 test/sampling/test_deterministic.py diff --git a/botorch/sampling/__init__.py b/botorch/sampling/__init__.py index 9b6c758283..0c57f783ad 100644 --- a/botorch/sampling/__init__.py +++ b/botorch/sampling/__init__.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from botorch.sampling.base import MCSampler +from botorch.sampling.deterministic import DeterministicSampler from botorch.sampling.get_sampler import get_sampler from botorch.sampling.list_sampler import ListSampler from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler @@ -19,6 +20,7 @@ __all__ = [ + "DeterministicSampler", "ForkedRNGSampler", "get_sampler", "IIDNormalSampler", diff --git a/botorch/sampling/deterministic.py b/botorch/sampling/deterministic.py new file mode 100644 index 0000000000..39c59f0932 --- /dev/null +++ b/botorch/sampling/deterministic.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +A dummy sampler for use with deterministic models. +""" + +from __future__ import annotations + +from botorch.posteriors.deterministic import DeterministicPosterior +from botorch.sampling.stochastic_samplers import StochasticSampler + + +class DeterministicSampler(StochasticSampler): + r"""A sampler that simply calls `posterior.rsample`, intended to be used with + `DeterministicModel` & `DeterministicPosterior`. + + This is effectively signals that `StochasticSampler` is safe to use with + deterministic models since their output is deterministic by definition. + """ + + def _update_base_samples( + self, posterior: DeterministicPosterior, base_sampler: DeterministicSampler + ) -> None: + r"""This is a no-op since there are no base samples to update. + + Args: + posterior: The posterior for which the base samples are constructed. + base_sampler: The base sampler to retrieve the base samples from. + """ + return diff --git a/botorch/sampling/get_sampler.py b/botorch/sampling/get_sampler.py index 17da92c647..447d91a75b 100644 --- a/botorch/sampling/get_sampler.py +++ b/botorch/sampling/get_sampler.py @@ -16,13 +16,13 @@ from botorch.posteriors.torch import TorchPosterior from botorch.posteriors.transformed import TransformedPosterior from botorch.sampling.base import MCSampler +from botorch.sampling.deterministic import DeterministicSampler from botorch.sampling.list_sampler import ListSampler from botorch.sampling.normal import ( IIDNormalSampler, NormalMCSampler, SobolQMCNormalSampler, ) -from botorch.sampling.stochastic_samplers import StochasticSampler from botorch.utils.dispatcher import Dispatcher from gpytorch.distributions import MultivariateNormal from torch.distributions import Distribution @@ -112,7 +112,7 @@ def _get_sampler_deterministic( posterior: DeterministicPosterior, sample_shape: torch.Size, **kwargs: Any ) -> MCSampler: r"""Get the dummy `StochasticSampler` for the `DeterministicPosterior`.""" - return StochasticSampler(sample_shape=sample_shape, **kwargs) + return DeterministicSampler(sample_shape=sample_shape, **kwargs) @GetSampler.register(object) diff --git a/botorch/sampling/stochastic_samplers.py b/botorch/sampling/stochastic_samplers.py index 51626df7c5..ab5d43d470 100644 --- a/botorch/sampling/stochastic_samplers.py +++ b/botorch/sampling/stochastic_samplers.py @@ -13,7 +13,6 @@ import torch from botorch.posteriors import Posterior -from botorch.posteriors.deterministic import DeterministicPosterior from botorch.sampling.base import MCSampler from torch import Tensor @@ -64,18 +63,3 @@ def forward(self, posterior: Posterior) -> Tensor: The samples drawn from the posterior. """ return posterior.rsample(sample_shape=self.sample_shape) - - def _update_base_samples( - self, posterior: Posterior, base_sampler: StochasticSampler - ) -> None: - r"""Update the sampler to use the original base samples for X_baseline. - - This is used in CachedCholeskyAcquisitionFunctions to ensure consistency. - This is a no-op for DeterministicPosterior and errors out otherwise. - - Args: - posterior: The posterior for which the base samples are constructed. - base_sampler: The base sampler to retrieve the base samples from. - """ - if not isinstance(posterior, DeterministicPosterior): - super()._update_base_samples(posterior=posterior, base_sampler=base_sampler) diff --git a/sphinx/source/sampling.rst b/sphinx/source/sampling.rst index b3fd4943dd..d58685f8e2 100644 --- a/sphinx/source/sampling.rst +++ b/sphinx/source/sampling.rst @@ -12,6 +12,11 @@ Monte-Carlo Sampler API .. automodule:: botorch.sampling.base :members: +Deterministic Sampler +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.sampling.deterministic + :members: + Get Sampler Helper ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.sampling.get_sampler diff --git a/test/sampling/test_deterministic.py b/test/sampling/test_deterministic.py new file mode 100644 index 0000000000..4dedef1605 --- /dev/null +++ b/test/sampling/test_deterministic.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from botorch.sampling.deterministic import DeterministicSampler +from botorch.utils.testing import BotorchTestCase, MockPosterior + + +class TestDeterministicSampler(BotorchTestCase): + def test_deterministic_sampler(self): + # Basic usage. + samples = torch.rand(1, 2) + posterior = MockPosterior(samples=samples) + sampler = DeterministicSampler(sample_shape=torch.Size([2])) + self.assertTrue(torch.equal(samples.repeat(2, 1, 1), sampler(posterior))) + + # Test _update_base_samples. + sampler._update_base_samples( + posterior=posterior, + base_sampler=sampler, + ) diff --git a/test/sampling/test_stochastic_samplers.py b/test/sampling/test_stochastic_samplers.py index 4437a9a2b0..991c637459 100644 --- a/test/sampling/test_stochastic_samplers.py +++ b/test/sampling/test_stochastic_samplers.py @@ -7,7 +7,6 @@ from unittest import mock import torch -from botorch.posteriors.deterministic import DeterministicPosterior from botorch.posteriors.torch import TorchPosterior from botorch.sampling.stochastic_samplers import ForkedRNGSampler, StochasticSampler from botorch.utils.testing import BotorchTestCase, MockPosterior @@ -40,7 +39,3 @@ def test_stochastic_sampler(self): # Test _update_base_samples. with self.assertRaisesRegex(NotImplementedError, "_update_base_samples"): sampler._update_base_samples(posterior=posterior, base_sampler=sampler) - sampler._update_base_samples( - posterior=DeterministicPosterior(values=torch.rand(1, 2)), - base_sampler=sampler, - )