diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index a143dd1c23..9923d75c35 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -532,7 +532,7 @@ def __init__( OutcomeTransform.__init__(self) self._stratification_idx = stratification_idx task_values = task_values.unique(sorted=True) - self.strata_mapping = get_task_value_remapping(task_values, dtype=torch.long) + self.strata_mapping = get_task_value_remapping(task_values, dtype=torch.double) if self.strata_mapping is None: self.strata_mapping = task_values n_strata = self.strata_mapping.shape[0] @@ -576,7 +576,7 @@ def forward( strata = X[..., self._stratification_idx].long() unique_strata = strata.unique() for s in unique_strata: - mapped_strata = self.strata_mapping[s] + mapped_strata = self.strata_mapping[s].long() mask = strata != s Y_strata = Y.clone() Y_strata[..., mask, :] = float("nan") @@ -616,7 +616,7 @@ def _get_per_input_means_stdvs( - The per-input stdvs squared. """ strata = X[..., self._stratification_idx].long() - mapped_strata = self.strata_mapping[strata].unsqueeze(-1) + mapped_strata = self.strata_mapping[strata].unsqueeze(-1).long() # get means and stdvs for each strata n_extra_batch_dims = mapped_strata.ndim - 2 - len(self._batch_shape) expand_shape = mapped_strata.shape[:n_extra_batch_dims] + self.means.shape diff --git a/botorch/models/utils/assorted.py b/botorch/models/utils/assorted.py index 6cf04a5923..521b355515 100644 --- a/botorch/models/utils/assorted.py +++ b/botorch/models/utils/assorted.py @@ -422,6 +422,8 @@ def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor return value will be `None`, when the task values are contiguous integers starting from zero. """ + if dtype not in (torch.float, torch.double): + raise ValueError(f"dtype must be torch.float or torch.double, but got {dtype}.") task_range = torch.arange( len(task_values), dtype=task_values.dtype, device=task_values.device ) diff --git a/test/models/test_multitask.py b/test/models/test_multitask.py index b9edc3e073..f36ca4301c 100644 --- a/test/models/test_multitask.py +++ b/test/models/test_multitask.py @@ -700,3 +700,12 @@ def test_get_task_value_remapping(self) -> None: mapping = get_task_value_remapping(task_values, dtype) self.assertTrue(torch.equal(mapping[[1, 3]], expected_mapping_no_nan)) self.assertTrue(torch.isnan(mapping[[0, 2]]).all()) + + def test_get_task_value_remapping_invalid_dtype(self) -> None: + task_values = torch.tensor([1, 3]) + for dtype in (torch.int32, torch.long, torch.bool): + with self.assertRaisesRegex( + ValueError, + f"dtype must be torch.float or torch.double, but got {dtype}.", + ): + get_task_value_remapping(task_values, dtype) diff --git a/test/models/transforms/test_outcome.py b/test/models/transforms/test_outcome.py index 29d31593b9..03a5c1167a 100644 --- a/test/models/transforms/test_outcome.py +++ b/test/models/transforms/test_outcome.py @@ -372,16 +372,24 @@ def test_stratified_standardize(self): n = 5 seed = randint(0, 100) torch.manual_seed(seed) - for dtype, batch_shape in itertools.product( - (torch.float, torch.double), (torch.Size([]), torch.Size([3])) + for dtype, batch_shape, task_values in itertools.product( + (torch.float, torch.double), + (torch.Size([]), torch.Size([3])), + ( + torch.tensor([0, 1], dtype=torch.long, device=self.device), + torch.tensor([0, 3], dtype=torch.long, device=self.device), + ), ): torch.manual_seed(seed) + tval = task_values[1].item() X = torch.rand(*batch_shape, n, 2, dtype=dtype, device=self.device) - X[..., -1] = torch.tensor([0, 1, 0, 1, 0], dtype=dtype, device=self.device) + X[..., -1] = torch.tensor( + [0, tval, 0, tval, 0], dtype=dtype, device=self.device + ) Y = torch.randn(*batch_shape, n, 1, dtype=dtype, device=self.device) Yvar = torch.rand(*batch_shape, n, 1, dtype=dtype, device=self.device) strata_tf = StratifiedStandardize( - task_values=torch.tensor([0, 1], dtype=torch.long, device=self.device), + task_values=task_values, stratification_idx=-1, batch_shape=batch_shape, ) @@ -400,9 +408,11 @@ def test_stratified_standardize(self): tf_Y1, tf_Yvar1 = tf1(Y=Y1, Yvar=Yvar1, X=X1) # check that stratified means are expected self.assertAllClose(strata_tf.means[..., :1, :], tf0.means) - self.assertAllClose(strata_tf.means[..., 1:, :], tf1.means) + # use remapped task values to index + self.assertAllClose(strata_tf.means[..., 1:2, :], tf1.means) self.assertAllClose(strata_tf.stdvs[..., :1, :], tf0.stdvs) - self.assertAllClose(strata_tf.stdvs[..., 1:, :], tf1.stdvs) + # use remapped task values to index + self.assertAllClose(strata_tf.stdvs[..., 1:2, :], tf1.stdvs) # check the transformed values self.assertAllClose(tf_Y0, tf_Y[mask0].view(*batch_shape, -1, 1)) self.assertAllClose(tf_Y1, tf_Y[mask1].view(*batch_shape, -1, 1))