From 49efc376639e1817e54c451100e78654fd52b503 Mon Sep 17 00:00:00 2001 From: Zain Huda Date: Thu, 1 Feb 2024 12:15:48 -0800 Subject: [PATCH] move tensor concatenation out of update and before compute collective call (#1656) Summary: Previously unlanded D51399384 because there was a NCCL AllGather hang caused by the previous version of the optimization. This diff fixes the issue by moving the tensor concatenation before the collective is called. By doing this, model outputs are transformed to be identical to what was previous, and thus the flow (sync, collective call, compute) is identical to previous AUC implementation. But at the same time, we avoid concatting every update thus leveraging significant gains. **Why did the NCCL collective hang last time?** There existed an edge case where ranks could have mismatched number of tensors, the way torchmetrics sync and collective call is written it calls an allgather per tensor in a list of tensors. Since number of tensors in a list across ranks is not guaranteed to be the same, the issue can arise where an allgather is being called for a tensor that does not exist. Imagine, 3 tensors on rank 0 and 2 tensors on rank 1. For calling allgather on each tensor, on rank 0 calling on the 3rd tensor will cause a NCCL hang and subsequent timeout since there is no tensor on rank 1 that corresponds to it. A additional GPU unit test is added to cover this scenario, where on rank 0 - 2 tensors are passed in and on rank 1 - 1 tensor is passed in. With the concat before the collective, the test passes. Without the concat, the test will hang. Differential Revision: D53027097 --- torchrec/metrics/auc.py | 178 ++++++++++++++++++------ torchrec/metrics/test_utils/__init__.py | 126 +++++++++++++++++ torchrec/metrics/tests/test_auc.py | 74 ++++++++++ 3 files changed, 338 insertions(+), 40 deletions(-) diff --git a/torchrec/metrics/auc.py b/torchrec/metrics/auc.py index 688f04583..1b49b60fd 100644 --- a/torchrec/metrics/auc.py +++ b/torchrec/metrics/auc.py @@ -6,10 +6,15 @@ # LICENSE file in the root directory of this source tree. from functools import partial -from typing import Any, cast, Dict, List, Optional, Type +from typing import Any, Callable, cast, Dict, List, Optional, Type, Union import torch import torch.distributed as dist +from pyre_extensions import override +from torch import Tensor +from torchmetrics.utilities import apply_to_collection +from torchmetrics.utilities.data import _flatten, dim_zero_cat +from torchmetrics.utilities.distributed import gather_all_tensors from torchrec.metrics.metrics_config import RecComputeMode, RecTaskInfo from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix from torchrec.metrics.rec_metric import ( @@ -50,9 +55,9 @@ def _compute_auc_helper( def compute_auc( n_tasks: int, - predictions: torch.Tensor, - labels: torch.Tensor, - weights: torch.Tensor, + predictions: Union[List[torch.Tensor], torch.Tensor], + labels: Union[List[torch.Tensor], torch.Tensor], + weights: Union[List[torch.Tensor], torch.Tensor], apply_bin: bool = False, ) -> torch.Tensor: """ @@ -60,10 +65,23 @@ def compute_auc( Args: n_tasks (int): number of tasks. - predictions (torch.Tensor): tensor of size (n_tasks, n_examples). - labels (torch.Tensor): tensor of size (n_tasks, n_examples). - weights (torch.Tensor): tensor of size (n_tasks, n_examples). + predictions (List[torch.Tensor]): tensor of size (n_tasks, n_examples). + labels (List[torch.Tensor]): tensor of size (n_tasks, n_examples). + weights (List[torch.Tensor]): tensor of size (n_tasks, n_examples). """ + if len(predictions) > 1: + # Lists of tensors not concatted can be passed in a not distributed setting, so we need to concat here + # pyre-ignore[6] + predictions = torch.cat(predictions, dim=-1) + # pyre-ignore[6] + labels = torch.cat(labels, dim=-1) + # pyre-ignore[6] + weights = torch.cat(weights, dim=-1) + else: + predictions = predictions[0] + labels = labels[0] + weights = weights[0] + aucs = [] for predictions_i, labels_i, weights_i in zip(predictions, labels, weights): auc = _compute_auc_helper(predictions_i, labels_i, weights_i, apply_bin) @@ -73,9 +91,9 @@ def compute_auc( def compute_auc_per_group( n_tasks: int, - predictions: torch.Tensor, - labels: torch.Tensor, - weights: torch.Tensor, + predictions: Union[List[torch.Tensor], torch.Tensor], + labels: Union[List[torch.Tensor], torch.Tensor], + weights: Union[List[torch.Tensor], torch.Tensor], grouping_keys: torch.Tensor, ) -> torch.Tensor: """ @@ -90,14 +108,25 @@ def compute_auc_per_group( Returns: torch.Tensor: tensor of size (n_tasks,), average of AUCs per group. """ + # prepare model outputs + if len(predictions) > 1: + # Lists of tensors not concatted can be passed in a not distributed setting, so we need to concat here + # pyre-ignore[6] + predictions = torch.cat(predictions, dim=-1) + # pyre-ignore[6] + labels = torch.cat(labels, dim=-1) + # pyre-ignore[6] + weights = torch.cat(weights, dim=-1) + else: + predictions = predictions[0] + labels = labels[0] + weights = weights[0] + aucs = [] if grouping_keys.numel() != 0 and grouping_keys[0] == -1: # we added padding as the first elements during init to avoid floating point exception in sync() # removing the paddings to avoid numerical errors. grouping_keys = grouping_keys[1:] - predictions = predictions[:, 1:] - labels = labels[:, 1:] - weights = weights[:, 1:] # get unique group indices group_indices = torch.unique(grouping_keys) @@ -162,6 +191,7 @@ def __init__( self._grouped_auc: bool = grouped_auc self._apply_bin: bool = apply_bin + self._num_samples: int = 0 self._add_state( PREDICTIONS, [], @@ -204,7 +234,7 @@ def __init__( def _init_states(self) -> None: if len(getattr(self, PREDICTIONS)) > 0: return - + self._num_samples = 0 getattr(self, PREDICTIONS).append( torch.zeros((self._n_tasks, 1), dtype=torch.float, device=self.device) ) @@ -241,25 +271,42 @@ def update( predictions = predictions.float() labels = labels.float() weights = weights.float() - num_samples = getattr(self, PREDICTIONS)[0].size(-1) batch_size = predictions.size(-1) - start_index = max(num_samples + batch_size - self._window_size, 0) + start_index = max(self._num_samples + batch_size - self._window_size, 0) + # Using `self.predictions =` will cause Pyre errors. - getattr(self, PREDICTIONS)[0] = torch.cat( - [ - cast(torch.Tensor, getattr(self, PREDICTIONS)[0])[:, start_index:], - predictions, - ], - dim=-1, - ) - getattr(self, LABELS)[0] = torch.cat( - [cast(torch.Tensor, getattr(self, LABELS)[0])[:, start_index:], labels], - dim=-1, - ) - getattr(self, WEIGHTS)[0] = torch.cat( - [cast(torch.Tensor, getattr(self, WEIGHTS)[0])[:, start_index:], weights], - dim=-1, - ) + w_preds = getattr(self, PREDICTIONS) + w_labels = getattr(self, LABELS) + w_weights = getattr(self, WEIGHTS) + + # remove init states + if self._num_samples == 0: + for lst in [w_preds, w_labels, w_weights]: + lst.pop(0) + + w_preds.append(predictions) + w_labels.append(labels) + w_weights.append(weights) + + self._num_samples += batch_size + + while self._num_samples > self._window_size: + diff = self._num_samples - self._window_size + if diff > w_preds[0].size(-1): + self._num_samples -= w_preds[0].size(-1) + # Remove the first element from predictions, labels, and weights + for lst in [w_preds, w_labels, w_weights]: + lst.pop(0) + else: + # Update the first element of predictions, labels, and weights + # Off by one potentially - keeping legacy behaviour + for lst in [w_preds, w_labels, w_weights]: + lst[0] = lst[0][:, diff:] + # if empty tensor, remove it + if torch.numel(lst[0]) == 0: + lst.pop(0) + self._num_samples -= diff + if self._grouped_auc: if REQUIRED_INPUTS not in kwargs or ( (grouping_keys := kwargs[REQUIRED_INPUTS].get(GROUPING_KEYS)) is None @@ -276,19 +323,21 @@ def update( ) def _compute(self) -> List[MetricComputationReport]: - reports = [ + reports = [] + reports.append( MetricComputationReport( name=MetricName.AUC, metric_prefix=MetricPrefix.WINDOW, value=compute_auc( self._n_tasks, - cast(torch.Tensor, getattr(self, PREDICTIONS)[0]), - cast(torch.Tensor, getattr(self, LABELS)[0]), - cast(torch.Tensor, getattr(self, WEIGHTS)[0]), + cast(List[torch.Tensor], getattr(self, PREDICTIONS)), + cast(List[torch.Tensor], getattr(self, LABELS)), + cast(List[torch.Tensor], getattr(self, WEIGHTS)), self._apply_bin, ), ) - ] + ) + if self._grouped_auc: reports.append( MetricComputationReport( @@ -296,15 +345,64 @@ def _compute(self) -> List[MetricComputationReport]: metric_prefix=MetricPrefix.WINDOW, value=compute_auc_per_group( self._n_tasks, - cast(torch.Tensor, getattr(self, PREDICTIONS)[0]), - cast(torch.Tensor, getattr(self, LABELS)[0]), - cast(torch.Tensor, getattr(self, WEIGHTS)[0]), - cast(torch.Tensor, getattr(self, GROUPING_KEYS)[0]), + cast(List[torch.Tensor], getattr(self, PREDICTIONS)), + cast(List[torch.Tensor], getattr(self, LABELS)), + cast(List[torch.Tensor], getattr(self, WEIGHTS)), + cast(torch.Tensor, getattr(self, GROUPING_KEYS))[0], ), ) ) return reports + @override + def _sync_dist( + self, + dist_sync_fn: Callable = gather_all_tensors, # pyre-ignore[24] + process_group: Optional[Any] = None, # pyre-ignore[2] + ) -> None: + """ + This function is overridden from torchmetric.Metric, since for AUC we want to cat the tensors + right before the allgather collective is called. + """ + input_dict = {attr: getattr(self, attr) for attr in self._reductions} + for attr, reduction_fn in self._reductions.items(): + # pre-concatenate metric states that are lists to reduce number of all_gather operations + if ( + reduction_fn == dim_zero_cat + and isinstance(input_dict[attr], list) + and len(input_dict[attr]) > 1 + ): + input_dict[attr] = [dim_zero_cat(input_dict[attr])] + + for k, v in input_dict.items(): + # this needs to account for list only and IF list has more than one elem in it that is just tensors + # concat list of tensors into a single tensor + if isinstance(v, list) and len(v) > 1: + input_dict[k] = [torch.cat(v, dim=-1)] + + output_dict = apply_to_collection( + input_dict, + Tensor, + dist_sync_fn, + group=process_group or self.process_group, + ) + + for attr, reduction_fn in self._reductions.items(): + # pre-processing ops (stack or flatten for inputs) + if isinstance(output_dict[attr][0], Tensor): + output_dict[attr] = torch.stack(output_dict[attr]) + elif isinstance(output_dict[attr][0], list): + output_dict[attr] = _flatten(output_dict[attr]) + + if not (callable(reduction_fn) or reduction_fn is None): + raise TypeError("reduction_fn must be callable or None") + reduced = ( + reduction_fn(output_dict[attr]) + if reduction_fn is not None + else output_dict[attr] + ) + setattr(self, attr, reduced) + def reset(self) -> None: super().reset() self._init_states() diff --git a/torchrec/metrics/test_utils/__init__.py b/torchrec/metrics/test_utils/__init__.py index ae78306e2..d2440cc67 100644 --- a/torchrec/metrics/test_utils/__init__.py +++ b/torchrec/metrics/test_utils/__init__.py @@ -41,7 +41,10 @@ def gen_test_batch( weight_value: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, n_classes: Optional[int] = None, + seed: Optional[int] = None, ) -> Dict[str, torch.Tensor]: + if seed is not None: + torch.manual_seed(seed) if label_value is not None: label = label_value else: @@ -340,6 +343,129 @@ def get_launch_config(world_size: int, rdzv_endpoint: str) -> pet.LaunchConfig: ) +def rec_metric_gpu_sync_test_launcher( + target_clazz: Type[RecMetric], + target_compute_mode: RecComputeMode, + test_clazz: Optional[Type[TestMetric]], + metric_name: str, + task_names: List[str], + fused_update_limit: int, + compute_on_all_ranks: bool, + should_validate_update: bool, + world_size: int, + entry_point: Callable[..., None], + batch_size: int = BATCH_SIZE, + batch_window_size: int = BATCH_WINDOW_SIZE, + **kwargs: Any, +) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + # might need to update lanch config to include max_nodes = 2 + lc = get_launch_config( + world_size=world_size, rdzv_endpoint=os.path.join(tmpdir, "rdzv") + ) + + # launch using torch elastic, launches for each rank + pet.elastic_launch(lc, entrypoint=entry_point)( + target_clazz, + target_compute_mode, + test_clazz, + task_names, + metric_name, + world_size, + fused_update_limit, + compute_on_all_ranks, + should_validate_update, + batch_size, + batch_window_size, + ) + + +def sync_test_helper( + target_clazz: Type[RecMetric], + target_compute_mode: RecComputeMode, + test_clazz: Optional[Type[TestMetric]], + task_names: List[str], + metric_name: str, + world_size: int, + fused_update_limit: int = 0, + compute_on_all_ranks: bool = False, + should_validate_update: bool = False, + batch_size: int = BATCH_SIZE, + batch_window_size: int = BATCH_WINDOW_SIZE, + n_classes: Optional[int] = None, + zero_weights: bool = False, +) -> None: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + dist.init_process_group( + backend="gloo", + world_size=world_size, + rank=rank, + ) + + tasks = gen_test_tasks(task_names) + + auc = target_clazz( + world_size=world_size, + batch_size=batch_size, + my_rank=rank, + compute_on_all_ranks=compute_on_all_ranks, + tasks=tasks, + window_size=batch_window_size * world_size, + ) + + weight_value: Optional[torch.Tensor] = None + + _model_outs = [ + gen_test_batch( + label_name=task.label_name, + prediction_name=task.prediction_name, + weight_name=task.weight_name, + batch_size=batch_size, + n_classes=n_classes, + weight_value=weight_value, + seed=42, # we set seed because of how test metric places tensors on ranks + ) + for task in tasks + ] + model_outs = [] + model_outs.append({k: v for d in _model_outs for k, v in d.items()}) + + # we send an uneven number of tensors to each rank to test that GPU sync works + if rank == 0: + for _ in range(3): + labels, predictions, weights, _ = parse_task_model_outputs( + tasks, model_outs[0] + ) + auc.update(predictions=predictions, labels=labels, weights=weights) + elif rank == 1: + for _ in range(1): + labels, predictions, weights, _ = parse_task_model_outputs( + tasks, model_outs[0] + ) + auc.update(predictions=predictions, labels=labels, weights=weights) + + # check against test metric + test_metrics: TestRecMetricOutput = ({}, {}, {}, {}) + if test_clazz is not None: + # pyre-ignore[45]: Cannot instantiate abstract class `TestMetric`. + test_metric_obj = test_clazz(world_size, tasks) + # with how testmetric is setup we cannot do asymmertrical updates across ranks + # so we duplicate model_outs twice to match number of updates in aggregate + model_outs = model_outs * 2 + test_metrics = test_metric_obj.compute(model_outs, 2, batch_window_size, None) + + res = auc.compute() + + if res: + assert torch.allclose( + test_metrics[1][task_names[0]], + res[f"auc-{task_names[0]}|window_auc"], + ) + + dist.destroy_process_group() + + def rec_metric_value_test_launcher( target_clazz: Type[RecMetric], target_compute_mode: RecComputeMode, diff --git a/torchrec/metrics/tests/test_auc.py b/torchrec/metrics/tests/test_auc.py index e81922cd4..9bc7357a5 100644 --- a/torchrec/metrics/tests/test_auc.py +++ b/torchrec/metrics/tests/test_auc.py @@ -20,7 +20,9 @@ ) from torchrec.metrics.test_utils import ( metric_test_helper, + rec_metric_gpu_sync_test_launcher, rec_metric_value_test_launcher, + sync_test_helper, TestMetric, ) @@ -115,6 +117,27 @@ def test_fused_auc(self) -> None: ) +class AUCGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = AUCMetric + task_name: str = "auc" + + def test_sync_auc(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=AUCMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestAUCMetric, + metric_name=AUCGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) + + class AUCMetricValueTest(unittest.TestCase): r"""This set of tests verify the computation logic of AUC in several corner cases that we know the computation results. The goal is to @@ -177,6 +200,57 @@ def test_calc_auc_balanced(self) -> None: actual_auc = self.auc.compute()["auc-DefaultTask|window_auc"] torch.allclose(expected_auc, actual_auc) + def test_calc_uneven_updates(self) -> None: + auc = AUCMetric( + world_size=1, + my_rank=0, + batch_size=100, + tasks=[DefaultTaskInfo], + ) + + expected_auc = torch.tensor([0.4464], dtype=torch.float) + # first batch + self.labels["DefaultTask"] = torch.tensor([1, 0, 0]) + self.predictions["DefaultTask"] = torch.tensor([0.2, 0.6, 0.8]) + self.weights["DefaultTask"] = torch.tensor([0.13, 0.2, 0.5]) + + auc.update(**self.batches) + # second batch + self.labels["DefaultTask"] = torch.tensor([1, 1]) + self.predictions["DefaultTask"] = torch.tensor([0.4, 0.9]) + self.weights["DefaultTask"] = torch.tensor([0.8, 0.75]) + + auc.update(**self.batches) + multiple_batch = self.auc.compute()["auc-DefaultTask|window_auc"] + torch.allclose(expected_auc, multiple_batch) + + def test_stress_auc(self) -> None: + auc = AUCMetric( + world_size=1, + my_rank=0, + batch_size=100, + tasks=[DefaultTaskInfo], + ) + # init states, so we expect 3 (state tensors) * 4 bytes (float) + self.assertEqual(sum(auc.get_memory_usage().values()), 12) + + # window size is only 100, so we should expect expected AUC to be same + expected_auc = torch.tensor([1.0], dtype=torch.float) + # first batch + self.labels["DefaultTask"] = torch.tensor([[1, 0, 0, 1, 1]]) + self.predictions["DefaultTask"] = torch.tensor([[1, 0, 0, 1, 1]]) + self.weights["DefaultTask"] = torch.tensor([[1] * 5]) + + for _ in range(1000): + auc.update(**self.batches) + + # check memory, window size is 100, so we have upperbound of memory to expect + # so with a 100 window size / tensors of size 5 = 20 tensors (per state) * 3 states * 20 bytes per tensor of size 5 = 1200 bytes + self.assertEqual(sum(auc.get_memory_usage().values()), 1200) + + result = auc.compute()["auc-DefaultTask|window_auc"] + torch.allclose(expected_auc, result) + def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]: return [