Skip to content

Commit

Permalink
Open-Sourcing MAP SAAS (#2694)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2694

This commit open-sources the MAP-SAAS model, originally implemented by dme65 to provide a more efficient alternative to the fully Bayesian SAAS model.

Reviewed By: Balandat

Differential Revision: D68522782
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Jan 23, 2025
1 parent 0df5521 commit 4a8b91b
Show file tree
Hide file tree
Showing 7 changed files with 1,309 additions and 31 deletions.
135 changes: 135 additions & 0 deletions botorch/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,19 @@
from typing import Any
from warnings import catch_warnings, simplefilter, warn_explicit, WarningMessage

import torch

from botorch.exceptions.errors import ModelFittingError, UnsupportedError
from botorch.exceptions.warnings import OptimizationWarning
from botorch.logging import logger
from botorch.models import SingleTaskGP
from botorch.models.approximate_gp import ApproximateGPyTorchModel
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
from botorch.models.map_saas import get_map_saas_model
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.optim.closures import get_loss_closure_with_grads
from botorch.optim.core import _LBFGSB_MAXITER_MAXFUN_REGEX
from botorch.optim.fit import fit_gpytorch_mll_scipy, fit_gpytorch_mll_torch
Expand All @@ -38,11 +44,13 @@
from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder
from gpytorch.likelihoods import Likelihood
from gpytorch.mlls._approximate_mll import _ApproximateMarginalLogLikelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from linear_operator.utils.errors import NotPSDError
from pyro.infer.mcmc import MCMC, NUTS
from torch import device, Tensor
from torch.distributions import HalfCauchy
from torch.nn import Parameter
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -382,3 +390,130 @@ def fit_fully_bayesian_model_nuts(
# Load the MCMC samples back into the BoTorch model
model.load_mcmc_samples(mcmc_samples)
model.eval()


def get_fitted_map_saas_model(
train_X: Tensor,
train_Y: Tensor,
train_Yvar: Tensor | None = None,
input_transform: InputTransform | None = None,
outcome_transform: OutcomeTransform | None = None,
tau: float | None = None,
optimizer_kwargs: dict[str, Any] | None = None,
) -> SingleTaskGP:
"""Get a fitted MAP SAAS model with a Matern kernel.
Args:
train_X: Tensor of shape `n x d` with training inputs.
train_Y: Tensor of shape `n x 1` with training targets.
train_Yvar: Optional tensor of shape `n x 1` with observed noise,
inferred if None.
input_transform: An optional input transform.
outcome_transform: An optional outcome transforms.
tau: Fixed value of the global shrinkage tau. If None, the model
places a HC(0.1) prior on tau.
optimizer_kwargs: A dict of options for the optimizer passed
to fit_gpytorch_mll.
Returns:
A fitted SingleTaskGP with a Matern kernel.
"""

# make sure optimizer_kwargs is a Dict
optimizer_kwargs = optimizer_kwargs or {}

model = get_map_saas_model(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
input_transform=(
input_transform.train() if input_transform is not None else None
),
outcome_transform=outcome_transform,
tau=tau,
)
mll = ExactMarginalLogLikelihood(model=model, likelihood=model.likelihood)
fit_gpytorch_mll(mll, optimizer_kwargs=optimizer_kwargs)
return model


def get_fitted_map_saas_ensemble(
train_X: Tensor,
train_Y: Tensor,
train_Yvar: Tensor | None = None,
input_transform: InputTransform | None = None,
outcome_transform: OutcomeTransform | None = None,
taus: Tensor | list[float] | None = None,
num_taus: int = 4,
optimizer_kwargs: dict[str, Any] | None = None,
) -> SaasFullyBayesianSingleTaskGP:
"""Get a fitted SAAS ensemble using several different tau values.
Args:
train_X: Tensor of shape `n x d` with training inputs.
train_Y: Tensor of shape `n x 1` with training targets.
train_Yvar: Optional tensor of shape `n x 1` with observed noise,
inferred if None.
input_transform: An optional input transform.
outcome_transform: An optional outcome transforms.
taus: Global shrinkage values to use. If None, we sample `num_taus` values
from an HC(0.1) distrbution.
num_taus: Optional argument for how many taus to sample.
optimizer_kwargs: A dict of options for the optimizer passed
to fit_gpytorch_mll.
Returns:
A fitted SaasFullyBayesianSingleTaskGP with a Matern kernel.
"""
tkwargs = {"device": train_X.device, "dtype": train_X.dtype}
if taus is None:
taus = HalfCauchy(0.1).sample([num_taus]).to(**tkwargs)
num_samples = len(taus)
if num_samples == 1:
raise ValueError(
"Use `get_fitted_map_saas_model` if you only specify one value of tau"
)

mean = torch.zeros(num_samples, **tkwargs)
outputscale = torch.zeros(num_samples, **tkwargs)
lengthscale = torch.zeros(num_samples, train_X.shape[-1], **tkwargs)
noise = torch.zeros(num_samples, **tkwargs)

# Fit a model for each tau and save the hyperparameters
for i, tau in enumerate(taus):
model = get_fitted_map_saas_model(
train_X,
train_Y,
train_Yvar=train_Yvar,
input_transform=input_transform,
outcome_transform=outcome_transform,
tau=tau,
optimizer_kwargs=optimizer_kwargs,
)
mean[i] = model.mean_module.constant.detach().clone()
outputscale[i] = model.covar_module.outputscale.detach().clone()
lengthscale[i, :] = model.covar_module.base_kernel.lengthscale.detach().clone()
if train_Yvar is None:
noise[i] = model.likelihood.noise.detach().clone()

# Load the samples into a fully Bayesian SAAS model
ensemble_model = SaasFullyBayesianSingleTaskGP(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
input_transform=(
input_transform.train() if input_transform is not None else None
),
outcome_transform=outcome_transform,
)
mcmc_samples = {
"mean": mean,
"outputscale": outputscale,
"lengthscale": lengthscale,
}
if train_Yvar is None:
mcmc_samples["noise"] = noise
ensemble_model.train()
ensemble_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
ensemble_model.eval()
return ensemble_model
4 changes: 4 additions & 0 deletions botorch/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
from botorch.models.higher_order_gp import HigherOrderGP

from botorch.models.map_saas import add_saas_prior, AdditiveMapSaasSingleTaskGP
from botorch.models.model import ModelList
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.multitask import KroneckerMultiTaskGP, MultiTaskGP
from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood

__all__ = [
"add_saas_prior",
"AdditiveMapSaasSingleTaskGP",
"AffineDeterministicModel",
"AffineFidelityCostModel",
"ApproximateGPyTorchModel",
Expand Down
Loading

0 comments on commit 4a8b91b

Please sign in to comment.