Skip to content

Commit

Permalink
Use gpytorch constraints for bounds on parameters during model fitting.
Browse files Browse the repository at this point in the history
Summary:
gpytorch now supports defining constraints on its submodules. This allows to specify parameter constraints where they belong (on the model), and have model fitting deal with this in a generic way.
Note that constraints that have a transform that is not `None` automatically enforces the constraint by using a transform. This can be an issue for quasi 2nd order optimizers though b/c the objective becomes flat when overshooting past the effective constraint in the line search.
Hence not doing the transform and imposing an explicit constraint is preferred. It may also be beneficial to use the transform in conjunction with an explicit bound - will have to evaluate that more.

Reviewed By: bletham

Differential Revision: D14840983

fbshipit-source-id: 6f52ec9eb0b970a692963083125e58df55a46de5
  • Loading branch information
Balandat committed Apr 26, 2019
1 parent 44f18fd commit 5333d5b
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 115 deletions.
16 changes: 9 additions & 7 deletions botorch/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Optional

import torch
from gpytorch.constraints.constraints import GreaterThan
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels.matern_kernel import MaternKernel
from gpytorch.kernels.scale_kernel import ScaleKernel
Expand All @@ -27,6 +28,9 @@
from .gpytorch import GPyTorchModel


MIN_INFERRED_NOISE_LEVEL = 1e-6


class SingleTaskGP(ExactGP, GPyTorchModel):
r"""A single-task Exact GP model.
Expand Down Expand Up @@ -58,10 +62,10 @@ def __init__(
raise ValueError(f"Unsupported shape {train_X.shape} for train_X.")
if likelihood is None:
likelihood = GaussianLikelihood(
noise_prior=GammaPrior(1.1, 0.05), batch_size=batch_size
noise_prior=GammaPrior(1.1, 0.05),
batch_size=batch_size,
noise_constraint=GreaterThan(MIN_INFERRED_NOISE_LEVEL, transform=None),
)
# TODO: Use gpytorch constraints
likelihood.parameter_bounds = {"noise_covar.raw_noise": (-15, None)}
else:
self._likelihood_state_dict = deepcopy(likelihood.state_dict())
super().__init__(train_X, train_Y, likelihood)
Expand Down Expand Up @@ -183,15 +187,13 @@ class HeteroskedasticSingleTaskGP(SingleTaskGP):
def __init__(self, train_X: Tensor, train_Y: Tensor, train_Y_se: Tensor) -> None:
train_Y_log_var = 2 * torch.log(train_Y_se)
noise_likelihood = GaussianLikelihood(
noise_prior=SmoothedBoxPrior(-3, 5, 0.5, transform=torch.log)
noise_prior=SmoothedBoxPrior(-3, 5, 0.5, transform=torch.log),
noise_constraint=GreaterThan(MIN_INFERRED_NOISE_LEVEL, transform=None),
)
noise_model = SingleTaskGP(
train_X=train_X, train_Y=train_Y_log_var, likelihood=noise_likelihood
)
likelihood = _GaussianLikelihoodBase(HeteroskedasticNoise(noise_model))
likelihood.parameter_bounds = {
"noise_covar.noise_model.likelihood.noise_covar.raw_noise": (-15, None)
}
super().__init__(train_X=train_X, train_Y=train_Y, likelihood=likelihood)
self.to(train_X)

Expand Down
25 changes: 23 additions & 2 deletions botorch/optim/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class OptimizationIteration(NamedTuple):

def fit_gpytorch_torch(
mll: MarginalLogLikelihood,
bounds: Optional[ParameterBounds] = None,
optimizer_cls: Optimizer = Adam,
lr: float = 0.05,
maxiter: int = 100,
Expand All @@ -41,6 +42,10 @@ def fit_gpytorch_torch(
Args:
mll: MarginalLogLikelihood to be maximized.
bounds: A ParameterBounds dictionary mapping parameter names to tuples
of lower and upper bounds. Bounds specified here take precedence
over bounds on the same parameters specified in the constraints
registered with the module.
optimizer_cls: Torch optimizer to use. Must not need a closure.
Defaults to Adam.
lr: Starting learning rate.
Expand All @@ -60,6 +65,17 @@ def fit_gpytorch_torch(
params=[{"params": mll.parameters()}], lr=lr, **optimizer_args
)

# get bounds specified in model (if any)
bounds_: ParameterBounds = {}
if hasattr(mll, "named_parameters_and_constraints"):
for param_name, _, constraint in mll.named_parameters_and_constraints():
if constraint is not None and not constraint.enforced:
bounds_[param_name] = constraint.lower_bound, constraint.upper_bound

# update with user-supplied bounds (overwrites if already exists)
if bounds is not None:
bounds_.update(bounds)

iterations = []
t1 = time.time()

Expand All @@ -80,11 +96,16 @@ def fit_gpytorch_torch(
loss_trajectory.append(loss.item())
for name, param in mll.named_parameters():
param_trajectory[name].append(param.detach().clone())
if disp and (i % 10 == 0 or i == (maxiter - 1)):
print(f"Iter {i +1}/{maxiter}: {loss.item()}")
if disp and ((i + 1) % 10 == 0 or i == (maxiter - 1)):
print(f"Iter {i + 1}/{maxiter}: {loss.item()}")
if track_iterations:
iterations.append(OptimizationIteration(i, loss.item(), time.time() - t1))
optimizer.step()
# project onto bounds:
if bounds_:
for pname, param in mll.named_parameters():
if pname in bounds_:
param.data = param.data.clamp(*bounds_[pname])
i += 1
converged = check_convergence(
loss_trajectory=loss_trajectory,
Expand Down
38 changes: 20 additions & 18 deletions botorch/optim/numpy_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections import OrderedDict
from math import inf
from typing import Dict, List, NamedTuple, Optional, Tuple
from typing import Dict, List, NamedTuple, Optional, Set, Tuple

import numpy as np
import torch
Expand All @@ -19,18 +19,22 @@ class TorchAttr(NamedTuple):


def module_to_array(
module: Module, bounds: Optional[ParameterBounds] = None
module: Module,
bounds: Optional[ParameterBounds] = None,
exclude: Optional[Set[str]] = None,
) -> Tuple[np.ndarray, Dict[str, TorchAttr], Optional[np.ndarray]]:
r"""Extract named parameters from a module into a numpy array.
Only extracts parameters with requires_grad, since it is meant for optimizing.
Args:
module: A module with parameters. May specify parameter constraints in
a `parameter_bounds` attribute.
bounds: A ParameterBounds dictionary mapping parameter names to tuples of
lower and upper bounds. Bounds specified here take precedence over
bounds specified in the `parameter_bounds` attribute of the module.
a `named_parameters_and_constraints` method.
bounds: A ParameterBounds dictionary mapping parameter names to tuples
of lower and upper bounds. Bounds specified here take precedence
over bounds on the same parameters specified in the constraints
registered with the module.
exclude: A list of parameter names that are to be excluded from extraction.
Returns:
np.ndarray: The parameter values
Expand All @@ -43,23 +47,21 @@ def module_to_array(
lower: List[np.ndarray] = []
upper: List[np.ndarray] = []
property_dict = OrderedDict()
exclude = set() if exclude is None else exclude

# extract parameter bounds from module.model.parameter_bounds and
# module.likelihood.parameter_bounds (if present)
model_bounds = getattr(getattr(module, "model", None), "parameter_bounds", {})
bounds_ = {".".join(["model", key]): val for key, val in model_bounds.items()}
likelihood_bounds = getattr(
getattr(module, "likelihood", None), "parameter_bounds", {}
)
bounds_.update(
{".".join(["likelihood", key]): val for key, val in likelihood_bounds.items()}
)
# update with user-supplied bounds
# get bounds specified in model (if any)
bounds_: ParameterBounds = {}
if hasattr(module, "named_parameters_and_constraints"):
for param_name, _, constraint in module.named_parameters_and_constraints():
if constraint is not None and not constraint.enforced:
bounds_[param_name] = constraint.lower_bound, constraint.upper_bound

# update with user-supplied bounds (overwrites if already exists)
if bounds is not None:
bounds_.update(bounds)

for p_name, t in module.named_parameters():
if t.requires_grad:
if p_name not in exclude and t.requires_grad:
property_dict[p_name] = TorchAttr(
shape=t.shape, dtype=t.dtype, device=t.device
)
Expand Down
Loading

0 comments on commit 5333d5b

Please sign in to comment.