Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Update _operation.py #22

Merged
merged 1 commit into from
Apr 12, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions energon/nn/layer/parallel_1d/_operation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import importlib

import torch

try:
import fused_mix_prec_layer_norm_cuda
# import fused_mix_prec_layer_norm_cuda
energon_layer_norm = importlib.import_module("energon_layer_norm")
except:
fused_mix_prec_layer_norm_cuda = None
energon_layer_norm = None


class FusedLayerNormAffineFunction1D(torch.autograd.Function):
Expand All @@ -27,7 +30,7 @@ def forward(ctx, input, weight, bias, normalized_shape, eps):
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
output, mean, invvar = energon_layer_norm.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
Expand All @@ -38,9 +41,9 @@ def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
= fused_mix_prec_layer_norm_cuda.backward_affine(
= energon_layer_norm.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)

return grad_input, grad_weight, grad_bias, None, None
return grad_input, grad_weight, grad_bias, None, None