Skip to content

Commit

Permalink
Traceable LayerNorm (#1864)
Browse files Browse the repository at this point in the history
  • Loading branch information
yanboliang authored Dec 14, 2024
1 parent 2863aa0 commit 73375b3
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 2 deletions.
239 changes: 237 additions & 2 deletions apex/normalization/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,131 @@ def backward(ctx, grad_output):
return grad_input, grad_weight, grad_bias, None, None, None


if supports_custom_op():

@torch.library.custom_op("apex::fused_layer_norm_affine_fwd", mutates_args=())
def fused_layer_norm_affine_fwd(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
normalized_shape: List[int],
eps: float,
memory_efficient: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")

input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward_affine(
input_, normalized_shape, weight_, bias_, eps
)
return output, mean, invvar

@fused_layer_norm_affine_fwd.register_fake
def fused_layer_norm_affine_fwd_fake(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
normalized_shape: List[int],
eps: float,
memory_efficient: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
input = input.contiguous()
weight = weight.contiguous()
bias = bias.contiguous()
idiff = input.ndim - len(normalized_shape)
n = 1
for i in range(idiff):
n *= input.shape[i]
if input.dtype in [torch.float16, torch.bfloat16]:
dtype = torch.float32
else:
dtype = input.dtype
mean = torch.empty([n], dtype=dtype, device=input.device)
invvar = torch.empty_like(mean)
return torch.empty_like(input), mean, invvar

@torch.library.custom_op("apex::fused_layer_norm_affine_bwd", mutates_args=())
def fused_layer_norm_affine_bwd(
grad_output: torch.Tensor,
mean: torch.Tensor,
invvar: torch.Tensor,
input_or_output: torch.Tensor,
normalized_shape: List[int],
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
memory_efficient: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(
grad_output.contiguous(),
mean,
invvar,
input_or_output,
normalized_shape,
weight,
bias,
eps,
memory_efficient,
)
return grad_input, grad_weight, grad_bias

@fused_layer_norm_affine_bwd.register_fake
def fused_layer_norm_affine_bwd_fake(
grad_output: torch.Tensor,
mean: torch.Tensor,
invvar: torch.Tensor,
input_or_output: torch.Tensor,
normalized_shape: List[int],
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
memory_efficient: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
grad_input = torch.empty_like(input_or_output)
grad_weight = torch.empty_like(weight)
grad_bias = torch.empty_like(bias)
return grad_input, grad_weight, grad_bias

def _fused_layer_norm_affine_backward(ctx, grad_output, grad_mean, grad_invvar):
input_or_output, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias = fused_layer_norm_affine_bwd(
grad_output,
mean,
invvar,
input_or_output,
ctx.normalized_shape,
weight_,
bias_,
ctx.eps,
ctx.memory_efficient,
)
return grad_input, grad_weight, grad_bias, None, None, None

def _fused_layer_norm_affine_setup_context(ctx, inputs, output):
input, weight, bias, normalized_shape, eps, memory_efficient = inputs
output, mean, invvar = output
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
if memory_efficient:
ctx.save_for_backward(output, weight_, bias_, None, invvar)
else:
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
ctx.normalized_shape = normalized_shape
ctx.eps = eps
ctx.memory_efficient = memory_efficient

fused_layer_norm_affine_fwd.register_autograd(
_fused_layer_norm_affine_backward,
setup_context=_fused_layer_norm_affine_setup_context,
)


class FusedRMSNormAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, normalized_shape, eps, memory_efficient=False):
Expand Down Expand Up @@ -291,6 +416,110 @@ def backward(ctx, grad_output):
return grad_input, None, None, None


if supports_custom_op():

@torch.library.custom_op("apex::fused_layer_norm_fwd", mutates_args=())
def fused_layer_norm_fwd(
input: torch.Tensor,
normalized_shape: List[int],
eps: float,
memory_efficient: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")

input_ = input.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward(
input_, normalized_shape, eps
)
return output, mean, invvar

@fused_layer_norm_fwd.register_fake
def fused_layer_norm_fwd_fake(
input: torch.Tensor,
normalized_shape: List[int],
eps: float,
memory_efficient: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
input = input.contiguous()
idiff = input.ndim - len(normalized_shape)
n = 1
for i in range(idiff):
n *= input.shape[i]
if input.dtype in [torch.float16, torch.bfloat16]:
dtype = torch.float32
else:
dtype = input.dtype
mean = torch.empty([n], dtype=dtype, device=input.device)
invvar = torch.empty_like(mean)
return torch.empty_like(input), mean, invvar

@torch.library.custom_op("apex::fused_layer_norm_bwd", mutates_args=())
def fused_layer_norm_bwd(
grad_output: torch.Tensor,
mean: torch.Tensor,
invvar: torch.Tensor,
input_or_output: torch.Tensor,
normalized_shape: List[int],
eps: float,
memory_efficient: bool = False,
) -> torch.Tensor:
grad_input = fused_layer_norm_cuda.backward(
grad_output.contiguous(),
mean,
invvar,
input_or_output,
normalized_shape,
eps,
memory_efficient,
)
return grad_input

@fused_layer_norm_bwd.register_fake
def fused_layer_norm_bwd_fake(
grad_output: torch.Tensor,
mean: torch.Tensor,
invvar: torch.Tensor,
input_or_output: torch.Tensor,
normalized_shape: List[int],
eps: float,
memory_efficient: bool = False,
) -> torch.Tensor:
grad_input = torch.empty_like(input_or_output)
return grad_input

def _fused_layer_norm_backward(ctx, grad_output, grad_mean, grad_invvar):
input_or_output, mean, invvar = ctx.saved_tensors
grad_input = fused_layer_norm_bwd(
grad_output,
mean,
invvar,
input_or_output,
ctx.normalized_shape,
ctx.eps,
ctx.memory_efficient,
)
return grad_input, None, None, None

def _fused_layer_norm_setup_context(ctx, inputs, output):
input, normalized_shape, eps, memory_efficient = inputs
output, mean, invvar = output
input_ = input.contiguous()
if memory_efficient:
ctx.save_for_backward(output, None, invvar)
else:
ctx.save_for_backward(input_, mean, invvar)
ctx.normalized_shape = normalized_shape
ctx.eps = eps
ctx.memory_efficient = memory_efficient

fused_layer_norm_fwd.register_autograd(
_fused_layer_norm_backward,
setup_context=_fused_layer_norm_setup_context,
)


class FusedRMSNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, normalized_shape, eps, memory_efficient=False):
Expand Down Expand Up @@ -435,13 +664,19 @@ def _fused_rms_norm_setup_context(ctx, inputs, output):
def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6, memory_efficient=False):
args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps, memory_efficient)
with torch.amp.autocast('cuda', enabled=False):
return FusedLayerNormAffineFunction.apply(*args)
if supports_custom_op():
return fused_layer_norm_affine_fwd(*args)[0]
else:
return FusedLayerNormAffineFunction.apply(*args)


def fused_layer_norm(input, normalized_shape, eps=1e-6, memory_efficient=False):
args = _cast_if_autocast_enabled(input, normalized_shape, eps, memory_efficient)
with torch.amp.autocast('cuda', enabled=False):
return FusedLayerNormFunction.apply(*args)
if supports_custom_op():
return fused_layer_norm_fwd(*args)[0]
else:
return FusedLayerNormFunction.apply(*args)


def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6, memory_efficient=False):
Expand Down
24 changes: 24 additions & 0 deletions tests/L0/run_fused_layer_norm/test_fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,30 @@ def test_layer_norm_export(self):
self._verify_export(fused, fused_x)
self._verify_export(fused_m, fused_x)

@common_utils.parametrize("elementwise_affine", (True, False))
def test_compile_fused_layer_norm(self, elementwise_affine):
batch_size = 16
normalized_shape = [32, 16]
eager_mod = FusedLayerNorm(
normalized_shape=normalized_shape, elementwise_affine=elementwise_affine
).cuda()
compiled_mod = torch.compile(fullgraph=True)(eager_mod)
input_shape = [batch_size] + normalized_shape
eager_x = torch.randn(input_shape, device="cuda").requires_grad_(True)
compiled_x = eager_x.detach().clone().requires_grad_(True)

expected = eager_mod(eager_x)
actual = compiled_mod(compiled_x)
torch.testing.assert_close(actual, expected.detach())

g_eager = torch.rand_like(expected)
with torch.no_grad():
g_compiled = g_eager.detach().clone()
expected.backward(g_eager)
actual.backward(g_compiled)

torch.testing.assert_close(eager_x.grad, compiled_x.grad)

@common_utils.parametrize("elementwise_affine", (True, False))
def test_compile_fused_rms_norm(self, elementwise_affine):
batch_size = 16
Expand Down

0 comments on commit 73375b3

Please sign in to comment.