Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
esantorella authored Mar 10, 2023
2 parents aa6f5cf + eaa6fb2 commit fcb9584
Show file tree
Hide file tree
Showing 12 changed files with 2,407 additions and 4,324 deletions.
1 change: 1 addition & 0 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,7 @@ def construct_inputs_qNEHVI(
"cache_pending": kwargs.get("cache_pending", True),
"max_iep": kwargs.get("max_iep", 0),
"incremental_nehvi": kwargs.get("incremental_nehvi", True),
"cache_root": kwargs.get("cache_root", True),
}


Expand Down
1 change: 0 additions & 1 deletion botorch/acquisition/multi_objective/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def __init__(
r"""Initialize Objective.
Args:
weights: `m'`-dim tensor of outcome weights.
outcomes: A list of the `m'` indices that the weights should be
applied to.
num_outcomes: The total number of outcomes `m`
Expand Down
6 changes: 6 additions & 0 deletions botorch/exceptions/warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,9 @@ class BotorchTensorDimensionWarning(BotorchWarning):
r"""Warning raised when a tensor possibly violates a botorch convention."""

pass


class UserInputWarning(BotorchWarning):
r"""Warning raised when a potential issue is detected with user provided inputs."""

pass
42 changes: 41 additions & 1 deletion botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import torch
from botorch.exceptions.errors import BotorchTensorDimensionError
from botorch.exceptions.warnings import UserInputWarning
from botorch.models.transforms.utils import subset_transform
from botorch.models.utils import fantasize
from botorch.utils.rounding import approximate_round, OneHotArgmaxSTE, RoundSTE
Expand Down Expand Up @@ -505,6 +506,7 @@ def __init__(
transform_on_fantasize: bool = True,
reverse: bool = False,
min_range: float = 1e-8,
learn_bounds: Optional[bool] = None,
) -> None:
r"""Normalize the inputs to the unit cube.
Expand All @@ -527,7 +529,13 @@ def __init__(
the inputs.
min_range: Amount of noise to add to the range to ensure no division by
zero errors.
learn_bounds: Whether to learn the bounds in train mode. Defaults
to False if bounds are provided, otherwise defaults to True.
"""
if learn_bounds is not None:
self.learn_coefficients = learn_bounds
else:
self.learn_coefficients = bounds is None
transform_dimension = d if indices is None else len(indices)
if bounds is not None:
if indices is not None and bounds.size(-1) == d:
Expand All @@ -544,7 +552,12 @@ def __init__(
else:
coefficient = torch.ones(*batch_shape, 1, transform_dimension)
offset = torch.zeros(*batch_shape, 1, transform_dimension)
self.learn_coefficients = True
if self.learn_coefficients is False:
warn(
"learn_bounds is False and no bounds were provided. The bounds "
"will not be updated and the transform will be a no-op.",
UserInputWarning,
)
super().__init__(
d=d,
coefficient=coefficient,
Expand Down Expand Up @@ -586,6 +599,21 @@ def _update_coefficients(self, X) -> None:
self._coefficient = torch.amax(X, dim=reduce_dims).unsqueeze(-2) - self.offset
self._coefficient.clamp_(min=self.min_range)

def get_init_args(self) -> Dict[str, Any]:
r"""Get the arguments necessary to construct an exact copy of the transform."""
return {
"d": self._d,
"indices": getattr(self, "indices", None),
"bounds": self.bounds,
"batch_shape": self.batch_shape,
"transform_on_train": self.transform_on_train,
"transform_on_eval": self.transform_on_eval,
"transform_on_fantasize": self.transform_on_fantasize,
"reverse": self.reverse,
"min_range": self.min_range,
"learn_bounds": self.learn_bounds,
}


class InputStandardize(AffineInputTransform):
r"""Standardize inputs (zero mean, unit variance).
Expand Down Expand Up @@ -796,6 +824,18 @@ def equals(self, other: InputTransform) -> bool:
and self.tau == other.tau
)

def get_init_args(self) -> Dict[str, Any]:
r"""Get the arguments necessary to construct an exact copy of the transform."""
return {
"integer_indices": self.integer_indices,
"categorical_features": self.categorical_features,
"transform_on_train": self.transform_on_train,
"transform_on_eval": self.transform_on_eval,
"transform_on_fantasize": self.transform_on_fantasize,
"approximate": self.approximate,
"tau": self.tau,
}


class Log10(ReversibleInputTransform, Module):
r"""A base-10 log transformation."""
Expand Down
26 changes: 3 additions & 23 deletions scripts/run_tutorials.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,13 @@
import datetime
import os
import subprocess
import tempfile
import time
from pathlib import Path
from subprocess import CalledProcessError
from typing import Any, Dict, Optional, Tuple

import nbformat
import pandas as pd
from memory_profiler import memory_usage
from nbconvert import PythonExporter


IGNORE_ALWAYS = { # ignored in smoke tests and full runs
Expand All @@ -30,7 +27,6 @@
RUN_IF_SMOKE_TEST_IGNORE_IF_STANDARD = { # only used in smoke tests
"thompson_sampling.ipynb", # very slow without KeOps + GPU
"composite_mtbo.ipynb", # TODO: very slow, figure out if we can make it faster
"Multi_objective_multi_fidelity_BO.ipynb", # TODO: very slow, speed up
# Causing the tutorials to crash when run without smoke test. Likely OOM.
# Fix planned.
"constraint_active_search.ipynb",
Expand Down Expand Up @@ -65,33 +61,18 @@ def get_output_file_path(smoke_test: bool) -> str:
return fname


def parse_ipynb(file: Path) -> str:
with open(file, "r") as nb_file:
nb_str = nb_file.read()
nb = nbformat.reads(nb_str, nbformat.NO_CONVERT)
exporter = PythonExporter()
script, _ = exporter.from_notebook_node(nb)
return script


def run_script(
script: str, timeout_minutes: int, env: Optional[Dict[str, str]] = None
tutorial: Path, timeout_minutes: int, env: Optional[Dict[str, str]] = None
) -> None:
# need to keep the file around & close it so subprocess does not run into I/O issues
with tempfile.NamedTemporaryFile(delete=False) as tf:
tf_name = tf.name
with open(tf_name, "w") as tmp_script:
tmp_script.write(script)
if env is not None:
env = {**os.environ, **env}
run_out = subprocess.run(
["ipython", tf_name],
["papermill", tutorial, "|"],
capture_output=True,
text=True,
env=env,
timeout=timeout_minutes * 60,
)
os.remove(tf_name)
return run_out


Expand All @@ -103,13 +84,12 @@ def run_tutorial(
them as a string, and returns runtime and memory information as a dict.
"""
timeout_minutes = 5 if smoke_test else 30
script = parse_ipynb(tutorial)
tic = time.monotonic()
print(f"Running tutorial {tutorial.name}.")
env = {"SMOKE_TEST": "True"} if smoke_test else None
try:
mem_usage, run_out = memory_usage(
(run_script, (script, timeout_minutes), {"env": env}),
(run_script, (tutorial, timeout_minutes), {"env": env}),
retval=True,
include_children=True,
)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"kaleido",
"matplotlib",
"memory_profiler",
"papermill",
"pykeops",
"torchvision",
]
Expand Down
3 changes: 3 additions & 0 deletions test/acquisition/test_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,7 @@ def test_construct_inputs_qNEHVI(self):
self.assertTrue(kwargs["cache_pending"])
self.assertEqual(kwargs["max_iep"], 0)
self.assertTrue(kwargs["incremental_nehvi"])
self.assertTrue(kwargs["cache_root"])

# Test check for block designs
mock_model = mock.Mock()
Expand Down Expand Up @@ -748,6 +749,7 @@ def test_construct_inputs_qNEHVI(self):
cache_pending=False,
max_iep=1,
incremental_nehvi=False,
cache_root=False,
)
ref_point_expected = objective(objective_thresholds)
self.assertTrue(torch.equal(kwargs["ref_point"], ref_point_expected))
Expand All @@ -768,6 +770,7 @@ def test_construct_inputs_qNEHVI(self):
self.assertFalse(kwargs["cache_pending"])
self.assertEqual(kwargs["max_iep"], 1)
self.assertFalse(kwargs["incremental_nehvi"])
self.assertFalse(kwargs["cache_root"])

# Test with risk measures.
with self.assertRaisesRegex(UnsupportedError, "feasibility-weighted"):
Expand Down
3 changes: 3 additions & 0 deletions test/exceptions/test_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
InputDataWarning,
OptimizationWarning,
SamplingWarning,
UserInputWarning,
)
from botorch.utils.testing import BotorchTestCase

Expand All @@ -28,6 +29,7 @@ def test_botorch_warnings_hierarchy(self):
self.assertIsInstance(OptimizationWarning(), BotorchWarning)
self.assertIsInstance(SamplingWarning(), BotorchWarning)
self.assertIsInstance(BotorchTensorDimensionWarning(), BotorchWarning)
self.assertIsInstance(UserInputWarning(), BotorchWarning)

def test_botorch_warnings(self):
for WarningClass in (
Expand All @@ -38,6 +40,7 @@ def test_botorch_warnings(self):
InputDataWarning,
OptimizationWarning,
SamplingWarning,
UserInputWarning,
):
with warnings.catch_warnings(record=True) as ws, settings.debug(True):
warnings.warn("message", WarningClass)
Expand Down
4 changes: 2 additions & 2 deletions test/models/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_AffineDeterministicModel(self):
X = torch.rand(*shape)
p = model.posterior(X)
mean_exp = model.b + (X.unsqueeze(-1) * a).sum(dim=-2)
self.assertTrue(torch.equal(p.mean, mean_exp))
self.assertAllClose(p.mean, mean_exp)
# # test two-dim output
a = torch.rand(3, 2)
model = AffineDeterministicModel(a)
Expand All @@ -105,7 +105,7 @@ def test_AffineDeterministicModel(self):
X = torch.rand(*shape)
p = model.posterior(X)
mean_exp = model.b + (X.unsqueeze(-1) * a).sum(dim=-2)
self.assertTrue(torch.equal(p.mean, mean_exp))
self.assertAllClose(p.mean, mean_exp)
# test subset output
X = torch.rand(4, 3)
subset_model = model.subset_output([0])
Expand Down
22 changes: 20 additions & 2 deletions test/models/transforms/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
from botorch import settings
from botorch.exceptions.errors import BotorchTensorDimensionError
from botorch.exceptions.warnings import UserInputWarning
from botorch.models.transforms.input import (
AffineInputTransform,
AppendFeatures,
Expand Down Expand Up @@ -155,16 +156,29 @@ def test_normalize(self):
self.assertEqual(nlz._d, 2)
self.assertEqual(nlz.mins.shape, torch.Size([3, 1, 2]))
self.assertEqual(nlz.ranges.shape, torch.Size([3, 1, 2]))
self.assertTrue(nlz.equals(Normalize(**nlz.get_init_args())))

# basic init, fixed bounds
# learn_bounds=False with no bounds.
with self.assertWarnsRegex(UserInputWarning, "learn_bounds"):
Normalize(d=2, learn_bounds=False)

# learn_bounds=True with bounds provided.
bounds = torch.zeros(2, 2, device=self.device, dtype=dtype)
nlz = Normalize(d=2, bounds=bounds, learn_bounds=True)
self.assertTrue(nlz.learn_bounds)
self.assertTrue(torch.equal(nlz.mins, bounds[..., 0:1, :]))
self.assertTrue(
torch.equal(nlz.ranges, bounds[..., 1:2, :] - bounds[..., 0:1, :])
)

# basic init, fixed bounds
nlz = Normalize(d=2, bounds=bounds)
self.assertFalse(nlz.learn_bounds)
self.assertTrue(nlz.training)
self.assertEqual(nlz._d, 2)
self.assertTrue(torch.equal(nlz.mins, bounds[..., 0:1, :]))
self.assertTrue(
torch.equal(nlz.mins, bounds[..., 1:2, :] - bounds[..., 0:1, :])
torch.equal(nlz.ranges, bounds[..., 1:2, :] - bounds[..., 0:1, :])
)
# with grad
bounds.requires_grad = True
Expand All @@ -180,6 +194,7 @@ def test_normalize(self):
nlz.eval()
self.assertIsNone(nlz.coefficient.grad_fn)
self.assertIsNone(nlz.offset.grad_fn)
self.assertTrue(nlz.equals(Normalize(**nlz.get_init_args())))

# basic init, provided indices
with self.assertRaises(ValueError):
Expand All @@ -204,6 +219,7 @@ def test_normalize(self):
== torch.tensor([0], dtype=torch.long, device=self.device)
).all()
)
self.assertTrue(nlz.equals(Normalize(**nlz.get_init_args())))

# test .to
other_dtype = torch.float if dtype == torch.double else torch.double
Expand Down Expand Up @@ -594,13 +610,15 @@ def test_round_transform(self):
self.assertTrue(round_tf.training)
self.assertFalse(round_tf.approximate)
self.assertEqual(round_tf.tau, 1e-3)
self.assertTrue(round_tf.equals(Round(**round_tf.get_init_args())))

# With tensor indices.
round_tf = Round(
integer_indices=torch.tensor(int_idcs, dtype=dtype, device=self.device),
categorical_features=categorical_feats,
)
self.assertEqual(round_tf.integer_indices.tolist(), int_idcs)
self.assertTrue(round_tf.equals(Round(**round_tf.get_init_args())))

# basic usage
for batch_shape, approx, categorical_features in itertools.product(
Expand Down
Loading

0 comments on commit fcb9584

Please sign in to comment.