Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Traceable LayerNorm #1864

Merged
merged 1 commit into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 237 additions & 2 deletions apex/normalization/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,131 @@ def backward(ctx, grad_output):
return grad_input, grad_weight, grad_bias, None, None, None


if supports_custom_op():

@torch.library.custom_op("apex::fused_layer_norm_affine_fwd", mutates_args=())
def fused_layer_norm_affine_fwd(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
normalized_shape: List[int],
eps: float,
memory_efficient: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")

input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward_affine(
input_, normalized_shape, weight_, bias_, eps
)
return output, mean, invvar

@fused_layer_norm_affine_fwd.register_fake
def fused_layer_norm_affine_fwd_fake(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
normalized_shape: List[int],
eps: float,
memory_efficient: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
input = input.contiguous()
weight = weight.contiguous()
bias = bias.contiguous()
idiff = input.ndim - len(normalized_shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: what would idiff stand for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mimics the behavior at

int idiff = input.ndimension() - normalized_shape.size();

n = 1
for i in range(idiff):
n *= input.shape[i]
if input.dtype in [torch.float16, torch.bfloat16]:
dtype = torch.float32
else:
dtype = input.dtype
mean = torch.empty([n], dtype=dtype, device=input.device)
invvar = torch.empty_like(mean)
return torch.empty_like(input), mean, invvar

@torch.library.custom_op("apex::fused_layer_norm_affine_bwd", mutates_args=())
def fused_layer_norm_affine_bwd(
grad_output: torch.Tensor,
mean: torch.Tensor,
invvar: torch.Tensor,
input_or_output: torch.Tensor,
normalized_shape: List[int],
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
memory_efficient: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(
grad_output.contiguous(),
mean,
invvar,
input_or_output,
normalized_shape,
weight,
bias,
eps,
memory_efficient,
)
return grad_input, grad_weight, grad_bias

@fused_layer_norm_affine_bwd.register_fake
def fused_layer_norm_affine_bwd_fake(
grad_output: torch.Tensor,
mean: torch.Tensor,
invvar: torch.Tensor,
input_or_output: torch.Tensor,
normalized_shape: List[int],
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
memory_efficient: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
grad_input = torch.empty_like(input_or_output)
grad_weight = torch.empty_like(weight)
grad_bias = torch.empty_like(bias)
return grad_input, grad_weight, grad_bias

def _fused_layer_norm_affine_backward(ctx, grad_output, grad_mean, grad_invvar):
input_or_output, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias = fused_layer_norm_affine_bwd(
grad_output,
mean,
invvar,
input_or_output,
ctx.normalized_shape,
weight_,
bias_,
ctx.eps,
ctx.memory_efficient,
)
return grad_input, grad_weight, grad_bias, None, None, None

def _fused_layer_norm_affine_setup_context(ctx, inputs, output):
input, weight, bias, normalized_shape, eps, memory_efficient = inputs
output, mean, invvar = output
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
if memory_efficient:
ctx.save_for_backward(output, weight_, bias_, None, invvar)
else:
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
ctx.normalized_shape = normalized_shape
ctx.eps = eps
ctx.memory_efficient = memory_efficient

fused_layer_norm_affine_fwd.register_autograd(
_fused_layer_norm_affine_backward,
setup_context=_fused_layer_norm_affine_setup_context,
)


class FusedRMSNormAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, normalized_shape, eps, memory_efficient=False):
Expand Down Expand Up @@ -291,6 +416,110 @@ def backward(ctx, grad_output):
return grad_input, None, None, None


if supports_custom_op():

@torch.library.custom_op("apex::fused_layer_norm_fwd", mutates_args=())
def fused_layer_norm_fwd(
input: torch.Tensor,
normalized_shape: List[int],
eps: float,
memory_efficient: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")

input_ = input.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward(
input_, normalized_shape, eps
)
return output, mean, invvar

@fused_layer_norm_fwd.register_fake
def fused_layer_norm_fwd_fake(
input: torch.Tensor,
normalized_shape: List[int],
eps: float,
memory_efficient: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
input = input.contiguous()
idiff = input.ndim - len(normalized_shape)
n = 1
for i in range(idiff):
n *= input.shape[i]
if input.dtype in [torch.float16, torch.bfloat16]:
dtype = torch.float32
else:
dtype = input.dtype
mean = torch.empty([n], dtype=dtype, device=input.device)
invvar = torch.empty_like(mean)
return torch.empty_like(input), mean, invvar

@torch.library.custom_op("apex::fused_layer_norm_bwd", mutates_args=())
def fused_layer_norm_bwd(
grad_output: torch.Tensor,
mean: torch.Tensor,
invvar: torch.Tensor,
input_or_output: torch.Tensor,
normalized_shape: List[int],
eps: float,
memory_efficient: bool = False,
) -> torch.Tensor:
grad_input = fused_layer_norm_cuda.backward(
grad_output.contiguous(),
mean,
invvar,
input_or_output,
normalized_shape,
eps,
memory_efficient,
)
return grad_input

@fused_layer_norm_bwd.register_fake
def fused_layer_norm_bwd_fake(
grad_output: torch.Tensor,
mean: torch.Tensor,
invvar: torch.Tensor,
input_or_output: torch.Tensor,
normalized_shape: List[int],
eps: float,
memory_efficient: bool = False,
) -> torch.Tensor:
grad_input = torch.empty_like(input_or_output)
return grad_input

def _fused_layer_norm_backward(ctx, grad_output, grad_mean, grad_invvar):
input_or_output, mean, invvar = ctx.saved_tensors
grad_input = fused_layer_norm_bwd(
grad_output,
mean,
invvar,
input_or_output,
ctx.normalized_shape,
ctx.eps,
ctx.memory_efficient,
)
return grad_input, None, None, None

def _fused_layer_norm_setup_context(ctx, inputs, output):
input, normalized_shape, eps, memory_efficient = inputs
output, mean, invvar = output
input_ = input.contiguous()
if memory_efficient:
ctx.save_for_backward(output, None, invvar)
else:
ctx.save_for_backward(input_, mean, invvar)
ctx.normalized_shape = normalized_shape
ctx.eps = eps
ctx.memory_efficient = memory_efficient

fused_layer_norm_fwd.register_autograd(
_fused_layer_norm_backward,
setup_context=_fused_layer_norm_setup_context,
)


class FusedRMSNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, normalized_shape, eps, memory_efficient=False):
Expand Down Expand Up @@ -435,13 +664,19 @@ def _fused_rms_norm_setup_context(ctx, inputs, output):
def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6, memory_efficient=False):
args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps, memory_efficient)
with torch.amp.autocast('cuda', enabled=False):
return FusedLayerNormAffineFunction.apply(*args)
if supports_custom_op():
return fused_layer_norm_affine_fwd(*args)[0]
else:
return FusedLayerNormAffineFunction.apply(*args)


def fused_layer_norm(input, normalized_shape, eps=1e-6, memory_efficient=False):
args = _cast_if_autocast_enabled(input, normalized_shape, eps, memory_efficient)
with torch.amp.autocast('cuda', enabled=False):
return FusedLayerNormFunction.apply(*args)
if supports_custom_op():
return fused_layer_norm_fwd(*args)[0]
else:
return FusedLayerNormFunction.apply(*args)


def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6, memory_efficient=False):
Expand Down
24 changes: 24 additions & 0 deletions tests/L0/run_fused_layer_norm/test_fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,30 @@ def test_layer_norm_export(self):
self._verify_export(fused, fused_x)
self._verify_export(fused_m, fused_x)

@common_utils.parametrize("elementwise_affine", (True, False))
def test_compile_fused_layer_norm(self, elementwise_affine):
batch_size = 16
normalized_shape = [32, 16]
eager_mod = FusedLayerNorm(
normalized_shape=normalized_shape, elementwise_affine=elementwise_affine
).cuda()
compiled_mod = torch.compile(fullgraph=True)(eager_mod)
input_shape = [batch_size] + normalized_shape
eager_x = torch.randn(input_shape, device="cuda").requires_grad_(True)
compiled_x = eager_x.detach().clone().requires_grad_(True)

expected = eager_mod(eager_x)
actual = compiled_mod(compiled_x)
torch.testing.assert_close(actual, expected.detach())

g_eager = torch.rand_like(expected)
with torch.no_grad():
g_compiled = g_eager.detach().clone()
expected.backward(g_eager)
actual.backward(g_compiled)

torch.testing.assert_close(eager_x.grad, compiled_x.grad)

@common_utils.parametrize("elementwise_affine", (True, False))
def test_compile_fused_rms_norm(self, elementwise_affine):
batch_size = 16
Expand Down