From e337ed0dde879caa105ff0e3c710e600a9eb5176 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Tue, 3 Sep 2024 17:28:15 -0700 Subject: [PATCH] Removed unused dw tensor in Triton RMSNorm ghstack-source-id: c2337c6f976b41288498b7f3aa9b6f3d54d49ad9 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/567 --- torchtitan/models/norms.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index 266453301..315274521 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -284,7 +284,6 @@ def backward(ctx, dy): M, N = dy.shape dx = torch.empty_like(x) - dw = torch.empty_like(weight) sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)