Skip to content

Commit

Permalink
move tensor concatenation out of update and before compute collective…
Browse files Browse the repository at this point in the history
… call (#1656)

Summary:

Previously unlanded D51399384 because there was a NCCL AllGather hang caused by the previous version of the optimization. This diff fixes the issue by moving the tensor concatenation before the collective is called. By doing this, model outputs are transformed to be identical to what was previous, and thus the flow (sync, collective call, compute) is identical to previous AUC implementation. But at the same time, we avoid concatting every update thus leveraging significant gains.

**Why did the NCCL collective hang last time?**
There existed an edge case where ranks could have mismatched number of tensors, the way torchmetrics sync and collective call is written it calls an allgather per tensor in a list of tensors. Since number of tensors in a list across ranks is not guaranteed to be the same, the issue can arise where an allgather is being called for a tensor that does not exist.

Imagine, 3 tensors on rank 0 and 2 tensors on rank 1. For calling allgather on each tensor, on rank 0 calling on the 3rd tensor will cause a NCCL hang and subsequent timeout since there is no tensor on rank 1 that corresponds to it.

A additional GPU unit test is added to cover this scenario, where on rank 0 - 2 tensors are passed in and on rank 1 - 1 tensor is passed in. With the concat before the collective, the test passes. Without the concat, the test will hang.

Differential Revision: D53027097
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Feb 1, 2024
1 parent 584bb44 commit 49efc37
Show file tree
Hide file tree
Showing 3 changed files with 338 additions and 40 deletions.
178 changes: 138 additions & 40 deletions torchrec/metrics/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
# LICENSE file in the root directory of this source tree.

from functools import partial
from typing import Any, cast, Dict, List, Optional, Type
from typing import Any, Callable, cast, Dict, List, Optional, Type, Union

import torch
import torch.distributed as dist
from pyre_extensions import override
from torch import Tensor
from torchmetrics.utilities import apply_to_collection
from torchmetrics.utilities.data import _flatten, dim_zero_cat
from torchmetrics.utilities.distributed import gather_all_tensors
from torchrec.metrics.metrics_config import RecComputeMode, RecTaskInfo
from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
from torchrec.metrics.rec_metric import (
Expand Down Expand Up @@ -50,20 +55,33 @@ def _compute_auc_helper(

def compute_auc(
n_tasks: int,
predictions: torch.Tensor,
labels: torch.Tensor,
weights: torch.Tensor,
predictions: Union[List[torch.Tensor], torch.Tensor],
labels: Union[List[torch.Tensor], torch.Tensor],
weights: Union[List[torch.Tensor], torch.Tensor],
apply_bin: bool = False,
) -> torch.Tensor:
"""
Computes AUC (Area Under the Curve) for binary classification.
Args:
n_tasks (int): number of tasks.
predictions (torch.Tensor): tensor of size (n_tasks, n_examples).
labels (torch.Tensor): tensor of size (n_tasks, n_examples).
weights (torch.Tensor): tensor of size (n_tasks, n_examples).
predictions (List[torch.Tensor]): tensor of size (n_tasks, n_examples).
labels (List[torch.Tensor]): tensor of size (n_tasks, n_examples).
weights (List[torch.Tensor]): tensor of size (n_tasks, n_examples).
"""
if len(predictions) > 1:
# Lists of tensors not concatted can be passed in a not distributed setting, so we need to concat here
# pyre-ignore[6]
predictions = torch.cat(predictions, dim=-1)
# pyre-ignore[6]
labels = torch.cat(labels, dim=-1)
# pyre-ignore[6]
weights = torch.cat(weights, dim=-1)
else:
predictions = predictions[0]
labels = labels[0]
weights = weights[0]

aucs = []
for predictions_i, labels_i, weights_i in zip(predictions, labels, weights):
auc = _compute_auc_helper(predictions_i, labels_i, weights_i, apply_bin)
Expand All @@ -73,9 +91,9 @@ def compute_auc(

def compute_auc_per_group(
n_tasks: int,
predictions: torch.Tensor,
labels: torch.Tensor,
weights: torch.Tensor,
predictions: Union[List[torch.Tensor], torch.Tensor],
labels: Union[List[torch.Tensor], torch.Tensor],
weights: Union[List[torch.Tensor], torch.Tensor],
grouping_keys: torch.Tensor,
) -> torch.Tensor:
"""
Expand All @@ -90,14 +108,25 @@ def compute_auc_per_group(
Returns:
torch.Tensor: tensor of size (n_tasks,), average of AUCs per group.
"""
# prepare model outputs
if len(predictions) > 1:
# Lists of tensors not concatted can be passed in a not distributed setting, so we need to concat here
# pyre-ignore[6]
predictions = torch.cat(predictions, dim=-1)
# pyre-ignore[6]
labels = torch.cat(labels, dim=-1)
# pyre-ignore[6]
weights = torch.cat(weights, dim=-1)
else:
predictions = predictions[0]
labels = labels[0]
weights = weights[0]

aucs = []
if grouping_keys.numel() != 0 and grouping_keys[0] == -1:
# we added padding as the first elements during init to avoid floating point exception in sync()
# removing the paddings to avoid numerical errors.
grouping_keys = grouping_keys[1:]
predictions = predictions[:, 1:]
labels = labels[:, 1:]
weights = weights[:, 1:]

# get unique group indices
group_indices = torch.unique(grouping_keys)
Expand Down Expand Up @@ -162,6 +191,7 @@ def __init__(

self._grouped_auc: bool = grouped_auc
self._apply_bin: bool = apply_bin
self._num_samples: int = 0
self._add_state(
PREDICTIONS,
[],
Expand Down Expand Up @@ -204,7 +234,7 @@ def __init__(
def _init_states(self) -> None:
if len(getattr(self, PREDICTIONS)) > 0:
return

self._num_samples = 0
getattr(self, PREDICTIONS).append(
torch.zeros((self._n_tasks, 1), dtype=torch.float, device=self.device)
)
Expand Down Expand Up @@ -241,25 +271,42 @@ def update(
predictions = predictions.float()
labels = labels.float()
weights = weights.float()
num_samples = getattr(self, PREDICTIONS)[0].size(-1)
batch_size = predictions.size(-1)
start_index = max(num_samples + batch_size - self._window_size, 0)
start_index = max(self._num_samples + batch_size - self._window_size, 0)

# Using `self.predictions =` will cause Pyre errors.
getattr(self, PREDICTIONS)[0] = torch.cat(
[
cast(torch.Tensor, getattr(self, PREDICTIONS)[0])[:, start_index:],
predictions,
],
dim=-1,
)
getattr(self, LABELS)[0] = torch.cat(
[cast(torch.Tensor, getattr(self, LABELS)[0])[:, start_index:], labels],
dim=-1,
)
getattr(self, WEIGHTS)[0] = torch.cat(
[cast(torch.Tensor, getattr(self, WEIGHTS)[0])[:, start_index:], weights],
dim=-1,
)
w_preds = getattr(self, PREDICTIONS)
w_labels = getattr(self, LABELS)
w_weights = getattr(self, WEIGHTS)

# remove init states
if self._num_samples == 0:
for lst in [w_preds, w_labels, w_weights]:
lst.pop(0)

w_preds.append(predictions)
w_labels.append(labels)
w_weights.append(weights)

self._num_samples += batch_size

while self._num_samples > self._window_size:
diff = self._num_samples - self._window_size
if diff > w_preds[0].size(-1):
self._num_samples -= w_preds[0].size(-1)
# Remove the first element from predictions, labels, and weights
for lst in [w_preds, w_labels, w_weights]:
lst.pop(0)
else:
# Update the first element of predictions, labels, and weights
# Off by one potentially - keeping legacy behaviour
for lst in [w_preds, w_labels, w_weights]:
lst[0] = lst[0][:, diff:]
# if empty tensor, remove it
if torch.numel(lst[0]) == 0:
lst.pop(0)
self._num_samples -= diff

if self._grouped_auc:
if REQUIRED_INPUTS not in kwargs or (
(grouping_keys := kwargs[REQUIRED_INPUTS].get(GROUPING_KEYS)) is None
Expand All @@ -276,35 +323,86 @@ def update(
)

def _compute(self) -> List[MetricComputationReport]:
reports = [
reports = []
reports.append(
MetricComputationReport(
name=MetricName.AUC,
metric_prefix=MetricPrefix.WINDOW,
value=compute_auc(
self._n_tasks,
cast(torch.Tensor, getattr(self, PREDICTIONS)[0]),
cast(torch.Tensor, getattr(self, LABELS)[0]),
cast(torch.Tensor, getattr(self, WEIGHTS)[0]),
cast(List[torch.Tensor], getattr(self, PREDICTIONS)),
cast(List[torch.Tensor], getattr(self, LABELS)),
cast(List[torch.Tensor], getattr(self, WEIGHTS)),
self._apply_bin,
),
)
]
)

if self._grouped_auc:
reports.append(
MetricComputationReport(
name=MetricName.GROUPED_AUC,
metric_prefix=MetricPrefix.WINDOW,
value=compute_auc_per_group(
self._n_tasks,
cast(torch.Tensor, getattr(self, PREDICTIONS)[0]),
cast(torch.Tensor, getattr(self, LABELS)[0]),
cast(torch.Tensor, getattr(self, WEIGHTS)[0]),
cast(torch.Tensor, getattr(self, GROUPING_KEYS)[0]),
cast(List[torch.Tensor], getattr(self, PREDICTIONS)),
cast(List[torch.Tensor], getattr(self, LABELS)),
cast(List[torch.Tensor], getattr(self, WEIGHTS)),
cast(torch.Tensor, getattr(self, GROUPING_KEYS))[0],
),
)
)
return reports

@override
def _sync_dist(
self,
dist_sync_fn: Callable = gather_all_tensors, # pyre-ignore[24]
process_group: Optional[Any] = None, # pyre-ignore[2]
) -> None:
"""
This function is overridden from torchmetric.Metric, since for AUC we want to cat the tensors
right before the allgather collective is called.
"""
input_dict = {attr: getattr(self, attr) for attr in self._reductions}
for attr, reduction_fn in self._reductions.items():
# pre-concatenate metric states that are lists to reduce number of all_gather operations
if (
reduction_fn == dim_zero_cat
and isinstance(input_dict[attr], list)
and len(input_dict[attr]) > 1
):
input_dict[attr] = [dim_zero_cat(input_dict[attr])]

for k, v in input_dict.items():
# this needs to account for list only and IF list has more than one elem in it that is just tensors
# concat list of tensors into a single tensor
if isinstance(v, list) and len(v) > 1:
input_dict[k] = [torch.cat(v, dim=-1)]

output_dict = apply_to_collection(
input_dict,
Tensor,
dist_sync_fn,
group=process_group or self.process_group,
)

for attr, reduction_fn in self._reductions.items():
# pre-processing ops (stack or flatten for inputs)
if isinstance(output_dict[attr][0], Tensor):
output_dict[attr] = torch.stack(output_dict[attr])
elif isinstance(output_dict[attr][0], list):
output_dict[attr] = _flatten(output_dict[attr])

if not (callable(reduction_fn) or reduction_fn is None):
raise TypeError("reduction_fn must be callable or None")
reduced = (
reduction_fn(output_dict[attr])
if reduction_fn is not None
else output_dict[attr]
)
setattr(self, attr, reduced)

def reset(self) -> None:
super().reset()
self._init_states()
Expand Down
Loading

0 comments on commit 49efc37

Please sign in to comment.