From 2f2b7e2163186e47da82b2da54b2fb0f001d9a7f Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Tue, 7 Feb 2023 19:54:21 -0800 Subject: [PATCH] Fix shape error in optimize_acqf_cyclic (#1648) Summary: ## Motivation Fixes https://github.com/pytorch/botorch/issues/873 In the past, `optimize_acqf` implicitly needed 3d inputs when there are equality constraints or inequality constraints and fixed_features don't provide the trivial solution, even though it worked with 2d inputs (no b-batches) in other cases. `optimize_acqf_cyclic` passed it 2d inputs, which would not generally work. I initially considered changing `optimize_acqf_cyclic` to pass 3d inputs, but since I found another place where 2d inputs were used, I decided to change `optimize_acqf` so it works with 2d inputs instead. This was not caught because the only usage of `optimize_acqf_cyclic` was in a test that mocked `optimize_acqf`, so `optimize_acqf_cyclic` was never actually run end-to-end. I changed the test for `optimize_acqf_cyclic` to be more end-to-end, at the cost of worse testing of some intermediate properties. We could keep both versions though. [x] Better docstring documentation on input shapes [x] Add a singleton leading b-dimension where initial conditions are 2d Pull Request resolved: https://github.com/pytorch/botorch/pull/1648 Test Plan: [x] More end-to-end test of `optimize_acqf_cyclic` that doesn't stub in `optimize_acqf` (see above) [x] more input validation and unit tests for input validation [x] Ran cases that now raise errors without the new error handling, to make sure they were erroring before [x] Make `_make_linear_constraints` work with 2d inputs so that `optimize_acqf` also does (previously, optimize_acqf only worked in some cases) Reviewed By: Balandat Differential Revision: D42875942 Pulled By: esantorella fbshipit-source-id: e3c650683a6b8d7c9e36fe1f14558db2854bab56 --- botorch/generation/gen.py | 5 +- botorch/optim/optimize.py | 19 +++- botorch/optim/parameter_constraints.py | 31 +++++- test/optim/test_optimize.py | 115 ++++++++++++++------- test/optim/test_parameter_constraints.py | 125 +++++++++++++---------- 5 files changed, 197 insertions(+), 98 deletions(-) diff --git a/botorch/generation/gen.py b/botorch/generation/gen.py index f6d23ebe05..249ee8f8e6 100644 --- a/botorch/generation/gen.py +++ b/botorch/generation/gen.py @@ -56,7 +56,8 @@ def gen_candidates_scipy( using `scipy.optimize.minimize` via a numpy converter. Args: - initial_conditions: Starting points for optimization. + initial_conditions: Starting points for optimization, with shape + (b) x q x d. acquisition_function: Acquisition function to be used. lower_bounds: Minimum values for each column of initial_conditions. upper_bounds: Maximum values for each column of initial_conditions. @@ -162,7 +163,7 @@ def gen_candidates_scipy( X=initial_conditions, lower_bounds=lower_bounds, upper_bounds=upper_bounds ) constraints = make_scipy_linear_constraints( - shapeX=clamped_candidates.shape, + shapeX=shapeX, inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, ) diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index 12fbad1be7..b361d8d853 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -110,7 +110,9 @@ def optimize_acqf( Returns: A two-element tuple containing - - a `(num_restarts) x q x d`-dim tensor of generated candidates. + - A tensor of generated candidates. The shape is + -- `q x d` if `return_best_only` is True (default) + -- `num_restarts x q x d` if `return_best_only` is False - a tensor of associated acquisition values. If `sequential=False`, this is a `(num_restarts)`-dim tensor of joint acquisition values (with explicit restart dimension if `return_best_only=False`). If @@ -158,6 +160,19 @@ def optimize_acqf( "initial conditions for the case of nonlinear inequality constraints." ) + d = bounds.shape[1] + if initial_conditions_provided: + if batch_initial_conditions.ndim not in (2, 3): + raise ValueError( + "batch_initial_conditions must be 2-dimensional or 3-dimensional. " + f"Its shape is {batch_initial_conditions.shape}." + ) + if batch_initial_conditions.shape[-1] != d: + raise ValueError( + f"batch_initial_conditions.shape[-1] must be {d}. The " + f"shape is {batch_initial_conditions.shape}." + ) + # Sets initial condition generator ic_gen if initial conditions not provided if not initial_conditions_provided: ic_gen = kwargs.pop("ic_generator", None) @@ -298,7 +313,7 @@ def _optimize_batch_candidates( logger.info(f"Generated candidate batch {i+1} of {len(batched_ics)}.") batch_candidates = torch.cat(batch_candidates_list) - batch_acq_values = torch.cat(batch_acq_values_list) + batch_acq_values = torch.stack(batch_acq_values_list).flatten() return batch_candidates, batch_acq_values, opt_warnings batch_candidates, batch_acq_values, ws = _optimize_batch_candidates(timeout_sec) diff --git a/botorch/optim/parameter_constraints.py b/botorch/optim/parameter_constraints.py index de3f545977..0184f8c3a9 100644 --- a/botorch/optim/parameter_constraints.py +++ b/botorch/optim/parameter_constraints.py @@ -73,7 +73,7 @@ def make_scipy_linear_constraints( r"""Generate scipy constraints from torch representation. Args: - shapeX: The shape of the torch.Tensor to optimize over (i.e. `b x q x d`) + shapeX: The shape of the torch.Tensor to optimize over (i.e. `(b) x q x d`) inequality constraints: A list of tuples (indices, coefficients, rhs), with each tuple encoding an inequality constraint of the form `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`, where @@ -219,10 +219,35 @@ def _make_linear_constraints( version of the input tensor `X`, returning a scalar. - "jac": A callable evaluating the constraint's Jacobian on `x`, a flattened version of the input tensor `X`, returning a numpy array. + + >>> shapeX = torch.Size([3, 5, 4]) + >>> constraints = _make_linear_constraints( + ... indices=torch.tensor([1., 2.]), + ... coefficients=torch.tensor([-0.5, 1.3]), + ... rhs=0.49, + ... shapeX=shapeX, + ... eq=True + ... ) + >>> len(constraints) + 15 + >>> constraints[0].keys() + dict_keys(['type', 'fun', 'jac']) + >>> x = np.arange(60).reshape(shapeX) + >>> constraints[0]["fun"](x) + 1.61 # 1 * -0.5 + 2 * 1.3 - 0.49 + >>> constraints[0]["jac"](x) + [0., -0.5, 1.3, 0., 0., ...] + >>> constraints[1]["fun"](x) # + 4.81 """ - if len(shapeX) != 3: - raise UnsupportedError("`shapeX` must be `b x q x d`") + if len(shapeX) not in (2, 3): + raise UnsupportedError( + f"`shapeX` must be `(b) x q x d` (at least two-dimensional). It is " + f"{shapeX}." + ) q, d = shapeX[-2:] + if len(shapeX) == 2: + shapeX = torch.Size([1, q, d]) n = shapeX.numel() constraints: List[ScipyConstraintDict] = [] coeffs = _arrayify(coefficients) diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index f07f077fe8..2e849d2d10 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -334,24 +334,64 @@ def test_optimize_acqf_sequential_notimplemented(self): ) def test_optimize_acqf_runs_given_batch_initial_conditions(self): - num_restarts, raw_samples, dim = 1, 1, 1 + num_restarts, raw_samples, dim = 1, 2, 3 opt_x = 2 / np.pi - # start near one (of many) optima - initial_conditions = (opt_x * 1.01) * torch.ones( - (num_restarts, raw_samples, dim) - ) + # -x[i] * 1 >= -opt_x * 1.01 => x[i] <= opt_x * 1.01 + inequality_constraints = [ + (torch.tensor([i]), -torch.tensor([1]), -opt_x * 1.01) for i in range(dim) + ] + [ + # x[i] * 1 >= opt_x * .99 + (torch.tensor([i]), torch.tensor([1]), opt_x * 0.99) + for i in range(dim) + ] + q = 1 + + ic_shapes = [(1, 2, dim), (2, 1, dim), (1, dim)] + torch.manual_seed(0) - batch_candidates, acq_value_list = optimize_acqf( - acq_function=SinOneOverXAcqusitionFunction(), - bounds=torch.stack([-1 * torch.ones(dim), torch.ones(dim)]), - q=1, - num_restarts=num_restarts, - raw_samples=raw_samples, - batch_initial_conditions=initial_conditions, - ) - self.assertAlmostEqual(batch_candidates.item(), opt_x, delta=1e-5) - self.assertAlmostEqual(acq_value_list.item(), 1) + for shape in ic_shapes: + with self.subTest(shape=shape): + # start near one (of many) optima + initial_conditions = (opt_x * 1.01) * torch.ones(shape) + batch_candidates, acq_value_list = optimize_acqf( + acq_function=SinOneOverXAcqusitionFunction(), + bounds=torch.stack([-1 * torch.ones(dim), torch.ones(dim)]), + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + batch_initial_conditions=initial_conditions, + inequality_constraints=inequality_constraints, + ) + self.assertAllClose( + batch_candidates, + opt_x * torch.ones_like(batch_candidates), + # must be at least 50% closer to the optimum than it started + atol=0.004, + rtol=0.005, + ) + self.assertAlmostEqual(acq_value_list.item(), 1, places=3) + + def test_optimize_acqf_wrong_ic_shape_inequality_constraints(self) -> None: + dim = 3 + ic_shapes = [(1, 2, dim + 1), (1, 2, dim, 1), (1, dim + 1), (1, 1), (dim,)] + + for shape in ic_shapes: + with self.subTest(shape=shape): + initial_conditions = torch.ones(shape) + expected_error = ( + rf"batch_initial_conditions.shape\[-1\] must be {dim}\." + if len(shape) in (2, 3) + else r"batch_initial_conditions must be 2\-dimensional or " + ) + with self.assertRaisesRegex(ValueError, expected_error): + optimize_acqf( + acq_function=MockAcquisitionFunction(), + bounds=torch.stack([-1 * torch.ones(dim), torch.ones(dim)]), + q=4, + batch_initial_conditions=initial_conditions, + num_restarts=1, + ) def test_optimize_acqf_warns_on_opt_failure(self): """ @@ -808,15 +848,20 @@ def test_optimize_acqf_cyclic(self, mock_optimize_acqf): tkwargs = {"device": self.device} bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)]) inequality_constraints = [ - [torch.tensor([3]), torch.tensor([4]), torch.tensor(5)] + [torch.tensor([2], dtype=int), torch.tensor([4.0]), torch.tensor(5.0)] ] mock_acq_function = MockAcquisitionFunction() for q, dtype in itertools.product([1, 3], (torch.float, torch.double)): - inequality_constraints[0] = [ - t.to(**tkwargs) for t in inequality_constraints[0] + tkwargs["dtype"] = dtype + inequality_constraints = [ + ( + # indices can't be floats or doubles + inequality_constraints[0][0], + inequality_constraints[0][1].to(**tkwargs), + inequality_constraints[0][2].to(**tkwargs), + ) ] mock_optimize_acqf.reset_mock() - tkwargs["dtype"] = dtype bounds = bounds.to(**tkwargs) candidate_rvs = [] acq_val_rvs = [] @@ -855,23 +900,23 @@ def test_optimize_acqf_cyclic(self, mock_optimize_acqf): post_processing_func=rounding_func, cyclic_options={"maxiter": num_cycles}, ) - # check that X_pending is set correctly in cyclic optimization - if q > 1: - x_pending_call_args_list = mock_set_X_pending.call_args_list - idxr = torch.ones(q, dtype=torch.bool, device=self.device) - for i in range(len(x_pending_call_args_list) - 1): - idxr[i] = 0 - self.assertTrue( - torch.equal( - x_pending_call_args_list[i][0][0], orig_candidates[idxr] - ) + # check that X_pending is set correctly in cyclic optimization + if q > 1: + x_pending_call_args_list = mock_set_X_pending.call_args_list + idxr = torch.ones(q, dtype=torch.bool, device=self.device) + for i in range(len(x_pending_call_args_list) - 1): + idxr[i] = 0 + self.assertTrue( + torch.equal( + x_pending_call_args_list[i][0][0], orig_candidates[idxr] ) - idxr[i] = 1 - orig_candidates[i] = candidate_rvs[i + 1] - # check reset to base_X_pendingg - self.assertIsNone(x_pending_call_args_list[-1][0][0]) - else: - mock_set_X_pending.assert_not_called() + ) + idxr[i] = 1 + orig_candidates[i] = candidate_rvs[i + 1] + # check reset to base_X_pendingg + self.assertIsNone(x_pending_call_args_list[-1][0][0]) + else: + mock_set_X_pending.assert_not_called() # check final candidates expected_candidates = ( torch.cat(candidate_rvs[-q:], dim=0) if q > 1 else candidate_rvs[0] diff --git a/test/optim/test_parameter_constraints.py b/test/optim/test_parameter_constraints.py index 55ed089e5f..c3601df295 100644 --- a/test/optim/test_parameter_constraints.py +++ b/test/optim/test_parameter_constraints.py @@ -4,6 +4,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from itertools import product + import numpy as np import torch from botorch.exceptions.errors import CandidateGenerationError, UnsupportedError @@ -45,9 +47,11 @@ def test_lin_constraint_jac(self): self.assertTrue(all(np.equal(res, np.array([1.0, 0.0, -2.0])))) def test_make_linear_constraints(self): + # equality constraints, 1d indices indices = torch.tensor([1, 2], dtype=torch.long, device=self.device) - shapeX = torch.Size([3, 2, 4]) - for dtype in (torch.float, torch.double): + for dtype, shapeX in product( + (torch.float, torch.double), (torch.Size([3, 2, 4]), torch.Size([2, 4])) + ): coefficients = torch.tensor([1.0, 2.0], dtype=dtype, device=self.device) constraints = _make_linear_constraints( indices=indices, @@ -70,21 +74,27 @@ def test_make_linear_constraints(self): jac_exp = np.zeros(shapeX.numel()) jac_exp[[-3, -2]] = [1, 2] self.assertTrue(np.allclose(constraints[-1]["jac"](x), jac_exp)) - # check inequality type - lcs = _make_linear_constraints( - indices=torch.tensor([1]), - coefficients=torch.tensor([1.0]), - rhs=1.0, - shapeX=torch.Size([1, 1, 2]), - eq=False, - ) - self.assertEqual(len(lcs), 1) - self.assertEqual(lcs[0]["type"], "ineq") - # check constraint across q-batch + # inequality constraints, 1d indices + for shapeX in [torch.Size([1, 1, 2]), torch.Size([1, 2])]: + lcs = _make_linear_constraints( + indices=torch.tensor([1]), + coefficients=torch.tensor([1.0]), + rhs=1.0, + shapeX=shapeX, + eq=False, + ) + self.assertEqual(len(lcs), 1) + self.assertEqual(lcs[0]["type"], "ineq") + + # constraint across q-batch (2d indics), equality constraint indices = torch.tensor([[0, 3], [1, 2]], dtype=torch.long, device=self.device) - shapeX = torch.Size([3, 2, 4]) - for dtype in (torch.float, torch.double): + + for dtype, shapeX in product( + (torch.float, torch.double), (torch.Size([3, 2, 4]), torch.Size([2, 4])) + ): + q, d = shapeX[-2:] + b = 1 if len(shapeX) == 2 else shapeX[0] coefficients = torch.tensor([1.0, 2.0], dtype=dtype, device=self.device) constraints = _make_linear_constraints( indices=indices, @@ -97,11 +107,11 @@ def test_make_linear_constraints(self): all(set(c.keys()) == {"fun", "jac", "type"} for c in constraints) ) self.assertTrue(all(c["type"] == "eq" for c in constraints)) - self.assertEqual(len(constraints), shapeX[0]) + self.assertEqual(len(constraints), b) x = np.random.rand(shapeX.numel()) - offsets = [shapeX[i:].numel() for i in range(1, len(shapeX))] + offsets = [q * d, d] # rule is [i, j, k] is i * offset[0] + j * offset[1] + k - for i in range(shapeX[0]): + for i in range(b): pos1 = i * offsets[0] + 3 pos2 = i * offsets[0] + 1 * offsets[1] + 2 self.assertEqual(constraints[i]["fun"](x), x[pos1] + 2 * x[pos2] - 1.0) @@ -119,49 +129,52 @@ def test_make_linear_constraints(self): ) def test_make_scipy_linear_constraints(self): - shapeX = torch.Size([2, 1, 4]) - res = make_scipy_linear_constraints( - shapeX=shapeX, inequality_constraints=None, equality_constraints=None - ) - self.assertEqual(res, []) - indices = torch.tensor([0, 1], dtype=torch.long, device=self.device) - coefficients = torch.tensor([1.5, -1.0], device=self.device) - cs = make_scipy_linear_constraints( - shapeX=shapeX, - inequality_constraints=[(indices, coefficients, 1.0)], - equality_constraints=[(indices, coefficients, 1.0)], - ) - self.assertEqual(len(cs), 4) - self.assertTrue({c["type"] for c in cs} == {"ineq", "eq"}) - cs = make_scipy_linear_constraints( - shapeX=shapeX, inequality_constraints=[(indices, coefficients, 1.0)] - ) - self.assertEqual(len(cs), 2) - self.assertTrue(all(c["type"] == "ineq" for c in cs)) - cs = make_scipy_linear_constraints( - shapeX=shapeX, equality_constraints=[(indices, coefficients, 1.0)] - ) - self.assertEqual(len(cs), 2) - self.assertTrue(all(c["type"] == "eq" for c in cs)) + for shapeX in [torch.Size([2, 1, 4]), torch.Size([1, 4])]: + b = shapeX[0] if len(shapeX) == 3 else 1 + res = make_scipy_linear_constraints( + shapeX=shapeX, inequality_constraints=None, equality_constraints=None + ) + self.assertEqual(res, []) + indices = torch.tensor([0, 1], dtype=torch.long, device=self.device) + coefficients = torch.tensor([1.5, -1.0], device=self.device) + # both inequality and equality constraints + cs = make_scipy_linear_constraints( + shapeX=shapeX, + inequality_constraints=[(indices, coefficients, 1.0)], + equality_constraints=[(indices, coefficients, 1.0)], + ) + self.assertEqual(len(cs), 2 * b) + self.assertTrue({c["type"] for c in cs} == {"ineq", "eq"}) + # inequality only + cs = make_scipy_linear_constraints( + shapeX=shapeX, inequality_constraints=[(indices, coefficients, 1.0)] + ) + self.assertEqual(len(cs), b) + self.assertTrue(all(c["type"] == "ineq" for c in cs)) + # equality only + cs = make_scipy_linear_constraints( + shapeX=shapeX, equality_constraints=[(indices, coefficients, 1.0)] + ) + self.assertEqual(len(cs), b) + self.assertTrue(all(c["type"] == "eq" for c in cs)) - # test that len(shapeX) < 3 raises an error - with self.assertRaises(UnsupportedError): - make_scipy_linear_constraints( - shapeX=torch.Size([2, 1]), + # test that 2-dim indices work properly + indices = indices.unsqueeze(0) + cs = make_scipy_linear_constraints( + shapeX=shapeX, inequality_constraints=[(indices, coefficients, 1.0)], equality_constraints=[(indices, coefficients, 1.0)], ) - # test that 2-dim indices work properly - indices = indices.unsqueeze(0) - cs = make_scipy_linear_constraints( - shapeX=shapeX, - inequality_constraints=[(indices, coefficients, 1.0)], - equality_constraints=[(indices, coefficients, 1.0)], - ) - self.assertEqual(len(cs), 4) - self.assertTrue({c["type"] for c in cs} == {"ineq", "eq"}) + self.assertEqual(len(cs), 2 * b) + self.assertTrue({c["type"] for c in cs} == {"ineq", "eq"}) + + def test_make_scipy_linear_constraints_unsupported(self): + shapeX = torch.Size([2, 1, 4]) + coefficients = torch.tensor([1.5, -1.0], device=self.device) + # test that >2-dim indices raises an UnsupportedError - indices = indices.unsqueeze(0) + indices = torch.tensor([0, 1], dtype=torch.long, device=self.device) + indices = indices.unsqueeze(0).unsqueeze(0) with self.assertRaises(UnsupportedError): make_scipy_linear_constraints( shapeX=shapeX,