diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index a38e0fe022..5a4a6f568b 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -933,7 +933,11 @@ def optimize_acqf_mixed( nonlinear_inequality_constraints: list[tuple[Callable, bool]] | None = None, post_processing_func: Callable[[Tensor], Tensor] | None = None, batch_initial_conditions: Tensor | None = None, + return_best_only: bool = True, + gen_candidates: TGenCandidates | None = None, ic_generator: TGenInitialConditions | None = None, + timeout_sec: float | None = None, + retry_on_optimization_warning: bool = True, ic_gen_kwargs: dict | None = None, ) -> tuple[Tensor, Tensor]: r"""Optimize over a list of fixed_features and returns the best solution. @@ -982,20 +986,38 @@ def optimize_acqf_mixed( transformations). batch_initial_conditions: A tensor to specify the initial conditions. Set this if you do not want to use default initialization strategy. + return_best_only: If False, outputs the solutions corresponding to all + random restart initializations of the optimization. Setting this keyword + to False is only allowed for `q=1`. Defaults to True. + gen_candidates: A callable for generating candidates (and their associated + acquisition values) given a tensor of initial conditions and an + acquisition function. Other common inputs include lower and upper bounds + and a dictionary of options, but refer to the documentation of specific + generation functions (e.g gen_candidates_scipy and gen_candidates_torch) + for method-specific inputs. Default: `gen_candidates_scipy` ic_generator: Function for generating initial conditions. Not needed when `batch_initial_conditions` are provided. Defaults to `gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition functions and `gen_batch_initial_conditions` otherwise. Must be specified for nonlinear inequality constraints. + timeout_sec: Max amount of time optimization can run for. + retry_on_optimization_warning: Whether to retry candidate generation with a new + set of initial conditions when it fails with an `OptimizationWarning`. ic_gen_kwargs: Additional keyword arguments passed to function specified by `ic_generator` Returns: A two-element tuple containing - - a `q x d`-dim tensor of generated candidates. - - an associated acquisition value. + - A tensor of generated candidates. The shape is + -- `q x d` if `return_best_only` is True (default) + -- `num_restarts x q x d` if `return_best_only` is False + - a tensor of associated acquisition values of dim `num_restarts` + if `return_best_only=False` else a scalar acquisition value. """ + if not return_best_only and q > 1: + raise NotImplementedError("`return_best_only=False` is only supported for q=1.") + if not fixed_features_list: raise ValueError("fixed_features_list must be non-empty.") @@ -1010,11 +1032,12 @@ def optimize_acqf_mixed( ic_gen_kwargs = ic_gen_kwargs or {} if q == 1: + timeout_sec = timeout_sec / len(fixed_features_list) if timeout_sec else None ff_candidate_list, ff_acq_value_list = [], [] num_candidate_generation_failures = 0 for fixed_features in fixed_features_list: try: - candidate, acq_value = optimize_acqf( + candidates, acq_values = optimize_acqf( acq_function=acq_function, bounds=bounds, q=q, @@ -1028,15 +1051,19 @@ def optimize_acqf_mixed( post_processing_func=post_processing_func, batch_initial_conditions=batch_initial_conditions, ic_generator=ic_generator, - return_best_only=True, + return_best_only=False, # here we always return all candidates + # and filter later + gen_candidates=gen_candidates, + timeout_sec=timeout_sec, + retry_on_optimization_warning=retry_on_optimization_warning, **ic_gen_kwargs, ) except CandidateGenerationError: # if candidate generation fails, we skip this candidate num_candidate_generation_failures += 1 continue - ff_candidate_list.append(candidate) - ff_acq_value_list.append(acq_value) + ff_candidate_list.append(candidates) + ff_acq_value_list.append(acq_values) if len(ff_candidate_list) == 0: raise CandidateGenerationError( @@ -1051,9 +1078,17 @@ def optimize_acqf_mixed( OptimizationWarning, stacklevel=3, ) + ff_acq_values = torch.stack(ff_acq_value_list) - best = torch.argmax(ff_acq_values) - return ff_candidate_list[best], ff_acq_values[best] + max_res = torch.max(ff_acq_values, dim=-1) + best_batch_idx = torch.argmax(max_res.values) + best_batch_candidates = ff_candidate_list[best_batch_idx] + best_acq_values = ff_acq_value_list[best_batch_idx] + if not return_best_only: + return best_batch_candidates, best_acq_values + + best_idx = max_res.indices[best_batch_idx] + return best_batch_candidates[best_idx], best_acq_values[best_idx] # For batch optimization with q > 1 we do not want to enumerate all n_combos^n # possible combinations of discrete choices. Instead, we use sequential greedy @@ -1061,6 +1096,7 @@ def optimize_acqf_mixed( base_X_pending = acq_function.X_pending candidates = torch.tensor([], device=bounds.device, dtype=bounds.dtype) + timeout_sec = timeout_sec / q if timeout_sec else None for _ in range(q): candidate, acq_value = optimize_acqf_mixed( acq_function=acq_function, @@ -1075,8 +1111,12 @@ def optimize_acqf_mixed( nonlinear_inequality_constraints=nonlinear_inequality_constraints, post_processing_func=post_processing_func, batch_initial_conditions=batch_initial_conditions, + gen_candidates=gen_candidates, ic_generator=ic_generator, ic_gen_kwargs=ic_gen_kwargs, + timeout_sec=timeout_sec, + retry_on_optimization_warning=retry_on_optimization_warning, + return_best_only=True, ) candidates = torch.cat([candidates, candidate], dim=-2) acq_function.set_X_pending( diff --git a/botorch/optim/optimize_homotopy.py b/botorch/optim/optimize_homotopy.py index 136897fe60..00a38ab425 100644 --- a/botorch/optim/optimize_homotopy.py +++ b/botorch/optim/optimize_homotopy.py @@ -5,6 +5,8 @@ from __future__ import annotations +import warnings + from collections.abc import Callable from typing import Any @@ -15,7 +17,7 @@ from botorch.generation.gen import TGenCandidates from botorch.optim.homotopy import Homotopy from botorch.optim.initializers import TGenInitialConditions -from botorch.optim.optimize import optimize_acqf +from botorch.optim.optimize import optimize_acqf, optimize_acqf_mixed from torch import Tensor @@ -67,14 +69,13 @@ def optimize_acqf_homotopy( equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, nonlinear_inequality_constraints: list[tuple[Callable, bool]] | None = None, fixed_features: dict[int, float] | None = None, + fixed_features_list: list[dict[int, float]] | None = None, post_processing_func: Callable[[Tensor], Tensor] | None = None, batch_initial_conditions: Tensor | None = None, gen_candidates: TGenCandidates | None = None, - sequential: bool = False, *, ic_generator: TGenInitialConditions | None = None, timeout_sec: float | None = None, - return_full_tree: bool = False, retry_on_optimization_warning: bool = True, **ic_gen_kwargs: Any, ) -> tuple[Tensor, Tensor]: @@ -129,6 +130,10 @@ def optimize_acqf_homotopy( `options`. fixed_features: A map `{feature_index: value}` for features that should be fixed to a particular value during generation. + fixed_features_list: A list of maps `{feature_index: value}`. The i-th + item represents the fixed_feature for the i-th optimization. If + `fixed_features_list` is provided, `optimize_acqf_mixed` is invoked. + All indices (`feature_index`) should be non-negative. post_processing_func: A function that post-processes an optimization result appropriately (i.e., according to `round-trip` transformations). @@ -140,37 +145,57 @@ def optimize_acqf_homotopy( and a dictionary of options, but refer to the documentation of specific generation functions (e.g gen_candidates_scipy and gen_candidates_torch) for method-specific inputs. Default: `gen_candidates_scipy` - sequential: If False, uses joint optimization, otherwise uses sequential - optimization. ic_generator: Function for generating initial conditions. Not needed when `batch_initial_conditions` are provided. Defaults to `gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition functions and `gen_batch_initial_conditions` otherwise. Must be specified for nonlinear inequality constraints. timeout_sec: Max amount of time optimization can run for. - return_full_tree: Return the full tree of optimizers of the previous - iteration. retry_on_optimization_warning: Whether to retry candidate generation with a new set of initial conditions when it fails with an `OptimizationWarning`. ic_gen_kwargs: Additional keyword arguments passed to function specified by `ic_generator` """ + if fixed_features and fixed_features_list: + raise ValueError( + "Either `fixed_feature` or `fixed_features_list` can be provided, not both." + ) + + if fixed_features: + message = ( + "The `fixed_features` argument is deprecated, " + "use `fixed_features_list` instead." + ) + warnings.warn( + message, + DeprecationWarning, + stacklevel=2, + ) + shared_optimize_acqf_kwargs = { "num_restarts": num_restarts, "inequality_constraints": inequality_constraints, "equality_constraints": equality_constraints, "nonlinear_inequality_constraints": nonlinear_inequality_constraints, - "fixed_features": fixed_features, "return_best_only": False, # False to make n_restarts persist through homotopy. "gen_candidates": gen_candidates, - "sequential": sequential, "ic_generator": ic_generator, "timeout_sec": timeout_sec, - "return_full_tree": return_full_tree, "retry_on_optimization_warning": retry_on_optimization_warning, **ic_gen_kwargs, } + if fixed_features_list and len(fixed_features_list) > 1: + optimization_fn = optimize_acqf_mixed + fixed_features_kwargs = {"fixed_features_list": fixed_features_list} + else: + optimization_fn = optimize_acqf + fixed_features_kwargs = { + "fixed_features": fixed_features_list[0] + if fixed_features_list + else fixed_features + } + candidate_list, acq_value_list = [], [] if q > 1: base_X_pending = acq_function.X_pending @@ -181,15 +206,17 @@ def optimize_acqf_homotopy( homotopy.restart() while not homotopy.should_stop: - candidates, acq_values = optimize_acqf( + candidates, acq_values = optimization_fn( acq_function=acq_function, bounds=bounds, q=1, options=options, batch_initial_conditions=candidates, raw_samples=q_raw_samples, + **fixed_features_kwargs, **shared_optimize_acqf_kwargs, ) + homotopy.step() # Set raw_samples to None such that pruned restarts are not repopulated @@ -204,13 +231,14 @@ def optimize_acqf_homotopy( ).unsqueeze(1) # Optimize one more time with the final options - candidates, acq_values = optimize_acqf( + candidates, acq_values = optimization_fn( acq_function=acq_function, bounds=bounds, q=1, options=final_options, raw_samples=q_raw_samples, batch_initial_conditions=candidates, + **fixed_features_kwargs, **shared_optimize_acqf_kwargs, ) diff --git a/test/optim/test_homotopy.py b/test/optim/test_homotopy.py index 859781e6d6..169e35b805 100644 --- a/test/optim/test_homotopy.py +++ b/test/optim/test_homotopy.py @@ -117,7 +117,7 @@ def test_optimize_acqf_homotopy(self): candidate, acqf_val = optimize_acqf_homotopy( q=1, acq_function=acqf, - bounds=torch.tensor([[-10], [5]]).to(**tkwargs), + bounds=torch.tensor([[-10], [5]], **tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), num_restarts=2, raw_samples=16, @@ -132,14 +132,62 @@ def test_optimize_acqf_homotopy(self): f=lambda x: 5 - (x - p).sum(dim=-1, keepdims=True) ** 2 ) acqf = PosteriorMean(model=model) + # test raise warning on using `fixed_features` argument + message = ( + "The `fixed_features` argument is deprecated, " + "use `fixed_features_list` instead." + ) + with self.assertWarnsRegex(DeprecationWarning, message): + optimize_acqf_homotopy( + q=1, + acq_function=acqf, + bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), + homotopy=Homotopy(homotopy_parameters=[hp]), + num_restarts=2, + raw_samples=16, + fixed_features=fixed_features, + ) + + candidate, acqf_val = optimize_acqf_homotopy( + q=1, + acq_function=acqf, + bounds=torch.tensor([[-10, -10], [5, 5]], **tkwargs), + homotopy=Homotopy(homotopy_parameters=[hp]), + num_restarts=2, + raw_samples=16, + fixed_features_list=[fixed_features], + ) + self.assertEqual(candidate[0, 0], torch.tensor(1, **tkwargs)) + + # test fixed feature list + fixed_features_list = [{0: 1.0}, {1: 3.0}] + model = GenericDeterministicModel( + f=lambda x: 5 - (x - p).sum(dim=-1, keepdims=True) ** 2 + ) + acqf = PosteriorMean(model=model) + # test raise error when fixed_features and fixed_features_list are both provided + with self.assertRaisesRegex( + ValueError, + "Either `fixed_feature` or `fixed_features_list` can be provided", + ): + optimize_acqf_homotopy( + q=1, + acq_function=acqf, + bounds=torch.tensor([[-10, -10, -10], [5, 5, 5]], **tkwargs), + homotopy=Homotopy(homotopy_parameters=[hp]), + num_restarts=2, + raw_samples=16, + fixed_features_list=fixed_features_list, + fixed_features=fixed_features, + ) candidate, acqf_val = optimize_acqf_homotopy( q=1, acq_function=acqf, - bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), + bounds=torch.tensor([[-10, -10, -10], [5, 5, 5]], **tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), num_restarts=2, raw_samples=16, - fixed_features=fixed_features, + fixed_features_list=fixed_features_list, ) self.assertEqual(candidate[0, 0], torch.tensor(1, **tkwargs)) @@ -148,11 +196,11 @@ def test_optimize_acqf_homotopy(self): candidate, acqf_val = optimize_acqf_homotopy( q=3, acq_function=acqf, - bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), + bounds=torch.tensor([[-10, -10], [5, 5]], **tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), num_restarts=2, raw_samples=16, - fixed_features=fixed_features, + fixed_features_list=[fixed_features], ) self.assertEqual(candidate.shape, torch.Size([3, 2])) self.assertEqual(acqf_val.shape, torch.Size([3])) @@ -170,7 +218,7 @@ def test_optimize_acqf_homotopy(self): candidate, acqf_val = optimize_acqf_homotopy( q=1, acq_function=acqf, - bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), + bounds=torch.tensor([[-10, -10], [5, 5]], **tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), num_restarts=2, raw_samples=16, diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index d913882c1d..4cb541722b 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -1529,7 +1529,9 @@ def test_optimize_acqf_mixed_q1(self, mock_optimize_acqf): tkwargs = {"device": self.device} bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)]) mock_acq_function = MockAcquisitionFunction() - for num_ff, dtype in itertools.product([1, 3], (torch.float, torch.double)): + for num_ff, dtype, return_best_only in itertools.product( + [1, 3], (torch.float, torch.double), (True, False) + ): tkwargs["dtype"] = dtype mock_optimize_acqf.reset_mock() bounds = bounds.to(**tkwargs) @@ -1537,8 +1539,8 @@ def test_optimize_acqf_mixed_q1(self, mock_optimize_acqf): candidate_rvs = [] acq_val_rvs = [] for _ in range(num_ff): - candidate_rvs.append(torch.rand(1, 3, **tkwargs)) - acq_val_rvs.append(torch.rand(1, **tkwargs)) + candidate_rvs.append(torch.rand(num_restarts, 1, 3, **tkwargs)) + acq_val_rvs.append(torch.rand(num_restarts, **tkwargs)) fixed_features_list = [{i: i * 0.1} for i in range(num_ff)] side_effect = list(zip(candidate_rvs, acq_val_rvs)) mock_optimize_acqf.side_effect = side_effect @@ -1551,13 +1553,29 @@ def test_optimize_acqf_mixed_q1(self, mock_optimize_acqf): num_restarts=num_restarts, raw_samples=raw_samples, options=options, + return_best_only=return_best_only, post_processing_func=rounding_func, ) # compute expected output - ff_acq_values = torch.stack(acq_val_rvs) - best = torch.argmax(ff_acq_values) - expected_candidates = candidate_rvs[best] - expected_acq_value = ff_acq_values[best] + best_acq_values = torch.tensor( + [torch.max(acq_values) for acq_values in acq_val_rvs] + ) + best_batch_idx = torch.argmax(best_acq_values) + + if return_best_only: + best_batch_candidates = candidate_rvs[best_batch_idx] + best_batch_acq_values = acq_val_rvs[best_batch_idx] + best_idx = torch.argmax(best_batch_acq_values) + expected_candidates = best_batch_candidates[best_idx] + expected_acq_value = best_batch_acq_values[best_idx] + self.assertEqual(expected_candidates.dim(), 2) + + else: + expected_candidates = candidate_rvs[best_batch_idx] + expected_acq_value = acq_val_rvs[best_batch_idx] + self.assertEqual(expected_candidates.dim(), 3) + self.assertEqual(expected_acq_value.dim(), 1) + self.assertTrue(torch.equal(candidates, expected_candidates)) self.assertTrue(torch.equal(acq_value, expected_acq_value)) # check call arguments for optimize_acqf @@ -1572,11 +1590,14 @@ def test_optimize_acqf_mixed_q1(self, mock_optimize_acqf): "inequality_constraints": None, "equality_constraints": None, "fixed_features": None, + "gen_candidates": None, "post_processing_func": rounding_func, "batch_initial_conditions": None, - "return_best_only": True, + "return_best_only": False, "sequential": False, "ic_generator": None, + "timeout_sec": None, + "retry_on_optimization_warning": True, "nonlinear_inequality_constraints": None, } for i in range(len(call_args_list)): @@ -1612,10 +1633,24 @@ def test_optimize_acqf_mixed_q2(self, mock_optimize_acqf): candidate_rvs, exp_candidates, acq_val_rvs = [], [], [] # generate mock side effects and compute expected outputs for _ in range(q): - candidate_rvs_q = [torch.rand(1, 3, **tkwargs) for _ in range(num_ff)] - acq_val_rvs_q = [torch.rand(1, **tkwargs) for _ in range(num_ff)] - best = torch.argmax(torch.stack(acq_val_rvs_q)) - exp_candidates.append(candidate_rvs_q[best]) + candidate_rvs_q = [ + torch.rand(num_restarts, 1, 3, **tkwargs) for _ in range(num_ff) + ] + acq_val_rvs_q = [ + torch.rand(num_restarts, **tkwargs) for _ in range(num_ff) + ] + + best_acq_values = torch.tensor( + [torch.max(acq_values) for acq_values in acq_val_rvs_q] + ) + best_batch_idx = torch.argmax(best_acq_values) + + best_batch_candidates = candidate_rvs_q[best_batch_idx] + best_batch_acq_values = acq_val_rvs_q[best_batch_idx] + best_idx = torch.argmax(best_batch_acq_values) + + exp_candidates.append(best_batch_candidates[best_idx]) + candidate_rvs += candidate_rvs_q acq_val_rvs += acq_val_rvs_q side_effect = list(zip(candidate_rvs, acq_val_rvs)) @@ -1643,7 +1678,9 @@ def test_optimize_acqf_mixed_q2(self, mock_optimize_acqf): self.assertTrue(torch.equal(acq_value, expected_acq_value)) def test_optimize_acqf_mixed_empty_ff(self): - with self.assertRaises(ValueError): + with self.assertRaisesRegex( + ValueError, expected_regex="fixed_features_list must be non-empty." + ): mock_acq_function = MockAcquisitionFunction() optimize_acqf_mixed( acq_function=mock_acq_function, @@ -1654,6 +1691,22 @@ def test_optimize_acqf_mixed_empty_ff(self): raw_samples=10, ) + def test_optimize_acqf_mixed_return_best_only_q2(self): + mock_acq_function = MockAcquisitionFunction() + with self.assertRaisesRegex( + NotImplementedError, + expected_regex="`return_best_only=False` is only supported for q=1.", + ): + optimize_acqf_mixed( + acq_function=mock_acq_function, + q=2, + fixed_features_list=[{0: 0.0}], + bounds=torch.stack([torch.zeros(3), 4 * torch.ones(3)]), + num_restarts=2, + raw_samples=10, + return_best_only=False, + ) + def test_optimize_acqf_one_shot_large_q(self): with self.assertRaises(ValueError): mock_acq_function = MockOneShotAcquisitionFunction()