From 1d04d311fa7798c82575ec6cab409aa5979a7899 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Tue, 4 Jun 2024 17:45:00 -0700 Subject: [PATCH] Update polytope sampling code and add thinning capability (#2358) Summary: This set of changes does the following: * adds an `n_thinning` argument to `sample_polytope` and `HitAndRunPolytopeSampler`; changes the defaults for `HitAndRunPolytopeSampler` args to `n_burnin=200` and `n_thinning=20` * Changes `HitAndRunPolytopeSampler` to take the `seed` arg in its constructor, and removes the arg from the `draw()` method (the method on the base class is adjusted accordingly). The resulting behavior is that if a `HitAndRunPolytopeSampler` is instantiated with the same args and seed, then the sequence of `draw()`s will be deterministic. `DelaunayPolytopeSampler` is stateless, and so retains its existing behavior. * normalizes the (inequality and equality) constraints in `HitAndRunPolytopeSampler` to avoid the same issue as https://github.com/pytorch/botorch/issues/1225. If `bounds` are note provided, emits a warning that this cannot be performed (doing this would require vertex enumeration of the constraint polytope, which is NP-hard and too costly). * introduces `normalize_dense_linear_constraints` to normalize constraint given in dense format to the unit cube * removes `normalize_linear_constraint`; `normalize_sparse_linear_constraints` is to be used instead * simplifies some of the testing code Note: This change is in preparation for fixing https://github.com/facebook/Ax/issues/2373 Test Plan: Ran a stress test to make sure this doesn't cause flaky tests: https://www.internalfb.com/intern/testinfra/testconsole/testrun/3940649908470083/ Differential Revision: D58068753 Pulled By: Balandat --- botorch/optim/initializers.py | 12 +- botorch/utils/sampling.py | 197 +++++++++++++++++++++++--------- test/optim/test_initializers.py | 6 +- test/utils/test_sampling.py | 160 ++++++++++++++------------ 4 files changed, 240 insertions(+), 135 deletions(-) diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index fbb3b0dcd8..44c152baf5 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -180,7 +180,7 @@ def sample_q_batches_from_polytope( q: int, bounds: Tensor, n_burnin: int, - thinning: int, + n_thinning: int, seed: int, inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, @@ -192,8 +192,8 @@ def sample_q_batches_from_polytope( q: Number of samples per q-batch bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`. n_burnin: The number of burn-in samples for the Markov chain sampler. - thinning: The amount of thinning (number of steps to take between - returning samples). + n_thinning: The amount of thinning. The sampler will return every + `n_thinning` sample (after burn-in). seed: The random seed. inequality_constraints: A list of tuples (indices, coefficients, rhs), with each tuple encoding an inequality constraint of the form @@ -225,7 +225,7 @@ def sample_q_batches_from_polytope( ), seed=seed, n_burnin=n_burnin, - thinning=thinning * q, + n_thinning=n_thinning * q, ) else: samples = get_polytope_samples( @@ -235,7 +235,7 @@ def sample_q_batches_from_polytope( equality_constraints=equality_constraints, seed=seed, n_burnin=n_burnin, - thinning=thinning, + n_thinning=n_thinning, ) return samples.view(n, q, -1).cpu() @@ -367,7 +367,7 @@ def gen_batch_initial_conditions( q=q, bounds=bounds, n_burnin=options.get("n_burnin", 10000), - thinning=options.get("thinning", 32), + n_thinning=options.get("n_thinning", 32), seed=seed, equality_constraints=equality_constraints, inequality_constraints=inequality_constraints, diff --git a/botorch/utils/sampling.py b/botorch/utils/sampling.py index c29ab2ddaa..9936269906 100644 --- a/botorch/utils/sampling.py +++ b/botorch/utils/sampling.py @@ -16,14 +16,17 @@ from __future__ import annotations +import warnings + from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, Generator, Iterable, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, Generator, Iterable, List, Optional, Tuple, TYPE_CHECKING, Union import numpy as np import scipy import torch from botorch.exceptions.errors import BotorchError +from botorch.exceptions.warnings import UserInputWarning from botorch.sampling.qmc import NormalQMCEngine from botorch.utils.transforms import unnormalize from scipy.spatial import Delaunay, HalfspaceIntersection @@ -63,7 +66,7 @@ def draw_sobol_samples( bounds: Tensor, n: int, q: int, - batch_shape: Optional[Iterable[int], torch.Size] = None, + batch_shape: Optional[Union[Iterable[int], torch.Size]] = None, seed: Optional[int] = None, ) -> Tensor: r"""Draw qMC samples from the box defined by bounds. @@ -218,6 +221,7 @@ def sample_polytope( x0: Tensor, n: int = 10000, n0: int = 100, + n_thinning: int = 1, seed: Optional[int] = None, ) -> Tensor: r""" @@ -225,15 +229,17 @@ def sample_polytope( described via inequality constraints A*x<=b. Args: - A: A Tensor describing inequality constraints - so that all samples satisfy Ax<=b. - b: A Tensor describing the inequality constraints - so that all samples satisfy Ax<=b. + A: A `m x d`-dim Tensor describing inequality constraints + so that all samples satisfy `Ax <= b`. + b: A `m`-dim Tensor describing the inequality constraints + so that all samples satisfy `Ax <= b`. x0: A `d`-dim Tensor representing a starting point of the chain satisfying the constraints. n: The number of resulting samples kept in the output. n0: The number of burn-in samples. The chain will produce n+n0 samples but the first n0 samples are not saved. + n_thinning: The amount of thinnning. This function will return every + `n_thinning`-th sample from the chain (after burn-in). seed: The seed for the sampler. If omitted, use a random seed. Returns: @@ -252,7 +258,7 @@ def sample_polytope( A = A[non_zero_rows] b = b[non_zero_rows] - n_tot = n + n0 + n_tot = n0 + n * n_thinning seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item() with manual_seed(seed=seed): rands = torch.rand(n_tot, dtype=A.dtype, device=A.device) @@ -292,11 +298,13 @@ def sample_polytope( # If ar < 0 at the boundary, alpha >= 0. if w_eq_0.logical_and(ar < 0).any(): alpha_min = max(alpha_min, 0.0) - # alpha~Unif[alpha_min, alpha_max] + # alpha ~ Uniform[alpha_min, alpha_max] alpha = alpha_min + rnd * (alpha_max - alpha_min) x = x + alpha * r - if i >= n0: # save samples after burn-in period - out[i - n0] = x.squeeze() + if (k := i - n0) >= 0: # save samples after burn-in period + idx, rem = divmod(k, n_thinning) + if rem == 0: + out[idx] = x.squeeze() return out @@ -561,12 +569,11 @@ def find_interior_point(self) -> Tensor: # -------- Abstract methods to be implemented by subclasses -------- # @abstractmethod - def draw(self, n: int = 1, seed: Optional[int] = None) -> Tensor: + def draw(self, n: int = 1) -> Tensor: r"""Draw samples from the polytope. Args: n: The number of samples. - seed: The random seed. Returns: A `n x d` Tensor of samples from the polytope. @@ -583,7 +590,9 @@ def __init__( equality_constraints: Optional[Tuple[Tensor, Tensor]] = None, bounds: Optional[Tensor] = None, interior_point: Optional[Tensor] = None, - n_burnin: int = 0, + n_burnin: int = 200, + n_thinning: int = 20, + seed: Optional[int] = None, ) -> None: r"""A sampler for sampling from a polyope using a hit-and-run algorithm. @@ -596,46 +605,105 @@ def __init__( `C @ x = d`, where `C` is a `n_eq_con x d`-dim Tensor and `d` is a `n_eq_con x 1`-dim Tensor with `n_eq_con` the number of equalities. bounds: A `2 x d`-dim tensor of box bounds, where `inf` (`-inf`) means - that the respective dimension is unbounded from above (below). + that the respective dimension is unbounded from above (below). If + omitted, no bounds (in addition to the above constraints) are applied. interior_point: A `d x 1`-dim Tensor representing a point in the (relative) interior of the polytope. If omitted, determined automatically by solving a Linear Program. - n_burnin: The number of burn in samples. + n_burnin: The number of burn in samples. The sampler will discard + `n_burnin` samples before returning the first sample. + n_thinning: The amount of thinning. The sampler will return every + `n_thinning` sample (after burn-in). This may need to be increased + for sets of constraints that are difficult to satisfy (i.e. in which + case the volume of the constraint polytope is small relative to that + of its bounding box). + seed: The random seed. """ + if inequality_constraints is None and bounds is None: + raise BotorchError( + "HitAndRunPolytopeSampler requires either inequality constraints " + "or bounds." + ) + # Normalize constraints to avoid the following issue: + # https://github.com/pytorch/botorch/issues/1225 + offset, scale = None, None + if inequality_constraints or equality_constraints: + if bounds is None: + warnings.warn( + "HitAndRunPolytopeSampler did not receive `bounds`, which can " + "lead to non-uniform sampling if the parameter ranges are very " + "different (see https://github.com/pytorch/botorch/issues/1225).", + UserInputWarning, + stacklevel=3, + ) + else: + if inequality_constraints: + inequality_constraints = normalize_dense_linear_constraints( + bounds=bounds, constraints=inequality_constraints + ) + if equality_constraints: + equality_constraints = normalize_dense_linear_constraints( + bounds=bounds, constraints=equality_constraints + ) + lower, upper = bounds + offset = lower + scale = upper - lower + if interior_point is not None: + # If provided, we also need to normalize the interior point + interior_point = (interior_point - offset[:, None]) / scale[:, None] + bounds = torch.zeros_like(bounds) + bounds[1, :] = 1.0 + super().__init__( inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, bounds=bounds, interior_point=interior_point, ) - self.n_burnin = n_burnin - - def draw(self, n: int = 1, seed: Optional[int] = None) -> Tensor: + self.n_burnin: int = n_burnin + self.n_thinning: int = n_thinning + self.num_samples_generated: int = 0 + self._seed: Optional[int] = seed + self._offset: Optional[Tensor] = offset + self._scale: Optional[Tensor] = scale + + def draw(self, n: int = 1) -> Tensor: r"""Draw samples from the polytope. Args: n: The number of samples. - seed: The random seed. Returns: A `n x d` Tensor of samples from the polytope. """ + # There are two layers of normalization. In the outer layer, the space + # has been normalized to the unit cube. In the inner layer, we remove + # any equality constraints and sample on the subspace defined by those + # equality constraints, with an additional shift to normalize the interior + # point to the origin. Below, after sampling in that inner layer, we have + # to reverse both layers of normalization. transformed_samples = sample_polytope( - # run this on the cpu + # Run this on the cpu since there is a lot of looping going on A=self.new_A.cpu(), b=(self.b - self.A @ self.x0).cpu(), x0=torch.zeros((self.nullC.size(1), 1), dtype=self.A.dtype), n=n, - n0=self.n_burnin, - seed=seed, + n0=self.n_burnin if self.num_samples_generated == 0 else 0, + n_thinning=self.n_thinning, + seed=self._seed, ).to(self.b) + # Update the seed for the next call in a deterministic fashion + if self._seed is not None: + self._seed += n + # Unnormalize the inner layer init_shift = self.x0.transpose(-1, -2) samples = init_shift + transformed_samples @ self.nullC.transpose(-1, -2) - # keep the last element of the resulting chain as - # the beginning of the next chain + # Keep the last element as the beginning of the next chain self.x0 = samples[-1].reshape(-1, 1) - # reset counter so there is no burn-in for subsequent samples - self.n_burnin = 0 + # Unnormalize the outer layer + if self._scale is not None: + samples = self._offset + self._scale * samples + self.num_samples_generated += n return samples @@ -760,10 +828,10 @@ def draw(self, n: int = 1, seed: Optional[int] = None) -> Tensor: return samples -def normalize_linear_constraints( +def normalize_sparse_linear_constraints( bounds: Tensor, constraints: List[Tuple[Tensor, Tensor, float]] ) -> List[Tuple[Tensor, Tensor, float]]: - r"""Normalize linear constraints to the unit cube. + r"""Normalize sparse linear constraints to the unit cube. Args: bounds (Tensor): A `2 x d`-dim tensor containing the box bounds. @@ -773,7 +841,6 @@ def normalize_linear_constraints( `\sum_i (X[indices[i]] * coefficients[i]) >= rhs` or `\sum_i (X[indices[i]] * coefficients[i]) = rhs`. """ - new_constraints = [] for index, coefficient, rhs in constraints: lower, upper = bounds[:, index] @@ -784,21 +851,51 @@ def normalize_linear_constraints( return new_constraints +def normalize_dense_linear_constraints( + bounds: Tensor, + constraints: Tuple[Tensor, Tensor], +) -> Tuple[Tensor, Tensor]: + r"""Normalize dense linear constraints to the unit cube. + + Args: + bounds: A `2 x d`-dim tensor containing the box bounds. + constraints: A tensor tuple `(A, b)` describing constraints + `A @ x (<)= b`, where `A` is a `n_con x d`-dim Tensor and + `b` is a `n_con x 1`-dim Tensor, with `n_con` the number of + constraints and `d` the dimension of the sample space. + + Returns: + A tensor tuple `(A_nlz, b_nlz)` of normalized constraints. + """ + lower, upper = bounds + A, b = constraints + A_nlz = (upper - lower) * A + b_nlz = b - (A @ lower).unsqueeze(-1) + return A_nlz, b_nlz + + def get_polytope_samples( n: int, bounds: Tensor, inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, seed: Optional[int] = None, - thinning: int = 32, n_burnin: int = 10_000, + n_thinning: int = 32, ) -> Tensor: r"""Sample from polytope defined by box bounds and (in)equality constraints. This uses a hit-and-run Markov chain sampler. - TODO: make this method return the sampler object, to avoid doing burn-in - every time we draw samples. + NOTE: Much of the functionality of this method has been moved into + `HitAndRunPolytopeSampler`. If you want to repeatedly draw samples, you should + use `HitAndRunPolytopeSampler` directly in order to avoid repeatedly running + a burn-in of the chain. To do so, you need to convert the sparse constraint + format that `get_polytope_samples` expects to the dense constraint format that + `HitAndRunPolytopeSampler` expects. This can be done via the + `sparse_to_dense_constraints` method (but remember to adjust the constraint + from the `Ax >= b` format expecxted here to the `Ax <= b` format expected by + `PolytopeSampler` by multiplying both `A` and `b` by -1.) Args: n: The number of samples. @@ -810,46 +907,34 @@ def get_polytope_samples( with each tuple encoding an inequality constraint of the form `\sum_i (X[indices[i]] * coefficients[i]) = rhs`. seed: The random seed. - thinning: The amount of thinning. n_burnin: The number of burn-in samples for the Markov chain sampler. + n_thinning: The amount of thinnning. This function will return every + `n_thinning`-th sample from the chain (after burn-in). Returns: A `n x d`-dim tensor of samples. """ - # create tensors representing linear inequality constraints - # of the form Ax >= b. if inequality_constraints: - # normalize_linear_constraints is called to solve this issue: - # https://github.com/pytorch/botorch/issues/1225 - constraints = normalize_linear_constraints(bounds, inequality_constraints) - A, b = sparse_to_dense_constraints( d=bounds.shape[-1], - constraints=constraints, + constraints=inequality_constraints, ) - # Note the inequality constraints are of the form Ax >= b, + # Note that the inequality constraints are of the form Ax >= b, # but PolytopeSampler expects inequality constraints of the # form Ax <= b, so we multiply by -1 below. - dense_inequality_constraints = -A, -b - else: - dense_inequality_constraints = None + inequality_constraints = -A, -b if equality_constraints: - constraints = normalize_linear_constraints(bounds, equality_constraints) - dense_equality_constraints = sparse_to_dense_constraints( - d=bounds.shape[-1], constraints=constraints + equality_constraints = sparse_to_dense_constraints( + d=bounds.shape[-1], constraints=equality_constraints ) - else: - dense_equality_constraints = None - normalized_bounds = torch.zeros_like(bounds) - normalized_bounds[1, :] = 1.0 polytope_sampler = HitAndRunPolytopeSampler( - bounds=normalized_bounds, - inequality_constraints=dense_inequality_constraints, - equality_constraints=dense_equality_constraints, + bounds=bounds, + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, n_burnin=n_burnin, + n_thinning=n_thinning, ) - samples = polytope_sampler.draw(n=n * thinning, seed=seed)[::thinning] - return bounds[0] + samples * (bounds[1] - bounds[0]) + return polytope_sampler.draw(n=n, seed=seed) def sparse_to_dense_constraints( diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 0f2c1bdfdb..a0c84eb422 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -495,7 +495,7 @@ def test_gen_batch_initial_conditions_sample_q_batches_from_polytope(self): q=q, bounds=bounds, n_burnin=10000, - thinning=32, + n_thinning=32, seed=42, inequality_constraints=inequalities, equality_constraints=equalities, @@ -571,8 +571,8 @@ def test_gen_batch_initial_conditions_constraints(self): "alpha": 0.1, "seed": seed, "init_batch_limit": init_batch_limit, - "thinning": 2, "n_burnin": 3, + "n_thinning": 2, }, inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, @@ -669,8 +669,8 @@ def test_gen_batch_initial_conditions_interpoint_constraints(self): "alpha": 0.1, "seed": seed, "init_batch_limit": None, - "thinning": 2, "n_burnin": 3, + "n_thinning": 2, }, inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, diff --git a/test/utils/test_sampling.py b/test/utils/test_sampling.py index 534978c4d6..8af3398f13 100644 --- a/test/utils/test_sampling.py +++ b/test/utils/test_sampling.py @@ -15,6 +15,7 @@ import numpy as np import torch from botorch.exceptions.errors import BotorchError +from botorch.exceptions.warnings import UserInputWarning from botorch.models.gp_regression import SingleTaskGP from botorch.sampling.pathwise import draw_matheron_paths from botorch.utils.sampling import ( @@ -26,7 +27,7 @@ get_polytope_samples, HitAndRunPolytopeSampler, manual_seed, - normalize_linear_constraints, + normalize_sparse_linear_constraints, optimize_posterior_samples, PolytopeSampler, sample_hypersphere, @@ -37,6 +38,25 @@ from botorch.utils.testing import BotorchTestCase +def _get_constraints(device: torch.device, dtype: torch.dtype): + bounds = torch.zeros(2, 3, device=device, dtype=dtype) + bounds[1] = 1.0 + A = torch.tensor( + [ + [-1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, -1.0, 0.0], + [0.0, 0.0, -1.0], + [0.0, 4.0, 1.0], + ], + device=device, + dtype=dtype, + ) + b = torch.tensor([[0.0], [1.0], [0.0], [0.0], [1.0]], device=device, dtype=dtype) + x0 = torch.tensor([[0.1], [0.1], [0.1]], device=device, dtype=dtype) + return bounds, A, b, x0 + + class TestManualSeed(BotorchTestCase): def test_manual_seed(self): initial_state = torch.random.get_rng_state() @@ -165,7 +185,7 @@ def test_sparse_to_dense_constraints(self): expected_b = torch.tensor([[3.0]], **tkwargs) self.assertTrue(torch.equal(b, expected_b)) - def test_normalize_linear_constraints(self): + def test_normalize_sparse_linear_constraints(self): tkwargs = {"device": self.device} for dtype in (torch.float, torch.double): tkwargs["dtype"] = dtype @@ -179,7 +199,7 @@ def test_normalize_linear_constraints(self): bounds = torch.tensor( [[0.1, 0.3, 0.1, 30.0], [0.6, 0.7, 0.7, 700.0]], **tkwargs ) - new_constraints = normalize_linear_constraints(bounds, constraints) + new_constraints = normalize_sparse_linear_constraints(bounds, constraints) expected_coefficients = torch.tensor([0.4000, 0.6000, 0.5000], **tkwargs) self.assertTrue( torch.allclose(new_constraints[0][1], expected_coefficients) @@ -187,7 +207,7 @@ def test_normalize_linear_constraints(self): expected_rhs = 0.5 self.assertAlmostEqual(new_constraints[0][-1], expected_rhs) - def test_normalize_linear_constraints_wrong_dtype(self): + def test_normalize_sparse_linear_constraints_wrong_dtype(self): for dtype in (torch.float, torch.double): with self.subTest(dtype=dtype): tkwargs = {"device": self.device, "dtype": dtype} @@ -201,7 +221,7 @@ def test_normalize_linear_constraints_wrong_dtype(self): bounds = torch.zeros(2, 4, **tkwargs) msg = "tensors used as indices must be long, byte or bool tensors" with self.assertRaises(IndexError, msg=msg): - normalize_linear_constraints(bounds, constraints) + normalize_sparse_linear_constraints(bounds, constraints) def test_find_interior_point(self): # basic problem: 1 <= x_1 <= 2, 2 <= x_2 <= 3 @@ -254,8 +274,8 @@ def test_get_polytope_samples(self): inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, seed=0, - thinning=3, n_burnin=2, + n_thinning=3, ) (A, b) = sparse_to_dense_constraints( d=4, constraints=inequality_constraints @@ -267,7 +287,8 @@ def test_get_polytope_samples(self): inequality_constraints=dense_inequality_constraints, equality_constraints=dense_equality_constraints, n_burnin=2, - ).draw(15, seed=0)[::3] + n_thinning=3, + ).draw(5, seed=0) self.assertTrue(torch.equal(samps, expected_samps)) # test no equality constraints @@ -277,15 +298,16 @@ def test_get_polytope_samples(self): bounds=bounds, inequality_constraints=inequality_constraints, seed=0, - thinning=3, n_burnin=2, + n_thinning=3, ) with manual_seed(0): expected_samps = HitAndRunPolytopeSampler( bounds=bounds, inequality_constraints=dense_inequality_constraints, n_burnin=2, - ).draw(15, seed=0)[::3] + n_thinning=3, + ).draw(5, seed=0) self.assertTrue(torch.equal(samps, expected_samps)) # test no inequality constraints @@ -295,15 +317,16 @@ def test_get_polytope_samples(self): bounds=bounds, equality_constraints=equality_constraints, seed=0, - thinning=3, n_burnin=2, + n_thinning=3, ) with manual_seed(0): expected_samps = HitAndRunPolytopeSampler( bounds=bounds, equality_constraints=dense_equality_constraints, n_burnin=2, - ).draw(15, seed=0)[::3] + n_thinning=3, + ).draw(5, seed=0) self.assertTrue(torch.equal(samps, expected_samps)) def test_sample_polytope_infeasible(self) -> None: @@ -338,80 +361,63 @@ def test_sample_polytope_boundary(self) -> None: class PolytopeSamplerTestBase(ABC): sampler_class: Type[PolytopeSampler] sampler_kwargs: Dict[str, Any] = {} - - def setUp(self): - super().setUp() - self.bounds = torch.zeros(2, 3, device=self.device) - self.bounds[1] = 1 - self.A = torch.tensor( - [ - [-1.0, 0.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, -1.0, 0.0], - [0.0, 0.0, -1.0], - [0.0, 4.0, 1.0], - ], - device=self.device, - ) - self.b = torch.tensor([[0.0], [1.0], [0.0], [0.0], [1.0]], device=self.device) - self.x0 = torch.tensor([0.1, 0.1, 0.1], device=self.device).unsqueeze(-1) + constructor_seed_kwarg: Dict[str, int] = {} + draw_seed_kwarg: Dict[str, int] = {} def test_sample_polytope(self): for dtype in (torch.float, torch.double): - A = self.A.to(dtype) - b = self.b.to(dtype) - x0 = self.x0.to(dtype) - bounds = self.bounds.to(dtype) - for interior_point in [x0, None]: + bounds, A, b, x0 = _get_constraints(device=self.device, dtype=dtype) + for interior_point in (x0, None): sampler = self.sampler_class( inequality_constraints=(A, b), bounds=bounds, interior_point=interior_point, **self.sampler_kwargs, ) - samples = sampler.draw(n=10, seed=1) - self.assertEqual(((A @ samples.t() - b) > 0).sum().item(), 0) + samples = sampler.draw(n=10) + self.assertTrue(torch.all(A @ samples.t() - b <= 0).item()) self.assertTrue((samples <= bounds[1]).all()) self.assertTrue((samples >= bounds[0]).all()) - # make sure we can draw mulitple samples + # make sure we can draw multiple samples more_samples = sampler.draw(n=5) - self.assertEqual(((A @ more_samples.t() - b) > 0).sum().item(), 0) + self.assertTrue(torch.all(A @ more_samples.t() - b <= 0).item()) self.assertTrue((more_samples <= bounds[1]).all()) self.assertTrue((more_samples >= bounds[0]).all()) + # the samples should all be unique + all_samples = torch.cat([samples, more_samples], dim=0) + self.assertEqual( + len(all_samples), len(torch.unique(all_samples, dim=0)) + ) def test_sample_polytope_with_seed(self): for dtype in (torch.float, torch.double): - A = self.A.to(dtype) - b = self.b.to(dtype) - x0 = self.x0.to(dtype) - bounds = self.bounds.to(dtype) - for interior_point in [x0, None]: + bounds, A, b, x0 = _get_constraints(device=self.device, dtype=dtype) + for interior_point in (x0, None): sampler1 = self.sampler_class( inequality_constraints=(A, b), bounds=bounds, interior_point=interior_point, **self.sampler_kwargs, + **self.constructor_seed_kwarg, ) sampler2 = self.sampler_class( inequality_constraints=(A, b), bounds=bounds, interior_point=interior_point, **self.sampler_kwargs, + **self.constructor_seed_kwarg, ) - samples1 = sampler1.draw(n=10, seed=42) - samples2 = sampler2.draw(n=10, seed=42) + samples1 = sampler1.draw(n=10, **self.draw_seed_kwarg) + samples2 = sampler2.draw(n=10, **self.draw_seed_kwarg) self.assertTrue(torch.allclose(samples1, samples2)) def test_sample_polytope_with_eq_constraints(self): for dtype in (torch.float, torch.double): - A = self.A.to(dtype) - b = self.b.to(dtype) - x0 = self.x0.to(dtype) - bounds = self.bounds.to(dtype) + bounds, A, b, x0 = _get_constraints(device=self.device, dtype=dtype) C = torch.tensor([[1.0, -1, 0.0]], device=self.device, dtype=dtype) d = torch.zeros(1, 1, device=self.device, dtype=dtype) - for interior_point in [x0, None]: + for interior_point in (x0, None): sampler = self.sampler_class( inequality_constraints=(A, b), equality_constraints=(C, d), @@ -419,11 +425,9 @@ def test_sample_polytope_with_eq_constraints(self): interior_point=interior_point, **self.sampler_kwargs, ) - samples = sampler.draw(n=10, seed=1) - inequality_satisfied = ((A @ samples.t() - b) > 0).sum().item() == 0 - equality_satisfied = (C @ samples.t() - d).abs().sum().item() < 1e-6 - self.assertTrue(inequality_satisfied) - self.assertTrue(equality_satisfied) + samples = sampler.draw(n=10) + self.assertTrue(torch.all(A @ samples.t() - b <= 0).item()) + self.assertLessEqual((C @ samples.t() - d).abs().sum().item(), 1e-6) self.assertTrue((samples <= bounds[1]).all()) self.assertTrue((samples >= bounds[0]).all()) # test no inequality constraints @@ -433,9 +437,8 @@ def test_sample_polytope_with_eq_constraints(self): interior_point=interior_point, **self.sampler_kwargs, ) - samples = sampler.draw(n=10, seed=1) - equality_satisfied = (C @ samples.t() - d).abs().sum().item() < 1e-6 - self.assertTrue(equality_satisfied) + samples = sampler.draw(n=10) + self.assertLessEqual((C @ samples.t() - d).abs().sum().item(), 1e-6) self.assertTrue((samples <= bounds[1]).all()) self.assertTrue((samples >= bounds[0]).all()) # test no inequality constraints or bounds @@ -456,8 +459,10 @@ def test_sample_polytope_1d(self): x0 = torch.tensor([[0.1], [0.1]], device=self.device, dtype=dtype) C = torch.tensor([[1.0, -1.0]], device=self.device, dtype=dtype) d = torch.tensor([[0.0]], device=self.device, dtype=dtype) - bounds = self.bounds[:, :2].to(dtype=dtype) - for interior_point in [x0, None]: + bounds = torch.tensor( + [[0.0, 0.0], [1.0, 1.0]], device=self.device, dtype=dtype + ) + for interior_point in (x0, None): sampler = self.sampler_class( inequality_constraints=(A, b), equality_constraints=(C, d), @@ -465,11 +470,9 @@ def test_sample_polytope_1d(self): interior_point=interior_point, **self.sampler_kwargs, ) - samples = sampler.draw(n=10, seed=1) - inequality_satisfied = ((A @ samples.t() - b) > 0).sum().item() == 0 - equality_satisfied = (C @ samples.t() - d).abs().sum().item() < 1e-6 - self.assertTrue(inequality_satisfied) - self.assertTrue(equality_satisfied) + samples = sampler.draw(n=10) + self.assertTrue(torch.all(A @ samples.t() - b <= 0).item()) + self.assertLessEqual((C @ samples.t() - d).abs().sum().item(), 1e-6) self.assertTrue((samples <= bounds[1]).all()) self.assertTrue((samples >= bounds[0]).all()) @@ -481,14 +484,18 @@ def test_initial_point(self): dtype=dtype, ) b = torch.tensor([[0.0], [-1.0], [1.0]], device=self.device, dtype=dtype) - x0 = self.x0.to(dtype) + x0 = torch.tensor([[0.1], [0.1], [0.1]], device=self.device, dtype=dtype) + bounds = torch.zeros(2, 3, device=self.device, dtype=dtype) + bounds[1] = 1.0 # testing for infeasibility of the initial point and # infeasibility of the original LP (status 2 of the linprog output). - for interior_point in [x0, None]: + for interior_point in (x0, None): with self.assertRaises(ValueError): self.sampler_class( - inequality_constraints=(A, b), interior_point=interior_point + inequality_constraints=(A, b), + bounds=bounds, + interior_point=interior_point, ) class Result: @@ -499,16 +506,28 @@ class Result: with mock.patch("scipy.optimize.linprog") as mock_linprog: mock_linprog.return_value = Result() with self.assertRaises(ValueError): - self.sampler_class(inequality_constraints=(A, b)) + self.sampler_class( + inequality_constraints=(A, b), + bounds=bounds, + ) class TestHitAndRunPolytopeSampler(PolytopeSamplerTestBase, BotorchTestCase): sampler_class = HitAndRunPolytopeSampler - sampler_kwargs = {"n_burnin": 2} + sampler_kwargs = {"n_burnin": 2, "n_thinning": 2} + constructor_seed_kwarg = {"seed": 33125612} + + def test_normalization_warning(self): + _, A, b, x0 = _get_constraints(device=self.device, dtype=torch.double) + with self.assertWarnsRegex( + UserInputWarning, "HitAndRunPolytopeSampler did not receive `bounds`" + ): + HitAndRunPolytopeSampler(inequality_constraints=(A, b), interior_point=x0) class TestDelaunayPolytopeSampler(PolytopeSamplerTestBase, BotorchTestCase): sampler_class = DelaunayPolytopeSampler + draw_seed_kwarg = {"seed": 33125612} def test_sample_polytope_unbounded(self): A = torch.tensor( @@ -516,12 +535,13 @@ def test_sample_polytope_unbounded(self): device=self.device, ) b = torch.tensor([[0.0], [0.0], [0.0], [1.0]], device=self.device) + x0 = torch.tensor([[0.1], [0.1], [0.1]], device=self.device) with self.assertRaises(ValueError): with warnings.catch_warnings(): warnings.simplefilter("ignore") self.sampler_class( inequality_constraints=(A, b), - interior_point=self.x0, + interior_point=x0, **self.sampler_kwargs, )