Skip to content

Commit

Permalink
Introduce a DeterministicSampler (pytorch#1641)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1641

... to sanction its use with `optimize_acqf`.

Reviewed By: Balandat

Differential Revision: D42700043

fbshipit-source-id: 0d85ff5741167801ea11c942dade653b61877289
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jan 24, 2023
1 parent 85bbf98 commit 67c4f40
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 23 deletions.
2 changes: 2 additions & 0 deletions botorch/sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

from botorch.sampling.base import MCSampler
from botorch.sampling.deterministic import DeterministicSampler
from botorch.sampling.get_sampler import get_sampler
from botorch.sampling.list_sampler import ListSampler
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
Expand All @@ -19,6 +20,7 @@


__all__ = [
"DeterministicSampler",
"ForkedRNGSampler",
"get_sampler",
"IIDNormalSampler",
Expand Down
34 changes: 34 additions & 0 deletions botorch/sampling/deterministic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
A dummy sampler for use with deterministic models.
"""

from __future__ import annotations

from botorch.posteriors.deterministic import DeterministicPosterior
from botorch.sampling.stochastic_samplers import StochasticSampler


class DeterministicSampler(StochasticSampler):
r"""A sampler that simply calls `posterior.rsample`, intended to be used with
`DeterministicModel` & `DeterministicPosterior`.
This is effectively signals that `StochasticSampler` is safe to use with
deterministic models since their output is deterministic by definition.
"""

def _update_base_samples(
self, posterior: DeterministicPosterior, base_sampler: DeterministicSampler
) -> None:
r"""This is a no-op since there are no base samples to update.
Args:
posterior: The posterior for which the base samples are constructed.
base_sampler: The base sampler to retrieve the base samples from.
"""
return
4 changes: 2 additions & 2 deletions botorch/sampling/get_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
from botorch.posteriors.torch import TorchPosterior
from botorch.posteriors.transformed import TransformedPosterior
from botorch.sampling.base import MCSampler
from botorch.sampling.deterministic import DeterministicSampler
from botorch.sampling.list_sampler import ListSampler
from botorch.sampling.normal import (
IIDNormalSampler,
NormalMCSampler,
SobolQMCNormalSampler,
)
from botorch.sampling.stochastic_samplers import StochasticSampler
from botorch.utils.dispatcher import Dispatcher
from gpytorch.distributions import MultivariateNormal
from torch.distributions import Distribution
Expand Down Expand Up @@ -112,7 +112,7 @@ def _get_sampler_deterministic(
posterior: DeterministicPosterior, sample_shape: torch.Size, **kwargs: Any
) -> MCSampler:
r"""Get the dummy `StochasticSampler` for the `DeterministicPosterior`."""
return StochasticSampler(sample_shape=sample_shape, **kwargs)
return DeterministicSampler(sample_shape=sample_shape, **kwargs)


@GetSampler.register(object)
Expand Down
16 changes: 0 additions & 16 deletions botorch/sampling/stochastic_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import torch
from botorch.posteriors import Posterior
from botorch.posteriors.deterministic import DeterministicPosterior
from botorch.sampling.base import MCSampler
from torch import Tensor

Expand Down Expand Up @@ -64,18 +63,3 @@ def forward(self, posterior: Posterior) -> Tensor:
The samples drawn from the posterior.
"""
return posterior.rsample(sample_shape=self.sample_shape)

def _update_base_samples(
self, posterior: Posterior, base_sampler: StochasticSampler
) -> None:
r"""Update the sampler to use the original base samples for X_baseline.
This is used in CachedCholeskyAcquisitionFunctions to ensure consistency.
This is a no-op for DeterministicPosterior and errors out otherwise.
Args:
posterior: The posterior for which the base samples are constructed.
base_sampler: The base sampler to retrieve the base samples from.
"""
if not isinstance(posterior, DeterministicPosterior):
super()._update_base_samples(posterior=posterior, base_sampler=base_sampler)
5 changes: 5 additions & 0 deletions sphinx/source/sampling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ Monte-Carlo Sampler API
.. automodule:: botorch.sampling.base
:members:

Deterministic Sampler
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.sampling.deterministic
:members:

Get Sampler Helper
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.sampling.get_sampler
Expand Down
24 changes: 24 additions & 0 deletions test/sampling/test_deterministic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from botorch.sampling.deterministic import DeterministicSampler
from botorch.utils.testing import BotorchTestCase, MockPosterior


class TestDeterministicSampler(BotorchTestCase):
def test_deterministic_sampler(self):
# Basic usage.
samples = torch.rand(1, 2)
posterior = MockPosterior(samples=samples)
sampler = DeterministicSampler(sample_shape=torch.Size([2]))
self.assertTrue(torch.equal(samples.repeat(2, 1, 1), sampler(posterior)))

# Test _update_base_samples.
sampler._update_base_samples(
posterior=posterior,
base_sampler=sampler,
)
5 changes: 0 additions & 5 deletions test/sampling/test_stochastic_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from unittest import mock

import torch
from botorch.posteriors.deterministic import DeterministicPosterior
from botorch.posteriors.torch import TorchPosterior
from botorch.sampling.stochastic_samplers import ForkedRNGSampler, StochasticSampler
from botorch.utils.testing import BotorchTestCase, MockPosterior
Expand Down Expand Up @@ -40,7 +39,3 @@ def test_stochastic_sampler(self):
# Test _update_base_samples.
with self.assertRaisesRegex(NotImplementedError, "_update_base_samples"):
sampler._update_base_samples(posterior=posterior, base_sampler=sampler)
sampler._update_base_samples(
posterior=DeterministicPosterior(values=torch.rand(1, 2)),
base_sampler=sampler,
)

0 comments on commit 67c4f40

Please sign in to comment.