Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance & runtime improvements to info-theoretic acquisition functions (2/N) - AcqOpt initializer #2751

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

hvarfner
Copy link
Contributor

@hvarfner hvarfner commented Feb 20, 2025

A series of improvements directed towards improving the performance of PES & JES, as well as their MultiObj counterparts.

This PR adds an initializer for the acquisition function optimization, which drastically speeds up the number of required forward passes from ~150-250 --> ~25 by providing suggestions close to the sampled optima obtained during acquisition function construction.

@esantorella
Moreover, better acquisition function values are found (PR 1's BO loop, but both acq opts are run in parallel):
found_acquisition_value

Moreover, it is a lot faster:
opt_time

This does not always improve performance, however (PR1 is more local due to sample_around_best dominating candidate generation, which is generally good):
regret

Lastly, a nice comp to LogNEI with the introduced mods:
jes_logei

Moreover, they are now much closer in terms of runtime:
total_runtime

And here's the allocation between posterior sampling time and acq optimization time.
runtime

So apart from Michalewicz, it does pretty good now!

Related PRs

Previous one

@facebook-github-bot facebook-github-bot added the CLA Signed Do not delete this pull request or issue due to inactivity. label Feb 20, 2025
Copy link

codecov bot commented Feb 20, 2025

Codecov Report

Attention: Patch coverage is 94.52055% with 4 lines in your changes missing coverage. Please review.

Project coverage is 99.97%. Comparing base (9a7c517) to head (ed81a46).

Files with missing lines Patch % Lines
botorch/optim/initializers.py 95.71% 3 Missing ⚠️
botorch/optim/optimize.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2751      +/-   ##
==========================================
- Coverage   99.99%   99.97%   -0.03%     
==========================================
  Files         203      203              
  Lines       18690    18726      +36     
==========================================
+ Hits        18689    18721      +32     
- Misses          1        5       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@hvarfner hvarfner force-pushed the es_initializer branch 4 times, most recently from 938d9be to f2db5ac Compare February 21, 2025 13:25
@hvarfner hvarfner changed the title Performance & runtime improvements to info-theoretic acquisition functions (1/N) - AcqOpt initializer Performance & runtime improvements to info-theoretic acquisition functions (2/N) - AcqOpt initializer Feb 21, 2025
@facebook-github-bot
Copy link
Contributor

@esantorella has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@esantorella has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Comment on lines 478 to 588
options: dict[str, bool | float | int] | None = None,
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
):
options = options or {}
device = bounds.device
if not hasattr(acq_function, "optimal_inputs"):
raise AttributeError(
"gen_optimal_input_initial_conditions can only be used with "
"an AcquisitionFunction that has an optimal_inputs attribute."
)
frac_random: float = options.get("frac_random", 0.0)
if not 0 <= frac_random <= 1:
raise ValueError(
f"frac_random must take on values in (0,1). Value: {frac_random}"
)

batch_limit = options.get("batch_limit")
num_optima = acq_function.optimal_inputs.shape[:-1].numel()
suggestions = acq_function.optimal_inputs.reshape(num_optima, -1)
X = torch.empty(0, q, bounds.shape[1], dtype=bounds.dtype)
num_random = round(raw_samples * frac_random)
if num_random > 0:
X_rnd = sample_q_batches_from_polytope(
n=num_random,
q=q,
bounds=bounds,
n_burnin=options.get("n_burnin", 10000),
n_thinning=options.get("n_thinning", 32),
equality_constraints=equality_constraints,
inequality_constraints=inequality_constraints,
)
X = torch.cat((X, X_rnd))

if num_random < raw_samples:
X_perturbed = sample_points_around_best(
acq_function=acq_function,
n_discrete_points=q * (raw_samples - num_random),
sigma=options.get("sample_around_best_sigma", 1e-2),
bounds=bounds,
best_X=suggestions,
)
X_perturbed = X_perturbed.view(
raw_samples - num_random, q, bounds.shape[-1]
).cpu()
X = torch.cat((X, X_perturbed))

if options.get("sample_around_best", False):
X_best = sample_points_around_best(
acq_function=acq_function,
n_discrete_points=q * raw_samples,
sigma=options.get("sample_around_best_sigma", 1e-2),
bounds=bounds,
)
X_best = X_best.view(raw_samples, q, bounds.shape[-1]).cpu()
X = torch.cat((X, X_best))

with torch.no_grad():
if batch_limit is None:
batch_limit = X.shape[0]
# Evaluate the acquisition function on `X_rnd` using `batch_limit`
# sized chunks.
acq_vals = torch.cat(
[
acq_function(x_.to(device=device)).cpu()
for x_ in X.split(split_size=batch_limit, dim=0)
],
dim=0,
)
idx = boltzmann_sample(
function_values=acq_vals,
num_samples=num_restarts,
eta=options.get("eta", 2.0),
Copy link
Member

@esantorella esantorella Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By passing these individually rather than as a dict, we help static analysis tools (and people) see that the code isn't obviously wrong, and prevent unused options from being passed and silently dropped. That can be especially helpful in guarding against typos or when refactoring.

You could then update the call sites to pass **options instead of options -- personally I'd pass them individually everywhere, but it may be a matter of taste.

Suggested change
options: dict[str, bool | float | int] | None = None,
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
):
options = options or {}
device = bounds.device
if not hasattr(acq_function, "optimal_inputs"):
raise AttributeError(
"gen_optimal_input_initial_conditions can only be used with "
"an AcquisitionFunction that has an optimal_inputs attribute."
)
frac_random: float = options.get("frac_random", 0.0)
if not 0 <= frac_random <= 1:
raise ValueError(
f"frac_random must take on values in (0,1). Value: {frac_random}"
)
batch_limit = options.get("batch_limit")
num_optima = acq_function.optimal_inputs.shape[:-1].numel()
suggestions = acq_function.optimal_inputs.reshape(num_optima, -1)
X = torch.empty(0, q, bounds.shape[1], dtype=bounds.dtype)
num_random = round(raw_samples * frac_random)
if num_random > 0:
X_rnd = sample_q_batches_from_polytope(
n=num_random,
q=q,
bounds=bounds,
n_burnin=options.get("n_burnin", 10000),
n_thinning=options.get("n_thinning", 32),
equality_constraints=equality_constraints,
inequality_constraints=inequality_constraints,
)
X = torch.cat((X, X_rnd))
if num_random < raw_samples:
X_perturbed = sample_points_around_best(
acq_function=acq_function,
n_discrete_points=q * (raw_samples - num_random),
sigma=options.get("sample_around_best_sigma", 1e-2),
bounds=bounds,
best_X=suggestions,
)
X_perturbed = X_perturbed.view(
raw_samples - num_random, q, bounds.shape[-1]
).cpu()
X = torch.cat((X, X_perturbed))
if options.get("sample_around_best", False):
X_best = sample_points_around_best(
acq_function=acq_function,
n_discrete_points=q * raw_samples,
sigma=options.get("sample_around_best_sigma", 1e-2),
bounds=bounds,
)
X_best = X_best.view(raw_samples, q, bounds.shape[-1]).cpu()
X = torch.cat((X, X_best))
with torch.no_grad():
if batch_limit is None:
batch_limit = X.shape[0]
# Evaluate the acquisition function on `X_rnd` using `batch_limit`
# sized chunks.
acq_vals = torch.cat(
[
acq_function(x_.to(device=device)).cpu()
for x_ in X.split(split_size=batch_limit, dim=0)
],
dim=0,
)
idx = boltzmann_sample(
function_values=acq_vals,
num_samples=num_restarts,
eta=options.get("eta", 2.0),
frac_random: float = 0.0,
batch_limit: int | None = None,
n_burnin: int = 10000,
n_thinning: int = 32,
sample_around_best: bool = False,
sample_around_best_sigma: float = 1e-2,
eta: float = 2.0,
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
):
options = options or {}
device = bounds.device
if not hasattr(acq_function, "optimal_inputs"):
raise AttributeError(
"gen_optimal_input_initial_conditions can only be used with "
"an AcquisitionFunction that has an optimal_inputs attribute."
)
frac_random: float = options.get("frac_random", 0.0)
if not 0 <= frac_random <= 1:
raise ValueError(
f"frac_random must take on values in (0,1). Value: {frac_random}"
)
batch_limit = options.get("batch_limit")
num_optima = acq_function.optimal_inputs.shape[:-1].numel()
suggestions = acq_function.optimal_inputs.reshape(num_optima, -1)
X = torch.empty(0, q, bounds.shape[1], dtype=bounds.dtype)
num_random = round(raw_samples * frac_random)
if num_random > 0:
X_rnd = sample_q_batches_from_polytope(
n=num_random,
q=q,
bounds=bounds,
n_burnin=options.get("n_burnin", 10000),
n_thinning=options.get("n_thinning", 32),
equality_constraints=equality_constraints,
inequality_constraints=inequality_constraints,
)
X = torch.cat((X, X_rnd))
if num_random < raw_samples:
X_perturbed = sample_points_around_best(
acq_function=acq_function,
n_discrete_points=q * (raw_samples - num_random),
sigma=options.get("sample_around_best_sigma", 1e-2),
bounds=bounds,
best_X=suggestions,
)
X_perturbed = X_perturbed.view(
raw_samples - num_random, q, bounds.shape[-1]
).cpu()
X = torch.cat((X, X_perturbed))
if options.get("sample_around_best", False):
X_best = sample_points_around_best(
acq_function=acq_function,
n_discrete_points=q * raw_samples,
sigma=options.get("sample_around_best_sigma", 1e-2),
bounds=bounds,
)
X_best = X_best.view(raw_samples, q, bounds.shape[-1]).cpu()
X = torch.cat((X, X_best))
with torch.no_grad():
if batch_limit is None:
batch_limit = X.shape[0]
# Evaluate the acquisition function on `X_rnd` using `batch_limit`
# sized chunks.
acq_vals = torch.cat(
[
acq_function(x_.to(device=device)).cpu()
for x_ in X.split(split_size=batch_limit, dim=0)
],
dim=0,
)
idx = boltzmann_sample(
function_values=acq_vals,
num_samples=num_restarts,
eta=options.get("eta", 2.0),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'm okay doing it this way! It just seems that changing ic_generator alone wouldn't suffice if one (for some reason) wanted to change between them since they would be inconsistent?

@@ -468,6 +468,91 @@ def gen_batch_initial_conditions(
return batch_initial_conditions


def gen_optimal_input_initial_conditions(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a docstring explaining the behavior of this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!

X = torch.cat((X, X_rnd))

if num_random < raw_samples:
X_perturbed = sample_points_around_best(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit nonintuitive that we do this even when sample_around_best is False, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly! My though was, since it is not actually sampling around the incumbent but around the sampled optima, I could keep it and re-use its logic. I tried to mimic the KG logic for it, and that uses frac_random for a similar reason.

@esantorella
Copy link
Member

Thanks for this! I'm looking forward to seeing the plots.

@facebook-github-bot
Copy link
Contributor

@esantorella has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@hvarfner
Copy link
Contributor Author

I have not quite figured out why the test coverage is not there, since I thought I addressed it today. I will also figure out the conflicts ASAP!

@facebook-github-bot
Copy link
Contributor

@esantorella has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed Do not delete this pull request or issue due to inactivity.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants