Skip to content

Commit

Permalink
move tensor concatenation out of update and before compute collective…
Browse files Browse the repository at this point in the history
… call (#1656)

Summary:
Pull Request resolved: #1656

Previously unlanded D51399384 because there was a NCCL AllGather hang caused by the previous version of the optimization. This diff fixes the issue by moving the tensor concatenation before the collective is called. By doing this, model outputs are transformed to be identical to what was previous, and thus the flow (sync, collective call, compute) is identical to previous AUC implementation. But at the same time, we avoid concatting every update thus leveraging significant gains.

**Why did the NCCL collective hang last time?**
There existed an edge case where ranks could have mismatched number of tensors, the way torchmetrics sync and collective call is written it calls an allgather per tensor in a list of tensors. Since number of tensors in a list across ranks is not guaranteed to be the same, the issue can arise where an allgather is being called for a tensor that does not exist.

Imagine, 3 tensors on rank 0 and 2 tensors on rank 1. For calling allgather on each tensor, on rank 0 calling on the 3rd tensor will cause a NCCL hang and subsequent timeout since there is no tensor on rank 1 that corresponds to it. In the case where other ranks have more tensors than rank 0, the collective would not hang but those additional tensors would not be collected on to rank 0 resulting in a wrong calculation. The unit test covers for this case as well.

A additional GPU unit test is added to cover this scenario, where on rank 0 - 2 tensors are passed in and on rank 1 - 1 tensor is passed in (and vice versa). With the concat before the collective, the test passes. Without the concat, the test will hang.

Reviewed By: dstaay-fb

Differential Revision: D53027097

fbshipit-source-id: 37746acda3a1b83120e46520f1b22a98dc6d51f1
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Feb 23, 2024
1 parent 84c92d2 commit bc52746
Show file tree
Hide file tree
Showing 4 changed files with 362 additions and 49 deletions.
149 changes: 104 additions & 45 deletions torchrec/metrics/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
# LICENSE file in the root directory of this source tree.

from functools import partial
from typing import Any, cast, Dict, List, Optional, Type
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type

import torch
import torch.distributed as dist
from torchmetrics.utilities.distributed import gather_all_tensors
from torchrec.metrics.metrics_config import RecComputeMode, RecTaskInfo
from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
from torchrec.metrics.rec_metric import (
Expand All @@ -26,6 +27,29 @@
REQUIRED_INPUTS = "required_inputs"


def _concat_if_needed(
predictions: List[torch.Tensor],
labels: List[torch.Tensor],
weights: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
This check exists because of how the state is organized due to quirks in RecMetrics.
Since we do not do tensor concatenatation in the compute or update call, there are cases (in non-distributed settings)
where the tensors from updates are not concatted into a single tensor. Which is determined by the length of the list.
"""
preds_t, labels_t, weights_t = None, None, None
if len(predictions) > 1:
preds_t = torch.cat(predictions, dim=-1)
labels_t = torch.cat(labels, dim=-1)
weights_t = torch.cat(weights, dim=-1)
else:
preds_t = predictions[0]
labels_t = labels[0]
weights_t = weights[0]

return preds_t, labels_t, weights_t


def _compute_auc_helper(
predictions: torch.Tensor,
labels: torch.Tensor,
Expand All @@ -50,59 +74,58 @@ def _compute_auc_helper(

def compute_auc(
n_tasks: int,
predictions: torch.Tensor,
labels: torch.Tensor,
weights: torch.Tensor,
predictions: List[torch.Tensor],
labels: List[torch.Tensor],
weights: List[torch.Tensor],
apply_bin: bool = False,
) -> torch.Tensor:
"""
Computes AUC (Area Under the Curve) for binary classification.
Args:
n_tasks (int): number of tasks.
predictions (torch.Tensor): tensor of size (n_tasks, n_examples).
labels (torch.Tensor): tensor of size (n_tasks, n_examples).
weights (torch.Tensor): tensor of size (n_tasks, n_examples).
predictions (List[torch.Tensor]): tensor of size (n_tasks, n_examples).
labels (List[torch.Tensor]): tensor of size (n_tasks, n_examples).
weights (List[torch.Tensor]): tensor of size (n_tasks, n_examples).
"""
preds_t, labels_t, weights_t = _concat_if_needed(predictions, labels, weights)
aucs = []
for predictions_i, labels_i, weights_i in zip(predictions, labels, weights):
for predictions_i, labels_i, weights_i in zip(preds_t, labels_t, weights_t):
auc = _compute_auc_helper(predictions_i, labels_i, weights_i, apply_bin)
aucs.append(auc.view(1))
return torch.cat(aucs)


def compute_auc_per_group(
n_tasks: int,
predictions: torch.Tensor,
labels: torch.Tensor,
weights: torch.Tensor,
predictions: List[torch.Tensor],
labels: List[torch.Tensor],
weights: List[torch.Tensor],
grouping_keys: torch.Tensor,
) -> torch.Tensor:
"""
Computes AUC (Area Under the Curve) for binary classification for groups of predictions/labels.
Args:
n_tasks (int): number of tasks
predictions (torch.Tensor): tensor of size (n_tasks, n_examples)
labels (torch.Tensor): tensor of size (n_tasks, n_examples)
weights (torch.Tensor): tensor of size (n_tasks, n_examples)
predictions (List[torch.Tensor]): tensor of size (n_tasks, n_examples)
labels (List[torch.Tensor]: tensor of size (n_tasks, n_examples)
weights (List[torch.Tensor]): tensor of size (n_tasks, n_examples)
grouping_keys (torch.Tensor): tensor of size (n_examples,)
Returns:
torch.Tensor: tensor of size (n_tasks,), average of AUCs per group.
"""
preds_t, labels_t, weights_t = _concat_if_needed(predictions, labels, weights)
aucs = []
if grouping_keys.numel() != 0 and grouping_keys[0] == -1:
# we added padding as the first elements during init to avoid floating point exception in sync()
# removing the paddings to avoid numerical errors.
grouping_keys = grouping_keys[1:]
predictions = predictions[:, 1:]
labels = labels[:, 1:]
weights = weights[:, 1:]

# get unique group indices
group_indices = torch.unique(grouping_keys)

for (predictions_i, labels_i, weights_i) in zip(predictions, labels, weights):
for (predictions_i, labels_i, weights_i) in zip(preds_t, labels_t, weights_t):
# Loop over each group
auc_groups_sum = torch.tensor([0], dtype=torch.float32)
for group_idx in group_indices:
Expand Down Expand Up @@ -162,6 +185,7 @@ def __init__(

self._grouped_auc: bool = grouped_auc
self._apply_bin: bool = apply_bin
self._num_samples: int = 0
self._add_state(
PREDICTIONS,
[],
Expand Down Expand Up @@ -204,7 +228,7 @@ def __init__(
def _init_states(self) -> None:
if len(getattr(self, PREDICTIONS)) > 0:
return

self._num_samples = 0
getattr(self, PREDICTIONS).append(
torch.zeros((self._n_tasks, 1), dtype=torch.float, device=self.device)
)
Expand Down Expand Up @@ -241,25 +265,42 @@ def update(
predictions = predictions.float()
labels = labels.float()
weights = weights.float()
num_samples = getattr(self, PREDICTIONS)[0].size(-1)
batch_size = predictions.size(-1)
start_index = max(num_samples + batch_size - self._window_size, 0)
start_index = max(self._num_samples + batch_size - self._window_size, 0)

# Using `self.predictions =` will cause Pyre errors.
getattr(self, PREDICTIONS)[0] = torch.cat(
[
cast(torch.Tensor, getattr(self, PREDICTIONS)[0])[:, start_index:],
predictions,
],
dim=-1,
)
getattr(self, LABELS)[0] = torch.cat(
[cast(torch.Tensor, getattr(self, LABELS)[0])[:, start_index:], labels],
dim=-1,
)
getattr(self, WEIGHTS)[0] = torch.cat(
[cast(torch.Tensor, getattr(self, WEIGHTS)[0])[:, start_index:], weights],
dim=-1,
)
w_preds = getattr(self, PREDICTIONS)
w_labels = getattr(self, LABELS)
w_weights = getattr(self, WEIGHTS)

# remove init states
if self._num_samples == 0:
for lst in [w_preds, w_labels, w_weights]:
lst.pop(0)

w_preds.append(predictions)
w_labels.append(labels)
w_weights.append(weights)

self._num_samples += batch_size

while self._num_samples > self._window_size:
diff = self._num_samples - self._window_size
if diff > w_preds[0].size(-1):
self._num_samples -= w_preds[0].size(-1)
# Remove the first element from predictions, labels, and weights
for lst in [w_preds, w_labels, w_weights]:
lst.pop(0)
else:
# Update the first element of predictions, labels, and weights
# Off by one potentially - keeping legacy behaviour
for lst in [w_preds, w_labels, w_weights]:
lst[0] = lst[0][:, diff:]
# if empty tensor, remove it
if torch.numel(lst[0]) == 0:
lst.pop(0)
self._num_samples -= diff

if self._grouped_auc:
if REQUIRED_INPUTS not in kwargs or (
(grouping_keys := kwargs[REQUIRED_INPUTS].get(GROUPING_KEYS)) is None
Expand All @@ -276,35 +317,53 @@ def update(
)

def _compute(self) -> List[MetricComputationReport]:
reports = [
reports = []
reports.append(
MetricComputationReport(
name=MetricName.AUC,
metric_prefix=MetricPrefix.WINDOW,
value=compute_auc(
self._n_tasks,
cast(torch.Tensor, getattr(self, PREDICTIONS)[0]),
cast(torch.Tensor, getattr(self, LABELS)[0]),
cast(torch.Tensor, getattr(self, WEIGHTS)[0]),
cast(List[torch.Tensor], getattr(self, PREDICTIONS)),
cast(List[torch.Tensor], getattr(self, LABELS)),
cast(List[torch.Tensor], getattr(self, WEIGHTS)),
self._apply_bin,
),
)
]
)

if self._grouped_auc:
reports.append(
MetricComputationReport(
name=MetricName.GROUPED_AUC,
metric_prefix=MetricPrefix.WINDOW,
value=compute_auc_per_group(
self._n_tasks,
cast(torch.Tensor, getattr(self, PREDICTIONS)[0]),
cast(torch.Tensor, getattr(self, LABELS)[0]),
cast(torch.Tensor, getattr(self, WEIGHTS)[0]),
cast(torch.Tensor, getattr(self, GROUPING_KEYS)[0]),
cast(List[torch.Tensor], getattr(self, PREDICTIONS)),
cast(List[torch.Tensor], getattr(self, LABELS)),
cast(List[torch.Tensor], getattr(self, WEIGHTS)),
cast(torch.Tensor, getattr(self, GROUPING_KEYS))[0],
),
)
)
return reports

def _sync_dist(
self,
dist_sync_fn: Callable = gather_all_tensors, # pyre-ignore[24]
process_group: Optional[Any] = None, # pyre-ignore[2]
) -> None:
"""
This function is overridden from torchmetric.Metric, since for AUC we want to concat the tensors
right before the allgather collective is called. It directly changes the attributes/states, which
is ok because end of function sets the attributes to reduced values
"""
for attr in self._reductions: # pragma: no cover
val = getattr(self, attr)
if isinstance(val, list) and len(val) > 1:
setattr(self, attr, [torch.cat(val, dim=-1)])
super()._sync_dist(dist_sync_fn, process_group)

def reset(self) -> None:
super().reset()
self._init_states()
Expand Down
Loading

0 comments on commit bc52746

Please sign in to comment.