From c157b57017858d52ecfe4eb05e55090d72f6b45d Mon Sep 17 00:00:00 2001 From: hvarfner Date: Tue, 25 Feb 2025 11:37:04 +0100 Subject: [PATCH] Added entropy search acquisition function initializer, ensured unittest coverage --- botorch/optim/initializers.py | 194 +++++++++++++++++++++++--------- botorch/optim/optimize.py | 4 + test/optim/test_initializers.py | 107 ++++++++++++++++++ test/optim/test_optimize.py | 11 ++ 4 files changed, 263 insertions(+), 53 deletions(-) diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index e5e81f3dcc..d8d3acbdbe 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -468,6 +468,91 @@ def gen_batch_initial_conditions( return batch_initial_conditions +def gen_optimal_input_initial_conditions( + acq_function: AcquisitionFunction, + bounds: Tensor, + q: int, + num_restarts: int, + raw_samples: int, + fixed_features: dict[int, float] | None = None, + options: dict[str, bool | float | int] | None = None, + inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, +): + options = options or {} + device = bounds.device + if not hasattr(acq_function, "optimal_inputs"): + raise AttributeError( + "gen_optimal_input_initial_conditions can only be used with " + "an AcquisitionFunction that has an optimal_inputs attribute." + ) + frac_random: float = options.get("frac_random", 0.0) + if not 0 <= frac_random <= 1: + raise ValueError( + f"frac_random must take on values in (0,1). Value: {frac_random}" + ) + + batch_limit = options.get("batch_limit") + num_optima = acq_function.optimal_inputs.shape[:-1].numel() + suggestions = acq_function.optimal_inputs.reshape(num_optima, -1) + X = torch.empty(0, q, bounds.shape[1], dtype=bounds.dtype) + num_random = round(raw_samples * frac_random) + if num_random > 0: + X_rnd = sample_q_batches_from_polytope( + n=num_random, + q=q, + bounds=bounds, + n_burnin=options.get("n_burnin", 10000), + n_thinning=options.get("n_thinning", 32), + equality_constraints=equality_constraints, + inequality_constraints=inequality_constraints, + ) + X = torch.cat((X, X_rnd)) + + if num_random < raw_samples: + X_perturbed = sample_points_around_best( + acq_function=acq_function, + n_discrete_points=q * (raw_samples - num_random), + sigma=options.get("sample_around_best_sigma", 1e-2), + bounds=bounds, + best_X=suggestions, + ) + X_perturbed = X_perturbed.view( + raw_samples - num_random, q, bounds.shape[-1] + ).cpu() + X = torch.cat((X, X_perturbed)) + + if options.get("sample_around_best", False): + X_best = sample_points_around_best( + acq_function=acq_function, + n_discrete_points=q * raw_samples, + sigma=options.get("sample_around_best_sigma", 1e-2), + bounds=bounds, + ) + X_best = X_best.view(raw_samples, q, bounds.shape[-1]).cpu() + X = torch.cat((X, X_best)) + + with torch.no_grad(): + if batch_limit is None: + batch_limit = X.shape[0] + # Evaluate the acquisition function on `X_rnd` using `batch_limit` + # sized chunks. + acq_vals = torch.cat( + [ + acq_function(x_.to(device=device)).cpu() + for x_ in X.split(split_size=batch_limit, dim=0) + ], + dim=0, + ) + idx = boltzmann_sample( + function_values=acq_vals, + num_samples=num_restarts, + eta=options.get("eta", 2.0), + ) + # set the respective initial conditions to the sampled optimizers + return X[idx] + + def gen_one_shot_kg_initial_conditions( acq_function: qKnowledgeGradient, bounds: Tensor, @@ -1136,6 +1221,7 @@ def sample_points_around_best( best_pct: float = 5.0, subset_sigma: float = 1e-1, prob_perturb: float | None = None, + best_X: Tensor | None = None, ) -> Tensor | None: r"""Find best points and sample nearby points. @@ -1154,60 +1240,62 @@ def sample_points_around_best( An optional `n_discrete_points x d`-dim tensor containing the sampled points. This is None if no baseline points are found. """ - X = get_X_baseline(acq_function=acq_function) - if X is None: - return - with torch.no_grad(): - try: - posterior = acq_function.model.posterior(X) - except AttributeError: - warnings.warn( - "Failed to sample around previous best points.", - BotorchWarning, - stacklevel=3, - ) + if best_X is None: + X = get_X_baseline(acq_function=acq_function) + if X is None: return - mean = posterior.mean - while mean.ndim > 2: - # take average over batch dims - mean = mean.mean(dim=0) - try: - f_pred = acq_function.objective(mean) - # Some acquisition functions do not have an objective - # and for some acquisition functions the objective is None - except (AttributeError, TypeError): - f_pred = mean - if hasattr(acq_function, "maximize"): - # make sure that the optimiztaion direction is set properly - if not acq_function.maximize: - f_pred = -f_pred - try: - # handle constraints for EHVI-based acquisition functions - constraints = acq_function.constraints - if constraints is not None: - neg_violation = -torch.stack( - [c(mean).clamp_min(0.0) for c in constraints], dim=-1 - ).sum(dim=-1) - feas = neg_violation == 0 - if feas.any(): - f_pred[~feas] = float("-inf") - else: - # set objective equal to negative violation - f_pred = neg_violation - except AttributeError: - pass - if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1: - # multi-objective - # find pareto set - is_pareto = is_non_dominated(f_pred) - best_X = X[is_pareto] - else: - if f_pred.shape[-1] == 1: - f_pred = f_pred.squeeze(-1) - n_best = max(1, round(X.shape[0] * best_pct / 100)) - # the view() is to ensure that best_idcs is not a scalar tensor - best_idcs = torch.topk(f_pred, n_best).indices.view(-1) - best_X = X[best_idcs] + with torch.no_grad(): + try: + posterior = acq_function.model.posterior(X) + except AttributeError: + warnings.warn( + "Failed to sample around previous best points.", + BotorchWarning, + stacklevel=3, + ) + return + mean = posterior.mean + while mean.ndim > 2: + # take average over batch dims + mean = mean.mean(dim=0) + try: + f_pred = acq_function.objective(mean) + # Some acquisition functions do not have an objective + # and for some acquisition functions the objective is None + except (AttributeError, TypeError): + f_pred = mean + if hasattr(acq_function, "maximize"): + # make sure that the optimiztaion direction is set properly + if not acq_function.maximize: + f_pred = -f_pred + try: + # handle constraints for EHVI-based acquisition functions + constraints = acq_function.constraints + if constraints is not None: + neg_violation = -torch.stack( + [c(mean).clamp_min(0.0) for c in constraints], dim=-1 + ).sum(dim=-1) + feas = neg_violation == 0 + if feas.any(): + f_pred[~feas] = float("-inf") + else: + # set objective equal to negative violation + f_pred = neg_violation + except AttributeError: + pass + if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1: + # multi-objective + # find pareto set + is_pareto = is_non_dominated(f_pred) + best_X = X[is_pareto] + else: + if f_pred.shape[-1] == 1: + f_pred = f_pred.squeeze(-1) + n_best = max(1, round(X.shape[0] * best_pct / 100)) + # the view() is to ensure that best_idcs is not a scalar tensor + best_idcs = torch.topk(f_pred, n_best).indices.view(-1) + best_X = X[best_idcs] + use_perturbed_sampling = best_X.shape[-1] >= 20 or prob_perturb is not None n_trunc_normal_points = ( n_discrete_points // 2 if use_perturbed_sampling else n_discrete_points diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index 6f3a5876a9..8d56ddd0d4 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -20,6 +20,7 @@ AcquisitionFunction, OneShotAcquisitionFunction, ) +from botorch.acquisition.joint_entropy_search import qJointEntropySearch from botorch.acquisition.knowledge_gradient import qKnowledgeGradient from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import ( qHypervolumeKnowledgeGradient, @@ -33,6 +34,7 @@ gen_batch_initial_conditions, gen_one_shot_hvkg_initial_conditions, gen_one_shot_kg_initial_conditions, + gen_optimal_input_initial_conditions, TGenInitialConditions, ) from botorch.optim.stopping import ExpMAStoppingCriterion @@ -174,6 +176,8 @@ def get_ic_generator(self) -> TGenInitialConditions: return gen_one_shot_kg_initial_conditions elif isinstance(self.acq_function, qHypervolumeKnowledgeGradient): return gen_one_shot_hvkg_initial_conditions + elif isinstance(self.acq_function, qJointEntropySearch): + return gen_optimal_input_initial_conditions return gen_batch_initial_conditions diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 187a09d7f3..5d2117a069 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -13,6 +13,7 @@ import torch from botorch.acquisition.analytic import PosteriorMean from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction +from botorch.acquisition.joint_entropy_search import qJointEntropySearch from botorch.acquisition.knowledge_gradient import qKnowledgeGradient from botorch.acquisition.monte_carlo import ( qExpectedImprovement, @@ -34,6 +35,7 @@ gen_batch_initial_conditions, gen_one_shot_hvkg_initial_conditions, gen_one_shot_kg_initial_conditions, + gen_optimal_input_initial_conditions, gen_value_function_initial_conditions, initialize_q_batch, initialize_q_batch_nonneg, @@ -47,6 +49,7 @@ ) from botorch.sampling.normal import IIDNormalSampler from botorch.utils.sampling import manual_seed, unnormalize +from botorch.utils.test_helpers import get_model from botorch.utils.testing import ( _get_max_violation_of_bounds, _get_max_violation_of_constraints, @@ -1074,6 +1077,110 @@ def test_gen_one_shot_kg_initial_conditions(self): ) self.assertTrue(torch.all(ics[..., -n_value:, :] == 1)) + def test_gen_optimal_input_initial_conditions(self): + num_restarts = 10 + raw_samples = 16 + q = 3 + for dtype in (torch.float, torch.double): + model = get_model( + torch.rand(4, 2, dtype=dtype), torch.rand(4, 1, dtype=dtype) + ) + optimal_inputs = torch.rand(5, 2, dtype=dtype) + optimal_outputs = torch.rand(5, 1, dtype=dtype) + jes = qJointEntropySearch( + model=model, + optimal_inputs=optimal_inputs, + optimal_outputs=optimal_outputs, + ) + bounds = torch.tensor([[0, 0], [1, 1]], device=self.device, dtype=dtype) + # base case + ics = gen_optimal_input_initial_conditions( + acq_function=jes, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + ) + self.assertEqual(ics.shape, torch.Size([num_restarts, q, 2])) + + # since we do sample_around best, this should generate enough points + # despite num_restarts being larger than raw_samples + ics = gen_optimal_input_initial_conditions( + acq_function=jes, + bounds=bounds, + q=q, + num_restarts=15, + raw_samples=8, + options={"frac_random": 0.2, "sample_around_best": True}, + ) + self.assertEqual(ics.shape, torch.Size([15, q, 2])) + + # test option error + with self.assertRaises(ValueError): + gen_optimal_input_initial_conditions( + acq_function=jes, + bounds=bounds, + q=1, + num_restarts=num_restarts, + raw_samples=raw_samples, + options={"frac_random": 2.0}, + ) + + ei = qExpectedImprovement(model, 99.9) + with self.assertRaisesRegex( + AttributeError, + "gen_optimal_input_initial_conditions can only be used with " + "an AcquisitionFunction that has an optimal_inputs attribute.", + ): + gen_optimal_input_initial_conditions( + acq_function=ei, + bounds=bounds, + q=1, + num_restarts=num_restarts, + raw_samples=raw_samples, + options={"frac_random": 2.0}, + ) + # test generation logic + random_ics = torch.rand(raw_samples // 2, q, 2) + suggested_ics = torch.rand(raw_samples // 2 * q, 2) + with ExitStack() as es: + mock_random_ics = es.enter_context( + mock.patch( + "botorch.optim.initializers.sample_q_batches_from_polytope", + return_value=random_ics, + ) + ) + mock_suggested_ics = es.enter_context( + mock.patch( + "botorch.optim.initializers.sample_points_around_best", + return_value=suggested_ics, + ) + ) + mock_choose = es.enter_context( + mock.patch( + "torch.multinomial", + return_value=torch.arange(0, 10), + ) + ) + + ics = gen_optimal_input_initial_conditions( + acq_function=jes, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + options={"frac_random": 0.5}, + ) + + mock_suggested_ics.assert_called_once() + mock_random_ics.assert_called_once() + mock_choose.assert_called_once() + + expected_result = torch.cat( + (random_ics, suggested_ics.view(raw_samples // 2, q, 2)[0:2]) + ) + self.assertTrue(torch.equal(ics, expected_result)) + class TestGenOneShotHVKGInitialConditions(BotorchTestCase): def test_gen_one_shot_hvkg_initial_conditions(self): diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index 8d8be47ea0..95f476b79a 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -18,6 +18,7 @@ AcquisitionFunction, OneShotAcquisitionFunction, ) +from botorch.acquisition.joint_entropy_search import qJointEntropySearch from botorch.acquisition.knowledge_gradient import qKnowledgeGradient from botorch.acquisition.monte_carlo import qExpectedImprovement from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import ( @@ -32,6 +33,7 @@ from botorch.optim.initializers import ( gen_one_shot_hvkg_initial_conditions, gen_one_shot_kg_initial_conditions, + gen_optimal_input_initial_conditions, ) from botorch.optim.optimize import ( _combine_initial_conditions, @@ -2068,6 +2070,15 @@ def test_get_ic_generator(self): ic_generator = opt_inputs.get_ic_generator() self.assertIs(ic_generator, gen_one_shot_kg_initial_conditions) + acqf = qJointEntropySearch( + model=m1, optimal_inputs=torch.rand(5, 3), optimal_outputs=torch.rand(5, 1) + ) + opt_inputs = OptimizeAcqfInputs( + acq_function=acqf, bounds=bounds, q=1, num_restarts=1, **kwargs + ) + ic_generator = opt_inputs.get_ic_generator() + self.assertIs(ic_generator, gen_optimal_input_initial_conditions) + def my_gen(): pass