diff --git a/torchrec/metrics/metrics_namespace.py b/torchrec/metrics/metrics_namespace.py index f9610eccb..d804e4f12 100644 --- a/torchrec/metrics/metrics_namespace.py +++ b/torchrec/metrics/metrics_namespace.py @@ -63,6 +63,8 @@ class MetricName(MetricNameBase): NDCG = "ndcg" XAUC = "xauc" SCALAR = "scalar" + TOTAL_POSITIVE_EXAMPLES = "total_positive_examples" + TOTAL_NEGATIVE_EXAMPLES = "total_negative_examples" class MetricNamespaceBase(StrValueMixin, Enum): diff --git a/torchrec/metrics/tests/test_tower_qps.py b/torchrec/metrics/tests/test_tower_qps.py index 4afb37d82..72c174e4d 100644 --- a/torchrec/metrics/tests/test_tower_qps.py +++ b/torchrec/metrics/tests/test_tower_qps.py @@ -250,6 +250,13 @@ def test_warmup_checkpointing(self) -> None: window_size=200, ) model_output = gen_test_batch(batch_size) + labels = model_output["label"] + num_positive_examples = labels[labels > 0].sum() + num_negative_examples = labels[labels <= 0].sum() + + self.assertTrue(hasattr(qps._metrics_computations[0], "num_positive_examples")) + self.assertTrue(hasattr(qps._metrics_computations[0], "num_negative_examples")) + for i in range(5): for _ in range(warmup_steps + extra_steps): qps.update( @@ -265,6 +272,17 @@ def test_warmup_checkpointing(self) -> None: qps._metrics_computations[0].num_examples, batch_size * (warmup_steps + extra_steps) * (i + 1), ) + + self.assertEquals( + qps._metrics_computations[0].num_positive_examples, + num_positive_examples * (warmup_steps + extra_steps) * (i + 1), + ) + + self.assertEquals( + qps._metrics_computations[0].num_negative_examples, + num_negative_examples * (warmup_steps + extra_steps) * (i + 1), + ) + # Mimic trainer crashing and loading a checkpoint. qps._metrics_computations[0]._steps = 0 diff --git a/torchrec/metrics/tower_qps.py b/torchrec/metrics/tower_qps.py index 8e72824c6..7d9f4ad0a 100644 --- a/torchrec/metrics/tower_qps.py +++ b/torchrec/metrics/tower_qps.py @@ -28,6 +28,8 @@ NUM_EXAMPLES = "num_examples" WARMUP_EXAMPLES = "warmup_examples" TIME_LAPSE = "time_lapse" +NUM_POSITIVE_EXAMPLES = "num_positive_examples" +NUM_NEGATIVE_EXAMPLES = "num_negative_examples" def _compute_tower_qps( @@ -68,6 +70,20 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: dist_reduce_fx="sum", persistent=True, ) + self._add_state( + NUM_POSITIVE_EXAMPLES, + torch.zeros(self._n_tasks, dtype=torch.long), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + NUM_NEGATIVE_EXAMPLES, + torch.zeros(self._n_tasks, dtype=torch.long), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) self._add_state( TIME_LAPSE, torch.zeros(self._n_tasks, dtype=torch.double), @@ -87,10 +103,26 @@ def update( **kwargs: Dict[str, Any], ) -> None: self._steps += 1 + num_examples_scalar = labels.shape[-1] num_examples = torch.tensor(num_examples_scalar, dtype=torch.long) self_num_examples = getattr(self, NUM_EXAMPLES) self_num_examples += num_examples + + num_positive_examples_scalar = int(labels[labels > 0].sum()) + num_positive_examples = torch.tensor( + num_positive_examples_scalar, dtype=torch.long + ) + self_num_positive_examples = getattr(self, NUM_POSITIVE_EXAMPLES) + self_num_positive_examples += num_positive_examples + + num_negative_examples_scalar = int(labels[labels <= 0].sum()) + num_negative_examples = torch.tensor( + num_negative_examples_scalar, dtype=torch.long + ) + self_num_negative_examples = getattr(self, NUM_NEGATIVE_EXAMPLES) + self_num_negative_examples += num_negative_examples + ts = time.monotonic() if self._steps <= self._warmup_steps: self_warmup_examples = getattr(self, WARMUP_EXAMPLES) @@ -131,6 +163,16 @@ def _compute(self) -> List[MetricComputationReport]: metric_prefix=MetricPrefix.DEFAULT, value=cast(torch.Tensor, self.num_examples).detach(), ), + MetricComputationReport( + name=MetricName.TOTAL_POSITIVE_EXAMPLES, + metric_prefix=MetricPrefix.DEFAULT, + value=cast(torch.Tensor, self.num_positive_examples).detach(), + ), + MetricComputationReport( + name=MetricName.TOTAL_NEGATIVE_EXAMPLES, + metric_prefix=MetricPrefix.DEFAULT, + value=cast(torch.Tensor, self.num_negative_examples).detach(), + ), ]