diff --git a/apex/normalization/fused_layer_norm.py b/apex/normalization/fused_layer_norm.py index f2c29100..c6e70002 100644 --- a/apex/normalization/fused_layer_norm.py +++ b/apex/normalization/fused_layer_norm.py @@ -67,6 +67,131 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias, None, None, None +if supports_custom_op(): + + @torch.library.custom_op("apex::fused_layer_norm_affine_fwd", mutates_args=()) + def fused_layer_norm_affine_fwd( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + global fused_layer_norm_cuda + if fused_layer_norm_cuda is None: + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_layer_norm_cuda.forward_affine( + input_, normalized_shape, weight_, bias_, eps + ) + return output, mean, invvar + + @fused_layer_norm_affine_fwd.register_fake + def fused_layer_norm_affine_fwd_fake( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + input = input.contiguous() + weight = weight.contiguous() + bias = bias.contiguous() + idiff = input.ndim - len(normalized_shape) + n = 1 + for i in range(idiff): + n *= input.shape[i] + if input.dtype in [torch.float16, torch.bfloat16]: + dtype = torch.float32 + else: + dtype = input.dtype + mean = torch.empty([n], dtype=dtype, device=input.device) + invvar = torch.empty_like(mean) + return torch.empty_like(input), mean, invvar + + @torch.library.custom_op("apex::fused_layer_norm_affine_bwd", mutates_args=()) + def fused_layer_norm_affine_bwd( + grad_output: torch.Tensor, + mean: torch.Tensor, + invvar: torch.Tensor, + input_or_output: torch.Tensor, + normalized_shape: List[int], + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine( + grad_output.contiguous(), + mean, + invvar, + input_or_output, + normalized_shape, + weight, + bias, + eps, + memory_efficient, + ) + return grad_input, grad_weight, grad_bias + + @fused_layer_norm_affine_bwd.register_fake + def fused_layer_norm_affine_bwd_fake( + grad_output: torch.Tensor, + mean: torch.Tensor, + invvar: torch.Tensor, + input_or_output: torch.Tensor, + normalized_shape: List[int], + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + grad_input = torch.empty_like(input_or_output) + grad_weight = torch.empty_like(weight) + grad_bias = torch.empty_like(bias) + return grad_input, grad_weight, grad_bias + + def _fused_layer_norm_affine_backward(ctx, grad_output, grad_mean, grad_invvar): + input_or_output, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input, grad_weight, grad_bias = fused_layer_norm_affine_bwd( + grad_output, + mean, + invvar, + input_or_output, + ctx.normalized_shape, + weight_, + bias_, + ctx.eps, + ctx.memory_efficient, + ) + return grad_input, grad_weight, grad_bias, None, None, None + + def _fused_layer_norm_affine_setup_context(ctx, inputs, output): + input, weight, bias, normalized_shape, eps, memory_efficient = inputs + output, mean, invvar = output + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + if memory_efficient: + ctx.save_for_backward(output, weight_, bias_, None, invvar) + else: + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + + fused_layer_norm_affine_fwd.register_autograd( + _fused_layer_norm_affine_backward, + setup_context=_fused_layer_norm_affine_setup_context, + ) + + class FusedRMSNormAffineFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, weight, normalized_shape, eps, memory_efficient=False): @@ -291,6 +416,110 @@ def backward(ctx, grad_output): return grad_input, None, None, None +if supports_custom_op(): + + @torch.library.custom_op("apex::fused_layer_norm_fwd", mutates_args=()) + def fused_layer_norm_fwd( + input: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + global fused_layer_norm_cuda + if fused_layer_norm_cuda is None: + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + + input_ = input.contiguous() + output, mean, invvar = fused_layer_norm_cuda.forward( + input_, normalized_shape, eps + ) + return output, mean, invvar + + @fused_layer_norm_fwd.register_fake + def fused_layer_norm_fwd_fake( + input: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + input = input.contiguous() + idiff = input.ndim - len(normalized_shape) + n = 1 + for i in range(idiff): + n *= input.shape[i] + if input.dtype in [torch.float16, torch.bfloat16]: + dtype = torch.float32 + else: + dtype = input.dtype + mean = torch.empty([n], dtype=dtype, device=input.device) + invvar = torch.empty_like(mean) + return torch.empty_like(input), mean, invvar + + @torch.library.custom_op("apex::fused_layer_norm_bwd", mutates_args=()) + def fused_layer_norm_bwd( + grad_output: torch.Tensor, + mean: torch.Tensor, + invvar: torch.Tensor, + input_or_output: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> torch.Tensor: + grad_input = fused_layer_norm_cuda.backward( + grad_output.contiguous(), + mean, + invvar, + input_or_output, + normalized_shape, + eps, + memory_efficient, + ) + return grad_input + + @fused_layer_norm_bwd.register_fake + def fused_layer_norm_bwd_fake( + grad_output: torch.Tensor, + mean: torch.Tensor, + invvar: torch.Tensor, + input_or_output: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> torch.Tensor: + grad_input = torch.empty_like(input_or_output) + return grad_input + + def _fused_layer_norm_backward(ctx, grad_output, grad_mean, grad_invvar): + input_or_output, mean, invvar = ctx.saved_tensors + grad_input = fused_layer_norm_bwd( + grad_output, + mean, + invvar, + input_or_output, + ctx.normalized_shape, + ctx.eps, + ctx.memory_efficient, + ) + return grad_input, None, None, None + + def _fused_layer_norm_setup_context(ctx, inputs, output): + input, normalized_shape, eps, memory_efficient = inputs + output, mean, invvar = output + input_ = input.contiguous() + if memory_efficient: + ctx.save_for_backward(output, None, invvar) + else: + ctx.save_for_backward(input_, mean, invvar) + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + + fused_layer_norm_fwd.register_autograd( + _fused_layer_norm_backward, + setup_context=_fused_layer_norm_setup_context, + ) + + class FusedRMSNormFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, normalized_shape, eps, memory_efficient=False): @@ -435,13 +664,19 @@ def _fused_rms_norm_setup_context(ctx, inputs, output): def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6, memory_efficient=False): args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps, memory_efficient) with torch.amp.autocast('cuda', enabled=False): - return FusedLayerNormAffineFunction.apply(*args) + if supports_custom_op(): + return fused_layer_norm_affine_fwd(*args)[0] + else: + return FusedLayerNormAffineFunction.apply(*args) def fused_layer_norm(input, normalized_shape, eps=1e-6, memory_efficient=False): args = _cast_if_autocast_enabled(input, normalized_shape, eps, memory_efficient) with torch.amp.autocast('cuda', enabled=False): - return FusedLayerNormFunction.apply(*args) + if supports_custom_op(): + return fused_layer_norm_fwd(*args)[0] + else: + return FusedLayerNormFunction.apply(*args) def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6, memory_efficient=False): diff --git a/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py b/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py index abd0ac72..fa97076d 100644 --- a/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py +++ b/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py @@ -310,6 +310,30 @@ def test_layer_norm_export(self): self._verify_export(fused, fused_x) self._verify_export(fused_m, fused_x) + @common_utils.parametrize("elementwise_affine", (True, False)) + def test_compile_fused_layer_norm(self, elementwise_affine): + batch_size = 16 + normalized_shape = [32, 16] + eager_mod = FusedLayerNorm( + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine + ).cuda() + compiled_mod = torch.compile(fullgraph=True)(eager_mod) + input_shape = [batch_size] + normalized_shape + eager_x = torch.randn(input_shape, device="cuda").requires_grad_(True) + compiled_x = eager_x.detach().clone().requires_grad_(True) + + expected = eager_mod(eager_x) + actual = compiled_mod(compiled_x) + torch.testing.assert_close(actual, expected.detach()) + + g_eager = torch.rand_like(expected) + with torch.no_grad(): + g_compiled = g_eager.detach().clone() + expected.backward(g_eager) + actual.backward(g_compiled) + + torch.testing.assert_close(eager_x.grad, compiled_x.grad) + @common_utils.parametrize("elementwise_affine", (True, False)) def test_compile_fused_rms_norm(self, elementwise_affine): batch_size = 16