Skip to content

Commit

Permalink
Added entropy search acquisition function initializer, ensured unittest
Browse files Browse the repository at this point in the history
coverage
  • Loading branch information
hvarfner committed Feb 25, 2025
1 parent 6b3002b commit c157b57
Show file tree
Hide file tree
Showing 4 changed files with 263 additions and 53 deletions.
194 changes: 141 additions & 53 deletions botorch/optim/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,91 @@ def gen_batch_initial_conditions(
return batch_initial_conditions


def gen_optimal_input_initial_conditions(
acq_function: AcquisitionFunction,
bounds: Tensor,
q: int,
num_restarts: int,
raw_samples: int,
fixed_features: dict[int, float] | None = None,
options: dict[str, bool | float | int] | None = None,
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
):
options = options or {}
device = bounds.device
if not hasattr(acq_function, "optimal_inputs"):
raise AttributeError(
"gen_optimal_input_initial_conditions can only be used with "
"an AcquisitionFunction that has an optimal_inputs attribute."
)
frac_random: float = options.get("frac_random", 0.0)
if not 0 <= frac_random <= 1:
raise ValueError(
f"frac_random must take on values in (0,1). Value: {frac_random}"
)

batch_limit = options.get("batch_limit")
num_optima = acq_function.optimal_inputs.shape[:-1].numel()
suggestions = acq_function.optimal_inputs.reshape(num_optima, -1)
X = torch.empty(0, q, bounds.shape[1], dtype=bounds.dtype)
num_random = round(raw_samples * frac_random)
if num_random > 0:
X_rnd = sample_q_batches_from_polytope(
n=num_random,
q=q,
bounds=bounds,
n_burnin=options.get("n_burnin", 10000),
n_thinning=options.get("n_thinning", 32),
equality_constraints=equality_constraints,
inequality_constraints=inequality_constraints,
)
X = torch.cat((X, X_rnd))

if num_random < raw_samples:
X_perturbed = sample_points_around_best(
acq_function=acq_function,
n_discrete_points=q * (raw_samples - num_random),
sigma=options.get("sample_around_best_sigma", 1e-2),
bounds=bounds,
best_X=suggestions,
)
X_perturbed = X_perturbed.view(
raw_samples - num_random, q, bounds.shape[-1]
).cpu()
X = torch.cat((X, X_perturbed))

if options.get("sample_around_best", False):
X_best = sample_points_around_best(
acq_function=acq_function,
n_discrete_points=q * raw_samples,
sigma=options.get("sample_around_best_sigma", 1e-2),
bounds=bounds,
)
X_best = X_best.view(raw_samples, q, bounds.shape[-1]).cpu()
X = torch.cat((X, X_best))

with torch.no_grad():
if batch_limit is None:
batch_limit = X.shape[0]
# Evaluate the acquisition function on `X_rnd` using `batch_limit`
# sized chunks.
acq_vals = torch.cat(
[
acq_function(x_.to(device=device)).cpu()
for x_ in X.split(split_size=batch_limit, dim=0)
],
dim=0,
)
idx = boltzmann_sample(
function_values=acq_vals,
num_samples=num_restarts,
eta=options.get("eta", 2.0),
)
# set the respective initial conditions to the sampled optimizers
return X[idx]


def gen_one_shot_kg_initial_conditions(
acq_function: qKnowledgeGradient,
bounds: Tensor,
Expand Down Expand Up @@ -1136,6 +1221,7 @@ def sample_points_around_best(
best_pct: float = 5.0,
subset_sigma: float = 1e-1,
prob_perturb: float | None = None,
best_X: Tensor | None = None,
) -> Tensor | None:
r"""Find best points and sample nearby points.
Expand All @@ -1154,60 +1240,62 @@ def sample_points_around_best(
An optional `n_discrete_points x d`-dim tensor containing the
sampled points. This is None if no baseline points are found.
"""
X = get_X_baseline(acq_function=acq_function)
if X is None:
return
with torch.no_grad():
try:
posterior = acq_function.model.posterior(X)
except AttributeError:
warnings.warn(
"Failed to sample around previous best points.",
BotorchWarning,
stacklevel=3,
)
if best_X is None:
X = get_X_baseline(acq_function=acq_function)
if X is None:
return
mean = posterior.mean
while mean.ndim > 2:
# take average over batch dims
mean = mean.mean(dim=0)
try:
f_pred = acq_function.objective(mean)
# Some acquisition functions do not have an objective
# and for some acquisition functions the objective is None
except (AttributeError, TypeError):
f_pred = mean
if hasattr(acq_function, "maximize"):
# make sure that the optimiztaion direction is set properly
if not acq_function.maximize:
f_pred = -f_pred
try:
# handle constraints for EHVI-based acquisition functions
constraints = acq_function.constraints
if constraints is not None:
neg_violation = -torch.stack(
[c(mean).clamp_min(0.0) for c in constraints], dim=-1
).sum(dim=-1)
feas = neg_violation == 0
if feas.any():
f_pred[~feas] = float("-inf")
else:
# set objective equal to negative violation
f_pred = neg_violation
except AttributeError:
pass
if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1:
# multi-objective
# find pareto set
is_pareto = is_non_dominated(f_pred)
best_X = X[is_pareto]
else:
if f_pred.shape[-1] == 1:
f_pred = f_pred.squeeze(-1)
n_best = max(1, round(X.shape[0] * best_pct / 100))
# the view() is to ensure that best_idcs is not a scalar tensor
best_idcs = torch.topk(f_pred, n_best).indices.view(-1)
best_X = X[best_idcs]
with torch.no_grad():
try:
posterior = acq_function.model.posterior(X)
except AttributeError:
warnings.warn(
"Failed to sample around previous best points.",
BotorchWarning,
stacklevel=3,
)
return
mean = posterior.mean
while mean.ndim > 2:
# take average over batch dims
mean = mean.mean(dim=0)
try:
f_pred = acq_function.objective(mean)
# Some acquisition functions do not have an objective
# and for some acquisition functions the objective is None
except (AttributeError, TypeError):
f_pred = mean
if hasattr(acq_function, "maximize"):
# make sure that the optimiztaion direction is set properly
if not acq_function.maximize:
f_pred = -f_pred
try:
# handle constraints for EHVI-based acquisition functions
constraints = acq_function.constraints
if constraints is not None:
neg_violation = -torch.stack(
[c(mean).clamp_min(0.0) for c in constraints], dim=-1
).sum(dim=-1)
feas = neg_violation == 0
if feas.any():
f_pred[~feas] = float("-inf")
else:
# set objective equal to negative violation
f_pred = neg_violation
except AttributeError:
pass
if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1:
# multi-objective
# find pareto set
is_pareto = is_non_dominated(f_pred)
best_X = X[is_pareto]
else:
if f_pred.shape[-1] == 1:
f_pred = f_pred.squeeze(-1)
n_best = max(1, round(X.shape[0] * best_pct / 100))
# the view() is to ensure that best_idcs is not a scalar tensor
best_idcs = torch.topk(f_pred, n_best).indices.view(-1)
best_X = X[best_idcs]

use_perturbed_sampling = best_X.shape[-1] >= 20 or prob_perturb is not None
n_trunc_normal_points = (
n_discrete_points // 2 if use_perturbed_sampling else n_discrete_points
Expand Down
4 changes: 4 additions & 0 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AcquisitionFunction,
OneShotAcquisitionFunction,
)
from botorch.acquisition.joint_entropy_search import qJointEntropySearch
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import (
qHypervolumeKnowledgeGradient,
Expand All @@ -33,6 +34,7 @@
gen_batch_initial_conditions,
gen_one_shot_hvkg_initial_conditions,
gen_one_shot_kg_initial_conditions,
gen_optimal_input_initial_conditions,
TGenInitialConditions,
)
from botorch.optim.stopping import ExpMAStoppingCriterion
Expand Down Expand Up @@ -174,6 +176,8 @@ def get_ic_generator(self) -> TGenInitialConditions:
return gen_one_shot_kg_initial_conditions
elif isinstance(self.acq_function, qHypervolumeKnowledgeGradient):
return gen_one_shot_hvkg_initial_conditions
elif isinstance(self.acq_function, qJointEntropySearch):
return gen_optimal_input_initial_conditions
return gen_batch_initial_conditions


Expand Down
107 changes: 107 additions & 0 deletions test/optim/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch
from botorch.acquisition.analytic import PosteriorMean
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
from botorch.acquisition.joint_entropy_search import qJointEntropySearch
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.acquisition.monte_carlo import (
qExpectedImprovement,
Expand All @@ -34,6 +35,7 @@
gen_batch_initial_conditions,
gen_one_shot_hvkg_initial_conditions,
gen_one_shot_kg_initial_conditions,
gen_optimal_input_initial_conditions,
gen_value_function_initial_conditions,
initialize_q_batch,
initialize_q_batch_nonneg,
Expand All @@ -47,6 +49,7 @@
)
from botorch.sampling.normal import IIDNormalSampler
from botorch.utils.sampling import manual_seed, unnormalize
from botorch.utils.test_helpers import get_model
from botorch.utils.testing import (
_get_max_violation_of_bounds,
_get_max_violation_of_constraints,
Expand Down Expand Up @@ -1074,6 +1077,110 @@ def test_gen_one_shot_kg_initial_conditions(self):
)
self.assertTrue(torch.all(ics[..., -n_value:, :] == 1))

def test_gen_optimal_input_initial_conditions(self):
num_restarts = 10
raw_samples = 16
q = 3
for dtype in (torch.float, torch.double):
model = get_model(
torch.rand(4, 2, dtype=dtype), torch.rand(4, 1, dtype=dtype)
)
optimal_inputs = torch.rand(5, 2, dtype=dtype)
optimal_outputs = torch.rand(5, 1, dtype=dtype)
jes = qJointEntropySearch(
model=model,
optimal_inputs=optimal_inputs,
optimal_outputs=optimal_outputs,
)
bounds = torch.tensor([[0, 0], [1, 1]], device=self.device, dtype=dtype)
# base case
ics = gen_optimal_input_initial_conditions(
acq_function=jes,
bounds=bounds,
q=q,
num_restarts=num_restarts,
raw_samples=raw_samples,
)
self.assertEqual(ics.shape, torch.Size([num_restarts, q, 2]))

# since we do sample_around best, this should generate enough points
# despite num_restarts being larger than raw_samples
ics = gen_optimal_input_initial_conditions(
acq_function=jes,
bounds=bounds,
q=q,
num_restarts=15,
raw_samples=8,
options={"frac_random": 0.2, "sample_around_best": True},
)
self.assertEqual(ics.shape, torch.Size([15, q, 2]))

# test option error
with self.assertRaises(ValueError):
gen_optimal_input_initial_conditions(
acq_function=jes,
bounds=bounds,
q=1,
num_restarts=num_restarts,
raw_samples=raw_samples,
options={"frac_random": 2.0},
)

ei = qExpectedImprovement(model, 99.9)
with self.assertRaisesRegex(
AttributeError,
"gen_optimal_input_initial_conditions can only be used with "
"an AcquisitionFunction that has an optimal_inputs attribute.",
):
gen_optimal_input_initial_conditions(
acq_function=ei,
bounds=bounds,
q=1,
num_restarts=num_restarts,
raw_samples=raw_samples,
options={"frac_random": 2.0},
)
# test generation logic
random_ics = torch.rand(raw_samples // 2, q, 2)
suggested_ics = torch.rand(raw_samples // 2 * q, 2)
with ExitStack() as es:
mock_random_ics = es.enter_context(
mock.patch(
"botorch.optim.initializers.sample_q_batches_from_polytope",
return_value=random_ics,
)
)
mock_suggested_ics = es.enter_context(
mock.patch(
"botorch.optim.initializers.sample_points_around_best",
return_value=suggested_ics,
)
)
mock_choose = es.enter_context(
mock.patch(
"torch.multinomial",
return_value=torch.arange(0, 10),
)
)

ics = gen_optimal_input_initial_conditions(
acq_function=jes,
bounds=bounds,
q=q,
num_restarts=num_restarts,
raw_samples=raw_samples,
options={"frac_random": 0.5},
)

mock_suggested_ics.assert_called_once()
mock_random_ics.assert_called_once()
mock_choose.assert_called_once()

expected_result = torch.cat(
(random_ics, suggested_ics.view(raw_samples // 2, q, 2)[0:2])
)
self.assertTrue(torch.equal(ics, expected_result))


class TestGenOneShotHVKGInitialConditions(BotorchTestCase):
def test_gen_one_shot_hvkg_initial_conditions(self):
Expand Down
Loading

0 comments on commit c157b57

Please sign in to comment.