forked from pytorch/botorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce a
DeterministicSampler
(pytorch#1641)
Summary: Pull Request resolved: pytorch#1641 ... to sanction its use with `optimize_acqf`. Reviewed By: Balandat Differential Revision: D42700043 fbshipit-source-id: 0d85ff5741167801ea11c942dade653b61877289
- Loading branch information
1 parent
85bbf98
commit 67c4f40
Showing
7 changed files
with
67 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
r""" | ||
A dummy sampler for use with deterministic models. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from botorch.posteriors.deterministic import DeterministicPosterior | ||
from botorch.sampling.stochastic_samplers import StochasticSampler | ||
|
||
|
||
class DeterministicSampler(StochasticSampler): | ||
r"""A sampler that simply calls `posterior.rsample`, intended to be used with | ||
`DeterministicModel` & `DeterministicPosterior`. | ||
This is effectively signals that `StochasticSampler` is safe to use with | ||
deterministic models since their output is deterministic by definition. | ||
""" | ||
|
||
def _update_base_samples( | ||
self, posterior: DeterministicPosterior, base_sampler: DeterministicSampler | ||
) -> None: | ||
r"""This is a no-op since there are no base samples to update. | ||
Args: | ||
posterior: The posterior for which the base samples are constructed. | ||
base_sampler: The base sampler to retrieve the base samples from. | ||
""" | ||
return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from botorch.sampling.deterministic import DeterministicSampler | ||
from botorch.utils.testing import BotorchTestCase, MockPosterior | ||
|
||
|
||
class TestDeterministicSampler(BotorchTestCase): | ||
def test_deterministic_sampler(self): | ||
# Basic usage. | ||
samples = torch.rand(1, 2) | ||
posterior = MockPosterior(samples=samples) | ||
sampler = DeterministicSampler(sample_shape=torch.Size([2])) | ||
self.assertTrue(torch.equal(samples.repeat(2, 1, 1), sampler(posterior))) | ||
|
||
# Test _update_base_samples. | ||
sampler._update_base_samples( | ||
posterior=posterior, | ||
base_sampler=sampler, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters