Skip to content

Commit

Permalink
StratifiedStandardize OutcomeTransform (#2671)
Browse files Browse the repository at this point in the history
Summary:

see title. This allows applying stratified standardization at the model level, which will enable selecting whether to use a Single-task or multi-task model in Ax while using the appropriate transform. I.e. One could specify ModelConfigs that use 1) `SingleTaskGP` + `Standardize`, 2) `MultiTaskGP` + `StratifiedStandardize`.

Differential Revision: D67728920
  • Loading branch information
sdaulton authored and facebook-github-bot committed Jan 9, 2025
1 parent 6026c6f commit e220777
Show file tree
Hide file tree
Showing 5 changed files with 372 additions and 43 deletions.
35 changes: 1 addition & 34 deletions botorch/models/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from botorch.models.model import FantasizeMixin
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
from botorch.models.utils.assorted import get_task_value_remapping
from botorch.models.utils.gpytorch_modules import (
get_covar_module_with_dim_scaled_prior,
get_gaussian_likelihood_with_lognormal_prior,
Expand Down Expand Up @@ -82,40 +83,6 @@
from torch import Tensor


def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor | None:
"""Construct an mapping of discrete task values to contiguous int-valued floats.
Args:
task_values: A sorted long-valued tensor of task values.
dtype: The dtype of the model inputs (e.g. `X`), which the new
task values should have mapped to (e.g. float, double).
Returns:
A tensor of shape `task_values.max() + 1` that maps task values
to new task values. The indexing operation `mapper[task_value]`
will produce a tensor of new task values, of the same shape as
the original. The elements of the `mapper` tensor that do not
appear in the original `task_values` are mapped to `nan`. The
return value will be `None`, when the task values are contiguous
integers starting from zero.
"""
task_range = torch.arange(
len(task_values), dtype=task_values.dtype, device=task_values.device
)
mapper = None
if not torch.equal(task_values, task_range):
# Create a tensor that maps task values to new task values.
# The number of tasks should be small, so this should be quite efficient.
mapper = torch.full(
(int(task_values.max().item()) + 1,),
float("nan"),
dtype=dtype,
device=task_values.device,
)
mapper[task_values] = task_range.to(dtype=dtype)
return mapper


class MultiTaskGP(ExactGP, MultiTaskGPyTorchModel, FantasizeMixin):
r"""Multi-Task exact GP model using an ICM (intrinsic co-regionalization model)
kernel. See [Bonilla2007MTGP]_ and [Swersky2013MTBO]_ for a reference on the
Expand Down
253 changes: 244 additions & 9 deletions botorch/models/transforms/outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@

import torch
from botorch.models.transforms.utils import (
nanstd,
norm_to_lognorm_mean,
norm_to_lognorm_variance,
)
from botorch.models.utils.assorted import get_task_value_remapping
from botorch.posteriors import GPyTorchPosterior, Posterior, TransformedPosterior
from botorch.utils.transforms import normalize_indices
from linear_operator.operators import CholLinearOperator, DiagLinearOperator
Expand Down Expand Up @@ -259,6 +261,25 @@ def __init__(
self._batch_shape = batch_shape
self._min_stdv = min_stdv

def _get_per_input_means_stdvs(
self, X: Tensor, include_stdvs_sq: bool
) -> tuple[Tensor, Tensor, Tensor | None]:
r"""Get per-input means and stdvs.
Args:
X: A `batch_shape x n x d`-dim tensor of input parameters.
include_stdvs_sq: Whether to include the stdvs squared.
This parameter is not used by this method
Returns:
A three-tuple with the means and stdvs:
- The per-input means.
- The per-input stdvs.
- The per-input stdvs squared.
"""
return self.means, self.stdvs, self._stdvs_sq

def forward(
self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
) -> tuple[Tensor, Tensor | None]:
Expand Down Expand Up @@ -313,9 +334,12 @@ def forward(
self.stdvs = stdvs
self._stdvs_sq = stdvs.pow(2)
self._is_trained = torch.tensor(True)

Y_tf = (Y - self.means) / self.stdvs
Yvar_tf = Yvar / self._stdvs_sq if Yvar is not None else None
include_stdvs_sq = Yvar is not None
means, stdvs, stdvs_sq = self._get_per_input_means_stdvs(
X=X, include_stdvs_sq=include_stdvs_sq
)
Y_tf = (Y - means) / stdvs
Yvar_tf = Yvar / stdvs_sq if include_stdvs_sq else None
return Y_tf, Yvar_tf

def subset_output(self, idcs: list[int]) -> OutcomeTransform:
Expand Down Expand Up @@ -376,9 +400,12 @@ def untransform(
"(e.g. `transform(Y)`) before calling `untransform`, since "
"means and standard deviations need to be computed."
)

Y_utf = self.means + self.stdvs * Y
Yvar_utf = self._stdvs_sq * Yvar if Yvar is not None else None
include_stdvs_sq = Yvar is not None
means, stdvs, stdvs_sq = self._get_per_input_means_stdvs(
X=X, include_stdvs_sq=include_stdvs_sq
)
Y_utf = means + stdvs * Y
Yvar_utf = stdvs_sq * Yvar if include_stdvs_sq else None
return Y_utf, Yvar_utf

@property
Expand Down Expand Up @@ -433,8 +460,9 @@ def untransform_posterior(
)
# GPyTorchPosterior (TODO: Should we Lazy-evaluate the mean here as well?)
mvn = posterior.distribution
offset = self.means
scale_fac = self.stdvs
offset, scale_fac, _ = self._get_per_input_means_stdvs(
X=X, include_stdvs_sq=False
)
if not posterior._is_mt:
mean_tf = offset.squeeze(-1) + scale_fac.squeeze(-1) * mvn.mean
scale_fac = scale_fac.squeeze(-1).expand_as(mean_tf)
Expand All @@ -449,7 +477,7 @@ def untransform_posterior(

if (
not mvn.islazy
# TODO: Figure out attribute namming weirdness here
# TODO: Figure out attribute naming weirdness here
or mvn._MultivariateNormal__unbroadcasted_scale_tril is not None
):
# if already computed, we can save a lot of time using scale_tril
Expand All @@ -465,6 +493,213 @@ def untransform_posterior(
return GPyTorchPosterior(mvn_tf)


class StratifiedStandardize(Standardize):
r"""Standardize outcomes (zero mean, unit variance) along stratification dimension.
This module is stateful: If in train mode, calling forward updates the
module state (i.e. the mean/std normalizing constants). If in eval mode,
calling forward simply applies the standardization using the current module
state.
"""

def __init__(
self,
task_values: Tensor,
stratification_idx: int,
batch_shape: torch.Size = torch.Size(), # noqa: B008
min_stdv: float = 1e-8,
# dtype: torch.dtype = torch.double,
) -> None:
r"""Standardize outcomes (zero mean, unit variance) along stratification dim.
Note: This currenlty only supports single output models
(including multi-task models that have a single output).
Args:
task_values: `t`-dim tensor of task values.
stratification_idx: The index of the stratification dimension.
batch_shape: The batch_shape of the training targets.
min_stddv: The minimum standard deviation for which to perform
standardization (if lower, only de-mean the data).
"""
OutcomeTransform.__init__(self)
self._stratification_idx = stratification_idx
task_values = task_values.unique(sorted=True)
self.strata_mapping = get_task_value_remapping(task_values, dtype=torch.long)
if self.strata_mapping is None:
self.strata_mapping = task_values
n_strata = self.strata_mapping.shape[0]
self._min_stdv = min_stdv
self.register_buffer("means", torch.zeros(*batch_shape, n_strata, 1))
self.register_buffer("stdvs", torch.ones(*batch_shape, n_strata, 1))
self.register_buffer("_stdvs_sq", torch.ones(*batch_shape, n_strata, 1))
self.register_buffer("_is_trained", torch.tensor(False))
self._batch_shape = batch_shape
self._m = 1 # TODO: support multiple outputs
self._outputs = None

def forward(
self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
) -> tuple[Tensor, Tensor | None]:
r"""Standardize outcomes.
If the module is in train mode, this updates the module state (i.e. the
mean/std normalizing constants). If the module is in eval mode, simply
applies the normalization using the module state.
Args:
Y: A `batch_shape x n x m`-dim tensor of training targets.
Yvar: A `batch_shape x n x m`-dim tensor of observation noises
associated with the training targets (if applicable).
X: A `batch_shape x n x d`-dim tensor of input parameters.
Returns:
A two-tuple with the transformed outcomes:
- The transformed outcome observations.
- The transformed observation noise (if applicable).
"""
if X is None:
raise ValueError("X is required for StratifiedStandardize.")
if self.training:
if Y.shape[:-2] != self._batch_shape:
raise RuntimeError(
f"Expected Y.shape[:-2] to be {self._batch_shape}, matching "
"the `batch_shape` argument to `StratifiedStandardize`, but got "
f"Y.shape[:-2]={Y.shape[:-2]}."
)
elif Y.shape[-2] < 1:
raise ValueError(f"Can't standardize with no observations. {Y.shape=}.")
elif Y.size(-1) != self._m:
raise RuntimeError(
f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected "
f"{self._m}."
)
self.means = self.means.to(dtype=X.dtype, device=X.device)
self.stdvs = self.stdvs.to(dtype=X.dtype, device=X.device)
self._stdvs_sq = self._stdvs_sq.to(dtype=X.dtype, device=X.device)
strata = X[..., self._stratification_idx].long()
unique_strata = strata.unique()
for s in unique_strata:
mapped_strata = self.strata_mapping[s]
mask = strata != s
Y_strata = Y.clone()
Y_strata[..., mask, :] = float("nan")
if Y.shape[-2] == 1:
stdvs = torch.ones(
(*Y_strata.shape[:-2], 1, Y_strata.shape[-1]),
dtype=Y.dtype,
device=Y.device,
)
else:
stdvs = nanstd(X=Y_strata, dim=-2)
stdvs = stdvs.where(
stdvs >= self._min_stdv, torch.full_like(stdvs, 1.0)
)
means = Y_strata.nanmean(dim=-2)
self.means[..., mapped_strata, :] = means
self.stdvs[..., mapped_strata, :] = stdvs
self._stdvs_sq[..., mapped_strata, :] = stdvs.pow(2)
self._is_trained = torch.tensor(True)
training = self.training
self.training = False
tf_Y, tf_Yvar = super().forward(Y=Y, Yvar=Yvar, X=X)
self.training = training
return tf_Y, tf_Yvar

def _get_per_input_means_stdvs(
self, X: Tensor, include_stdvs_sq: bool
) -> tuple[Tensor, Tensor, Tensor | None]:
r"""Get per-input means and stdvs.
Args:
X: A `batch_shape x n x d`-dim tensor of input parameters.
include_stdvs_sq: Whether to include the stdvs squared.
Returns:
A three-tuple with the per-input means and stdvs:
- The per-input means.
- The per-input stdvs.
- The per-input stdvs squared.
"""
strata = X[..., self._stratification_idx].long()
mapped_strata = self.strata_mapping[strata].unsqueeze(-1)
# get means and stdvs for each strata
n_extra_batch_dims = mapped_strata.ndim - 2 - len(self._batch_shape)
view_shape = torch.Size([1] * n_extra_batch_dims) + self.means.shape
expand_shape = mapped_strata.shape[:n_extra_batch_dims] + self.means.shape
means = torch.gather(
input=self.means.view(view_shape).expand(expand_shape),
dim=-2,
index=mapped_strata,
)
stdvs = torch.gather(
input=self.stdvs.view(view_shape).expand(expand_shape),
dim=-2,
index=mapped_strata,
)
if include_stdvs_sq:
stdvs_sq = torch.gather(
input=self._stdvs_sq.view(view_shape).expand(expand_shape),
dim=-2,
index=mapped_strata,
)
else:
stdvs_sq = None
return means, stdvs, stdvs_sq

def subset_output(self, idcs: list[int]) -> OutcomeTransform:
r"""Subset the transform along the output dimension.
Args:
idcs: The output indices to subset the transform to.
Returns:
The current outcome transform, subset to the specified output indices.
"""
raise NotImplementedError

def untransform(
self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
) -> tuple[Tensor, Tensor | None]:
r"""Un-standardize outcomes.
Args:
Y: A `batch_shape x n x m`-dim tensor of standardized targets.
Yvar: A `batch_shape x n x m`-dim tensor of standardized observation
noises associated with the targets (if applicable).
X: A `batch_shape x n x d`-dim tensor of input parameters.
Returns:
A two-tuple with the un-standardized outcomes:
- The un-standardized outcome observations.
- The un-standardized observation noise (if applicable).
"""
if X is None:
raise ValueError("X is required for StratifiedStandardize.")
return super().untransform(Y=Y, Yvar=Yvar, X=X)

def untransform_posterior(
self, posterior: Posterior, X: Tensor | None = None
) -> GPyTorchPosterior | TransformedPosterior:
r"""Un-standardize the posterior.
Args:
posterior: A posterior in the standardized space.
X: A `batch_shape x n x d`-dim tensor of training inputs (if applicable).
Returns:
The un-standardized posterior. If the input posterior is a
`GPyTorchPosterior`, return a `GPyTorchPosterior`. Otherwise, return a
`TransformedPosterior`.
"""
if X is None:
raise ValueError("X is required for StratifiedStandardize.")
return super().untransform_posterior(posterior=posterior, X=X)


class Log(OutcomeTransform):
r"""Log-transform outcomes.
Expand Down
9 changes: 9 additions & 0 deletions botorch/models/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,12 @@ def interaction_features(X: Tensor) -> Tensor:
dim = X.shape[-1]
row_idcs, col_idcs = torch.triu_indices(dim, dim, offset=1)
return (X.unsqueeze(-1) @ X.unsqueeze(-2))[..., row_idcs, col_idcs].unsqueeze(-2)


def nanstd(X: Tensor, dim: int, keepdim: bool = False) -> Tensor:
n = (~torch.isnan(X)).sum(dim=dim)
return (
(X - X.nanmean(dim=dim, keepdim=True)).pow(2).nanmean(dim=dim, keepdim=keepdim)
* n
/ (n - 1)
).sqrt()
Loading

0 comments on commit e220777

Please sign in to comment.