Skip to content

Commit

Permalink
pass gen_candidates callable in optimize_acqf (#1655)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1655

see title. This will support using stochastic optimization

Differential Revision: D41629164

fbshipit-source-id: 0f31bdc3392f47546da31183fa2166bf18ec174b
  • Loading branch information
sdaulton authored and facebook-github-bot committed Feb 6, 2023
1 parent 076af96 commit f692120
Show file tree
Hide file tree
Showing 3 changed files with 336 additions and 231 deletions.
4 changes: 2 additions & 2 deletions botorch/generation/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

logger = _get_logger()

TGenCandidates = Callable[[Tensor, AcquisitionFunction, Any], Tuple[Tensor, Tensor]]


def gen_candidates_scipy(
initial_conditions: Tensor,
Expand Down Expand Up @@ -151,7 +153,6 @@ def gen_candidates_scipy(
clamped_candidates
)
return clamped_candidates, batch_acquisition

clamped_candidates = columnwise_clamp(
X=initial_conditions, lower=lower_bounds, upper=upper_bounds
)
Expand Down Expand Up @@ -359,7 +360,6 @@ def gen_candidates_torch(
clamped_candidates
)
return clamped_candidates, batch_acquisition

_clamp = partial(columnwise_clamp, lower=lower_bounds, upper=upper_bounds)
clamped_candidates = _clamp(initial_conditions).requires_grad_(True)
_optimizer = optimizer(params=[clamped_candidates], lr=options.get("lr", 0.025))
Expand Down
37 changes: 30 additions & 7 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.exceptions import InputDataError, UnsupportedError
from botorch.exceptions.warnings import OptimizationWarning
from botorch.generation.gen import gen_candidates_scipy
from botorch.generation.gen import gen_candidates_scipy, TGenCandidates
from botorch.logging import logger
from botorch.optim.initializers import (
gen_batch_initial_conditions,
gen_one_shot_kg_initial_conditions,
)
from botorch.optim.stopping import ExpMAStoppingCriterion
from botorch.optim.utils import _filter_kwargs
from torch import Tensor

INIT_OPTION_KEYS = {
Expand Down Expand Up @@ -64,6 +65,7 @@ def optimize_acqf(
post_processing_func: Optional[Callable[[Tensor], Tensor]] = None,
batch_initial_conditions: Optional[Tensor] = None,
return_best_only: bool = True,
gen_candidates: TGenCandidates = gen_candidates_scipy,
sequential: bool = False,
**kwargs: Any,
) -> Tuple[Tensor, Tensor]:
Expand Down Expand Up @@ -103,6 +105,12 @@ def optimize_acqf(
this if you do not want to use default initialization strategy.
return_best_only: If False, outputs the solutions corresponding to all
random restart initializations of the optimization.
gen_candidates: A callable for generating candidates (and their associated
acquisition values) given a tensor of initial conditions and an
acquisition function. Other common inputs include lower and upper bounds
and a dictionary of options, but refer to the documentation of specific
generation functions (e.g gen_candidates_scipy and gen_candidates_torch)
for method-specific inputs. Default: `gen_candidates_scipy`
sequential: If False, uses joint optimization, otherwise uses sequential
optimization.
kwargs: Additonal keyword arguments.
Expand Down Expand Up @@ -214,6 +222,7 @@ def optimize_acqf(
sequential=False,
ic_generator=ic_gen,
timeout_sec=timeout_sec,
gen_candidates=gen_candidates,
)

candidate_list.append(candidate)
Expand Down Expand Up @@ -262,6 +271,11 @@ def optimize_acqf(
batch_limit: int = options.get(
"batch_limit", num_restarts if not nonlinear_inequality_constraints else 1
)
has_parameter_constraints = (
inequality_constraints is not None
or equality_constraints is not None
or nonlinear_inequality_constraints is not None
)

def _optimize_batch_candidates(
timeout_sec: Optional[float],
Expand All @@ -273,24 +287,33 @@ def _optimize_batch_candidates(
if timeout_sec is not None:
timeout_sec = (timeout_sec - start_time) / len(batched_ics)

scipy_kws = {
gen_kwargs = {
"acquisition_function": acq_function,
"lower_bounds": None if bounds[0].isinf().all() else bounds[0],
"upper_bounds": None if bounds[1].isinf().all() else bounds[1],
"options": {k: v for k, v in options.items() if k not in INIT_OPTION_KEYS},
"inequality_constraints": inequality_constraints,
"equality_constraints": equality_constraints,
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
"fixed_features": fixed_features,
"timeout_sec": timeout_sec,
}

if has_parameter_constraints:
# only add parameter constraints to gen_kwargs if they are specified
# to avoid unnecessary warnings in _filter_kwargs
gen_kwargs.update(
{
"inequality_constraints": inequality_constraints,
"equality_constraints": equality_constraints,
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
}
)
filtered_gen_kwargs = _filter_kwargs(gen_candidates, **gen_kwargs)

for i, batched_ics_ in enumerate(batched_ics):
# optimize using random restart optimization
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always", category=OptimizationWarning)
batch_candidates_curr, batch_acq_values_curr = gen_candidates_scipy(
initial_conditions=batched_ics_, **scipy_kws
batch_candidates_curr, batch_acq_values_curr = gen_candidates(
initial_conditions=batched_ics_, **filtered_gen_kwargs
)
opt_warnings += ws
batch_candidates_list.append(batch_candidates_curr)
Expand Down
Loading

0 comments on commit f692120

Please sign in to comment.