Skip to content

Commit

Permalink
Performance & runtime improvements to info-theoretic acquisition func…
Browse files Browse the repository at this point in the history
…tions (0/N) - Restructuring of sampling methods (#2753)

Summary:
Reshuffling of sampling methods that are not directly related to acquisition function optimization (i.e., don't take it as an argument) based on [this discussion](#2748 (comment)). To remove code duplication specifically related to optimization of info-theoretic acquisition functions, this seemed like sensible moves!

Pull Request resolved: #2753

Test Plan:
Moved unittests and added new one for `boltzmann_sample`, which was used throughout and is once again used in subsequent PRs.

## Related PRs

First of a series, like [this one](#2748).

Reviewed By: esantorella

Differential Revision: D70131981

Pulled By: saitcakmak

fbshipit-source-id: 48dd86e7e06006054294d7cd8b9a3d318b0b0ad1
  • Loading branch information
hvarfner authored and facebook-github-bot committed Feb 25, 2025
1 parent 0be800e commit 78c04e2
Show file tree
Hide file tree
Showing 4 changed files with 352 additions and 266 deletions.
180 changes: 24 additions & 156 deletions botorch/optim/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import warnings
from collections.abc import Callable
from math import ceil
from typing import Optional, Union

import torch
Expand All @@ -43,14 +42,15 @@
from botorch.optim.utils import fix_features, get_X_baseline
from botorch.utils.multi_objective.pareto import is_non_dominated
from botorch.utils.sampling import (
batched_multinomial,
boltzmann_sample,
draw_sobol_samples,
get_polytope_samples,
manual_seed,
sample_perturbed_subset_dims,
sample_truncated_normal_perturbations,
)
from botorch.utils.transforms import normalize, standardize, unnormalize
from botorch.utils.transforms import unnormalize
from torch import Tensor
from torch.distributions import Normal
from torch.quasirandom import SobolEngine

TGenInitialConditions = Callable[
Expand Down Expand Up @@ -578,10 +578,12 @@ def gen_one_shot_kg_initial_conditions(

# sampling from the optimizers
n_value = int((1 - frac_random) * (q_aug - q)) # number of non-random ICs
eta = options.get("eta", 2.0)
weights = torch.exp(eta * standardize(fantasy_vals))
idx = torch.multinomial(weights, num_restarts * n_value, replacement=True)

idx = boltzmann_sample(
function_values=fantasy_vals,
num_samples=num_restarts * n_value,
eta=options.get("eta", 2.0),
replacement=True,
)
# set the respective initial conditions to the sampled optimizers
ics[..., -n_value:, :] = fantasy_cands[idx, 0].view(num_restarts, n_value, -1)
return ics
Expand Down Expand Up @@ -699,14 +701,14 @@ def gen_one_shot_hvkg_initial_conditions(
sequential=False,
)
# sampling from the optimizers
eta = options.get("eta", 2.0)
if num_optim_restarts > 0:
probs = torch.nn.functional.softmax(eta * standardize(fantasy_vals), dim=0)
idx = torch.multinomial(
probs,
num_optim_restarts * acq_function.num_fantasies,
idx = boltzmann_sample(
function_values=fantasy_vals,
num_samples=num_optim_restarts * acq_function.num_fantasies,
eta=options.get("eta", 2.0),
replacement=True,
)

optim_ics = fantasy_cands[idx]
if is_mf_hvkg:
# add fixed features
Expand Down Expand Up @@ -885,11 +887,10 @@ def gen_value_function_initial_conditions(
# sampling from the optimizers
n_value = int((1 - frac_random) * raw_samples) # number of non-random ICs
if n_value > 0:
eta = options.get("eta", 2.0)
weights = torch.exp(eta * standardize(fantasy_vals))
idx = batched_multinomial(
weights=weights.expand(*batch_shape, -1),
idx = boltzmann_sample(
function_values=fantasy_vals.expand(*batch_shape, -1),
num_samples=n_value,
eta=options.get("eta", 2.0),
replacement=True,
).permute(-1, *range(len(batch_shape)))
resampled = fantasy_cands[idx]
Expand Down Expand Up @@ -979,18 +980,12 @@ def initialize_q_batch(
return X[idcs], acq_vals[idcs]

max_val, max_idx = torch.max(acq_vals, dim=0)
Z = (acq_vals - acq_vals.mean(dim=0)) / Ystd
etaZ = eta * Z
weights = torch.exp(etaZ)
while torch.isinf(weights).any():
etaZ *= 0.5
weights = torch.exp(etaZ)
if batch_shape == torch.Size():
idcs = torch.multinomial(weights, n)
else:
idcs = batched_multinomial(
weights=weights.permute(*range(1, len(batch_shape) + 1), 0), num_samples=n
).permute(-1, *range(len(batch_shape)))
idcs = boltzmann_sample(
acq_vals.permute(*range(1, len(batch_shape) + 1), 0),
num_samples=n,
eta=eta,
).permute(-1, *range(len(batch_shape)))

# make sure we get the maximum
if max_idx not in idcs:
idcs[-1] = max_idx
Expand Down Expand Up @@ -1239,133 +1234,6 @@ def sample_points_around_best(
return perturbed_X


def sample_truncated_normal_perturbations(
X: Tensor,
n_discrete_points: int,
sigma: float,
bounds: Tensor,
qmc: bool = True,
) -> Tensor:
r"""Sample points around `X`.
Sample perturbed points around `X` such that the added perturbations
are sampled from N(0, sigma^2 I) and truncated to be within [0,1]^d.
Args:
X: A `n x d`-dim tensor starting points.
n_discrete_points: The number of points to sample.
sigma: The standard deviation of the additive gaussian noise for
perturbing the points.
bounds: A `2 x d`-dim tensor containing the bounds.
qmc: A boolean indicating whether to use qmc.
Returns:
A `n_discrete_points x d`-dim tensor containing the sampled points.
"""
X = normalize(X, bounds=bounds)
d = X.shape[1]
# sample points from N(X_center, sigma^2 I), truncated to be within
# [0, 1]^d.
if X.shape[0] > 1:
rand_indices = torch.randint(X.shape[0], (n_discrete_points,), device=X.device)
X = X[rand_indices]
if qmc:
std_bounds = torch.zeros(2, d, dtype=X.dtype, device=X.device)
std_bounds[1] = 1
u = draw_sobol_samples(bounds=std_bounds, n=n_discrete_points, q=1).squeeze(1)
else:
u = torch.rand((n_discrete_points, d), dtype=X.dtype, device=X.device)
# compute bounds to sample from
a = -X
b = 1 - X
# compute z-score of bounds
alpha = a / sigma
beta = b / sigma
normal = Normal(0, 1)
cdf_alpha = normal.cdf(alpha)
# use inverse transform
perturbation = normal.icdf(cdf_alpha + u * (normal.cdf(beta) - cdf_alpha)) * sigma
# add perturbation and clip points that are still outside
perturbed_X = (X + perturbation).clamp(0.0, 1.0)
return unnormalize(perturbed_X, bounds=bounds)


def sample_perturbed_subset_dims(
X: Tensor,
bounds: Tensor,
n_discrete_points: int,
sigma: float = 1e-1,
qmc: bool = True,
prob_perturb: float | None = None,
) -> Tensor:
r"""Sample around `X` by perturbing a subset of the dimensions.
By default, dimensions are perturbed with probability equal to
`min(20 / d, 1)`. As shown in [Regis]_, perturbing a small number
of dimensions can be beneificial. The perturbations are sampled
from N(0, sigma^2 I) and truncated to be within [0,1]^d.
Args:
X: A `n x d`-dim tensor starting points. `X`
must be normalized to be within `[0, 1]^d`.
bounds: The bounds to sample perturbed values from
n_discrete_points: The number of points to sample.
sigma: The standard deviation of the additive gaussian noise for
perturbing the points.
qmc: A boolean indicating whether to use qmc.
prob_perturb: The probability of perturbing each dimension. If omitted,
defaults to `min(20 / d, 1)`.
Returns:
A `n_discrete_points x d`-dim tensor containing the sampled points.
"""
if bounds.ndim != 2:
raise BotorchTensorDimensionError("bounds must be a `2 x d`-dim tensor.")
elif X.ndim != 2:
raise BotorchTensorDimensionError("X must be a `n x d`-dim tensor.")
d = bounds.shape[-1]
if prob_perturb is None:
# Only perturb a subset of the features
prob_perturb = min(20.0 / d, 1.0)

if X.shape[0] == 1:
X_cand = X.repeat(n_discrete_points, 1)
else:
rand_indices = torch.randint(X.shape[0], (n_discrete_points,), device=X.device)
X_cand = X[rand_indices]
pert = sample_truncated_normal_perturbations(
X=X_cand,
n_discrete_points=n_discrete_points,
sigma=sigma,
bounds=bounds,
qmc=qmc,
)

# find cases where we are not perturbing any dimensions
mask = (
torch.rand(
n_discrete_points,
d,
dtype=bounds.dtype,
device=bounds.device,
)
<= prob_perturb
)
ind = (~mask).all(dim=-1).nonzero()
# perturb `n_perturb` of the dimensions
n_perturb = ceil(d * prob_perturb)
perturb_mask = torch.zeros(d, dtype=mask.dtype, device=mask.device)
perturb_mask[:n_perturb].fill_(1)
# TODO: use batched `torch.randperm` when available:
# https://github.com/pytorch/pytorch/issues/42502
for idx in ind:
mask[idx] = perturb_mask[torch.randperm(d, device=bounds.device)]
# Create candidate points
X_cand[mask] = pert[mask]
return X_cand


def is_nonnegative(acq_function: AcquisitionFunction) -> bool:
r"""Determine whether a given acquisition function is non-negative.
Expand Down
Loading

0 comments on commit 78c04e2

Please sign in to comment.