Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vizier output transforms #2643

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8047aa1
wip: topk ic generation
CompRhys Nov 19, 2024
f5a8d64
tests: add tests
CompRhys Nov 20, 2024
96f9bef
Merge remote-tracking branch 'upstream/main' into topk-icgen
CompRhys Nov 24, 2024
a022462
fix: micro-optimization suggestion from review
CompRhys Nov 24, 2024
e75239d
fix: don't use unnormalize due to unexpected behaviour with constant …
CompRhys Nov 25, 2024
8e27422
doc: initialize_q_batch_topk -> initialize_q_batch_topn
CompRhys Nov 25, 2024
662caf1
tests: achive full coverage
CompRhys Nov 26, 2024
75eea37
clean: remote debug snippet
CompRhys Nov 26, 2024
5e0fe59
Merge remote-tracking branch 'upstream/main' into topk-icgen
CompRhys Nov 27, 2024
ec4d7f8
fea: add InfeasibleTranforms from vizier
CompRhys Nov 27, 2024
351c3f8
fea: add the logwarp transform from vizier
CompRhys Nov 27, 2024
a656685
wip half rank
CompRhys Nov 29, 2024
88a2e5d
fea: use unnormalize in more places but add flag to turn off the cons…
CompRhys Dec 2, 2024
e0202e2
doc: add docstring for the new update_constant_bounds argument
CompRhys Dec 2, 2024
7f439c5
Merge branch 'topk-icgen' into vizier-output-transforms
CompRhys Dec 2, 2024
44225d8
Merge remote-tracking branch 'upstream/main' into vizier-output-trans…
CompRhys Dec 2, 2024
3a87cc6
wip: untransform still doesn't work
CompRhys Dec 2, 2024
2cc3efa
Merge remote-tracking branch 'upstream/main' into vizier-output-trans…
CompRhys Dec 3, 2024
926d9e2
fea: add half rank transform
CompRhys Dec 3, 2024
301fa5c
test: add tests for half-rank
CompRhys Dec 3, 2024
21c14b2
test: reduce the number of tests run whilst ensuring coverage
CompRhys Dec 3, 2024
c63c571
fix: fix some review comments
CompRhys Dec 3, 2024
cbee6d1
fea: add forward when not in train
CompRhys Dec 4, 2024
b8311ae
fix: force=True to enforce device safety
CompRhys Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
doc: initialize_q_batch_topk -> initialize_q_batch_topn
  • Loading branch information
CompRhys committed Nov 25, 2024
commit 8e274227e51adec300fe1bbd9da7ecee5d47ccf9
4 changes: 2 additions & 2 deletions botorch/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@
from botorch.optim.initializers import (
initialize_q_batch,
initialize_q_batch_nonneg,
initialize_q_batch_topk,
initialize_q_batch_topn,
)
from botorch.optim.optimize import (
gen_batch_initial_conditions,
@@ -47,7 +47,7 @@
"gen_batch_initial_conditions",
"initialize_q_batch",
"initialize_q_batch_nonneg",
"initialize_q_batch_topk",
"initialize_q_batch_topn",
"OptimizationResult",
"OptimizationStatus",
"optimize_acqf",
8 changes: 4 additions & 4 deletions botorch/optim/initializers.py
Original file line number Diff line number Diff line change
@@ -329,9 +329,9 @@
device = bounds.device
bounds_cpu = bounds.cpu()

if options.get("topk"):
init_func = initialize_q_batch_topk
if options.get("topn"):
init_func = initialize_q_batch_topn
init_func_opts = ["sorted", "largest"]

Check warning on line 334 in botorch/optim/initializers.py

Codecov / codecov/patch

botorch/optim/initializers.py#L333-L334

Added lines #L333 - L334 were not covered by tests
elif options.get("nonnegative") or is_nonnegative(acq_function):
init_func = initialize_q_batch_nonneg
init_func_opts = ["alpha", "eta"]
@@ -342,7 +342,7 @@
for opt in init_func_opts:
# default value of "largest" to "acq_function.maximize" if it exists
if opt == "largest" and hasattr(acq_function, "maximize"):
init_kwargs[opt] = acq_function.maximize

Check warning on line 345 in botorch/optim/initializers.py

Codecov / codecov/patch

botorch/optim/initializers.py#L345

Added line #L345 was not covered by tests

if opt in options:
init_kwargs[opt] = options.get(opt)
@@ -1079,7 +1079,7 @@
return X[idcs], acq_vals[idcs]


def initialize_q_batch_topk(
def initialize_q_batch_topn(
X: Tensor, acq_vals: Tensor, n: int, largest: bool = True, sorted: bool = True
) -> tuple[Tensor, Tensor]:
r"""Take the top `n` initial conditions for candidate generation.
@@ -1100,7 +1100,7 @@
>>> # for model with `d=6`:
>>> qUCB = qUpperConfidenceBound(model, beta=0.1)
>>> X_rnd = torch.rand(500, 3, 6)
>>> X_init, acq_init = initialize_q_batch_topk(
>>> X_init, acq_init = initialize_q_batch_topn(
... X=X_rnd, acq_vals=qUCB(X_rnd), n=10
... )

18 changes: 9 additions & 9 deletions test/optim/test_initializers.py
Original file line number Diff line number Diff line change
@@ -31,13 +31,13 @@
from botorch.models import SingleTaskGP
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.optim.initializers import (
initialize_q_batch,
initialize_q_batch_nonneg,
initialize_q_batch_topk,
gen_batch_initial_conditions,
gen_one_shot_hvkg_initial_conditions,
gen_one_shot_kg_initial_conditions,
gen_value_function_initial_conditions,
initialize_q_batch,
initialize_q_batch_nonneg,
initialize_q_batch_topn,
sample_perturbed_subset_dims,
sample_points_around_best,
sample_q_batches_from_polytope,
@@ -157,37 +157,37 @@ def test_initialize_q_batch(self):
with self.assertRaises(RuntimeError):
initialize_q_batch(X=X, acq_vals=acq_vals, n=10)

def test_initialize_q_batch_topk(self):
def test_initialize_q_batch_topn(self):
for dtype in (torch.float, torch.double):
# basic test
X = torch.rand(5, 3, 4, device=self.device, dtype=dtype)
acq_vals = torch.rand(5, device=self.device, dtype=dtype)
ics_X, ics_acq_vals = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=2)
ics_X, ics_acq_vals = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(ics_X.shape, torch.Size([2, 3, 4]))
self.assertEqual(ics_X.device, X.device)
self.assertEqual(ics_X.dtype, X.dtype)
self.assertEqual(ics_acq_vals.shape, torch.Size([2]))
self.assertEqual(ics_acq_vals.device, acq_vals.device)
self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype)
# ensure nothing happens if we want all samples
ics_X, ics_acq_vals = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=5)
ics_X, ics_acq_vals = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=5)
self.assertTrue(torch.equal(X, ics_X))
self.assertTrue(torch.equal(acq_vals, ics_acq_vals))
# make sure things work with constant inputs
acq_vals = torch.ones(5, device=self.device, dtype=dtype)
ics, _ = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=2)
ics, _ = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(ics.shape, torch.Size([2, 3, 4]))
self.assertEqual(ics.device, X.device)
self.assertEqual(ics.dtype, X.dtype)
# ensure raises correct warning
acq_vals = torch.zeros(5, device=self.device, dtype=dtype)
with warnings.catch_warnings(record=True) as w:
ics, _ = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=2)
ics, _ = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning))
self.assertEqual(ics.shape, torch.Size([2, 3, 4]))
with self.assertRaises(RuntimeError):
initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=10)
initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=10)

def test_initialize_q_batch_largeZ(self):
for dtype in (torch.float, torch.double):