From 79d9099f13319fe0909ef88533f336c14af8bb3e Mon Sep 17 00:00:00 2001 From: Atul Jangra Date: Fri, 31 Jan 2025 20:30:22 -0800 Subject: [PATCH] Remove unused memory checks to speed up `compute` Summary: The memory checks here have non-significant overhead in every compute step as there are a lot of tensor size calls involved here. In our runs, this accounted for around 20% time spent in the rec metric compute step. Given that this is not being used anymore, let's remove this call. This diff removes the call from the metric_module. In the next set of diffs, I'll remove the argument from the callsites. Differential Revision: D68995122 --- torchrec/metrics/metric_module.py | 41 +----- torchrec/metrics/tests/test_metric_module.py | 136 ------------------- 2 files changed, 2 insertions(+), 175 deletions(-) diff --git a/torchrec/metrics/metric_module.py b/torchrec/metrics/metric_module.py index e4417e2c0..17c1cfe26 100644 --- a/torchrec/metrics/metric_module.py +++ b/torchrec/metrics/metric_module.py @@ -109,9 +109,6 @@ MODEL_METRIC_LABEL: str = "model" -MEMORY_AVG_WARNING_PERCENTAGE = 20 -MEMORY_AVG_WARNING_WARMUP = 100 - MetricValue = Union[torch.Tensor, float] @@ -146,7 +143,7 @@ class RecMetricModule(nn.Module): throughput_metric (Optional[ThroughputMetric]): the ThroughputMetric. state_metrics (Optional[Dict[str, StateMetric]]): the dict of StateMetrics. compute_interval_steps (int): the intervals between two compute calls in the unit of batch number - memory_usage_limit_mb (float): the memory usage limit for OOM check + memory_usage_limit_mb (float): [Unused] the memory usage limit for OOM check Call Args: Not supported. @@ -177,8 +174,6 @@ class RecMetricModule(nn.Module): rec_metrics: RecMetricList throughput_metric: Optional[ThroughputMetric] state_metrics: Dict[str, StateMetric] - memory_usage_limit_mb: float - memory_usage_mb_avg: float oom_count: int compute_count: int last_compute_time: float @@ -195,6 +190,7 @@ def __init__( compute_interval_steps: int = 100, min_compute_interval: float = 0.0, max_compute_interval: float = float("inf"), + # Unused, but needed for backwards compatibility. TODO: Remove from callsites memory_usage_limit_mb: float = 512, ) -> None: super().__init__() @@ -205,8 +201,6 @@ def __init__( self.trained_batches: int = 0 self.batch_size = batch_size self.world_size = world_size - self.memory_usage_limit_mb = memory_usage_limit_mb - self.memory_usage_mb_avg = 0.0 self.oom_count = 0 self.compute_count = 0 @@ -230,37 +224,6 @@ def __init__( ) self.last_compute_time = -1.0 - def get_memory_usage(self) -> int: - r"""Total memory of unique RecMetric tensors in bytes""" - total = {} - for metric in self.rec_metrics.rec_metrics: - total.update(metric.get_memory_usage()) - return sum(total.values()) - - def check_memory_usage(self, compute_count: int) -> None: - memory_usage_mb = self.get_memory_usage() / (10**6) - if memory_usage_mb > self.memory_usage_limit_mb: - self.oom_count += 1 - logger.warning( - f"MetricModule is using {memory_usage_mb}MB. " - f"This is larger than the limit{self.memory_usage_limit_mb}MB. " - f"This is the f{self.oom_count}th OOM." - ) - - if ( - compute_count > MEMORY_AVG_WARNING_WARMUP - and memory_usage_mb - > self.memory_usage_mb_avg * ((100 + MEMORY_AVG_WARNING_PERCENTAGE) / 100) - ): - logger.warning( - f"MetricsModule is using more than {MEMORY_AVG_WARNING_PERCENTAGE}% of " - f"the average memory usage. Current usage: {memory_usage_mb}MB." - ) - - self.memory_usage_mb_avg = ( - self.memory_usage_mb_avg * (compute_count - 1) + memory_usage_mb - ) / compute_count - def _update_rec_metrics( self, model_out: Dict[str, torch.Tensor], **kwargs: Any ) -> None: diff --git a/torchrec/metrics/tests/test_metric_module.py b/torchrec/metrics/tests/test_metric_module.py index c5968b463..dd6b9518a 100644 --- a/torchrec/metrics/tests/test_metric_module.py +++ b/torchrec/metrics/tests/test_metric_module.py @@ -353,142 +353,6 @@ def test_initial_states_rank0_checkpointing(self) -> None: lc, entrypoint=self._run_trainer_initial_states_checkpointing )() - def test_empty_memory_usage(self) -> None: - mock_optimizer = MockOptimizer() - config = EmptyMetricsConfig - metric_module = generate_metric_module( - TestMetricModule, - metrics_config=config, - batch_size=128, - world_size=64, - my_rank=0, - state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer}, - device=torch.device("cpu"), - ) - self.assertEqual(metric_module.get_memory_usage(), 0) - - def test_ne_memory_usage(self) -> None: - mock_optimizer = MockOptimizer() - config = DefaultMetricsConfig - metric_module = generate_metric_module( - TestMetricModule, - metrics_config=config, - batch_size=128, - world_size=64, - my_rank=0, - state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer}, - device=torch.device("cpu"), - ) - # Default NEMetric's dtype is - # float64 (8 bytes) * 16 tensors of size 1 = 128 bytes - # Tensors in NeMetricComputation: - # 8 in _default, 8 specific attributes: 4 attributes, 4 window - self.assertEqual(metric_module.get_memory_usage(), 128) - metric_module.update(gen_test_batch(128)) - self.assertEqual(metric_module.get_memory_usage(), 160) - - def test_calibration_memory_usage(self) -> None: - mock_optimizer = MockOptimizer() - config = dataclasses.replace( - DefaultMetricsConfig, - rec_metrics={ - RecMetricEnum.CALIBRATION: RecMetricDef( - rec_tasks=[DefaultTaskInfo], window_size=_DEFAULT_WINDOW_SIZE - ) - }, - ) - metric_module = generate_metric_module( - TestMetricModule, - metrics_config=config, - batch_size=128, - world_size=64, - my_rank=0, - state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer}, - device=torch.device("cpu"), - ) - # Default calibration metric dtype is - # float64 (8 bytes) * 8 tensors, size 1 = 64 bytes - # Tensors in CalibrationMetricComputation: - # 4 in _default, 4 specific attributes: 2 attribute, 2 window - self.assertEqual(metric_module.get_memory_usage(), 64) - metric_module.update(gen_test_batch(128)) - self.assertEqual(metric_module.get_memory_usage(), 80) - - def test_auc_memory_usage(self) -> None: - mock_optimizer = MockOptimizer() - config = dataclasses.replace( - DefaultMetricsConfig, - rec_metrics={ - RecMetricEnum.AUC: RecMetricDef( - rec_tasks=[DefaultTaskInfo], window_size=_DEFAULT_WINDOW_SIZE - ) - }, - ) - metric_module = generate_metric_module( - TestMetricModule, - metrics_config=config, - batch_size=128, - world_size=64, - my_rank=0, - state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer}, - device=torch.device("cpu"), - ) - # 3 (tensors) * 4 (float) - self.assertEqual(metric_module.get_memory_usage(), 12) - metric_module.update(gen_test_batch(128)) - # 3 (tensors) * 128 (batch_size) * 4 (float) - self.assertEqual(metric_module.get_memory_usage(), 1536) - - # Test memory usage over multiple updates does not increase unexpectedly, we don't need to force OOM as just knowing if the memory usage is increeasing how we expect is enough - for _ in range(10): - metric_module.update(gen_test_batch(128)) - - # 3 tensors * 128 batch size * 4 float * 11 updates - self.assertEqual(metric_module.get_memory_usage(), 16896) - - # Ensure reset frees memory correctly - metric_module.reset() - self.assertEqual(metric_module.get_memory_usage(), 12) - - def test_check_memory_usage(self) -> None: - mock_optimizer = MockOptimizer() - config = DefaultMetricsConfig - metric_module = generate_metric_module( - TestMetricModule, - metrics_config=config, - batch_size=128, - world_size=64, - my_rank=0, - state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer}, - device=torch.device("cpu"), - ) - metric_module.update(gen_test_batch(128)) - with patch("torchrec.metrics.metric_module.logger") as logger_mock: - # Memory usage is fine. - metric_module.memory_usage_mb_avg = 160 / (10**6) - metric_module.check_memory_usage(1000) - self.assertEqual(metric_module.oom_count, 0) - self.assertEqual(logger_mock.warning.call_count, 0) - - # OOM but memory usage does not exceed avg. - metric_module.memory_usage_limit_mb = 0.000001 - metric_module.memory_usage_mb_avg = 160 / (10**6) - metric_module.check_memory_usage(1000) - self.assertEqual(metric_module.oom_count, 1) - self.assertEqual(logger_mock.warning.call_count, 1) - - # OOM and memory usage exceed avg but warmup is not over. - metric_module.memory_usage_mb_avg = 160 / (10**6) / 10 - metric_module.check_memory_usage(2) - self.assertEqual(metric_module.oom_count, 2) - self.assertEqual(logger_mock.warning.call_count, 2) - - # OOM and memory usage exceed avg and warmup is over. - metric_module.memory_usage_mb_avg = 160 / (10**6) / 1.25 - metric_module.check_memory_usage(1002) - self.assertEqual(metric_module.oom_count, 3) - self.assertEqual(logger_mock.warning.call_count, 4) - def test_should_compute(self) -> None: metric_module = generate_metric_module( TestMetricModule,