Skip to content

Commit

Permalink
Add num positive and negative examples to torch metrics
Browse files Browse the repository at this point in the history
Summary: This adds num positive and num negative examples in addition to total num examples in the QPS computation for torch metrics.

Reviewed By: wilson100hong

Differential Revision: D56376508

fbshipit-source-id: 0d4dc4631da8bd3f365500072c45d71d1e75bf96
  • Loading branch information
drdarshan authored and facebook-github-bot committed Apr 29, 2024
1 parent ae3de28 commit 8014517
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 0 deletions.
2 changes: 2 additions & 0 deletions torchrec/metrics/metrics_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class MetricName(MetricNameBase):
NDCG = "ndcg"
XAUC = "xauc"
SCALAR = "scalar"
TOTAL_POSITIVE_EXAMPLES = "total_positive_examples"
TOTAL_NEGATIVE_EXAMPLES = "total_negative_examples"


class MetricNamespaceBase(StrValueMixin, Enum):
Expand Down
18 changes: 18 additions & 0 deletions torchrec/metrics/tests/test_tower_qps.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,13 @@ def test_warmup_checkpointing(self) -> None:
window_size=200,
)
model_output = gen_test_batch(batch_size)
labels = model_output["label"]
num_positive_examples = labels[labels > 0].sum()
num_negative_examples = labels[labels <= 0].sum()

self.assertTrue(hasattr(qps._metrics_computations[0], "num_positive_examples"))
self.assertTrue(hasattr(qps._metrics_computations[0], "num_negative_examples"))

for i in range(5):
for _ in range(warmup_steps + extra_steps):
qps.update(
Expand All @@ -265,6 +272,17 @@ def test_warmup_checkpointing(self) -> None:
qps._metrics_computations[0].num_examples,
batch_size * (warmup_steps + extra_steps) * (i + 1),
)

self.assertEquals(
qps._metrics_computations[0].num_positive_examples,
num_positive_examples * (warmup_steps + extra_steps) * (i + 1),
)

self.assertEquals(
qps._metrics_computations[0].num_negative_examples,
num_negative_examples * (warmup_steps + extra_steps) * (i + 1),
)

# Mimic trainer crashing and loading a checkpoint.
qps._metrics_computations[0]._steps = 0

Expand Down
42 changes: 42 additions & 0 deletions torchrec/metrics/tower_qps.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
NUM_EXAMPLES = "num_examples"
WARMUP_EXAMPLES = "warmup_examples"
TIME_LAPSE = "time_lapse"
NUM_POSITIVE_EXAMPLES = "num_positive_examples"
NUM_NEGATIVE_EXAMPLES = "num_negative_examples"


def _compute_tower_qps(
Expand Down Expand Up @@ -68,6 +70,20 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
dist_reduce_fx="sum",
persistent=True,
)
self._add_state(
NUM_POSITIVE_EXAMPLES,
torch.zeros(self._n_tasks, dtype=torch.long),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)
self._add_state(
NUM_NEGATIVE_EXAMPLES,
torch.zeros(self._n_tasks, dtype=torch.long),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)
self._add_state(
TIME_LAPSE,
torch.zeros(self._n_tasks, dtype=torch.double),
Expand All @@ -87,10 +103,26 @@ def update(
**kwargs: Dict[str, Any],
) -> None:
self._steps += 1

num_examples_scalar = labels.shape[-1]
num_examples = torch.tensor(num_examples_scalar, dtype=torch.long)
self_num_examples = getattr(self, NUM_EXAMPLES)
self_num_examples += num_examples

num_positive_examples_scalar = int(labels[labels > 0].sum())
num_positive_examples = torch.tensor(
num_positive_examples_scalar, dtype=torch.long
)
self_num_positive_examples = getattr(self, NUM_POSITIVE_EXAMPLES)
self_num_positive_examples += num_positive_examples

num_negative_examples_scalar = int(labels[labels <= 0].sum())
num_negative_examples = torch.tensor(
num_negative_examples_scalar, dtype=torch.long
)
self_num_negative_examples = getattr(self, NUM_NEGATIVE_EXAMPLES)
self_num_negative_examples += num_negative_examples

ts = time.monotonic()
if self._steps <= self._warmup_steps:
self_warmup_examples = getattr(self, WARMUP_EXAMPLES)
Expand Down Expand Up @@ -131,6 +163,16 @@ def _compute(self) -> List[MetricComputationReport]:
metric_prefix=MetricPrefix.DEFAULT,
value=cast(torch.Tensor, self.num_examples).detach(),
),
MetricComputationReport(
name=MetricName.TOTAL_POSITIVE_EXAMPLES,
metric_prefix=MetricPrefix.DEFAULT,
value=cast(torch.Tensor, self.num_positive_examples).detach(),
),
MetricComputationReport(
name=MetricName.TOTAL_NEGATIVE_EXAMPLES,
metric_prefix=MetricPrefix.DEFAULT,
value=cast(torch.Tensor, self.num_negative_examples).detach(),
),
]


Expand Down

0 comments on commit 8014517

Please sign in to comment.