From e220777cd669935ef1e53dc0f55cd359af411cf4 Mon Sep 17 00:00:00 2001
From: Sam Daulton <sdaulton@meta.com>
Date: Thu, 9 Jan 2025 10:12:10 -0800
Subject: [PATCH] StratifiedStandardize OutcomeTransform (#2671)

Summary:

see title. This allows applying stratified standardization at the model level, which will enable selecting whether to use a Single-task or multi-task model in Ax while using the appropriate transform. I.e. One could specify ModelConfigs that use 1) `SingleTaskGP` + `Standardize`, 2) `MultiTaskGP` + `StratifiedStandardize`.

Differential Revision: D67728920
---
 botorch/models/multitask.py            |  35 +---
 botorch/models/transforms/outcome.py   | 253 ++++++++++++++++++++++++-
 botorch/models/transforms/utils.py     |   9 +
 botorch/models/utils/assorted.py       |  34 ++++
 test/models/transforms/test_outcome.py |  84 ++++++++
 5 files changed, 372 insertions(+), 43 deletions(-)

diff --git a/botorch/models/multitask.py b/botorch/models/multitask.py
index eaf0fa41a5..876e2c8ede 100644
--- a/botorch/models/multitask.py
+++ b/botorch/models/multitask.py
@@ -39,6 +39,7 @@
 from botorch.models.model import FantasizeMixin
 from botorch.models.transforms.input import InputTransform
 from botorch.models.transforms.outcome import OutcomeTransform, Standardize
+from botorch.models.utils.assorted import get_task_value_remapping
 from botorch.models.utils.gpytorch_modules import (
     get_covar_module_with_dim_scaled_prior,
     get_gaussian_likelihood_with_lognormal_prior,
@@ -82,40 +83,6 @@
 from torch import Tensor
 
 
-def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor | None:
-    """Construct an mapping of discrete task values to contiguous int-valued floats.
-
-    Args:
-        task_values: A sorted long-valued tensor of task values.
-        dtype: The dtype of the model inputs (e.g. `X`), which the new
-            task values should have mapped to (e.g. float, double).
-
-    Returns:
-        A tensor of shape `task_values.max() + 1` that maps task values
-        to new task values. The indexing operation `mapper[task_value]`
-        will produce a tensor of new task values, of the same shape as
-        the original. The elements of the `mapper` tensor that do not
-        appear in the original `task_values` are mapped to `nan`. The
-        return value will be `None`, when the task values are contiguous
-        integers starting from zero.
-    """
-    task_range = torch.arange(
-        len(task_values), dtype=task_values.dtype, device=task_values.device
-    )
-    mapper = None
-    if not torch.equal(task_values, task_range):
-        # Create a tensor that maps task values to new task values.
-        # The number of tasks should be small, so this should be quite efficient.
-        mapper = torch.full(
-            (int(task_values.max().item()) + 1,),
-            float("nan"),
-            dtype=dtype,
-            device=task_values.device,
-        )
-        mapper[task_values] = task_range.to(dtype=dtype)
-    return mapper
-
-
 class MultiTaskGP(ExactGP, MultiTaskGPyTorchModel, FantasizeMixin):
     r"""Multi-Task exact GP model using an ICM (intrinsic co-regionalization model)
     kernel. See [Bonilla2007MTGP]_ and [Swersky2013MTBO]_ for a reference on the
diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py
index 4c01204dd6..b9f58f63bc 100644
--- a/botorch/models/transforms/outcome.py
+++ b/botorch/models/transforms/outcome.py
@@ -27,9 +27,11 @@
 
 import torch
 from botorch.models.transforms.utils import (
+    nanstd,
     norm_to_lognorm_mean,
     norm_to_lognorm_variance,
 )
+from botorch.models.utils.assorted import get_task_value_remapping
 from botorch.posteriors import GPyTorchPosterior, Posterior, TransformedPosterior
 from botorch.utils.transforms import normalize_indices
 from linear_operator.operators import CholLinearOperator, DiagLinearOperator
@@ -259,6 +261,25 @@ def __init__(
         self._batch_shape = batch_shape
         self._min_stdv = min_stdv
 
+    def _get_per_input_means_stdvs(
+        self, X: Tensor, include_stdvs_sq: bool
+    ) -> tuple[Tensor, Tensor, Tensor | None]:
+        r"""Get per-input means and stdvs.
+
+        Args:
+            X: A `batch_shape x n x d`-dim tensor of input parameters.
+            include_stdvs_sq: Whether to include the stdvs squared.
+                This parameter is not used by this method
+
+        Returns:
+            A three-tuple with the  means and stdvs:
+
+            - The per-input means.
+            - The per-input stdvs.
+            - The per-input stdvs squared.
+        """
+        return self.means, self.stdvs, self._stdvs_sq
+
     def forward(
         self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
     ) -> tuple[Tensor, Tensor | None]:
@@ -313,9 +334,12 @@ def forward(
             self.stdvs = stdvs
             self._stdvs_sq = stdvs.pow(2)
             self._is_trained = torch.tensor(True)
-
-        Y_tf = (Y - self.means) / self.stdvs
-        Yvar_tf = Yvar / self._stdvs_sq if Yvar is not None else None
+        include_stdvs_sq = Yvar is not None
+        means, stdvs, stdvs_sq = self._get_per_input_means_stdvs(
+            X=X, include_stdvs_sq=include_stdvs_sq
+        )
+        Y_tf = (Y - means) / stdvs
+        Yvar_tf = Yvar / stdvs_sq if include_stdvs_sq else None
         return Y_tf, Yvar_tf
 
     def subset_output(self, idcs: list[int]) -> OutcomeTransform:
@@ -376,9 +400,12 @@ def untransform(
                 "(e.g. `transform(Y)`) before calling `untransform`, since "
                 "means and standard deviations need to be computed."
             )
-
-        Y_utf = self.means + self.stdvs * Y
-        Yvar_utf = self._stdvs_sq * Yvar if Yvar is not None else None
+        include_stdvs_sq = Yvar is not None
+        means, stdvs, stdvs_sq = self._get_per_input_means_stdvs(
+            X=X, include_stdvs_sq=include_stdvs_sq
+        )
+        Y_utf = means + stdvs * Y
+        Yvar_utf = stdvs_sq * Yvar if include_stdvs_sq else None
         return Y_utf, Yvar_utf
 
     @property
@@ -433,8 +460,9 @@ def untransform_posterior(
             )
         # GPyTorchPosterior (TODO: Should we Lazy-evaluate the mean here as well?)
         mvn = posterior.distribution
-        offset = self.means
-        scale_fac = self.stdvs
+        offset, scale_fac, _ = self._get_per_input_means_stdvs(
+            X=X, include_stdvs_sq=False
+        )
         if not posterior._is_mt:
             mean_tf = offset.squeeze(-1) + scale_fac.squeeze(-1) * mvn.mean
             scale_fac = scale_fac.squeeze(-1).expand_as(mean_tf)
@@ -449,7 +477,7 @@ def untransform_posterior(
 
         if (
             not mvn.islazy
-            # TODO: Figure out attribute namming weirdness here
+            # TODO: Figure out attribute naming weirdness here
             or mvn._MultivariateNormal__unbroadcasted_scale_tril is not None
         ):
             # if already computed, we can save a lot of time using scale_tril
@@ -465,6 +493,213 @@ def untransform_posterior(
         return GPyTorchPosterior(mvn_tf)
 
 
+class StratifiedStandardize(Standardize):
+    r"""Standardize outcomes (zero mean, unit variance) along stratification dimension.
+
+    This module is stateful: If in train mode, calling forward updates the
+    module state (i.e. the mean/std normalizing constants). If in eval mode,
+    calling forward simply applies the standardization using the current module
+    state.
+    """
+
+    def __init__(
+        self,
+        task_values: Tensor,
+        stratification_idx: int,
+        batch_shape: torch.Size = torch.Size(),  # noqa: B008
+        min_stdv: float = 1e-8,
+        # dtype: torch.dtype = torch.double,
+    ) -> None:
+        r"""Standardize outcomes (zero mean, unit variance) along stratification dim.
+
+        Note: This currenlty only supports single output models
+        (including multi-task models that have a single output).
+
+        Args:
+            task_values: `t`-dim tensor of task values.
+            stratification_idx: The index of the stratification dimension.
+            batch_shape: The batch_shape of the training targets.
+            min_stddv: The minimum standard deviation for which to perform
+                standardization (if lower, only de-mean the data).
+        """
+        OutcomeTransform.__init__(self)
+        self._stratification_idx = stratification_idx
+        task_values = task_values.unique(sorted=True)
+        self.strata_mapping = get_task_value_remapping(task_values, dtype=torch.long)
+        if self.strata_mapping is None:
+            self.strata_mapping = task_values
+        n_strata = self.strata_mapping.shape[0]
+        self._min_stdv = min_stdv
+        self.register_buffer("means", torch.zeros(*batch_shape, n_strata, 1))
+        self.register_buffer("stdvs", torch.ones(*batch_shape, n_strata, 1))
+        self.register_buffer("_stdvs_sq", torch.ones(*batch_shape, n_strata, 1))
+        self.register_buffer("_is_trained", torch.tensor(False))
+        self._batch_shape = batch_shape
+        self._m = 1  # TODO: support multiple outputs
+        self._outputs = None
+
+    def forward(
+        self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
+    ) -> tuple[Tensor, Tensor | None]:
+        r"""Standardize outcomes.
+
+        If the module is in train mode, this updates the module state (i.e. the
+        mean/std normalizing constants). If the module is in eval mode, simply
+        applies the normalization using the module state.
+
+        Args:
+            Y: A `batch_shape x n x m`-dim tensor of training targets.
+            Yvar: A `batch_shape x n x m`-dim tensor of observation noises
+                associated with the training targets (if applicable).
+            X: A `batch_shape x n x d`-dim tensor of input parameters.
+
+        Returns:
+            A two-tuple with the transformed outcomes:
+
+            - The transformed outcome observations.
+            - The transformed observation noise (if applicable).
+        """
+        if X is None:
+            raise ValueError("X is required for StratifiedStandardize.")
+        if self.training:
+            if Y.shape[:-2] != self._batch_shape:
+                raise RuntimeError(
+                    f"Expected Y.shape[:-2] to be {self._batch_shape}, matching "
+                    "the `batch_shape` argument to `StratifiedStandardize`, but got "
+                    f"Y.shape[:-2]={Y.shape[:-2]}."
+                )
+            elif Y.shape[-2] < 1:
+                raise ValueError(f"Can't standardize with no observations. {Y.shape=}.")
+            elif Y.size(-1) != self._m:
+                raise RuntimeError(
+                    f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected "
+                    f"{self._m}."
+                )
+            self.means = self.means.to(dtype=X.dtype, device=X.device)
+            self.stdvs = self.stdvs.to(dtype=X.dtype, device=X.device)
+            self._stdvs_sq = self._stdvs_sq.to(dtype=X.dtype, device=X.device)
+            strata = X[..., self._stratification_idx].long()
+            unique_strata = strata.unique()
+            for s in unique_strata:
+                mapped_strata = self.strata_mapping[s]
+                mask = strata != s
+                Y_strata = Y.clone()
+                Y_strata[..., mask, :] = float("nan")
+                if Y.shape[-2] == 1:
+                    stdvs = torch.ones(
+                        (*Y_strata.shape[:-2], 1, Y_strata.shape[-1]),
+                        dtype=Y.dtype,
+                        device=Y.device,
+                    )
+                else:
+                    stdvs = nanstd(X=Y_strata, dim=-2)
+                stdvs = stdvs.where(
+                    stdvs >= self._min_stdv, torch.full_like(stdvs, 1.0)
+                )
+                means = Y_strata.nanmean(dim=-2)
+                self.means[..., mapped_strata, :] = means
+                self.stdvs[..., mapped_strata, :] = stdvs
+                self._stdvs_sq[..., mapped_strata, :] = stdvs.pow(2)
+            self._is_trained = torch.tensor(True)
+        training = self.training
+        self.training = False
+        tf_Y, tf_Yvar = super().forward(Y=Y, Yvar=Yvar, X=X)
+        self.training = training
+        return tf_Y, tf_Yvar
+
+    def _get_per_input_means_stdvs(
+        self, X: Tensor, include_stdvs_sq: bool
+    ) -> tuple[Tensor, Tensor, Tensor | None]:
+        r"""Get per-input means and stdvs.
+
+        Args:
+            X: A `batch_shape x n x d`-dim tensor of input parameters.
+            include_stdvs_sq: Whether to include the stdvs squared.
+
+        Returns:
+            A three-tuple with the per-input means and stdvs:
+
+            - The per-input means.
+            - The per-input stdvs.
+            - The per-input stdvs squared.
+        """
+        strata = X[..., self._stratification_idx].long()
+        mapped_strata = self.strata_mapping[strata].unsqueeze(-1)
+        # get means and stdvs for each strata
+        n_extra_batch_dims = mapped_strata.ndim - 2 - len(self._batch_shape)
+        view_shape = torch.Size([1] * n_extra_batch_dims) + self.means.shape
+        expand_shape = mapped_strata.shape[:n_extra_batch_dims] + self.means.shape
+        means = torch.gather(
+            input=self.means.view(view_shape).expand(expand_shape),
+            dim=-2,
+            index=mapped_strata,
+        )
+        stdvs = torch.gather(
+            input=self.stdvs.view(view_shape).expand(expand_shape),
+            dim=-2,
+            index=mapped_strata,
+        )
+        if include_stdvs_sq:
+            stdvs_sq = torch.gather(
+                input=self._stdvs_sq.view(view_shape).expand(expand_shape),
+                dim=-2,
+                index=mapped_strata,
+            )
+        else:
+            stdvs_sq = None
+        return means, stdvs, stdvs_sq
+
+    def subset_output(self, idcs: list[int]) -> OutcomeTransform:
+        r"""Subset the transform along the output dimension.
+
+        Args:
+            idcs: The output indices to subset the transform to.
+
+        Returns:
+            The current outcome transform, subset to the specified output indices.
+        """
+        raise NotImplementedError
+
+    def untransform(
+        self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
+    ) -> tuple[Tensor, Tensor | None]:
+        r"""Un-standardize outcomes.
+
+        Args:
+            Y: A `batch_shape x n x m`-dim tensor of standardized targets.
+            Yvar: A `batch_shape x n x m`-dim tensor of standardized observation
+                noises associated with the targets (if applicable).
+            X: A `batch_shape x n x d`-dim tensor of input parameters.
+
+        Returns:
+            A two-tuple with the un-standardized outcomes:
+
+            - The un-standardized outcome observations.
+            - The un-standardized observation noise (if applicable).
+        """
+        if X is None:
+            raise ValueError("X is required for StratifiedStandardize.")
+        return super().untransform(Y=Y, Yvar=Yvar, X=X)
+
+    def untransform_posterior(
+        self, posterior: Posterior, X: Tensor | None = None
+    ) -> GPyTorchPosterior | TransformedPosterior:
+        r"""Un-standardize the posterior.
+
+        Args:
+            posterior: A posterior in the standardized space.
+            X: A `batch_shape x n x d`-dim tensor of training inputs (if applicable).
+
+        Returns:
+            The un-standardized posterior. If the input posterior is a
+            `GPyTorchPosterior`, return a `GPyTorchPosterior`. Otherwise, return a
+            `TransformedPosterior`.
+        """
+        if X is None:
+            raise ValueError("X is required for StratifiedStandardize.")
+        return super().untransform_posterior(posterior=posterior, X=X)
+
+
 class Log(OutcomeTransform):
     r"""Log-transform outcomes.
 
diff --git a/botorch/models/transforms/utils.py b/botorch/models/transforms/utils.py
index 17901d2efb..5d209c7048 100644
--- a/botorch/models/transforms/utils.py
+++ b/botorch/models/transforms/utils.py
@@ -141,3 +141,12 @@ def interaction_features(X: Tensor) -> Tensor:
     dim = X.shape[-1]
     row_idcs, col_idcs = torch.triu_indices(dim, dim, offset=1)
     return (X.unsqueeze(-1) @ X.unsqueeze(-2))[..., row_idcs, col_idcs].unsqueeze(-2)
+
+
+def nanstd(X: Tensor, dim: int, keepdim: bool = False) -> Tensor:
+    n = (~torch.isnan(X)).sum(dim=dim)
+    return (
+        (X - X.nanmean(dim=dim, keepdim=True)).pow(2).nanmean(dim=dim, keepdim=keepdim)
+        * n
+        / (n - 1)
+    ).sqrt()
diff --git a/botorch/models/utils/assorted.py b/botorch/models/utils/assorted.py
index fa146cb5dd..f0970e6a5f 100644
--- a/botorch/models/utils/assorted.py
+++ b/botorch/models/utils/assorted.py
@@ -397,3 +397,37 @@ class fantasize(_Flag):
     r"""A flag denoting whether we are currently in a `fantasize` context."""
 
     _state: bool = False
+
+
+def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor | None:
+    """Construct an mapping of discrete task values to contiguous int-valued floats.
+
+    Args:
+        task_values: A sorted long-valued tensor of task values.
+        dtype: The dtype of the model inputs (e.g. `X`), which the new
+            task values should have mapped to (e.g. float, double).
+
+    Returns:
+        A tensor of shape `task_values.max() + 1` that maps task values
+        to new task values. The indexing operation `mapper[task_value]`
+        will produce a tensor of new task values, of the same shape as
+        the original. The elements of the `mapper` tensor that do not
+        appear in the original `task_values` are mapped to `nan`. The
+        return value will be `None`, when the task values are contiguous
+        integers starting from zero.
+    """
+    task_range = torch.arange(
+        len(task_values), dtype=task_values.dtype, device=task_values.device
+    )
+    mapper = None
+    if not torch.equal(task_values, task_range):
+        # Create a tensor that maps task values to new task values.
+        # The number of tasks should be small, so this should be quite efficient.
+        mapper = torch.full(
+            (int(task_values.max().item()) + 1,),
+            float("nan"),
+            dtype=dtype,
+            device=task_values.device,
+        )
+        mapper[task_values] = task_range.to(dtype=dtype)
+    return mapper
diff --git a/test/models/transforms/test_outcome.py b/test/models/transforms/test_outcome.py
index 49fa23862f..7d3c751ab3 100644
--- a/test/models/transforms/test_outcome.py
+++ b/test/models/transforms/test_outcome.py
@@ -15,6 +15,7 @@
     OutcomeTransform,
     Power,
     Standardize,
+    StratifiedStandardize,
 )
 from botorch.models.transforms.utils import (
     norm_to_lognorm_mean,
@@ -365,6 +366,89 @@ def test_standardize_state_dict(self):
                 new_transform.load_state_dict(state_dict)
                 self.assertTrue(new_transform._is_trained)
 
+    def test_stratified_standardize(self):
+        n = 5
+        for dtype, batch_shape in itertools.product(
+            (torch.float, torch.double), (torch.Size([]), torch.Size([3]))
+        ):
+            X = torch.rand(*batch_shape, n, 2, dtype=dtype, device=self.device)
+            X[..., -1] = torch.tensor([0, 1, 0, 1, 0], dtype=dtype, device=self.device)
+            Y = torch.randn(*batch_shape, n, 1, dtype=dtype, device=self.device)
+            Yvar = torch.rand(*batch_shape, n, 1, dtype=dtype, device=self.device)
+            strata_tf = StratifiedStandardize(
+                task_values=torch.tensor([0, 1], dtype=torch.long, device=self.device),
+                stratification_idx=-1,
+                batch_shape=batch_shape,
+            )
+            tf_Y, tf_Yvar = strata_tf(Y=Y, Yvar=Yvar, X=X)
+            mask0 = X[..., -1] == 0
+            mask1 = ~mask0
+            Y0 = Y[mask0].view(*batch_shape, -1, 1)
+            Yvar0 = Yvar[mask0].view(*batch_shape, -1, 1)
+            X0 = X[mask0].view(*batch_shape, -1, 1)
+            Y1 = Y[mask1].view(*batch_shape, -1, 1)
+            Yvar1 = Yvar[mask1].view(*batch_shape, -1, 1)
+            X1 = X[mask1].view(*batch_shape, -1, 1)
+            tf0 = Standardize(
+                m=1,
+                batch_shape=batch_shape,
+            )
+            tf_Y0, tf_Yvar0 = tf0(Y=Y0, Yvar=Yvar0, X=X0)
+            tf1 = Standardize(
+                m=1,
+                batch_shape=batch_shape,
+            )
+            tf_Y1, tf_Yvar1 = tf1(Y=Y1, Yvar=Yvar1, X=X1)
+            # check that stratified means are expected
+            self.assertTrue(torch.allclose(strata_tf.means[..., :1, :], tf0.means))
+            self.assertTrue(torch.allclose(strata_tf.means[..., 1:, :], tf1.means))
+            self.assertTrue(torch.allclose(strata_tf.stdvs[..., :1, :], tf0.stdvs))
+            self.assertTrue(torch.allclose(strata_tf.stdvs[..., 1:, :], tf1.stdvs))
+            # check the transformed values
+            self.assertTrue(
+                torch.allclose(tf_Y0, tf_Y[mask0].view(*batch_shape, -1, 1))
+            )
+            self.assertTrue(
+                torch.allclose(tf_Y1, tf_Y[mask1].view(*batch_shape, -1, 1))
+            )
+            self.assertTrue(
+                torch.allclose(tf_Yvar0, tf_Yvar[mask0].view(*batch_shape, -1, 1))
+            )
+            self.assertTrue(
+                torch.allclose(tf_Yvar1, tf_Yvar[mask1].view(*batch_shape, -1, 1))
+            )
+            untf_Y, untf_Yvar = strata_tf.untransform(Y=tf_Y, Yvar=tf_Yvar, X=X)
+            # test untransform
+            self.assertTrue(torch.allclose(Y, untf_Y))
+            self.assertTrue(torch.allclose(Yvar, untf_Yvar))
+
+            # test untransform_posterior
+            for lazy in (True, False):
+                shape = batch_shape + torch.Size([n, 1])
+                posterior = _get_test_posterior(
+                    shape,
+                    device=self.device,
+                    dtype=dtype,
+                    interleaved=False,
+                    lazy=lazy,
+                )
+                p_utf = strata_tf.untransform_posterior(posterior, X=X)
+                self.assertEqual(p_utf.device.type, self.device.type)
+                self.assertTrue(p_utf.dtype == dtype)
+                strata_means, strata_stdvs, _ = strata_tf._get_per_input_means_stdvs(
+                    X=X, include_stdvs_sq=False
+                )
+                mean_expected = strata_means + strata_stdvs * posterior.mean
+                variance_expected = strata_stdvs**2 * posterior.variance
+                self.assertAllClose(p_utf.mean, mean_expected)
+                self.assertAllClose(p_utf.variance, variance_expected)
+                samples = p_utf.rsample()
+                self.assertEqual(samples.shape, torch.Size([1]) + shape)
+                samples = p_utf.rsample(sample_shape=torch.Size([4]))
+                self.assertEqual(samples.shape, torch.Size([4]) + shape)
+                samples2 = p_utf.rsample(sample_shape=torch.Size([4, 2]))
+                self.assertEqual(samples2.shape, torch.Size([4, 2]) + shape)
+
     def test_log(self):
         ms = (1, 2)
         batch_shapes = (torch.Size(), torch.Size([2]))