diff --git a/torchrec/metrics/auc.py b/torchrec/metrics/auc.py index 688f04583..1b49b60fd 100644 --- a/torchrec/metrics/auc.py +++ b/torchrec/metrics/auc.py @@ -6,10 +6,15 @@ # LICENSE file in the root directory of this source tree. from functools import partial -from typing import Any, cast, Dict, List, Optional, Type +from typing import Any, Callable, cast, Dict, List, Optional, Type, Union import torch import torch.distributed as dist +from pyre_extensions import override +from torch import Tensor +from torchmetrics.utilities import apply_to_collection +from torchmetrics.utilities.data import _flatten, dim_zero_cat +from torchmetrics.utilities.distributed import gather_all_tensors from torchrec.metrics.metrics_config import RecComputeMode, RecTaskInfo from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix from torchrec.metrics.rec_metric import ( @@ -50,9 +55,9 @@ def _compute_auc_helper( def compute_auc( n_tasks: int, - predictions: torch.Tensor, - labels: torch.Tensor, - weights: torch.Tensor, + predictions: Union[List[torch.Tensor], torch.Tensor], + labels: Union[List[torch.Tensor], torch.Tensor], + weights: Union[List[torch.Tensor], torch.Tensor], apply_bin: bool = False, ) -> torch.Tensor: """ @@ -60,10 +65,23 @@ def compute_auc( Args: n_tasks (int): number of tasks. - predictions (torch.Tensor): tensor of size (n_tasks, n_examples). - labels (torch.Tensor): tensor of size (n_tasks, n_examples). - weights (torch.Tensor): tensor of size (n_tasks, n_examples). + predictions (List[torch.Tensor]): tensor of size (n_tasks, n_examples). + labels (List[torch.Tensor]): tensor of size (n_tasks, n_examples). + weights (List[torch.Tensor]): tensor of size (n_tasks, n_examples). """ + if len(predictions) > 1: + # Lists of tensors not concatted can be passed in a not distributed setting, so we need to concat here + # pyre-ignore[6] + predictions = torch.cat(predictions, dim=-1) + # pyre-ignore[6] + labels = torch.cat(labels, dim=-1) + # pyre-ignore[6] + weights = torch.cat(weights, dim=-1) + else: + predictions = predictions[0] + labels = labels[0] + weights = weights[0] + aucs = [] for predictions_i, labels_i, weights_i in zip(predictions, labels, weights): auc = _compute_auc_helper(predictions_i, labels_i, weights_i, apply_bin) @@ -73,9 +91,9 @@ def compute_auc( def compute_auc_per_group( n_tasks: int, - predictions: torch.Tensor, - labels: torch.Tensor, - weights: torch.Tensor, + predictions: Union[List[torch.Tensor], torch.Tensor], + labels: Union[List[torch.Tensor], torch.Tensor], + weights: Union[List[torch.Tensor], torch.Tensor], grouping_keys: torch.Tensor, ) -> torch.Tensor: """ @@ -90,14 +108,25 @@ def compute_auc_per_group( Returns: torch.Tensor: tensor of size (n_tasks,), average of AUCs per group. """ + # prepare model outputs + if len(predictions) > 1: + # Lists of tensors not concatted can be passed in a not distributed setting, so we need to concat here + # pyre-ignore[6] + predictions = torch.cat(predictions, dim=-1) + # pyre-ignore[6] + labels = torch.cat(labels, dim=-1) + # pyre-ignore[6] + weights = torch.cat(weights, dim=-1) + else: + predictions = predictions[0] + labels = labels[0] + weights = weights[0] + aucs = [] if grouping_keys.numel() != 0 and grouping_keys[0] == -1: # we added padding as the first elements during init to avoid floating point exception in sync() # removing the paddings to avoid numerical errors. grouping_keys = grouping_keys[1:] - predictions = predictions[:, 1:] - labels = labels[:, 1:] - weights = weights[:, 1:] # get unique group indices group_indices = torch.unique(grouping_keys) @@ -162,6 +191,7 @@ def __init__( self._grouped_auc: bool = grouped_auc self._apply_bin: bool = apply_bin + self._num_samples: int = 0 self._add_state( PREDICTIONS, [], @@ -204,7 +234,7 @@ def __init__( def _init_states(self) -> None: if len(getattr(self, PREDICTIONS)) > 0: return - + self._num_samples = 0 getattr(self, PREDICTIONS).append( torch.zeros((self._n_tasks, 1), dtype=torch.float, device=self.device) ) @@ -241,25 +271,42 @@ def update( predictions = predictions.float() labels = labels.float() weights = weights.float() - num_samples = getattr(self, PREDICTIONS)[0].size(-1) batch_size = predictions.size(-1) - start_index = max(num_samples + batch_size - self._window_size, 0) + start_index = max(self._num_samples + batch_size - self._window_size, 0) + # Using `self.predictions =` will cause Pyre errors. - getattr(self, PREDICTIONS)[0] = torch.cat( - [ - cast(torch.Tensor, getattr(self, PREDICTIONS)[0])[:, start_index:], - predictions, - ], - dim=-1, - ) - getattr(self, LABELS)[0] = torch.cat( - [cast(torch.Tensor, getattr(self, LABELS)[0])[:, start_index:], labels], - dim=-1, - ) - getattr(self, WEIGHTS)[0] = torch.cat( - [cast(torch.Tensor, getattr(self, WEIGHTS)[0])[:, start_index:], weights], - dim=-1, - ) + w_preds = getattr(self, PREDICTIONS) + w_labels = getattr(self, LABELS) + w_weights = getattr(self, WEIGHTS) + + # remove init states + if self._num_samples == 0: + for lst in [w_preds, w_labels, w_weights]: + lst.pop(0) + + w_preds.append(predictions) + w_labels.append(labels) + w_weights.append(weights) + + self._num_samples += batch_size + + while self._num_samples > self._window_size: + diff = self._num_samples - self._window_size + if diff > w_preds[0].size(-1): + self._num_samples -= w_preds[0].size(-1) + # Remove the first element from predictions, labels, and weights + for lst in [w_preds, w_labels, w_weights]: + lst.pop(0) + else: + # Update the first element of predictions, labels, and weights + # Off by one potentially - keeping legacy behaviour + for lst in [w_preds, w_labels, w_weights]: + lst[0] = lst[0][:, diff:] + # if empty tensor, remove it + if torch.numel(lst[0]) == 0: + lst.pop(0) + self._num_samples -= diff + if self._grouped_auc: if REQUIRED_INPUTS not in kwargs or ( (grouping_keys := kwargs[REQUIRED_INPUTS].get(GROUPING_KEYS)) is None @@ -276,19 +323,21 @@ def update( ) def _compute(self) -> List[MetricComputationReport]: - reports = [ + reports = [] + reports.append( MetricComputationReport( name=MetricName.AUC, metric_prefix=MetricPrefix.WINDOW, value=compute_auc( self._n_tasks, - cast(torch.Tensor, getattr(self, PREDICTIONS)[0]), - cast(torch.Tensor, getattr(self, LABELS)[0]), - cast(torch.Tensor, getattr(self, WEIGHTS)[0]), + cast(List[torch.Tensor], getattr(self, PREDICTIONS)), + cast(List[torch.Tensor], getattr(self, LABELS)), + cast(List[torch.Tensor], getattr(self, WEIGHTS)), self._apply_bin, ), ) - ] + ) + if self._grouped_auc: reports.append( MetricComputationReport( @@ -296,15 +345,64 @@ def _compute(self) -> List[MetricComputationReport]: metric_prefix=MetricPrefix.WINDOW, value=compute_auc_per_group( self._n_tasks, - cast(torch.Tensor, getattr(self, PREDICTIONS)[0]), - cast(torch.Tensor, getattr(self, LABELS)[0]), - cast(torch.Tensor, getattr(self, WEIGHTS)[0]), - cast(torch.Tensor, getattr(self, GROUPING_KEYS)[0]), + cast(List[torch.Tensor], getattr(self, PREDICTIONS)), + cast(List[torch.Tensor], getattr(self, LABELS)), + cast(List[torch.Tensor], getattr(self, WEIGHTS)), + cast(torch.Tensor, getattr(self, GROUPING_KEYS))[0], ), ) ) return reports + @override + def _sync_dist( + self, + dist_sync_fn: Callable = gather_all_tensors, # pyre-ignore[24] + process_group: Optional[Any] = None, # pyre-ignore[2] + ) -> None: + """ + This function is overridden from torchmetric.Metric, since for AUC we want to cat the tensors + right before the allgather collective is called. + """ + input_dict = {attr: getattr(self, attr) for attr in self._reductions} + for attr, reduction_fn in self._reductions.items(): + # pre-concatenate metric states that are lists to reduce number of all_gather operations + if ( + reduction_fn == dim_zero_cat + and isinstance(input_dict[attr], list) + and len(input_dict[attr]) > 1 + ): + input_dict[attr] = [dim_zero_cat(input_dict[attr])] + + for k, v in input_dict.items(): + # this needs to account for list only and IF list has more than one elem in it that is just tensors + # concat list of tensors into a single tensor + if isinstance(v, list) and len(v) > 1: + input_dict[k] = [torch.cat(v, dim=-1)] + + output_dict = apply_to_collection( + input_dict, + Tensor, + dist_sync_fn, + group=process_group or self.process_group, + ) + + for attr, reduction_fn in self._reductions.items(): + # pre-processing ops (stack or flatten for inputs) + if isinstance(output_dict[attr][0], Tensor): + output_dict[attr] = torch.stack(output_dict[attr]) + elif isinstance(output_dict[attr][0], list): + output_dict[attr] = _flatten(output_dict[attr]) + + if not (callable(reduction_fn) or reduction_fn is None): + raise TypeError("reduction_fn must be callable or None") + reduced = ( + reduction_fn(output_dict[attr]) + if reduction_fn is not None + else output_dict[attr] + ) + setattr(self, attr, reduced) + def reset(self) -> None: super().reset() self._init_states() diff --git a/torchrec/metrics/test_utils/__init__.py b/torchrec/metrics/test_utils/__init__.py index ae78306e2..d2440cc67 100644 --- a/torchrec/metrics/test_utils/__init__.py +++ b/torchrec/metrics/test_utils/__init__.py @@ -41,7 +41,10 @@ def gen_test_batch( weight_value: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, n_classes: Optional[int] = None, + seed: Optional[int] = None, ) -> Dict[str, torch.Tensor]: + if seed is not None: + torch.manual_seed(seed) if label_value is not None: label = label_value else: @@ -340,6 +343,129 @@ def get_launch_config(world_size: int, rdzv_endpoint: str) -> pet.LaunchConfig: ) +def rec_metric_gpu_sync_test_launcher( + target_clazz: Type[RecMetric], + target_compute_mode: RecComputeMode, + test_clazz: Optional[Type[TestMetric]], + metric_name: str, + task_names: List[str], + fused_update_limit: int, + compute_on_all_ranks: bool, + should_validate_update: bool, + world_size: int, + entry_point: Callable[..., None], + batch_size: int = BATCH_SIZE, + batch_window_size: int = BATCH_WINDOW_SIZE, + **kwargs: Any, +) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + # might need to update lanch config to include max_nodes = 2 + lc = get_launch_config( + world_size=world_size, rdzv_endpoint=os.path.join(tmpdir, "rdzv") + ) + + # launch using torch elastic, launches for each rank + pet.elastic_launch(lc, entrypoint=entry_point)( + target_clazz, + target_compute_mode, + test_clazz, + task_names, + metric_name, + world_size, + fused_update_limit, + compute_on_all_ranks, + should_validate_update, + batch_size, + batch_window_size, + ) + + +def sync_test_helper( + target_clazz: Type[RecMetric], + target_compute_mode: RecComputeMode, + test_clazz: Optional[Type[TestMetric]], + task_names: List[str], + metric_name: str, + world_size: int, + fused_update_limit: int = 0, + compute_on_all_ranks: bool = False, + should_validate_update: bool = False, + batch_size: int = BATCH_SIZE, + batch_window_size: int = BATCH_WINDOW_SIZE, + n_classes: Optional[int] = None, + zero_weights: bool = False, +) -> None: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + dist.init_process_group( + backend="gloo", + world_size=world_size, + rank=rank, + ) + + tasks = gen_test_tasks(task_names) + + auc = target_clazz( + world_size=world_size, + batch_size=batch_size, + my_rank=rank, + compute_on_all_ranks=compute_on_all_ranks, + tasks=tasks, + window_size=batch_window_size * world_size, + ) + + weight_value: Optional[torch.Tensor] = None + + _model_outs = [ + gen_test_batch( + label_name=task.label_name, + prediction_name=task.prediction_name, + weight_name=task.weight_name, + batch_size=batch_size, + n_classes=n_classes, + weight_value=weight_value, + seed=42, # we set seed because of how test metric places tensors on ranks + ) + for task in tasks + ] + model_outs = [] + model_outs.append({k: v for d in _model_outs for k, v in d.items()}) + + # we send an uneven number of tensors to each rank to test that GPU sync works + if rank == 0: + for _ in range(3): + labels, predictions, weights, _ = parse_task_model_outputs( + tasks, model_outs[0] + ) + auc.update(predictions=predictions, labels=labels, weights=weights) + elif rank == 1: + for _ in range(1): + labels, predictions, weights, _ = parse_task_model_outputs( + tasks, model_outs[0] + ) + auc.update(predictions=predictions, labels=labels, weights=weights) + + # check against test metric + test_metrics: TestRecMetricOutput = ({}, {}, {}, {}) + if test_clazz is not None: + # pyre-ignore[45]: Cannot instantiate abstract class `TestMetric`. + test_metric_obj = test_clazz(world_size, tasks) + # with how testmetric is setup we cannot do asymmertrical updates across ranks + # so we duplicate model_outs twice to match number of updates in aggregate + model_outs = model_outs * 2 + test_metrics = test_metric_obj.compute(model_outs, 2, batch_window_size, None) + + res = auc.compute() + + if res: + assert torch.allclose( + test_metrics[1][task_names[0]], + res[f"auc-{task_names[0]}|window_auc"], + ) + + dist.destroy_process_group() + + def rec_metric_value_test_launcher( target_clazz: Type[RecMetric], target_compute_mode: RecComputeMode, diff --git a/torchrec/metrics/tests/test_auc.py b/torchrec/metrics/tests/test_auc.py index e81922cd4..9bc7357a5 100644 --- a/torchrec/metrics/tests/test_auc.py +++ b/torchrec/metrics/tests/test_auc.py @@ -20,7 +20,9 @@ ) from torchrec.metrics.test_utils import ( metric_test_helper, + rec_metric_gpu_sync_test_launcher, rec_metric_value_test_launcher, + sync_test_helper, TestMetric, ) @@ -115,6 +117,27 @@ def test_fused_auc(self) -> None: ) +class AUCGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = AUCMetric + task_name: str = "auc" + + def test_sync_auc(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=AUCMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestAUCMetric, + metric_name=AUCGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) + + class AUCMetricValueTest(unittest.TestCase): r"""This set of tests verify the computation logic of AUC in several corner cases that we know the computation results. The goal is to @@ -177,6 +200,57 @@ def test_calc_auc_balanced(self) -> None: actual_auc = self.auc.compute()["auc-DefaultTask|window_auc"] torch.allclose(expected_auc, actual_auc) + def test_calc_uneven_updates(self) -> None: + auc = AUCMetric( + world_size=1, + my_rank=0, + batch_size=100, + tasks=[DefaultTaskInfo], + ) + + expected_auc = torch.tensor([0.4464], dtype=torch.float) + # first batch + self.labels["DefaultTask"] = torch.tensor([1, 0, 0]) + self.predictions["DefaultTask"] = torch.tensor([0.2, 0.6, 0.8]) + self.weights["DefaultTask"] = torch.tensor([0.13, 0.2, 0.5]) + + auc.update(**self.batches) + # second batch + self.labels["DefaultTask"] = torch.tensor([1, 1]) + self.predictions["DefaultTask"] = torch.tensor([0.4, 0.9]) + self.weights["DefaultTask"] = torch.tensor([0.8, 0.75]) + + auc.update(**self.batches) + multiple_batch = self.auc.compute()["auc-DefaultTask|window_auc"] + torch.allclose(expected_auc, multiple_batch) + + def test_stress_auc(self) -> None: + auc = AUCMetric( + world_size=1, + my_rank=0, + batch_size=100, + tasks=[DefaultTaskInfo], + ) + # init states, so we expect 3 (state tensors) * 4 bytes (float) + self.assertEqual(sum(auc.get_memory_usage().values()), 12) + + # window size is only 100, so we should expect expected AUC to be same + expected_auc = torch.tensor([1.0], dtype=torch.float) + # first batch + self.labels["DefaultTask"] = torch.tensor([[1, 0, 0, 1, 1]]) + self.predictions["DefaultTask"] = torch.tensor([[1, 0, 0, 1, 1]]) + self.weights["DefaultTask"] = torch.tensor([[1] * 5]) + + for _ in range(1000): + auc.update(**self.batches) + + # check memory, window size is 100, so we have upperbound of memory to expect + # so with a 100 window size / tensors of size 5 = 20 tensors (per state) * 3 states * 20 bytes per tensor of size 5 = 1200 bytes + self.assertEqual(sum(auc.get_memory_usage().values()), 1200) + + result = auc.compute()["auc-DefaultTask|window_auc"] + torch.allclose(expected_auc, result) + def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]: return [