diff --git a/torchrec/metrics/auc.py b/torchrec/metrics/auc.py index 688f04583..0637eda52 100644 --- a/torchrec/metrics/auc.py +++ b/torchrec/metrics/auc.py @@ -6,10 +6,11 @@ # LICENSE file in the root directory of this source tree. from functools import partial -from typing import Any, cast, Dict, List, Optional, Type +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type import torch import torch.distributed as dist +from torchmetrics.utilities.distributed import gather_all_tensors from torchrec.metrics.metrics_config import RecComputeMode, RecTaskInfo from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix from torchrec.metrics.rec_metric import ( @@ -26,6 +27,29 @@ REQUIRED_INPUTS = "required_inputs" +def _concat_if_needed( + predictions: List[torch.Tensor], + labels: List[torch.Tensor], + weights: List[torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + This check exists because of how the state is organized due to quirks in RecMetrics. + Since we do not do tensor concatenatation in the compute or update call, there are cases (in non-distributed settings) + where the tensors from updates are not concatted into a single tensor. Which is determined by the length of the list. + """ + preds_t, labels_t, weights_t = None, None, None + if len(predictions) > 1: + preds_t = torch.cat(predictions, dim=-1) + labels_t = torch.cat(labels, dim=-1) + weights_t = torch.cat(weights, dim=-1) + else: + preds_t = predictions[0] + labels_t = labels[0] + weights_t = weights[0] + + return preds_t, labels_t, weights_t + + def _compute_auc_helper( predictions: torch.Tensor, labels: torch.Tensor, @@ -50,9 +74,9 @@ def _compute_auc_helper( def compute_auc( n_tasks: int, - predictions: torch.Tensor, - labels: torch.Tensor, - weights: torch.Tensor, + predictions: List[torch.Tensor], + labels: List[torch.Tensor], + weights: List[torch.Tensor], apply_bin: bool = False, ) -> torch.Tensor: """ @@ -60,12 +84,13 @@ def compute_auc( Args: n_tasks (int): number of tasks. - predictions (torch.Tensor): tensor of size (n_tasks, n_examples). - labels (torch.Tensor): tensor of size (n_tasks, n_examples). - weights (torch.Tensor): tensor of size (n_tasks, n_examples). + predictions (List[torch.Tensor]): tensor of size (n_tasks, n_examples). + labels (List[torch.Tensor]): tensor of size (n_tasks, n_examples). + weights (List[torch.Tensor]): tensor of size (n_tasks, n_examples). """ + preds_t, labels_t, weights_t = _concat_if_needed(predictions, labels, weights) aucs = [] - for predictions_i, labels_i, weights_i in zip(predictions, labels, weights): + for predictions_i, labels_i, weights_i in zip(preds_t, labels_t, weights_t): auc = _compute_auc_helper(predictions_i, labels_i, weights_i, apply_bin) aucs.append(auc.view(1)) return torch.cat(aucs) @@ -73,36 +98,34 @@ def compute_auc( def compute_auc_per_group( n_tasks: int, - predictions: torch.Tensor, - labels: torch.Tensor, - weights: torch.Tensor, + predictions: List[torch.Tensor], + labels: List[torch.Tensor], + weights: List[torch.Tensor], grouping_keys: torch.Tensor, ) -> torch.Tensor: """ Computes AUC (Area Under the Curve) for binary classification for groups of predictions/labels. Args: n_tasks (int): number of tasks - predictions (torch.Tensor): tensor of size (n_tasks, n_examples) - labels (torch.Tensor): tensor of size (n_tasks, n_examples) - weights (torch.Tensor): tensor of size (n_tasks, n_examples) + predictions (List[torch.Tensor]): tensor of size (n_tasks, n_examples) + labels (List[torch.Tensor]: tensor of size (n_tasks, n_examples) + weights (List[torch.Tensor]): tensor of size (n_tasks, n_examples) grouping_keys (torch.Tensor): tensor of size (n_examples,) Returns: torch.Tensor: tensor of size (n_tasks,), average of AUCs per group. """ + preds_t, labels_t, weights_t = _concat_if_needed(predictions, labels, weights) aucs = [] if grouping_keys.numel() != 0 and grouping_keys[0] == -1: # we added padding as the first elements during init to avoid floating point exception in sync() # removing the paddings to avoid numerical errors. grouping_keys = grouping_keys[1:] - predictions = predictions[:, 1:] - labels = labels[:, 1:] - weights = weights[:, 1:] # get unique group indices group_indices = torch.unique(grouping_keys) - for (predictions_i, labels_i, weights_i) in zip(predictions, labels, weights): + for (predictions_i, labels_i, weights_i) in zip(preds_t, labels_t, weights_t): # Loop over each group auc_groups_sum = torch.tensor([0], dtype=torch.float32) for group_idx in group_indices: @@ -162,6 +185,7 @@ def __init__( self._grouped_auc: bool = grouped_auc self._apply_bin: bool = apply_bin + self._num_samples: int = 0 self._add_state( PREDICTIONS, [], @@ -204,7 +228,7 @@ def __init__( def _init_states(self) -> None: if len(getattr(self, PREDICTIONS)) > 0: return - + self._num_samples = 0 getattr(self, PREDICTIONS).append( torch.zeros((self._n_tasks, 1), dtype=torch.float, device=self.device) ) @@ -241,25 +265,42 @@ def update( predictions = predictions.float() labels = labels.float() weights = weights.float() - num_samples = getattr(self, PREDICTIONS)[0].size(-1) batch_size = predictions.size(-1) - start_index = max(num_samples + batch_size - self._window_size, 0) + start_index = max(self._num_samples + batch_size - self._window_size, 0) + # Using `self.predictions =` will cause Pyre errors. - getattr(self, PREDICTIONS)[0] = torch.cat( - [ - cast(torch.Tensor, getattr(self, PREDICTIONS)[0])[:, start_index:], - predictions, - ], - dim=-1, - ) - getattr(self, LABELS)[0] = torch.cat( - [cast(torch.Tensor, getattr(self, LABELS)[0])[:, start_index:], labels], - dim=-1, - ) - getattr(self, WEIGHTS)[0] = torch.cat( - [cast(torch.Tensor, getattr(self, WEIGHTS)[0])[:, start_index:], weights], - dim=-1, - ) + w_preds = getattr(self, PREDICTIONS) + w_labels = getattr(self, LABELS) + w_weights = getattr(self, WEIGHTS) + + # remove init states + if self._num_samples == 0: + for lst in [w_preds, w_labels, w_weights]: + lst.pop(0) + + w_preds.append(predictions) + w_labels.append(labels) + w_weights.append(weights) + + self._num_samples += batch_size + + while self._num_samples > self._window_size: + diff = self._num_samples - self._window_size + if diff > w_preds[0].size(-1): + self._num_samples -= w_preds[0].size(-1) + # Remove the first element from predictions, labels, and weights + for lst in [w_preds, w_labels, w_weights]: + lst.pop(0) + else: + # Update the first element of predictions, labels, and weights + # Off by one potentially - keeping legacy behaviour + for lst in [w_preds, w_labels, w_weights]: + lst[0] = lst[0][:, diff:] + # if empty tensor, remove it + if torch.numel(lst[0]) == 0: + lst.pop(0) + self._num_samples -= diff + if self._grouped_auc: if REQUIRED_INPUTS not in kwargs or ( (grouping_keys := kwargs[REQUIRED_INPUTS].get(GROUPING_KEYS)) is None @@ -276,19 +317,21 @@ def update( ) def _compute(self) -> List[MetricComputationReport]: - reports = [ + reports = [] + reports.append( MetricComputationReport( name=MetricName.AUC, metric_prefix=MetricPrefix.WINDOW, value=compute_auc( self._n_tasks, - cast(torch.Tensor, getattr(self, PREDICTIONS)[0]), - cast(torch.Tensor, getattr(self, LABELS)[0]), - cast(torch.Tensor, getattr(self, WEIGHTS)[0]), + cast(List[torch.Tensor], getattr(self, PREDICTIONS)), + cast(List[torch.Tensor], getattr(self, LABELS)), + cast(List[torch.Tensor], getattr(self, WEIGHTS)), self._apply_bin, ), ) - ] + ) + if self._grouped_auc: reports.append( MetricComputationReport( @@ -296,15 +339,31 @@ def _compute(self) -> List[MetricComputationReport]: metric_prefix=MetricPrefix.WINDOW, value=compute_auc_per_group( self._n_tasks, - cast(torch.Tensor, getattr(self, PREDICTIONS)[0]), - cast(torch.Tensor, getattr(self, LABELS)[0]), - cast(torch.Tensor, getattr(self, WEIGHTS)[0]), - cast(torch.Tensor, getattr(self, GROUPING_KEYS)[0]), + cast(List[torch.Tensor], getattr(self, PREDICTIONS)), + cast(List[torch.Tensor], getattr(self, LABELS)), + cast(List[torch.Tensor], getattr(self, WEIGHTS)), + cast(torch.Tensor, getattr(self, GROUPING_KEYS))[0], ), ) ) return reports + def _sync_dist( + self, + dist_sync_fn: Callable = gather_all_tensors, # pyre-ignore[24] + process_group: Optional[Any] = None, # pyre-ignore[2] + ) -> None: + """ + This function is overridden from torchmetric.Metric, since for AUC we want to concat the tensors + right before the allgather collective is called. It directly changes the attributes/states, which + is ok because end of function sets the attributes to reduced values + """ + for attr in self._reductions: # pragma: no cover + val = getattr(self, attr) + if isinstance(val, list) and len(val) > 1: + setattr(self, attr, [torch.cat(val, dim=-1)]) + super()._sync_dist(dist_sync_fn, process_group) + def reset(self) -> None: super().reset() self._init_states() diff --git a/torchrec/metrics/test_utils/__init__.py b/torchrec/metrics/test_utils/__init__.py index ae78306e2..d9c57c4aa 100644 --- a/torchrec/metrics/test_utils/__init__.py +++ b/torchrec/metrics/test_utils/__init__.py @@ -41,7 +41,10 @@ def gen_test_batch( weight_value: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, n_classes: Optional[int] = None, + seed: Optional[int] = None, ) -> Dict[str, torch.Tensor]: + if seed is not None: + torch.manual_seed(seed) if label_value is not None: label = label_value else: @@ -340,6 +343,151 @@ def get_launch_config(world_size: int, rdzv_endpoint: str) -> pet.LaunchConfig: ) +def rec_metric_gpu_sync_test_launcher( + target_clazz: Type[RecMetric], + target_compute_mode: RecComputeMode, + test_clazz: Optional[Type[TestMetric]], + metric_name: str, + task_names: List[str], + fused_update_limit: int, + compute_on_all_ranks: bool, + should_validate_update: bool, + world_size: int, + entry_point: Callable[..., None], + batch_size: int = BATCH_SIZE, + batch_window_size: int = BATCH_WINDOW_SIZE, + **kwargs: Any, +) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + lc = get_launch_config( + world_size=world_size, rdzv_endpoint=os.path.join(tmpdir, "rdzv") + ) + + # launch using torch elastic, launches for each rank + pet.elastic_launch(lc, entrypoint=entry_point)( + target_clazz, + target_compute_mode, + test_clazz, + task_names, + metric_name, + world_size, + fused_update_limit, + compute_on_all_ranks, + should_validate_update, + batch_size, + batch_window_size, + ) + + +def sync_test_helper( + target_clazz: Type[RecMetric], + target_compute_mode: RecComputeMode, + test_clazz: Optional[Type[TestMetric]], + task_names: List[str], + metric_name: str, + world_size: int, + fused_update_limit: int = 0, + compute_on_all_ranks: bool = False, + should_validate_update: bool = False, + batch_size: int = BATCH_SIZE, + batch_window_size: int = BATCH_WINDOW_SIZE, + n_classes: Optional[int] = None, + zero_weights: bool = False, +) -> None: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + dist.init_process_group( + backend="gloo", + world_size=world_size, + rank=rank, + ) + + tasks = gen_test_tasks(task_names) + + auc = target_clazz( + world_size=world_size, + batch_size=batch_size, + my_rank=rank, + compute_on_all_ranks=compute_on_all_ranks, + tasks=tasks, + window_size=batch_window_size * world_size, + ) + + weight_value: Optional[torch.Tensor] = None + + _model_outs = [ + gen_test_batch( + label_name=task.label_name, + prediction_name=task.prediction_name, + weight_name=task.weight_name, + batch_size=batch_size, + n_classes=n_classes, + weight_value=weight_value, + seed=42, # we set seed because of how test metric places tensors on ranks + ) + for task in tasks + ] + model_outs = [] + model_outs.append({k: v for d in _model_outs for k, v in d.items()}) + + # we send an uneven number of tensors to each rank to test that GPU sync works + if rank == 0: + for _ in range(3): + labels, predictions, weights, _ = parse_task_model_outputs( + tasks, model_outs[0] + ) + auc.update(predictions=predictions, labels=labels, weights=weights) + elif rank == 1: + for _ in range(1): + labels, predictions, weights, _ = parse_task_model_outputs( + tasks, model_outs[0] + ) + auc.update(predictions=predictions, labels=labels, weights=weights) + + # check against test metric + test_metrics: TestRecMetricOutput = ({}, {}, {}, {}) + if test_clazz is not None: + # pyre-ignore[45]: Cannot instantiate abstract class `TestMetric`. + test_metric_obj = test_clazz(world_size, tasks) + # with how testmetric is setup we cannot do asymmertrical updates across ranks + # so we duplicate model_outs twice to match number of updates in aggregate + model_outs = model_outs * 2 + test_metrics = test_metric_obj.compute(model_outs, 2, batch_window_size, None) + + res = auc.compute() + + if rank == 0: + assert torch.allclose( + test_metrics[1][task_names[0]], + res[f"auc-{task_names[0]}|window_auc"], + ) + + # we also test the case where other rank has more tensors than rank 0 + auc.reset() + if rank == 0: + for _ in range(1): + labels, predictions, weights, _ = parse_task_model_outputs( + tasks, model_outs[0] + ) + auc.update(predictions=predictions, labels=labels, weights=weights) + elif rank == 1: + for _ in range(3): + labels, predictions, weights, _ = parse_task_model_outputs( + tasks, model_outs[0] + ) + auc.update(predictions=predictions, labels=labels, weights=weights) + + res = auc.compute() + + if rank == 0: + assert torch.allclose( + test_metrics[1][task_names[0]], + res[f"auc-{task_names[0]}|window_auc"], + ) + + dist.destroy_process_group() + + def rec_metric_value_test_launcher( target_clazz: Type[RecMetric], target_compute_mode: RecComputeMode, diff --git a/torchrec/metrics/tests/test_auc.py b/torchrec/metrics/tests/test_auc.py index e81922cd4..0bf19fa94 100644 --- a/torchrec/metrics/tests/test_auc.py +++ b/torchrec/metrics/tests/test_auc.py @@ -20,7 +20,9 @@ ) from torchrec.metrics.test_utils import ( metric_test_helper, + rec_metric_gpu_sync_test_launcher, rec_metric_value_test_launcher, + sync_test_helper, TestMetric, ) @@ -115,6 +117,27 @@ def test_fused_auc(self) -> None: ) +class AUCGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = AUCMetric + task_name: str = "auc" + + def test_sync_auc(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=AUCMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestAUCMetric, + metric_name=AUCGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) + + class AUCMetricValueTest(unittest.TestCase): r"""This set of tests verify the computation logic of AUC in several corner cases that we know the computation results. The goal is to @@ -177,6 +200,89 @@ def test_calc_auc_balanced(self) -> None: actual_auc = self.auc.compute()["auc-DefaultTask|window_auc"] torch.allclose(expected_auc, actual_auc) + def test_calc_uneven_updates(self) -> None: + auc = AUCMetric( + world_size=1, + my_rank=0, + batch_size=100, + tasks=[DefaultTaskInfo], + ) + + expected_auc = torch.tensor([0.4464], dtype=torch.float) + # first batch + self.labels["DefaultTask"] = torch.tensor([1, 0, 0]) + self.predictions["DefaultTask"] = torch.tensor([0.2, 0.6, 0.8]) + self.weights["DefaultTask"] = torch.tensor([0.13, 0.2, 0.5]) + + auc.update(**self.batches) + # second batch + self.labels["DefaultTask"] = torch.tensor([1, 1]) + self.predictions["DefaultTask"] = torch.tensor([0.4, 0.9]) + self.weights["DefaultTask"] = torch.tensor([0.8, 0.75]) + + auc.update(**self.batches) + multiple_batch = self.auc.compute()["auc-DefaultTask|window_auc"] + torch.allclose(expected_auc, multiple_batch) + + def test_window_size_auc(self) -> None: + # for determinisitc batches + torch.manual_seed(0) + + auc = AUCMetric( + world_size=1, + my_rank=0, + batch_size=5, + window_size=100, + tasks=[DefaultTaskInfo], + ) + + # init states, so we expect 3 (state tensors) * 4 bytes (float) + self.assertEqual(sum(auc.get_memory_usage().values()), 12) + + # bs = 5 + self.labels["DefaultTask"] = torch.rand(5) + self.predictions["DefaultTask"] = torch.rand(5) + self.weights["DefaultTask"] = torch.rand(5) + + for _ in range(1000): + auc.update(**self.batches) + + # check memory, window size is 100, so we have upperbound of memory to expect + # so with a 100 window size / tensors of size 5 = 20 tensors (per state) * 3 states * 20 bytes per tensor of size 5 = 1200 bytes + self.assertEqual(sum(auc.get_memory_usage().values()), 1200) + # with bs 5, we expect 20 tensors per state, so 60 tensors + self.assertEqual(len(auc.get_memory_usage().values()), 60) + + torch.allclose( + auc.compute()["auc-DefaultTask|window_auc"], + torch.tensor([0.4859], dtype=torch.float), + ) + + # test auc memory usage with window size equal to incoming batch + auc = AUCMetric( + world_size=1, + my_rank=0, + batch_size=100, + window_size=100, + tasks=[DefaultTaskInfo], + ) + + self.labels["DefaultTask"] = torch.rand(100) + self.predictions["DefaultTask"] = torch.rand(100) + self.weights["DefaultTask"] = torch.rand(100) + + for _ in range(10): + auc.update(**self.batches) + + # passing in batch size == window size, we expect for each state just one tensor of size 400, sum to 1200 as previous + self.assertEqual(sum(auc.get_memory_usage().values()), 1200) + self.assertEqual(len(auc.get_memory_usage().values()), 3) + + torch.allclose( + auc.compute()["auc-DefaultTask|window_auc"], + torch.tensor([0.4859], dtype=torch.float), + ) + def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]: return [ diff --git a/torchrec/metrics/tests/test_metric_module.py b/torchrec/metrics/tests/test_metric_module.py index 146271664..4fd17e98f 100644 --- a/torchrec/metrics/tests/test_metric_module.py +++ b/torchrec/metrics/tests/test_metric_module.py @@ -434,15 +434,15 @@ def test_auc_memory_usage(self) -> None: # 3 (tensors) * 4 (float) self.assertEqual(metric_module.get_memory_usage(), 12) metric_module.update(gen_test_batch(128)) - # 24 (initial states) + 3 (tensors) * 128 (batch_size) * 4 (float) - self.assertEqual(metric_module.get_memory_usage(), 1548) + # 3 (tensors) * 128 (batch_size) * 4 (float) + self.assertEqual(metric_module.get_memory_usage(), 1536) # Test memory usage over multiple updates does not increase unexpectedly, we don't need to force OOM as just knowing if the memory usage is increeasing how we expect is enough for _ in range(10): metric_module.update(gen_test_batch(128)) - # 24 initial states + 3 tensors * 128 batch size * 4 float * 11 updates - 12 initial memory - self.assertEqual(metric_module.get_memory_usage(), 16908) + # 3 tensors * 128 batch size * 4 float * 11 updates + self.assertEqual(metric_module.get_memory_usage(), 16896) # Ensure reset frees memory correctly metric_module.reset()