-
Notifications
You must be signed in to change notification settings - Fork 415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance & runtime improvements to info-theoretic acquisition functions (2/N) - AcqOpt initializer #2751
base: main
Are you sure you want to change the base?
Conversation
75bf7a0
to
ed81a46
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2751 +/- ##
==========================================
- Coverage 99.99% 99.97% -0.03%
==========================================
Files 203 203
Lines 18690 18726 +36
==========================================
+ Hits 18689 18721 +32
- Misses 1 5 +4 ☔ View full report in Codecov by Sentry. |
938d9be
to
f2db5ac
Compare
f2db5ac
to
24781e9
Compare
code, reshuffling of other sampling methods (that don't take an acqf)
Co-authored-by: Elizabeth Santorella <[email protected]>
24781e9
to
c157b57
Compare
@esantorella has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
improve performance and runtime of PES/JES
c157b57
to
211f79b
Compare
@esantorella has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
botorch/optim/initializers.py
Outdated
options: dict[str, bool | float | int] | None = None, | ||
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, | ||
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, | ||
): | ||
options = options or {} | ||
device = bounds.device | ||
if not hasattr(acq_function, "optimal_inputs"): | ||
raise AttributeError( | ||
"gen_optimal_input_initial_conditions can only be used with " | ||
"an AcquisitionFunction that has an optimal_inputs attribute." | ||
) | ||
frac_random: float = options.get("frac_random", 0.0) | ||
if not 0 <= frac_random <= 1: | ||
raise ValueError( | ||
f"frac_random must take on values in (0,1). Value: {frac_random}" | ||
) | ||
|
||
batch_limit = options.get("batch_limit") | ||
num_optima = acq_function.optimal_inputs.shape[:-1].numel() | ||
suggestions = acq_function.optimal_inputs.reshape(num_optima, -1) | ||
X = torch.empty(0, q, bounds.shape[1], dtype=bounds.dtype) | ||
num_random = round(raw_samples * frac_random) | ||
if num_random > 0: | ||
X_rnd = sample_q_batches_from_polytope( | ||
n=num_random, | ||
q=q, | ||
bounds=bounds, | ||
n_burnin=options.get("n_burnin", 10000), | ||
n_thinning=options.get("n_thinning", 32), | ||
equality_constraints=equality_constraints, | ||
inequality_constraints=inequality_constraints, | ||
) | ||
X = torch.cat((X, X_rnd)) | ||
|
||
if num_random < raw_samples: | ||
X_perturbed = sample_points_around_best( | ||
acq_function=acq_function, | ||
n_discrete_points=q * (raw_samples - num_random), | ||
sigma=options.get("sample_around_best_sigma", 1e-2), | ||
bounds=bounds, | ||
best_X=suggestions, | ||
) | ||
X_perturbed = X_perturbed.view( | ||
raw_samples - num_random, q, bounds.shape[-1] | ||
).cpu() | ||
X = torch.cat((X, X_perturbed)) | ||
|
||
if options.get("sample_around_best", False): | ||
X_best = sample_points_around_best( | ||
acq_function=acq_function, | ||
n_discrete_points=q * raw_samples, | ||
sigma=options.get("sample_around_best_sigma", 1e-2), | ||
bounds=bounds, | ||
) | ||
X_best = X_best.view(raw_samples, q, bounds.shape[-1]).cpu() | ||
X = torch.cat((X, X_best)) | ||
|
||
with torch.no_grad(): | ||
if batch_limit is None: | ||
batch_limit = X.shape[0] | ||
# Evaluate the acquisition function on `X_rnd` using `batch_limit` | ||
# sized chunks. | ||
acq_vals = torch.cat( | ||
[ | ||
acq_function(x_.to(device=device)).cpu() | ||
for x_ in X.split(split_size=batch_limit, dim=0) | ||
], | ||
dim=0, | ||
) | ||
idx = boltzmann_sample( | ||
function_values=acq_vals, | ||
num_samples=num_restarts, | ||
eta=options.get("eta", 2.0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By passing these individually rather than as a dict, we help static analysis tools (and people) see that the code isn't obviously wrong, and prevent unused options from being passed and silently dropped. That can be especially helpful in guarding against typos or when refactoring.
You could then update the call sites to pass **options
instead of options
-- personally I'd pass them individually everywhere, but it may be a matter of taste.
options: dict[str, bool | float | int] | None = None, | |
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, | |
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, | |
): | |
options = options or {} | |
device = bounds.device | |
if not hasattr(acq_function, "optimal_inputs"): | |
raise AttributeError( | |
"gen_optimal_input_initial_conditions can only be used with " | |
"an AcquisitionFunction that has an optimal_inputs attribute." | |
) | |
frac_random: float = options.get("frac_random", 0.0) | |
if not 0 <= frac_random <= 1: | |
raise ValueError( | |
f"frac_random must take on values in (0,1). Value: {frac_random}" | |
) | |
batch_limit = options.get("batch_limit") | |
num_optima = acq_function.optimal_inputs.shape[:-1].numel() | |
suggestions = acq_function.optimal_inputs.reshape(num_optima, -1) | |
X = torch.empty(0, q, bounds.shape[1], dtype=bounds.dtype) | |
num_random = round(raw_samples * frac_random) | |
if num_random > 0: | |
X_rnd = sample_q_batches_from_polytope( | |
n=num_random, | |
q=q, | |
bounds=bounds, | |
n_burnin=options.get("n_burnin", 10000), | |
n_thinning=options.get("n_thinning", 32), | |
equality_constraints=equality_constraints, | |
inequality_constraints=inequality_constraints, | |
) | |
X = torch.cat((X, X_rnd)) | |
if num_random < raw_samples: | |
X_perturbed = sample_points_around_best( | |
acq_function=acq_function, | |
n_discrete_points=q * (raw_samples - num_random), | |
sigma=options.get("sample_around_best_sigma", 1e-2), | |
bounds=bounds, | |
best_X=suggestions, | |
) | |
X_perturbed = X_perturbed.view( | |
raw_samples - num_random, q, bounds.shape[-1] | |
).cpu() | |
X = torch.cat((X, X_perturbed)) | |
if options.get("sample_around_best", False): | |
X_best = sample_points_around_best( | |
acq_function=acq_function, | |
n_discrete_points=q * raw_samples, | |
sigma=options.get("sample_around_best_sigma", 1e-2), | |
bounds=bounds, | |
) | |
X_best = X_best.view(raw_samples, q, bounds.shape[-1]).cpu() | |
X = torch.cat((X, X_best)) | |
with torch.no_grad(): | |
if batch_limit is None: | |
batch_limit = X.shape[0] | |
# Evaluate the acquisition function on `X_rnd` using `batch_limit` | |
# sized chunks. | |
acq_vals = torch.cat( | |
[ | |
acq_function(x_.to(device=device)).cpu() | |
for x_ in X.split(split_size=batch_limit, dim=0) | |
], | |
dim=0, | |
) | |
idx = boltzmann_sample( | |
function_values=acq_vals, | |
num_samples=num_restarts, | |
eta=options.get("eta", 2.0), | |
frac_random: float = 0.0, | |
batch_limit: int | None = None, | |
n_burnin: int = 10000, | |
n_thinning: int = 32, | |
sample_around_best: bool = False, | |
sample_around_best_sigma: float = 1e-2, | |
eta: float = 2.0, | |
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, | |
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, | |
): | |
options = options or {} | |
device = bounds.device | |
if not hasattr(acq_function, "optimal_inputs"): | |
raise AttributeError( | |
"gen_optimal_input_initial_conditions can only be used with " | |
"an AcquisitionFunction that has an optimal_inputs attribute." | |
) | |
frac_random: float = options.get("frac_random", 0.0) | |
if not 0 <= frac_random <= 1: | |
raise ValueError( | |
f"frac_random must take on values in (0,1). Value: {frac_random}" | |
) | |
batch_limit = options.get("batch_limit") | |
num_optima = acq_function.optimal_inputs.shape[:-1].numel() | |
suggestions = acq_function.optimal_inputs.reshape(num_optima, -1) | |
X = torch.empty(0, q, bounds.shape[1], dtype=bounds.dtype) | |
num_random = round(raw_samples * frac_random) | |
if num_random > 0: | |
X_rnd = sample_q_batches_from_polytope( | |
n=num_random, | |
q=q, | |
bounds=bounds, | |
n_burnin=options.get("n_burnin", 10000), | |
n_thinning=options.get("n_thinning", 32), | |
equality_constraints=equality_constraints, | |
inequality_constraints=inequality_constraints, | |
) | |
X = torch.cat((X, X_rnd)) | |
if num_random < raw_samples: | |
X_perturbed = sample_points_around_best( | |
acq_function=acq_function, | |
n_discrete_points=q * (raw_samples - num_random), | |
sigma=options.get("sample_around_best_sigma", 1e-2), | |
bounds=bounds, | |
best_X=suggestions, | |
) | |
X_perturbed = X_perturbed.view( | |
raw_samples - num_random, q, bounds.shape[-1] | |
).cpu() | |
X = torch.cat((X, X_perturbed)) | |
if options.get("sample_around_best", False): | |
X_best = sample_points_around_best( | |
acq_function=acq_function, | |
n_discrete_points=q * raw_samples, | |
sigma=options.get("sample_around_best_sigma", 1e-2), | |
bounds=bounds, | |
) | |
X_best = X_best.view(raw_samples, q, bounds.shape[-1]).cpu() | |
X = torch.cat((X, X_best)) | |
with torch.no_grad(): | |
if batch_limit is None: | |
batch_limit = X.shape[0] | |
# Evaluate the acquisition function on `X_rnd` using `batch_limit` | |
# sized chunks. | |
acq_vals = torch.cat( | |
[ | |
acq_function(x_.to(device=device)).cpu() | |
for x_ in X.split(split_size=batch_limit, dim=0) | |
], | |
dim=0, | |
) | |
idx = boltzmann_sample( | |
function_values=acq_vals, | |
num_samples=num_restarts, | |
eta=options.get("eta", 2.0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I'm okay doing it this way! It just seems that changing ic_generator
alone wouldn't suffice if one (for some reason) wanted to change between them since they would be inconsistent?
@@ -468,6 +468,91 @@ def gen_batch_initial_conditions( | |||
return batch_initial_conditions | |||
|
|||
|
|||
def gen_optimal_input_initial_conditions( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a docstring explaining the behavior of this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added!
X = torch.cat((X, X_rnd)) | ||
|
||
if num_random < raw_samples: | ||
X_perturbed = sample_points_around_best( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit nonintuitive that we do this even when sample_around_best
is False
, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly! My though was, since it is not actually sampling around the incumbent but around the sampled optima, I could keep it and re-use its logic. I tried to mimic the KG logic for it, and that uses frac_random
for a similar reason.
Thanks for this! I'm looking forward to seeing the plots. |
211f79b
to
1ec824c
Compare
@esantorella has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
I have not quite figured out why the test coverage is not there, since I thought I addressed it today. I will also figure out the conflicts ASAP! |
@esantorella has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
A series of improvements directed towards improving the performance of PES & JES, as well as their MultiObj counterparts.
This PR adds an initializer for the acquisition function optimization, which drastically speeds up the number of required forward passes from ~150-250 --> ~25 by providing suggestions close to the sampled optima obtained during acquisition function construction.
@esantorella

Moreover, better acquisition function values are found (PR 1's BO loop, but both acq opts are run in parallel):
Moreover, it is a lot faster:

This does not always improve performance, however (PR1 is more local due to sample_around_best dominating candidate generation, which is generally good):

Lastly, a nice comp to LogNEI with the introduced mods:

Moreover, they are now much closer in terms of runtime:

And here's the allocation between posterior sampling time and acq optimization time.

So apart from Michalewicz, it does pretty good now!
Related PRs
Previous one