Skip to content

Commit

Permalink
Move layer norm to phi (#40193)
Browse files Browse the repository at this point in the history
* update

* fix bugs; test=develop

* update; test=develop

* fix test compile error; test=develop

* fix cpu compile error; test=develop

* fix test error; test=develo

* fix layer_norm_op plugin error; test=develop

* fix error; test=develop

* fix test bug; test=develop

* update; test=develop

* polish code; test=develop

* fix bugs; test=develop

* remove unused depency; test=develop

* polish code; test=develop
  • Loading branch information
phlrain authored Mar 17, 2022
1 parent c335288 commit 681a686
Show file tree
Hide file tree
Showing 21 changed files with 1,036 additions and 698 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/layer_norm_op.h"

#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h"
#include "paddle/fluid/operators/layer_norm_op.h"
#include "paddle/phi/kernels/layer_norm_kernel.h"

namespace paddle {
namespace inference {
Expand Down Expand Up @@ -83,7 +83,7 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs,
cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * feature_size,
cudaMemcpyHostToDevice, stream);

paddle::operators::LayerNormDirectCUDAFunctor<float> layer_norm;
phi::LayerNormDirectCUDAFunctor<float> layer_norm;
layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d,
variance_d, begin_norm_axis, eps);
return cudaGetLastError() != cudaSuccess;
Expand Down Expand Up @@ -177,7 +177,7 @@ int LayerNormPluginDynamic::enqueue(
cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * feature_size,
cudaMemcpyHostToDevice, stream);

paddle::operators::LayerNormDirectCUDAFunctor<float> layer_norm;
phi::LayerNormDirectCUDAFunctor<float> layer_norm;
layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d,
variance_d, begin_norm_axis, eps);
} else {
Expand Down
24 changes: 15 additions & 9 deletions paddle/fluid/operators/fused/fused_dropout_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ limitations under the License. */
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/layer_norm_kernel.h"

namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace memory = paddle::memory;

USE_OP_ITSELF(dropout);
USE_OP(layer_norm);
USE_OP_ITSELF(layer_norm);

template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
Expand Down Expand Up @@ -136,18 +138,23 @@ void LayerNorm(const std::vector<LayerNormParamType<T>> &scale,
const platform::CUDADeviceContext &ctx) {
framework::Scope scope;
auto place = ctx.GetPlace();
paddle::optional<const framework::LoDTensor &> scale_opt = paddle::none;
if (scale.size() > 0) {
auto var_scale = scope.Var("Scale");
auto tensor_scale = var_scale->GetMutable<framework::LoDTensor>();
framework::TensorFromVector(scale, ctx, tensor_scale);
tensor_scale->Resize({cols});
scale_opt = *tensor_scale;
}

paddle::optional<const framework::LoDTensor &> bias_opt = paddle::none;
if (bias.size() > 0) {
auto var_bias = scope.Var("Bias");
auto tensor_bias = var_bias->GetMutable<framework::LoDTensor>();
framework::TensorFromVector(bias, ctx, tensor_bias);
tensor_bias->Resize({cols});

bias_opt = *tensor_bias;
}

auto var_x = scope.Var("X");
Expand All @@ -157,20 +164,19 @@ void LayerNorm(const std::vector<LayerNormParamType<T>> &scale,

auto var_y = scope.Var("Y");
auto tensor_y = var_y->GetMutable<framework::LoDTensor>();
tensor_y->Resize({rows, cols});

auto var_mean = scope.Var("Mean");
auto tensor_mean = var_mean->GetMutable<framework::LoDTensor>();
tensor_mean->Resize({rows});

auto var_variance = scope.Var("Variance");
auto tensor_variance = var_variance->GetMutable<framework::LoDTensor>();

framework::AttributeMap attrs;
attrs.insert({"epsilon", epsilon});

auto op = framework::OpRegistry::CreateOp(
"layer_norm", {{"X", {"X"}}, {"Scale", {"Scale"}}, {"Bias", {"Bias"}}},
{{"Y", {"Y"}}, {"Mean", {"Mean"}}, {"Variance", {"Variance"}}}, attrs);
op->Run(scope, place);
tensor_variance->Resize({rows});
ctx.Wait();
phi::LayerNormKernel<T>(static_cast<const phi::GPUContext &>(ctx), *tensor_x,
scale_opt, bias_opt, 1e-5, 1, false, tensor_y,
tensor_mean, tensor_variance);
framework::TensorToVector(*tensor_y, ctx, y);
framework::TensorToVector(*tensor_mean, ctx, means);
framework::TensorToVector(*tensor_variance, ctx, vars);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ struct TestFusedLayernormResidualDropoutBias {
residual_vec[i * cols + j] + out2[i * cols + j];
}
}

LayerNorm<T>(scale_vec, layernorm_bias_vec, correct_out, &correct_means,
&correct_vars, &correct_layernorm_out, epsilon, rows, cols,
*ctx);
Expand Down
17 changes: 9 additions & 8 deletions paddle/fluid/operators/layer_norm_kernel.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -758,12 +758,14 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
*/
template <typename T, typename U, typename ScaleT = U,
typename MaskType = uint8_t>
void ln_bwd_1024_kernel_driver(
const platform::CUDADeviceContext &dev_ctx, const int rows, const int cols,
float epsilon, const T *x_ptr, const ScaleT *scale_ptr, const U *mean_ptr,
const U *var_ptr, const T *dout_ptr, T *dx_ptr, ScaleT *dscale_ptr,
ScaleT *dbias_ptr, const MaskType *mask_ptr = nullptr,
T factor = static_cast<T>(0), T *d_dropout_src_ptr = nullptr) {
void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows,
const int cols, float epsilon, const T *x_ptr,
const ScaleT *scale_ptr, const U *mean_ptr,
const U *var_ptr, const T *dout_ptr, T *dx_ptr,
ScaleT *dscale_ptr, ScaleT *dbias_ptr,
const MaskType *mask_ptr = nullptr,
T factor = static_cast<T>(0),
T *d_dropout_src_ptr = nullptr) {
auto stream = dev_ctx.stream();
if (cols == 1024) {
// step-1: compute dx and reduced part results of dscale and dbias.
Expand Down Expand Up @@ -1334,8 +1336,7 @@ static void LayerNormBackward(
const U *mean, const U *var, T *d_x,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias, float epsilon,
int64_t batch_size, int64_t feature_size,
const platform::CUDADeviceContext &dev_ctx) {
int64_t batch_size, int64_t feature_size, const phi::GPUContext &dev_ctx) {
auto stream = dev_ctx.stream();
#ifdef __HIPCC__
const int kMaxBlockDim = 256;
Expand Down
10 changes: 1 addition & 9 deletions paddle/fluid/operators/layer_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/layer_norm_op.h"

#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"

#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
Expand Down Expand Up @@ -278,10 +277,3 @@ REGISTER_OPERATOR(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker,
ops::LayerNormGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(layer_norm_grad, ops::LayerNormGradOp,
ops::LayerNormGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
layer_norm, ops::LayerNormKernel<paddle::platform::CPUDeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
layer_norm_grad,
ops::LayerNormGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::LayerNormGradKernel<paddle::platform::CPUDeviceContext, double>);
Loading

0 comments on commit 681a686

Please sign in to comment.