diff --git a/energon/nn/layer/parallel_1d/_operation.py b/energon/nn/layer/parallel_1d/_operation.py index d6b851e..d5ff005 100644 --- a/energon/nn/layer/parallel_1d/_operation.py +++ b/energon/nn/layer/parallel_1d/_operation.py @@ -1,9 +1,12 @@ +import importlib + import torch try: - import fused_mix_prec_layer_norm_cuda + # import fused_mix_prec_layer_norm_cuda + energon_layer_norm = importlib.import_module("energon_layer_norm") except: - fused_mix_prec_layer_norm_cuda = None + energon_layer_norm = None class FusedLayerNormAffineFunction1D(torch.autograd.Function): @@ -27,7 +30,7 @@ def forward(ctx, input, weight, bias, normalized_shape, eps): input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() - output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( + output, mean, invvar = energon_layer_norm.forward_affine( input_, ctx.normalized_shape, weight_, bias_, ctx.eps) ctx.save_for_backward(input_, weight_, bias_, mean, invvar) return output @@ -38,9 +41,9 @@ def backward(ctx, grad_output): input_, weight_, bias_, mean, invvar = ctx.saved_tensors grad_input = grad_weight = grad_bias = None grad_input, grad_weight, grad_bias \ - = fused_mix_prec_layer_norm_cuda.backward_affine( + = energon_layer_norm.backward_affine( grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps) - return grad_input, grad_weight, grad_bias, None, None \ No newline at end of file + return grad_input, grad_weight, grad_bias, None, None