From e220777cd669935ef1e53dc0f55cd359af411cf4 Mon Sep 17 00:00:00 2001 From: Sam Daulton <sdaulton@meta.com> Date: Thu, 9 Jan 2025 10:12:10 -0800 Subject: [PATCH] StratifiedStandardize OutcomeTransform (#2671) Summary: see title. This allows applying stratified standardization at the model level, which will enable selecting whether to use a Single-task or multi-task model in Ax while using the appropriate transform. I.e. One could specify ModelConfigs that use 1) `SingleTaskGP` + `Standardize`, 2) `MultiTaskGP` + `StratifiedStandardize`. Differential Revision: D67728920 --- botorch/models/multitask.py | 35 +--- botorch/models/transforms/outcome.py | 253 ++++++++++++++++++++++++- botorch/models/transforms/utils.py | 9 + botorch/models/utils/assorted.py | 34 ++++ test/models/transforms/test_outcome.py | 84 ++++++++ 5 files changed, 372 insertions(+), 43 deletions(-) diff --git a/botorch/models/multitask.py b/botorch/models/multitask.py index eaf0fa41a5..876e2c8ede 100644 --- a/botorch/models/multitask.py +++ b/botorch/models/multitask.py @@ -39,6 +39,7 @@ from botorch.models.model import FantasizeMixin from botorch.models.transforms.input import InputTransform from botorch.models.transforms.outcome import OutcomeTransform, Standardize +from botorch.models.utils.assorted import get_task_value_remapping from botorch.models.utils.gpytorch_modules import ( get_covar_module_with_dim_scaled_prior, get_gaussian_likelihood_with_lognormal_prior, @@ -82,40 +83,6 @@ from torch import Tensor -def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor | None: - """Construct an mapping of discrete task values to contiguous int-valued floats. - - Args: - task_values: A sorted long-valued tensor of task values. - dtype: The dtype of the model inputs (e.g. `X`), which the new - task values should have mapped to (e.g. float, double). - - Returns: - A tensor of shape `task_values.max() + 1` that maps task values - to new task values. The indexing operation `mapper[task_value]` - will produce a tensor of new task values, of the same shape as - the original. The elements of the `mapper` tensor that do not - appear in the original `task_values` are mapped to `nan`. The - return value will be `None`, when the task values are contiguous - integers starting from zero. - """ - task_range = torch.arange( - len(task_values), dtype=task_values.dtype, device=task_values.device - ) - mapper = None - if not torch.equal(task_values, task_range): - # Create a tensor that maps task values to new task values. - # The number of tasks should be small, so this should be quite efficient. - mapper = torch.full( - (int(task_values.max().item()) + 1,), - float("nan"), - dtype=dtype, - device=task_values.device, - ) - mapper[task_values] = task_range.to(dtype=dtype) - return mapper - - class MultiTaskGP(ExactGP, MultiTaskGPyTorchModel, FantasizeMixin): r"""Multi-Task exact GP model using an ICM (intrinsic co-regionalization model) kernel. See [Bonilla2007MTGP]_ and [Swersky2013MTBO]_ for a reference on the diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index 4c01204dd6..b9f58f63bc 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -27,9 +27,11 @@ import torch from botorch.models.transforms.utils import ( + nanstd, norm_to_lognorm_mean, norm_to_lognorm_variance, ) +from botorch.models.utils.assorted import get_task_value_remapping from botorch.posteriors import GPyTorchPosterior, Posterior, TransformedPosterior from botorch.utils.transforms import normalize_indices from linear_operator.operators import CholLinearOperator, DiagLinearOperator @@ -259,6 +261,25 @@ def __init__( self._batch_shape = batch_shape self._min_stdv = min_stdv + def _get_per_input_means_stdvs( + self, X: Tensor, include_stdvs_sq: bool + ) -> tuple[Tensor, Tensor, Tensor | None]: + r"""Get per-input means and stdvs. + + Args: + X: A `batch_shape x n x d`-dim tensor of input parameters. + include_stdvs_sq: Whether to include the stdvs squared. + This parameter is not used by this method + + Returns: + A three-tuple with the means and stdvs: + + - The per-input means. + - The per-input stdvs. + - The per-input stdvs squared. + """ + return self.means, self.stdvs, self._stdvs_sq + def forward( self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None ) -> tuple[Tensor, Tensor | None]: @@ -313,9 +334,12 @@ def forward( self.stdvs = stdvs self._stdvs_sq = stdvs.pow(2) self._is_trained = torch.tensor(True) - - Y_tf = (Y - self.means) / self.stdvs - Yvar_tf = Yvar / self._stdvs_sq if Yvar is not None else None + include_stdvs_sq = Yvar is not None + means, stdvs, stdvs_sq = self._get_per_input_means_stdvs( + X=X, include_stdvs_sq=include_stdvs_sq + ) + Y_tf = (Y - means) / stdvs + Yvar_tf = Yvar / stdvs_sq if include_stdvs_sq else None return Y_tf, Yvar_tf def subset_output(self, idcs: list[int]) -> OutcomeTransform: @@ -376,9 +400,12 @@ def untransform( "(e.g. `transform(Y)`) before calling `untransform`, since " "means and standard deviations need to be computed." ) - - Y_utf = self.means + self.stdvs * Y - Yvar_utf = self._stdvs_sq * Yvar if Yvar is not None else None + include_stdvs_sq = Yvar is not None + means, stdvs, stdvs_sq = self._get_per_input_means_stdvs( + X=X, include_stdvs_sq=include_stdvs_sq + ) + Y_utf = means + stdvs * Y + Yvar_utf = stdvs_sq * Yvar if include_stdvs_sq else None return Y_utf, Yvar_utf @property @@ -433,8 +460,9 @@ def untransform_posterior( ) # GPyTorchPosterior (TODO: Should we Lazy-evaluate the mean here as well?) mvn = posterior.distribution - offset = self.means - scale_fac = self.stdvs + offset, scale_fac, _ = self._get_per_input_means_stdvs( + X=X, include_stdvs_sq=False + ) if not posterior._is_mt: mean_tf = offset.squeeze(-1) + scale_fac.squeeze(-1) * mvn.mean scale_fac = scale_fac.squeeze(-1).expand_as(mean_tf) @@ -449,7 +477,7 @@ def untransform_posterior( if ( not mvn.islazy - # TODO: Figure out attribute namming weirdness here + # TODO: Figure out attribute naming weirdness here or mvn._MultivariateNormal__unbroadcasted_scale_tril is not None ): # if already computed, we can save a lot of time using scale_tril @@ -465,6 +493,213 @@ def untransform_posterior( return GPyTorchPosterior(mvn_tf) +class StratifiedStandardize(Standardize): + r"""Standardize outcomes (zero mean, unit variance) along stratification dimension. + + This module is stateful: If in train mode, calling forward updates the + module state (i.e. the mean/std normalizing constants). If in eval mode, + calling forward simply applies the standardization using the current module + state. + """ + + def __init__( + self, + task_values: Tensor, + stratification_idx: int, + batch_shape: torch.Size = torch.Size(), # noqa: B008 + min_stdv: float = 1e-8, + # dtype: torch.dtype = torch.double, + ) -> None: + r"""Standardize outcomes (zero mean, unit variance) along stratification dim. + + Note: This currenlty only supports single output models + (including multi-task models that have a single output). + + Args: + task_values: `t`-dim tensor of task values. + stratification_idx: The index of the stratification dimension. + batch_shape: The batch_shape of the training targets. + min_stddv: The minimum standard deviation for which to perform + standardization (if lower, only de-mean the data). + """ + OutcomeTransform.__init__(self) + self._stratification_idx = stratification_idx + task_values = task_values.unique(sorted=True) + self.strata_mapping = get_task_value_remapping(task_values, dtype=torch.long) + if self.strata_mapping is None: + self.strata_mapping = task_values + n_strata = self.strata_mapping.shape[0] + self._min_stdv = min_stdv + self.register_buffer("means", torch.zeros(*batch_shape, n_strata, 1)) + self.register_buffer("stdvs", torch.ones(*batch_shape, n_strata, 1)) + self.register_buffer("_stdvs_sq", torch.ones(*batch_shape, n_strata, 1)) + self.register_buffer("_is_trained", torch.tensor(False)) + self._batch_shape = batch_shape + self._m = 1 # TODO: support multiple outputs + self._outputs = None + + def forward( + self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None + ) -> tuple[Tensor, Tensor | None]: + r"""Standardize outcomes. + + If the module is in train mode, this updates the module state (i.e. the + mean/std normalizing constants). If the module is in eval mode, simply + applies the normalization using the module state. + + Args: + Y: A `batch_shape x n x m`-dim tensor of training targets. + Yvar: A `batch_shape x n x m`-dim tensor of observation noises + associated with the training targets (if applicable). + X: A `batch_shape x n x d`-dim tensor of input parameters. + + Returns: + A two-tuple with the transformed outcomes: + + - The transformed outcome observations. + - The transformed observation noise (if applicable). + """ + if X is None: + raise ValueError("X is required for StratifiedStandardize.") + if self.training: + if Y.shape[:-2] != self._batch_shape: + raise RuntimeError( + f"Expected Y.shape[:-2] to be {self._batch_shape}, matching " + "the `batch_shape` argument to `StratifiedStandardize`, but got " + f"Y.shape[:-2]={Y.shape[:-2]}." + ) + elif Y.shape[-2] < 1: + raise ValueError(f"Can't standardize with no observations. {Y.shape=}.") + elif Y.size(-1) != self._m: + raise RuntimeError( + f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected " + f"{self._m}." + ) + self.means = self.means.to(dtype=X.dtype, device=X.device) + self.stdvs = self.stdvs.to(dtype=X.dtype, device=X.device) + self._stdvs_sq = self._stdvs_sq.to(dtype=X.dtype, device=X.device) + strata = X[..., self._stratification_idx].long() + unique_strata = strata.unique() + for s in unique_strata: + mapped_strata = self.strata_mapping[s] + mask = strata != s + Y_strata = Y.clone() + Y_strata[..., mask, :] = float("nan") + if Y.shape[-2] == 1: + stdvs = torch.ones( + (*Y_strata.shape[:-2], 1, Y_strata.shape[-1]), + dtype=Y.dtype, + device=Y.device, + ) + else: + stdvs = nanstd(X=Y_strata, dim=-2) + stdvs = stdvs.where( + stdvs >= self._min_stdv, torch.full_like(stdvs, 1.0) + ) + means = Y_strata.nanmean(dim=-2) + self.means[..., mapped_strata, :] = means + self.stdvs[..., mapped_strata, :] = stdvs + self._stdvs_sq[..., mapped_strata, :] = stdvs.pow(2) + self._is_trained = torch.tensor(True) + training = self.training + self.training = False + tf_Y, tf_Yvar = super().forward(Y=Y, Yvar=Yvar, X=X) + self.training = training + return tf_Y, tf_Yvar + + def _get_per_input_means_stdvs( + self, X: Tensor, include_stdvs_sq: bool + ) -> tuple[Tensor, Tensor, Tensor | None]: + r"""Get per-input means and stdvs. + + Args: + X: A `batch_shape x n x d`-dim tensor of input parameters. + include_stdvs_sq: Whether to include the stdvs squared. + + Returns: + A three-tuple with the per-input means and stdvs: + + - The per-input means. + - The per-input stdvs. + - The per-input stdvs squared. + """ + strata = X[..., self._stratification_idx].long() + mapped_strata = self.strata_mapping[strata].unsqueeze(-1) + # get means and stdvs for each strata + n_extra_batch_dims = mapped_strata.ndim - 2 - len(self._batch_shape) + view_shape = torch.Size([1] * n_extra_batch_dims) + self.means.shape + expand_shape = mapped_strata.shape[:n_extra_batch_dims] + self.means.shape + means = torch.gather( + input=self.means.view(view_shape).expand(expand_shape), + dim=-2, + index=mapped_strata, + ) + stdvs = torch.gather( + input=self.stdvs.view(view_shape).expand(expand_shape), + dim=-2, + index=mapped_strata, + ) + if include_stdvs_sq: + stdvs_sq = torch.gather( + input=self._stdvs_sq.view(view_shape).expand(expand_shape), + dim=-2, + index=mapped_strata, + ) + else: + stdvs_sq = None + return means, stdvs, stdvs_sq + + def subset_output(self, idcs: list[int]) -> OutcomeTransform: + r"""Subset the transform along the output dimension. + + Args: + idcs: The output indices to subset the transform to. + + Returns: + The current outcome transform, subset to the specified output indices. + """ + raise NotImplementedError + + def untransform( + self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None + ) -> tuple[Tensor, Tensor | None]: + r"""Un-standardize outcomes. + + Args: + Y: A `batch_shape x n x m`-dim tensor of standardized targets. + Yvar: A `batch_shape x n x m`-dim tensor of standardized observation + noises associated with the targets (if applicable). + X: A `batch_shape x n x d`-dim tensor of input parameters. + + Returns: + A two-tuple with the un-standardized outcomes: + + - The un-standardized outcome observations. + - The un-standardized observation noise (if applicable). + """ + if X is None: + raise ValueError("X is required for StratifiedStandardize.") + return super().untransform(Y=Y, Yvar=Yvar, X=X) + + def untransform_posterior( + self, posterior: Posterior, X: Tensor | None = None + ) -> GPyTorchPosterior | TransformedPosterior: + r"""Un-standardize the posterior. + + Args: + posterior: A posterior in the standardized space. + X: A `batch_shape x n x d`-dim tensor of training inputs (if applicable). + + Returns: + The un-standardized posterior. If the input posterior is a + `GPyTorchPosterior`, return a `GPyTorchPosterior`. Otherwise, return a + `TransformedPosterior`. + """ + if X is None: + raise ValueError("X is required for StratifiedStandardize.") + return super().untransform_posterior(posterior=posterior, X=X) + + class Log(OutcomeTransform): r"""Log-transform outcomes. diff --git a/botorch/models/transforms/utils.py b/botorch/models/transforms/utils.py index 17901d2efb..5d209c7048 100644 --- a/botorch/models/transforms/utils.py +++ b/botorch/models/transforms/utils.py @@ -141,3 +141,12 @@ def interaction_features(X: Tensor) -> Tensor: dim = X.shape[-1] row_idcs, col_idcs = torch.triu_indices(dim, dim, offset=1) return (X.unsqueeze(-1) @ X.unsqueeze(-2))[..., row_idcs, col_idcs].unsqueeze(-2) + + +def nanstd(X: Tensor, dim: int, keepdim: bool = False) -> Tensor: + n = (~torch.isnan(X)).sum(dim=dim) + return ( + (X - X.nanmean(dim=dim, keepdim=True)).pow(2).nanmean(dim=dim, keepdim=keepdim) + * n + / (n - 1) + ).sqrt() diff --git a/botorch/models/utils/assorted.py b/botorch/models/utils/assorted.py index fa146cb5dd..f0970e6a5f 100644 --- a/botorch/models/utils/assorted.py +++ b/botorch/models/utils/assorted.py @@ -397,3 +397,37 @@ class fantasize(_Flag): r"""A flag denoting whether we are currently in a `fantasize` context.""" _state: bool = False + + +def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor | None: + """Construct an mapping of discrete task values to contiguous int-valued floats. + + Args: + task_values: A sorted long-valued tensor of task values. + dtype: The dtype of the model inputs (e.g. `X`), which the new + task values should have mapped to (e.g. float, double). + + Returns: + A tensor of shape `task_values.max() + 1` that maps task values + to new task values. The indexing operation `mapper[task_value]` + will produce a tensor of new task values, of the same shape as + the original. The elements of the `mapper` tensor that do not + appear in the original `task_values` are mapped to `nan`. The + return value will be `None`, when the task values are contiguous + integers starting from zero. + """ + task_range = torch.arange( + len(task_values), dtype=task_values.dtype, device=task_values.device + ) + mapper = None + if not torch.equal(task_values, task_range): + # Create a tensor that maps task values to new task values. + # The number of tasks should be small, so this should be quite efficient. + mapper = torch.full( + (int(task_values.max().item()) + 1,), + float("nan"), + dtype=dtype, + device=task_values.device, + ) + mapper[task_values] = task_range.to(dtype=dtype) + return mapper diff --git a/test/models/transforms/test_outcome.py b/test/models/transforms/test_outcome.py index 49fa23862f..7d3c751ab3 100644 --- a/test/models/transforms/test_outcome.py +++ b/test/models/transforms/test_outcome.py @@ -15,6 +15,7 @@ OutcomeTransform, Power, Standardize, + StratifiedStandardize, ) from botorch.models.transforms.utils import ( norm_to_lognorm_mean, @@ -365,6 +366,89 @@ def test_standardize_state_dict(self): new_transform.load_state_dict(state_dict) self.assertTrue(new_transform._is_trained) + def test_stratified_standardize(self): + n = 5 + for dtype, batch_shape in itertools.product( + (torch.float, torch.double), (torch.Size([]), torch.Size([3])) + ): + X = torch.rand(*batch_shape, n, 2, dtype=dtype, device=self.device) + X[..., -1] = torch.tensor([0, 1, 0, 1, 0], dtype=dtype, device=self.device) + Y = torch.randn(*batch_shape, n, 1, dtype=dtype, device=self.device) + Yvar = torch.rand(*batch_shape, n, 1, dtype=dtype, device=self.device) + strata_tf = StratifiedStandardize( + task_values=torch.tensor([0, 1], dtype=torch.long, device=self.device), + stratification_idx=-1, + batch_shape=batch_shape, + ) + tf_Y, tf_Yvar = strata_tf(Y=Y, Yvar=Yvar, X=X) + mask0 = X[..., -1] == 0 + mask1 = ~mask0 + Y0 = Y[mask0].view(*batch_shape, -1, 1) + Yvar0 = Yvar[mask0].view(*batch_shape, -1, 1) + X0 = X[mask0].view(*batch_shape, -1, 1) + Y1 = Y[mask1].view(*batch_shape, -1, 1) + Yvar1 = Yvar[mask1].view(*batch_shape, -1, 1) + X1 = X[mask1].view(*batch_shape, -1, 1) + tf0 = Standardize( + m=1, + batch_shape=batch_shape, + ) + tf_Y0, tf_Yvar0 = tf0(Y=Y0, Yvar=Yvar0, X=X0) + tf1 = Standardize( + m=1, + batch_shape=batch_shape, + ) + tf_Y1, tf_Yvar1 = tf1(Y=Y1, Yvar=Yvar1, X=X1) + # check that stratified means are expected + self.assertTrue(torch.allclose(strata_tf.means[..., :1, :], tf0.means)) + self.assertTrue(torch.allclose(strata_tf.means[..., 1:, :], tf1.means)) + self.assertTrue(torch.allclose(strata_tf.stdvs[..., :1, :], tf0.stdvs)) + self.assertTrue(torch.allclose(strata_tf.stdvs[..., 1:, :], tf1.stdvs)) + # check the transformed values + self.assertTrue( + torch.allclose(tf_Y0, tf_Y[mask0].view(*batch_shape, -1, 1)) + ) + self.assertTrue( + torch.allclose(tf_Y1, tf_Y[mask1].view(*batch_shape, -1, 1)) + ) + self.assertTrue( + torch.allclose(tf_Yvar0, tf_Yvar[mask0].view(*batch_shape, -1, 1)) + ) + self.assertTrue( + torch.allclose(tf_Yvar1, tf_Yvar[mask1].view(*batch_shape, -1, 1)) + ) + untf_Y, untf_Yvar = strata_tf.untransform(Y=tf_Y, Yvar=tf_Yvar, X=X) + # test untransform + self.assertTrue(torch.allclose(Y, untf_Y)) + self.assertTrue(torch.allclose(Yvar, untf_Yvar)) + + # test untransform_posterior + for lazy in (True, False): + shape = batch_shape + torch.Size([n, 1]) + posterior = _get_test_posterior( + shape, + device=self.device, + dtype=dtype, + interleaved=False, + lazy=lazy, + ) + p_utf = strata_tf.untransform_posterior(posterior, X=X) + self.assertEqual(p_utf.device.type, self.device.type) + self.assertTrue(p_utf.dtype == dtype) + strata_means, strata_stdvs, _ = strata_tf._get_per_input_means_stdvs( + X=X, include_stdvs_sq=False + ) + mean_expected = strata_means + strata_stdvs * posterior.mean + variance_expected = strata_stdvs**2 * posterior.variance + self.assertAllClose(p_utf.mean, mean_expected) + self.assertAllClose(p_utf.variance, variance_expected) + samples = p_utf.rsample() + self.assertEqual(samples.shape, torch.Size([1]) + shape) + samples = p_utf.rsample(sample_shape=torch.Size([4])) + self.assertEqual(samples.shape, torch.Size([4]) + shape) + samples2 = p_utf.rsample(sample_shape=torch.Size([4, 2])) + self.assertEqual(samples2.shape, torch.Size([4, 2]) + shape) + def test_log(self): ms = (1, 2) batch_shapes = (torch.Size(), torch.Size([2]))