From 4e2b1ad86be368cddb054a4b181a089588f6b9de Mon Sep 17 00:00:00 2001 From: "Tian Zheng (Engrg-Hardware 1)" Date: Tue, 14 Nov 2023 06:34:26 +0000 Subject: [PATCH 1/4] Rename output --- paddle/phi/api/yaml/fused_ops.yaml | 2 +- paddle/phi/infermeta/fusion.cc | 6 ++++-- paddle/phi/infermeta/fusion.h | 2 +- .../kernels/fusion/gpu/fused_scale_bias_add_relu_kernel.cu | 4 +++- test/legacy_test/test_fused_scale_bias_add_relu_op.py | 2 +- 5 files changed, 10 insertions(+), 6 deletions(-) diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index de8d47beee64f..ec620cd13a63b 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -252,7 +252,7 @@ - op : fused_scale_bias_add_relu args : (Tensor x1, Tensor scale1, Tensor bias1, Tensor x2, Tensor scale2, Tensor bias2, bool fuse_dual, bool exhaustive_search) optional : scale2, bias2 - output : Tensor(y) + output : Tensor(out) infer_meta : func : FusedScaleBiasAddReluInferMeta kernel : diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 0bda38a08d651..c14d84ca0a32b 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -2112,7 +2112,7 @@ void FusedScaleBiasAddReluInferMeta(const MetaTensor& x1, const MetaTensor& bias2, bool fuse_dual, bool exhaustive_search, - MetaTensor* y) { + MetaTensor* out) { // check optional inputs if (fuse_dual) { bool has_scale2 = !!scale2; @@ -2127,7 +2127,9 @@ void FusedScaleBiasAddReluInferMeta(const MetaTensor& x1, fuse_dual)); } // set output dims - y->set_dims(x1.dims()); + out->set_dims(x1.dims()); + out->set_dtype(x1.dtype()); + out->set_layout(x1.layout()); } void SqueezeExcitationInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 8ae35445c6800..e433a47eb67ba 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -552,7 +552,7 @@ void FusedScaleBiasAddReluInferMeta(const MetaTensor& x1, const MetaTensor& bias2, bool fuse_prologue, bool exhaustive_search, - MetaTensor* y); + MetaTensor* out); void SqueezeExcitationInferMeta(const MetaTensor& x, const MetaTensor& filter, diff --git a/paddle/phi/kernels/fusion/gpu/fused_scale_bias_add_relu_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_scale_bias_add_relu_kernel.cu index ff5edd689f7f3..373dd28fdb874 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_scale_bias_add_relu_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_scale_bias_add_relu_kernel.cu @@ -46,13 +46,15 @@ void FusedScaleBiasAddReluKernel(const Context& dev_ctx, const paddle::optional& bias2, bool fuse_dual, bool exhaustive_search, - DenseTensor* y) { + DenseTensor* out) { PADDLE_ENFORCE_GE(dev_ctx.GetComputeCapability(), 80, phi::errors::PreconditionNotMet( "This op only supports Ampere and later devices, " "but got compute capability: %d.", dev_ctx.GetComputeCapability())); + + DenseTensor* y = out; auto& plan_cache = phi::autotune::AutoTuneCache::Instance().GetConvV8( phi::autotune::AlgorithmType::kScaleBiasAddRelu); diff --git a/test/legacy_test/test_fused_scale_bias_add_relu_op.py b/test/legacy_test/test_fused_scale_bias_add_relu_op.py index 44952e1ea23d1..0755dfc6a83c1 100644 --- a/test/legacy_test/test_fused_scale_bias_add_relu_op.py +++ b/test/legacy_test/test_fused_scale_bias_add_relu_op.py @@ -90,7 +90,7 @@ def setUp(self): } self.outputs = { - 'y': y_output, + 'out': y_output, } def has_cuda(self): From 287998e59d2ecc66e1b7b7f4e29c5fc5dcd7bd5f Mon Sep 17 00:00:00 2001 From: "Tian Zheng (Engrg-Hardware 1)" Date: Sun, 29 Oct 2023 08:22:42 +0000 Subject: [PATCH 2/4] Add fused_dconv_drelu_dbn_op --- .../pir/dialect/op_generator/ops_api_gen.py | 2 + paddle/phi/api/yaml/fused_ops.yaml | 10 + paddle/phi/infermeta/fusion.cc | 130 ++ paddle/phi/infermeta/fusion.h | 35 + paddle/phi/kernels/CMakeLists.txt | 7 +- paddle/phi/kernels/autotune/cache.cc | 7 + paddle/phi/kernels/autotune/cache.h | 5 +- .../gpu/fused_dconv_drelu_dbn_kernel.cu | 1162 +++++++++++++++++ test/legacy_test/CMakeLists.txt | 1 + .../test_fused_dconv_drelu_dbn_op.py | 478 +++++++ 10 files changed, 1833 insertions(+), 4 deletions(-) create mode 100644 paddle/phi/kernels/fusion/gpu/fused_dconv_drelu_dbn_kernel.cu create mode 100644 test/legacy_test/test_fused_dconv_drelu_dbn_op.py diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index 6f0a552b529d9..f654f7f021a6a 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -76,6 +76,7 @@ 'fused_multi_transformer_xpu', 'fused_scale_bias_relu_conv_bn', 'fused_scale_bias_add_relu', + 'fused_dconv_drelu_dbn', 'fusion_transpose_flatten_concat', 'skip_layernorm', 'generate_sequence_xpu', @@ -118,6 +119,7 @@ 'fused_elemwise_add_activation', 'fused_scale_bias_relu_conv_bn', 'fused_scale_bias_add_relu', + 'fused_dconv_drelu_dbn', 'recv_v2', 'rnn_', 'seed', diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index ec620cd13a63b..a47099bbfcee7 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -174,6 +174,16 @@ optional : bias, residual, norm_weight, norm_bias, residual_out support_dygraph_mode : true +- op : fused_dconv_drelu_dbn + args : (Tensor grad_output, Tensor weight, Tensor grad_output_add, Tensor residual_input, Tensor bn1_eqscale, Tensor bn1_eqbias, Tensor conv_input, Tensor bn1_mean, Tensor bn1_inv_std, Tensor bn1_gamma, Tensor bn1_beta, Tensor bn1_input, Tensor bn2_mean, Tensor bn2_inv_std, Tensor bn2_gamma, Tensor bn2_beta, Tensor bn2_input, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, str data_format, bool fuse_shortcut, bool fuse_dual, bool fuse_add, bool exhaustive_search) + output : Tensor(grad_weight), Tensor(grad_bn1_input), Tensor(grad_bn1_gamma), Tensor(grad_bn1_beta), Tensor(grad_bn2_input), Tensor(grad_bn2_gamma), Tensor(grad_bn2_beta) + optional : grad_output_add, residual_input, bn1_eqscale, bn1_eqbias, conv_input, bn2_mean, bn2_inv_std, bn2_gamma, bn2_beta, bn2_input, grad_bn2_input, grad_bn2_gamma, grad_bn2_beta + infer_meta : + func : FusedDconvDreluDbnInferMeta + kernel : + func : fused_dconv_drelu_dbn + data_type : grad_output + - op : fused_dropout_add args : (Tensor x, Tensor y, Tensor seed_tensor, Scalar p, bool is_test, str mode, int seed = 0, bool fix_seed = false) optional : seed_tensor diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index c14d84ca0a32b..740b5cf24ad3b 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -2132,6 +2132,136 @@ void FusedScaleBiasAddReluInferMeta(const MetaTensor& x1, out->set_layout(x1.layout()); } +void FusedDconvDreluDbnInferMeta(const MetaTensor& grad_output, + const MetaTensor& weight, + const MetaTensor& grad_output_add, + const MetaTensor& residual_input, + const MetaTensor& bn1_eqscale, + const MetaTensor& bn1_eqbias, + const MetaTensor& conv_input, + const MetaTensor& bn1_mean, + const MetaTensor& bn1_inv_std, + const MetaTensor& bn1_gamma, + const MetaTensor& bn1_beta, + const MetaTensor& bn1_input, + const MetaTensor& bn2_mean, + const MetaTensor& bn2_inv_std, + const MetaTensor& bn2_gamma, + const MetaTensor& bn2_beta, + const MetaTensor& bn2_input, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const std::string& padding_algorithm, + int groups, + const std::string& data_format, + bool fuse_shortcut, + bool fuse_dual, + bool fuse_add, + bool exhaustive_search, + MetaTensor* grad_weight, + MetaTensor* grad_bn1_input, + MetaTensor* grad_bn1_gamma, + MetaTensor* grad_bn1_beta, + MetaTensor* grad_bn2_input, + MetaTensor* grad_bn2_gamma, + MetaTensor* grad_bn2_beta) { + // Check if data format is NHWC + PADDLE_ENFORCE_EQ( + data_format, + "NHWC", + phi::errors::InvalidArgument( + "Operator(FusedScaleBiasReluConvBnstats) only supports data format " + "of " + "channel last (NHWC) now. But recieved: data_format = '%s'.", + data_format)); + + PADDLE_ENFORCE_EQ( + groups, + 1, + phi::errors::InvalidArgument("Expect group to be 1, got %d.", groups)); + + PADDLE_ENFORCE_EQ( + fuse_shortcut && fuse_dual, + 0, + phi::errors::InvalidArgument( + "fuse_shortcut and fuse_dual should not be set at the same time." + "Got fuse_shortcut=%d, fuse_dual=%d.", + fuse_shortcut, + fuse_dual)); + + if (fuse_add) { + PADDLE_ENFORCE_EQ( + !!grad_output_add, + true, + phi::errors::InvalidArgument( + "grad_output_add must be provided when fuse_add = true." + "Got fuse_add=%d, grad_output_add=%d.", + fuse_add, + !!grad_output_add)); + } + if (fuse_shortcut) { + PADDLE_ENFORCE_EQ( + !!residual_input, + true, + phi::errors::InvalidArgument( + "residual_input must be provided when fuse_shortcut = true." + "Got fuse_shortcut =%d, residual_input=%d.", + fuse_shortcut, + !!residual_input)); + } + if (fuse_shortcut || fuse_dual) { + PADDLE_ENFORCE_EQ( + !!conv_input, + true, + phi::errors::InvalidArgument( + "conv_input must be provided when either fuse_shortcut " + "or fuse_dual is set to true. Got conv_input=%d, fuse_shortcut=%d, " + "fuse_dual=%d.", + !!conv_input, + fuse_shortcut, + fuse_dual)); + } else { + PADDLE_ENFORCE_EQ( + bn1_eqscale && bn1_eqbias, + true, + phi::errors::InvalidArgument( + "bn1_eqscale and bn1_eqbias must be provided when neither " + "fuse_shortcut " + "or fuse_dual is set. Got bn1_eqscale=%d, bn1_eqbias=%d.", + !!bn1_eqscale, + !!bn1_eqbias)); + } + if (fuse_dual) { + PADDLE_ENFORCE_EQ( + bn2_mean && bn2_inv_std && bn2_gamma && bn2_beta && bn2_input, + true, + phi::errors::InvalidArgument("bn2_mean, bn2_inv_std, bn2_gamma, " + "bn2_beta, bn2_input must be provided " + "when fuse_dual is set. Got bn2_mean=%d, " + "bn2_inv_std=%d, bn2_gamma=%d, " + "bn2_beta=%d, bn2_input=%d.", + !!bn2_mean, + !!bn2_inv_std, + !!bn2_gamma, + !!bn2_beta, + !!bn2_input)); + } + grad_weight->set_dims(weight.dims()); + grad_bn1_input->set_dims(bn1_input.dims()); + grad_bn1_gamma->set_dims(bn1_gamma.dims()); + grad_bn1_beta->set_dims(bn1_beta.dims()); + if (grad_bn2_input) { + grad_bn2_input->set_dims(bn1_input.dims()); + } + if (grad_bn2_gamma) { + grad_bn2_gamma->set_dims(bn1_gamma.dims()); + } + if (grad_bn2_beta) { + grad_bn2_beta->set_dims(bn1_beta.dims()); + } +} + void SqueezeExcitationInferMeta(const MetaTensor& x, const MetaTensor& filter, const MetaTensor& filter_max, diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index e433a47eb67ba..46580488a6c49 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -554,6 +554,41 @@ void FusedScaleBiasAddReluInferMeta(const MetaTensor& x1, bool exhaustive_search, MetaTensor* out); +void FusedDconvDreluDbnInferMeta(const MetaTensor& grad_output, + const MetaTensor& weight, + const MetaTensor& grad_output_add, + const MetaTensor& residual_input, + const MetaTensor& bn1_eqscale, + const MetaTensor& bn1_eqbias, + const MetaTensor& conv_input, + const MetaTensor& bn1_mean, + const MetaTensor& bn1_inv_std, + const MetaTensor& bn1_gamma, + const MetaTensor& bn1_beta, + const MetaTensor& bn1_input, + const MetaTensor& bn2_mean, + const MetaTensor& bn2_inv_std, + const MetaTensor& bn2_gamma, + const MetaTensor& bn2_beta, + const MetaTensor& bn2_input, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const std::string& padding_algorithm, + int groups, + const std::string& data_format, + bool fuse_shortcut, + bool fuse_dual, + bool fuse_add, + bool exhaustive_search, + MetaTensor* grad_weight, + MetaTensor* grad_bn1_input, + MetaTensor* grad_bn1_gamma, + MetaTensor* grad_bn1_beta, + MetaTensor* grad_bn2_input, + MetaTensor* grad_bn2_gamma, + MetaTensor* grad_bn2_beta); + void SqueezeExcitationInferMeta(const MetaTensor& x, const MetaTensor& filter, const MetaTensor& filter_max, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 5355b4e08c21b..36cbac4a8683d 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -168,9 +168,10 @@ if(WITH_CUTLASS) endif() if(NOT WITH_CUDNN_FRONTEND) - list(REMOVE_ITEM kernel_cu - "fusion/gpu/fused_scale_bias_relu_conv_bn_kernel.cu" - "fusion/gpu/fused_scale_bias_add_relu_kernel.cu") + list( + REMOVE_ITEM kernel_cu "fusion/gpu/fused_scale_bias_relu_conv_bn_kernel.cu" + "fusion/gpu/fused_scale_bias_add_relu_kernel.cu" + "fusion/gpu/fused_dconv_drelu_dbn_kernel.cu") endif() set(cc_search_pattern diff --git a/paddle/phi/kernels/autotune/cache.cc b/paddle/phi/kernels/autotune/cache.cc index 0f1c9171264d1..72611d3482a6f 100644 --- a/paddle/phi/kernels/autotune/cache.cc +++ b/paddle/phi/kernels/autotune/cache.cc @@ -55,6 +55,13 @@ std::string AlgorithmTypeString(int64_t algo_type) { } else if (algo_type == static_cast(AlgorithmType::kScaleBiasAddRelu)) { return "scale_bias_add_relu"; + } else if (algo_type == + static_cast(AlgorithmType::kDgradDreluBnBwdWeight)) { + return "dgrad_drelu_bnbwdweight"; + } else if (algo_type == static_cast(AlgorithmType::kDbnApply)) { + return "dbn_apply"; + } else if (algo_type == static_cast(AlgorithmType::kBnActWgrad)) { + return "bn_act_wgrad"; } #endif return std::to_string(algo_type); diff --git a/paddle/phi/kernels/autotune/cache.h b/paddle/phi/kernels/autotune/cache.h index ff47bfffcc448..fcb9058cd0a76 100644 --- a/paddle/phi/kernels/autotune/cache.h +++ b/paddle/phi/kernels/autotune/cache.h @@ -58,7 +58,10 @@ enum class AlgorithmType { kScaleBiasReluConvBNstats = 13, kBNFinalize = 14, kScaleBiasAddRelu = 15, - kAlgorithmCount = 16 + kDgradDreluBnBwdWeight = 16, + kDbnApply = 17, + kBnActWgrad = 18, + kAlgorithmCount = 19 #endif }; diff --git a/paddle/phi/kernels/fusion/gpu/fused_dconv_drelu_dbn_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_dconv_drelu_dbn_kernel.cu new file mode 100644 index 0000000000000..e194ae3f4756b --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_dconv_drelu_dbn_kernel.cu @@ -0,0 +1,1162 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/phi/backends/gpu/cuda/cudnn_helper.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/flags.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/autotune/cache.h" +#include "paddle/phi/kernels/cpu/conv_util.h" +#include "paddle/phi/kernels/funcs/batch_norm_utils.h" +#include "paddle/phi/kernels/gpudnn/conv_cudnn_frontend.h" + +PHI_DECLARE_bool(cudnn_deterministic); +PHI_DECLARE_bool(cudnn_exhaustive_search); + +namespace phi { +namespace fusion { + +using helper = phi::CudnnFrontendConvHelper; + +template +using CudnnDataType = phi::backends::gpu::CudnnDataType; + +namespace { +cudnn_frontend::Operation MakeDreluOp(cudnnDataType_t dtype, + cudnn_frontend::Tensor const& dy_desc, + cudnn_frontend::Tensor const& x_desc, + cudnn_frontend::Tensor const& dx_desc) { + auto op_desc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_BWD) + .setComputeType(dtype) + .build(); + auto op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setdyDesc(dy_desc) + .setxDesc(x_desc) + .setdxDesc(dx_desc) + .setpwDesc(op_desc) + .build(); + VLOG(6) << op.describe(); + return op; +} + +cudnn_frontend::Operation MakeBnbwdweightOp( + cudnnDataType_t dtype, + cudnn_frontend::Tensor const& x_desc, + cudnn_frontend::Tensor const& mean_desc, + cudnn_frontend::Tensor const& invstd_desc, + cudnn_frontend::Tensor const& bn_scale_desc, + cudnn_frontend::Tensor const& dy_desc, + cudnn_frontend::Tensor const& dbn_bias_desc, + cudnn_frontend::Tensor const& dbn_scale_desc, + cudnn_frontend::Tensor const& eq_dy_scale_desc, + cudnn_frontend::Tensor const& eq_x_scale_desc, + cudnn_frontend::Tensor const& eqbias_desc) { + auto op = + cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_BN_BWD_WEIGHTS_DESCRIPTOR) + .setComputeType(dtype) + .setxDesc(x_desc) + .setSavedMeanAndInvVar(mean_desc, invstd_desc) + .setScale(bn_scale_desc) + .setdyDesc(dy_desc) + .setEqScalesAndBias(eq_dy_scale_desc, eq_x_scale_desc, eqbias_desc) + .setDScaleAndDBias(dbn_scale_desc, dbn_bias_desc) + .build(); + VLOG(6) << op.describe(); + return op; +} +} // namespace + +template +void _DgradDreluBnBwdWeightImpl(const Context& dev_ctx, + const DenseTensor* grad_output, + const DenseTensor* weight, + const DenseTensor* bn1_mean, + const DenseTensor* bn1_inv_std, + const DenseTensor* bn1_gamma, + const DenseTensor* bn1_beta, + const DenseTensor* bn1_input, + const DenseTensor* residual_input, + const DenseTensor* bn2_mean, + const DenseTensor* bn2_inv_std, + const DenseTensor* bn2_gamma, + const DenseTensor* bn2_beta, + const DenseTensor* bn2_input, + const DenseTensor* grad_output_add, + bool fuse_shortcut, + bool fuse_dual, + bool fuse_add, + const std::vector& strides, + const std::vector& dilations, + const std::vector& pre_padding, + const std::vector& post_padding, + bool exhaustive_search, + bool deterministic, + DenseTensor* grad_conv_input, + DenseTensor* grad_bn1_gamma, + DenseTensor* grad_bn1_beta, + DenseTensor* bn1_coeff_a, + DenseTensor* bn1_coeff_b, + DenseTensor* bn1_coeff_c, + DenseTensor* grad_bn2_gamma, + DenseTensor* grad_bn2_beta, + DenseTensor* bn2_coeff_a, + DenseTensor* bn2_coeff_b, + DenseTensor* bn2_coeff_c) { + auto& plan_cache = phi::autotune::AutoTuneCache::Instance().GetConvV8( + phi::autotune::AlgorithmType::kDgradDreluBnBwdWeight); + // get handles + auto handle = dev_ctx.cudnn_handle(); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + // transform filter to NHWC layout + DenseTensor w_tensor_transformed(weight->dtype()); + ResizeToChannelLast(dev_ctx, weight, &w_tensor_transformed); + TransToChannelLast(dev_ctx, weight, &w_tensor_transformed); + // build tensor descriptors + cudnnTensorFormat_t layout_format = CUDNN_TENSOR_NHWC; + auto tensor_format = + phi::backends::gpu::ToCudnnDataType(grad_output->dtype()); + auto tensor_format_math = CUDNN_DATA_FLOAT; + auto compute_dtype = CUDNN_DATA_FLOAT; + // get dims in CUDNN manner: [N, C, H, W] + auto dim_x = phi::backends::gpu::TransformDimOrder( + phi::vectorize(bn1_input->dims())); + auto dim_filt = phi::backends::gpu::TransformDimOrder( + phi::vectorize(w_tensor_transformed.dims())); + auto dim_y = phi::backends::gpu::TransformDimOrder( + phi::vectorize(grad_output->dims())); + std::vector dim_scale(dim_x.size(), 1); + dim_scale[1] = dim_x[1]; // [1, C, 1, 1] + + std::vector data_ptrs; + std::vector uids; + int64_t uid = 100; + + // Build tensor descriptors + // dgrad inputs + auto dy_desc = helper::GetGeneralTensorDescriptor( + dim_y, layout_format, ++uid, 16, tensor_format); + data_ptrs.push_back(const_cast(grad_output->data())); + uids.push_back(uid); + + auto w_desc = helper::GetGeneralTensorDescriptor( + dim_filt, layout_format, ++uid, 16, tensor_format); + data_ptrs.push_back(const_cast(w_tensor_transformed.data())); + uids.push_back(uid); + + // dBN1 inputs + auto bn1_mean_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format_math); + data_ptrs.push_back(const_cast(bn1_mean->data())); + uids.push_back(uid); + + auto bn1_inv_std_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format_math); + data_ptrs.push_back(const_cast(bn1_inv_std->data())); + uids.push_back(uid); + + auto bn1_scale_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format_math); + data_ptrs.push_back(const_cast(bn1_gamma->data())); + uids.push_back(uid); + + auto bn1_bias_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format_math); + data_ptrs.push_back(const_cast(bn1_beta->data())); + uids.push_back(uid); + + auto bn1_x_desc = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format); + data_ptrs.push_back(const_cast(bn1_input->data())); + uids.push_back(uid); + + // dBN2 inputs + auto bn2_mean_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format_math); + if (fuse_dual) { + data_ptrs.push_back(const_cast(bn2_mean->data())); + uids.push_back(uid); + } + + auto bn2_inv_std_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format_math); + if (fuse_dual) { + data_ptrs.push_back(const_cast(bn2_inv_std->data())); + uids.push_back(uid); + } + + auto bn2_scale_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format_math); + if (fuse_dual) { + data_ptrs.push_back(const_cast(bn2_gamma->data())); + uids.push_back(uid); + } + + auto bn2_bias_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format_math); + if (fuse_dual) { + data_ptrs.push_back(const_cast(bn2_beta->data())); + uids.push_back(uid); + } + + auto bn2_x_desc = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format); + if (fuse_dual) { + data_ptrs.push_back(const_cast(bn2_input->data())); + uids.push_back(uid); + } + + // shortcut input + auto relu_x_desc = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format); + if (fuse_shortcut) { + data_ptrs.push_back(const_cast(residual_input->data())); + uids.push_back(uid); + } + + // fuse_add inputs + auto dy_branch_desc = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format); + if (fuse_add) { + data_ptrs.push_back(const_cast(grad_output_add->data())); + uids.push_back(uid); + } + + // virtual outputs + auto dx_dgrad_desc = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + auto after_add0 = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + auto after_add1 = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + auto after_mul1 = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + auto after_add2 = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + auto after_mul2 = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + auto final_bitmask_desc = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + auto after_dual_add1 = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + auto after_dual_mul1 = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + auto after_dual_add2 = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + auto after_dual_mul2 = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + + // drelu outputs + auto dx_desc = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format); + data_ptrs.push_back(grad_conv_input->data()); + uids.push_back(uid); + + // dBN1 outputs + auto bn1_dgamma_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format_math); + data_ptrs.push_back(grad_bn1_gamma->data()); + uids.push_back(uid); + + auto bn1_dbeta_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format_math); + data_ptrs.push_back(grad_bn1_beta->data()); + uids.push_back(uid); + + auto bn1_eqscale_dy_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format_math); + data_ptrs.push_back(bn1_coeff_a->data()); + uids.push_back(uid); + + auto bn1_eqscale_x_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format_math); + data_ptrs.push_back(bn1_coeff_b->data()); + uids.push_back(uid); + + auto bn1_eqbias_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format_math); + data_ptrs.push_back(bn1_coeff_c->data()); + uids.push_back(uid); + + // dBN2 outputs + auto bn2_dgamma_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format_math); + if (fuse_dual) { + data_ptrs.push_back(grad_bn2_gamma->data()); + uids.push_back(uid); + } + auto bn2_dbeta_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format_math); + if (fuse_dual) { + data_ptrs.push_back(grad_bn2_beta->data()); + uids.push_back(uid); + } + auto bn2_eqscale_dy_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format_math); + if (fuse_dual) { + data_ptrs.push_back(bn2_coeff_a->data()); + uids.push_back(uid); + } + auto bn2_eqscale_x_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format_math); + if (fuse_dual) { + data_ptrs.push_back(bn2_coeff_b->data()); + uids.push_back(uid); + } + auto bn2_eqbias_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format_math); + if (fuse_dual) { + data_ptrs.push_back(bn2_coeff_c->data()); + uids.push_back(uid); + } + + // build ops + std::vector ops; + // make dgrad op + std::vector stride_int64 = helper::GetInt64Array(strides); + std::vector dilation_int64 = helper::GetInt64Array(dilations); + int64_t data_dim = pre_padding.size(); + auto conv_desc = cudnn_frontend::ConvDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setSpatialDimCount(data_dim) + .setSpatialStride(data_dim, stride_int64.data()) + .setPrePadding(data_dim, pre_padding.data()) + .setPostPadding(data_dim, post_padding.data()) + .setDilation(data_dim, dilation_int64.data()) + .build(); + VLOG(6) << conv_desc.describe(); + + auto dgrad_op = + cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) + .setdyDesc(dy_desc) + .setwDesc(w_desc) + .setdxDesc(dx_dgrad_desc) + .setcDesc(conv_desc) + .setAlpha(1.0f) + .setBeta(0.0f) + .build(); + VLOG(6) << dgrad_op.describe(); + ops.push_back(&dgrad_op); + + cudnn_frontend::Tensor* p_drelu_input_desc = &dx_dgrad_desc; + auto add0_op = helper::MakePointwiseOp(CUDNN_POINTWISE_ADD, + compute_dtype, + dx_dgrad_desc, + dy_branch_desc, + after_add0); + if (fuse_add) { + ops.push_back(&add0_op); + p_drelu_input_desc = &after_add0; + } + // make pointwise nodes + auto add1_op = helper::MakePointwiseOp(CUDNN_POINTWISE_ADD, + compute_dtype, + bn1_x_desc, + bn1_mean_desc, + after_add1, + 1.0, + -1.0); + ops.push_back(&add1_op); + + auto mul1_op = helper::MakePointwiseOp(CUDNN_POINTWISE_MUL, + compute_dtype, + after_add1, + bn1_inv_std_desc, + after_mul1); + ops.push_back(&mul1_op); + + auto mul2_op = helper::MakePointwiseOp(CUDNN_POINTWISE_MUL, + compute_dtype, + after_mul1, + bn1_scale_desc, + after_mul2); + ops.push_back(&mul2_op); + + auto add2_op = helper::MakePointwiseOp(CUDNN_POINTWISE_ADD, + compute_dtype, + after_mul2, + bn1_bias_desc, + after_add2); + ops.push_back(&add2_op); + + auto dual_add1_op = helper::MakePointwiseOp(CUDNN_POINTWISE_ADD, + compute_dtype, + bn2_x_desc, + bn2_mean_desc, + after_dual_add1, + 1.0, + -1.0); + if (fuse_dual) ops.push_back(&dual_add1_op); + + auto dual_mul1_op = helper::MakePointwiseOp(CUDNN_POINTWISE_MUL, + compute_dtype, + after_dual_add1, + bn2_inv_std_desc, + after_dual_mul1); + if (fuse_dual) ops.push_back(&dual_mul1_op); + + auto dual_mul2_op = helper::MakePointwiseOp(CUDNN_POINTWISE_MUL, + compute_dtype, + after_dual_mul1, + bn2_scale_desc, + after_dual_mul2); + if (fuse_dual) ops.push_back(&dual_mul2_op); + + auto dual_add2_op = helper::MakePointwiseOp(CUDNN_POINTWISE_ADD, + compute_dtype, + after_dual_mul2, + bn2_bias_desc, + after_dual_add2); + if (fuse_dual) ops.push_back(&dual_add2_op); + + cudnn_frontend::Tensor* p_bmask_input_desc = + fuse_shortcut ? &relu_x_desc : &after_dual_add2; + auto bmask_add_op = helper::MakePointwiseOp(CUDNN_POINTWISE_ADD, + compute_dtype, + after_add2, + *p_bmask_input_desc, + final_bitmask_desc); + if (fuse_shortcut || fuse_dual) ops.push_back(&bmask_add_op); + + cudnn_frontend::Tensor* p_drelu_bmask_desc = + (fuse_shortcut || fuse_dual) ? &final_bitmask_desc : &after_add2; + auto drelu_op = MakeDreluOp( + compute_dtype, *p_drelu_input_desc, *p_drelu_bmask_desc, dx_desc); + ops.push_back(&drelu_op); + + auto bn_bwd_weight_op = MakeBnbwdweightOp(compute_dtype, + bn1_x_desc, + bn1_mean_desc, + bn1_inv_std_desc, + bn1_scale_desc, + dx_desc, + bn1_dbeta_desc, + bn1_dgamma_desc, + bn1_eqscale_dy_desc, + bn1_eqscale_x_desc, + bn1_eqbias_desc); + ops.push_back(&bn_bwd_weight_op); + + auto dual_bn_bwd_weight_op = MakeBnbwdweightOp(compute_dtype, + bn2_x_desc, + bn2_mean_desc, + bn2_inv_std_desc, + bn2_scale_desc, + dx_desc, + bn2_dbeta_desc, + bn2_dgamma_desc, + bn2_eqscale_dy_desc, + bn2_eqscale_x_desc, + bn2_eqbias_desc); + if (fuse_dual) ops.push_back(&dual_bn_bwd_weight_op); + + // build op graph + auto op_graph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle) + .setOperationGraph(ops.size(), ops.data()) + .build(); + VLOG(6) << op_graph.describe(); + + cudnn_frontend::feature_vector_t feature_vector; + phi::autotune::BuildFeatureVector(&feature_vector, + dim_x, + dim_filt, + fuse_shortcut, + fuse_dual, + fuse_add, + strides, + dilations, + pre_padding, + post_padding); + + if (plan_cache.FindPlan(feature_vector, handle)) { + const cudnn_frontend::ExecutionPlan* cached_plan = nullptr; + int64_t workspace_size = 0; + plan_cache.GetPlanAndWorkspaceSize( + feature_vector, &cached_plan, &workspace_size, handle); + helper::ExecutePlan(handle, + &workspace_handle, + &data_ptrs, + &uids, + cached_plan->get_raw_desc(), + workspace_size); + return; + } + + auto plans = helper::FindExecutionPlans(&op_graph, + exhaustive_search, + deterministic, + &data_ptrs, + &uids, + handle, + &workspace_handle); + + helper::ExecutePlansAndCache(handle, + &workspace_handle, + &data_ptrs, + &uids, + &plans, + exhaustive_search, + feature_vector, + &plan_cache); +} + +template +void _DbnApplyImpl(const Context& dev_ctx, + const DenseTensor* dY_tensor, + const DenseTensor* X_tensor, + const DenseTensor* A_tensor, + const DenseTensor* B_tensor, + const DenseTensor* C_tensor, + const DenseTensor* X_dual_tensor, + const DenseTensor* A_dual_tensor, + const DenseTensor* B_dual_tensor, + const DenseTensor* C_dual_tensor, + bool fuse_dual, + bool exhaustive_search, + bool deterministic, + DenseTensor* dX_tensor, + DenseTensor* dX_dual_tensor) { + auto& plan_cache = phi::autotune::AutoTuneCache::Instance().GetConvV8( + phi::autotune::AlgorithmType::kDbnApply); + // get handles + auto handle = dev_ctx.cudnn_handle(); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + cudnnTensorFormat_t layout_format = CUDNN_TENSOR_NHWC; + auto tensor_format = phi::backends::gpu::ToCudnnDataType(dY_tensor->dtype()); + auto tensor_format_math = CUDNN_DATA_FLOAT; + auto compute_dtype = CUDNN_DATA_FLOAT; + // build tensor descriptors + auto dim_x = phi::backends::gpu::TransformDimOrder( + phi::vectorize(X_tensor->dims())); + std::vector dim_a(dim_x.size(), 1); + dim_a[1] = dim_x[1]; // [1, C, 1, 1] + + std::vector data_ptrs; + std::vector uids; + int64_t uid = 100; + + // inputs + auto dY_desc = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format, false); + data_ptrs.push_back(const_cast(dY_tensor->data())); + uids.push_back(uid); + + auto X_desc = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format, false); + data_ptrs.push_back(const_cast(X_tensor->data())); + uids.push_back(uid); + + auto A_desc = helper::GetGeneralTensorDescriptor( + dim_a, layout_format, ++uid, 16, tensor_format_math, false); + data_ptrs.push_back(const_cast(A_tensor->data())); + uids.push_back(uid); + + auto B_desc = helper::GetGeneralTensorDescriptor( + dim_a, layout_format, ++uid, 16, tensor_format_math, false); + data_ptrs.push_back(const_cast(B_tensor->data())); + uids.push_back(uid); + + auto C_desc = helper::GetGeneralTensorDescriptor( + dim_a, layout_format, ++uid, 16, tensor_format_math, false); + data_ptrs.push_back(const_cast(C_tensor->data())); + uids.push_back(uid); + + auto X_dual_desc = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format, false); + if (fuse_dual) { + data_ptrs.push_back(const_cast(X_dual_tensor->data())); + uids.push_back(uid); + } + + auto A_dual_desc = helper::GetGeneralTensorDescriptor( + dim_a, layout_format, ++uid, 16, tensor_format_math, false); + if (fuse_dual) { + data_ptrs.push_back(const_cast(A_dual_tensor->data())); + uids.push_back(uid); + } + + auto B_dual_desc = helper::GetGeneralTensorDescriptor( + dim_a, layout_format, ++uid, 16, tensor_format_math, false); + if (fuse_dual) { + data_ptrs.push_back(const_cast(B_dual_tensor->data())); + uids.push_back(uid); + } + + auto C_dual_desc = helper::GetGeneralTensorDescriptor( + dim_a, layout_format, ++uid, 16, tensor_format_math, false); + if (fuse_dual) { + data_ptrs.push_back(const_cast(C_dual_tensor->data())); + uids.push_back(uid); + } + + // virtual outputs + auto after_mul0 = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + auto after_mul1 = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + auto after_add = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + + auto after_mul0_dual = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + auto after_mul1_dual = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + auto after_add_dual = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + // outputs + auto dX_desc = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format, false); + data_ptrs.push_back(dX_tensor->data()); + uids.push_back(uid); + + auto dX_dual_desc = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format, false); + if (fuse_dual) { + data_ptrs.push_back(dX_dual_tensor->data()); + uids.push_back(uid); + } + + // op desc + auto mul0_op = helper::MakePointwiseOp( + CUDNN_POINTWISE_MUL, compute_dtype, A_desc, dY_desc, after_mul0); + + auto mul1_op = helper::MakePointwiseOp( + CUDNN_POINTWISE_MUL, compute_dtype, B_desc, X_desc, after_mul1); + + auto add0_op = helper::MakePointwiseOp( + CUDNN_POINTWISE_ADD, compute_dtype, after_mul0, after_mul1, after_add); + + auto add1_op = helper::MakePointwiseOp( + CUDNN_POINTWISE_ADD, compute_dtype, after_add, C_desc, dX_desc); + + auto mul0_op_dual = helper::MakePointwiseOp(CUDNN_POINTWISE_MUL, + compute_dtype, + A_dual_desc, + dY_desc, + after_mul0_dual); + + auto mul1_op_dual = helper::MakePointwiseOp(CUDNN_POINTWISE_MUL, + compute_dtype, + B_dual_desc, + X_dual_desc, + after_mul1_dual); + + auto add0_op_dual = helper::MakePointwiseOp(CUDNN_POINTWISE_ADD, + compute_dtype, + after_mul0_dual, + after_mul1_dual, + after_add_dual); + + auto add1_op_dual = helper::MakePointwiseOp(CUDNN_POINTWISE_ADD, + compute_dtype, + after_add_dual, + C_dual_desc, + dX_dual_desc); + + // build op graph + std::vector ops = { + &mul0_op, &mul1_op, &add0_op, &add1_op}; + if (fuse_dual) { + ops.insert(ops.end(), + {&mul0_op_dual, &mul1_op_dual, &add0_op_dual, &add1_op_dual}); + } + + auto op_graph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle) + .setOperationGraph(ops.size(), ops.data()) + .build(); + VLOG(6) << op_graph.describe(); + + cudnn_frontend::feature_vector_t feature_vector; + phi::autotune::BuildFeatureVector(&feature_vector, dim_x, fuse_dual); + + if (plan_cache.FindPlan(feature_vector, handle)) { + const cudnn_frontend::ExecutionPlan* cached_plan = nullptr; + int64_t workspace_size = 0; + plan_cache.GetPlanAndWorkspaceSize( + feature_vector, &cached_plan, &workspace_size, handle); + helper::ExecutePlan(handle, + &workspace_handle, + &data_ptrs, + &uids, + cached_plan->get_raw_desc(), + workspace_size); + return; + } + + auto plans = helper::FindExecutionPlans(&op_graph, + exhaustive_search, + deterministic, + &data_ptrs, + &uids, + handle, + &workspace_handle); + + helper::ExecutePlansAndCache(handle, + &workspace_handle, + &data_ptrs, + &uids, + &plans, + exhaustive_search, + feature_vector, + &plan_cache); +} + +template +void _BnActWgradImpl(const Context& dev_ctx, + const DenseTensor* conv_input, + const DenseTensor* grad_output, + const DenseTensor* bn_eqscale, + const DenseTensor* bn_eqbias, + bool fuse_bn_act, + const std::vector& strides, + const std::vector& dilations, + const std::vector& pre_padding, + const std::vector& post_padding, + bool exhaustive_search, + bool deterministic, + DenseTensor* dw_tensor) { + auto& plan_cache = phi::autotune::AutoTuneCache::Instance().GetConvV8( + phi::autotune::AlgorithmType::kBnActWgrad); + // get handles + auto handle = dev_ctx.cudnn_handle(); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + // transform filter to NHWC layout + DenseTensor dw_tensor_transformed(dw_tensor->dtype()); + ResizeToChannelLast(dev_ctx, dw_tensor, &dw_tensor_transformed); + // create tensor descriptors + cudnnTensorFormat_t layout_format = CUDNN_TENSOR_NHWC; + auto tensor_format = phi::backends::gpu::ToCudnnDataType(conv_input->dtype()); + auto tensor_format_math = CUDNN_DATA_FLOAT; + auto compute_dtype = CUDNN_DATA_FLOAT; + // create tensor discriptors + auto dim_x = phi::backends::gpu::TransformDimOrder( + phi::vectorize(conv_input->dims())); + auto dim_filt = phi::backends::gpu::TransformDimOrder( + phi::vectorize(dw_tensor_transformed.dims())); + auto dim_y = phi::backends::gpu::TransformDimOrder( + phi::vectorize(grad_output->dims())); + std::vector dim_scale(dim_x.size(), 1); + dim_scale[1] = dim_x[1]; // [1, C, 1, 1] + + std::vector data_ptrs; + std::vector uids; + int64_t uid = 100; + + // inputs + auto x_desc = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format); + data_ptrs.push_back(const_cast(conv_input->data())); + uids.push_back(uid); + + auto dy_desc = helper::GetGeneralTensorDescriptor( + dim_y, layout_format, ++uid, 16, tensor_format); + data_ptrs.push_back(const_cast(grad_output->data())); + uids.push_back(uid); + + auto scale_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format); + if (fuse_bn_act) { + data_ptrs.push_back(const_cast(bn_eqscale->data())); + uids.push_back(uid); + } + auto bias_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format); + if (fuse_bn_act) { + data_ptrs.push_back(const_cast(bn_eqbias->data())); + uids.push_back(uid); + } + + // outputs + auto dw_desc = helper::GetGeneralTensorDescriptor( + dim_filt, layout_format, ++uid, 16, tensor_format); + data_ptrs.push_back(dw_tensor_transformed.data()); + uids.push_back(uid); + + // virtual outputs + auto after_scale = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + + auto after_bias = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + + auto after_relu = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + + // build ops + std::vector stride_int64 = helper::GetInt64Array(strides); + std::vector dilation_int64 = helper::GetInt64Array(dilations); + int64_t data_dim = pre_padding.size(); + auto conv_desc = cudnn_frontend::ConvDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setSpatialDimCount(data_dim) + .setSpatialStride(data_dim, stride_int64.data()) + .setPrePadding(data_dim, pre_padding.data()) + .setPostPadding(data_dim, post_padding.data()) + .setDilation(data_dim, dilation_int64.data()) + .build(); + VLOG(6) << conv_desc.describe(); + + cudnn_frontend::Tensor* p_wgrad_x_desc = fuse_bn_act ? &after_relu : &x_desc; + auto wgrad_op = + cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR) + .setdyDesc(dy_desc) + .setdwDesc(dw_desc) + .setxDesc(*p_wgrad_x_desc) + .setcDesc(conv_desc) + .setAlpha(1.0f) + .setBeta(0.0f) + .build(); + VLOG(6) << wgrad_op.describe(); + + auto scale_op = helper::MakePointwiseOp( + CUDNN_POINTWISE_MUL, compute_dtype, x_desc, scale_desc, after_scale); + + auto bias_op = helper::MakePointwiseOp( + CUDNN_POINTWISE_ADD, compute_dtype, after_scale, bias_desc, after_bias); + + auto relu_desc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_FWD) + .setComputeType(compute_dtype) + .build(); + + auto relu_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(after_bias) + .setyDesc(after_relu) + .setpwDesc(relu_desc) + .build(); + + // build op graph + std::vector ops; + if (fuse_bn_act) + ops = {&wgrad_op, &scale_op, &bias_op, &relu_op}; + else + ops = {&wgrad_op}; + + auto op_graph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle) + .setOperationGraph(ops.size(), ops.data()) + .build(); + VLOG(6) << op_graph.describe(); + + cudnn_frontend::feature_vector_t feature_vector; + phi::autotune::BuildFeatureVector(&feature_vector, + dim_x, + dim_filt, + fuse_bn_act, + strides, + dilations, + pre_padding, + post_padding); + + if (plan_cache.FindPlan(feature_vector, handle)) { + const cudnn_frontend::ExecutionPlan* cached_plan = nullptr; + int64_t workspace_size = 0; + plan_cache.GetPlanAndWorkspaceSize( + feature_vector, &cached_plan, &workspace_size, handle); + helper::ExecutePlan(handle, + &workspace_handle, + &data_ptrs, + &uids, + cached_plan->get_raw_desc(), + workspace_size); + TransToChannelFirst(dev_ctx, &dw_tensor_transformed, dw_tensor); + return; + } + + auto plans = helper::FindExecutionPlans(&op_graph, + exhaustive_search, + deterministic, + &data_ptrs, + &uids, + handle, + &workspace_handle); + + helper::ExecutePlansAndCache(handle, + &workspace_handle, + &data_ptrs, + &uids, + &plans, + exhaustive_search, + feature_vector, + &plan_cache); + // transfer back to NCWH + TransToChannelFirst(dev_ctx, &dw_tensor_transformed, dw_tensor); +} + +/* +his op includes 3 kernels: +1. FusedDgradDreluBnBwdWeight +Ref: +https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#dgraddrelubnbwdweight +It fuses the backward of the following patterns: +(1) BN -> ReLU -> Conv + +(2) BN1 -> Add -> ReLU -> Conv + BN2 ----^ |---> (optional branch) + +(3) BN -> Add -> ReLU -> Conv + (shortcut)--^ |---> (optional branch) + +The meaning of three attributes are: +- fuse_shortcut: Whether a shortcut is added in the forward pattern, as in (2). +- fuse_dual: Whether two BN outputs are added in the forward pattern, as in (3). +- fuse_add: Whether ReLU output is used in a forward node other than Conv, + marked in (2)(3) as (optional branch). In this case, the gradient of the +branch should be added to the output dgrad. + +2. DbnApply +Ref: +https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#dualdbnapply +By default it performs the following: +dX = A* dY + B * X + C +With fuse_dual: +dX = A * dY + B * X + C +dX_dual = A_dual * dY + B_dual * X_dual + C_dual + +3. ConvBnWgrad +Ref: +https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#convbnwgrad +It fuses the following pattern: + +X = ReLU(BN_X * Scale + Bias) +dW = Wgrad(dY, X) + +Requirements: +- All tensors should have layout NHWC, except that weight, grad_weight are NCHW. +- bn_dgamma, bn_dbeta, bn_mean, bn_inv_std, bn_scale, bn_bias should have shape +[C] and dtype FP32. +- bn1_eqscale, bn1_eqbias should shape [C] and dtype FP16. +- bn_input, grad_input, residual_input, conv_input should have input shape of +Conv and dtype FP16. +*/ +template +void FusedDconvDreluDbnKernel( + const Context& dev_ctx, + const DenseTensor& grad_output, + const DenseTensor& weight, + const paddle::optional& grad_output_add, + const paddle::optional& residual_input, + const paddle::optional& bn1_eqscale, + const paddle::optional& bn1_eqbias, + const paddle::optional& conv_input, + const DenseTensor& bn1_mean, + const DenseTensor& bn1_inv_std, + const DenseTensor& bn1_gamma, + const DenseTensor& bn1_beta, + const DenseTensor& bn1_input, + const paddle::optional& bn2_mean, + const paddle::optional& bn2_inv_std, + const paddle::optional& bn2_gamma, + const paddle::optional& bn2_beta, + const paddle::optional& bn2_input, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const std::string& padding_algorithm, + int groups, + const std::string& data_format, + bool fuse_shortcut, + bool fuse_dual, + bool fuse_add, + bool exhaustive_search, + DenseTensor* grad_weight, + DenseTensor* grad_bn1_input, + DenseTensor* grad_bn1_gamma, + DenseTensor* grad_bn1_beta, + DenseTensor* grad_bn2_input, + DenseTensor* grad_bn2_gamma, + DenseTensor* grad_bn2_beta) { + PADDLE_ENFORCE_GE(dev_ctx.GetComputeCapability(), + 80, + phi::errors::PreconditionNotMet( + "This op only supports Ampere and later devices, " + "but got compute capability: %d.", + dev_ctx.GetComputeCapability())); + auto cudnn_version = phi::backends::gpu::DnnVersion(); + PADDLE_ENFORCE_GE(cudnn_version, + 8900, + phi::errors::PreconditionNotMet( + "This op only supports CUDNN version >= 8900, " + "but got %d.", + cudnn_version)); + // Attributes + bool fuse_wgrad_bn_act = !(fuse_shortcut || fuse_dual); + exhaustive_search = exhaustive_search || FLAGS_cudnn_exhaustive_search; + bool deterministic = FLAGS_cudnn_deterministic; + PADDLE_ENFORCE_EQ(exhaustive_search && deterministic, + false, + phi::errors::InvalidArgument( + "Cann't set exhaustive_search True and " + "FLAGS_cudnn_deterministic True at same time.")); + // update padding and dilation + std::vector paddings_vec = paddings; + std::vector dilations_vec = dilations; + auto in_dims = bn1_input.dims(); + auto filter_dims = weight.dims(); + DDim in_data_dims = slice_ddim(in_dims, 1, in_dims.size() - 1); + DDim filter_data_dims = slice_ddim( + filter_dims, 2, filter_dims.size()); // weight is in NCHW format + std::vector ksize = phi::vectorize(filter_data_dims); + phi::UpdatePaddingAndDilation(&paddings_vec, + &dilations_vec, + padding_algorithm, + in_data_dims, + strides, + ksize); + int data_dim = strides.size(); // 2d or 3d + std::vector pre_padding(data_dim, 0); + std::vector post_padding(data_dim, 0); + for (size_t i = 0; i < data_dim; ++i) { + pre_padding[i] = static_cast(paddings_vec[2 * i]); + post_padding[i] = static_cast(paddings_vec[2 * i + 1]); + } + // alloc output variables + dev_ctx.template Alloc(grad_weight); + dev_ctx.template Alloc(grad_bn1_input); + dev_ctx.template Alloc(grad_bn1_gamma); + dev_ctx.template Alloc(grad_bn1_beta); + if (fuse_shortcut || fuse_dual) { + dev_ctx.template Alloc(grad_bn2_input); + } + if (fuse_dual) { + dev_ctx.template Alloc(grad_bn2_gamma); + dev_ctx.template Alloc(grad_bn2_beta); + } + // intermediate buffers + DenseTensor grad_conv_input(bn1_input.dtype()); + if (fuse_shortcut) { + grad_conv_input.ShareDataWith(*grad_bn2_input); + } else { + grad_conv_input.Resize(bn1_input.dims()); + dev_ctx.template Alloc(&grad_conv_input); + } + auto bn_dtype = bn1_mean.dtype(); + auto bn_dims = bn1_mean.dims(); + DenseTensor bn1_coeff_a(bn_dtype); + DenseTensor bn1_coeff_b(bn_dtype); + DenseTensor bn1_coeff_c(bn_dtype); + DenseTensor bn2_coeff_a(bn_dtype); + DenseTensor bn2_coeff_b(bn_dtype); + DenseTensor bn2_coeff_c(bn_dtype); + bn1_coeff_a.Resize(bn_dims); + dev_ctx.template Alloc(&bn1_coeff_a); + bn1_coeff_b.Resize(bn_dims); + dev_ctx.template Alloc(&bn1_coeff_b); + bn1_coeff_c.Resize(bn_dims); + dev_ctx.template Alloc(&bn1_coeff_c); + if (fuse_dual) { + bn2_coeff_a.Resize(bn_dims); + dev_ctx.template Alloc(&bn2_coeff_a); + bn2_coeff_b.Resize(bn_dims); + dev_ctx.template Alloc(&bn2_coeff_b); + bn2_coeff_c.Resize(bn_dims); + dev_ctx.template Alloc(&bn2_coeff_c); + } + // Step 1: DgradDreluBnBwdWeight + _DgradDreluBnBwdWeightImpl(dev_ctx, + &grad_output, + &weight, + &bn1_mean, + &bn1_inv_std, + &bn1_gamma, + &bn1_beta, + &bn1_input, + paddle::get_pointer(residual_input), + paddle::get_pointer(bn2_mean), + paddle::get_pointer(bn2_inv_std), + paddle::get_pointer(bn2_gamma), + paddle::get_pointer(bn2_beta), + paddle::get_pointer(bn2_input), + paddle::get_pointer(grad_output_add), + fuse_shortcut, + fuse_dual, + fuse_add, + strides, + dilations_vec, + pre_padding, + post_padding, + exhaustive_search, + deterministic, + &grad_conv_input, + grad_bn1_gamma, + grad_bn1_beta, + &bn1_coeff_a, + &bn1_coeff_b, + &bn1_coeff_c, + grad_bn2_gamma, + grad_bn2_beta, + &bn2_coeff_a, + &bn2_coeff_b, + &bn2_coeff_c); + // Step 2: dBN Apply + _DbnApplyImpl(dev_ctx, + &grad_conv_input, + &bn1_input, + &bn1_coeff_a, + &bn1_coeff_b, + &bn1_coeff_c, + paddle::get_pointer(bn2_input), + &bn2_coeff_a, + &bn2_coeff_b, + &bn2_coeff_c, + fuse_dual, + exhaustive_search, + deterministic, + grad_bn1_input, + grad_bn2_input); + + // Step 3: Wgrad + const DenseTensor* wgrad_conv_input = + fuse_wgrad_bn_act ? &bn1_input : paddle::get_pointer(conv_input); + _BnActWgradImpl(dev_ctx, + wgrad_conv_input, + &grad_output, + paddle::get_pointer(bn1_eqscale), + paddle::get_pointer(bn1_eqbias), + fuse_wgrad_bn_act, + strides, + dilations, + pre_padding, + post_padding, + exhaustive_search, + deterministic, + grad_weight); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_dconv_drelu_dbn, + GPU, + ALL_LAYOUT, + phi::fusion::FusedDconvDreluDbnKernel, + phi::dtype::float16) { + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32); +} diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 6fab419ac86df..fa7bf95c2b7ae 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -507,6 +507,7 @@ endif() if(NOT WITH_CUDNN_FRONTEND) list(REMOVE_ITEM TEST_OPS test_fused_scale_bias_relu_conv_bn_op) list(REMOVE_ITEM TEST_OPS test_fused_scale_bias_add_relu_op) + list(REMOVE_ITEM TEST_OPS test_fused_dconv_drelu_dbn_op) endif() # Some ops need to check results when gc is enabled diff --git a/test/legacy_test/test_fused_dconv_drelu_dbn_op.py b/test/legacy_test/test_fused_dconv_drelu_dbn_op.py new file mode 100644 index 0000000000000..8862404b6acb9 --- /dev/null +++ b/test/legacy_test/test_fused_dconv_drelu_dbn_op.py @@ -0,0 +1,478 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +from op_test import OpTest, skip_check_grad_ci + +import paddle +from paddle import nn +from paddle.base import core, framework +from paddle.base.executor import Executor + + +def skip_unit_test(): + return ( + not paddle.is_compiled_with_cuda() + or paddle.device.cuda.get_device_capability()[0] < 8 + ) + + +skip_msg = "only support with cuda and Ampere or later devices" + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(skip_unit_test(), skip_msg) +class TestFusedDconvDreluDbnOp(OpTest): + def setUp(self): + self.__class__.op_type = "fused_dconv_drelu_dbn" + self.dtype = np.float16 + self.math_type = np.float32 + self.outputs = None + self.padding_algorithm = "EXIPLICIT" + self.data_format = "NHWC" + self.groups = 1 + self.rtol = 1e-5 + self.atol = 2e-2 + + self.init_dilation() + self.init_test_case() + self.init_paddings() + self.init_attr() + + self.X1 = np.random.random(self.input_size).astype(self.dtype) - 0.5 + self.X2 = np.random.random(self.input_size).astype(self.dtype) - 0.5 + self.dY1 = np.random.random(self.output_size).astype(self.dtype) - 0.5 + self.dY2 = np.random.random(self.input_size).astype(self.dtype) - 0.5 + + paddle.disable_static() + paddle.seed(0) + paddle.set_default_dtype(self.dtype) + self.bn1 = nn.BatchNorm( + self.input_size[-1], + momentum=self.momentum, + epsilon=self.epsilon, + data_layout=self.data_format, + ) + self.bn2 = nn.BatchNorm( + self.input_size[-1], + momentum=self.momentum, + epsilon=self.epsilon, + data_layout=self.data_format, + ) + self.relu = nn.ReLU() + self.conv = nn.Conv2D( + in_channels=self.input_size[-1], + out_channels=self.filter_size[0], + kernel_size=self.filter_size[-1], + stride=self.stride, + padding=self.pad, + groups=1, + bias_attr=False, + data_format=self.data_format, + ) + + self.w_input = self.conv.weight.numpy().astype(self.dtype) + self.bn1_scale_input = self.bn1.weight.numpy() + self.bn1_bias_input = self.bn1.bias.numpy() + self.bn1_running_mean_input = self.bn1._mean.numpy() + self.bn1_running_var_input = self.bn1._variance.numpy() + + self.bn2_scale_input = self.bn2.weight.numpy() + self.bn2_bias_input = self.bn2.bias.numpy() + self.bn2_running_mean_input = self.bn2._mean.numpy() + self.bn2_running_var_input = self.bn2._variance.numpy() + + def has_cuda(self): + return core.is_compiled_with_cuda() + + def get_feed_map(self, inputs, place): + feed_map = {} + for name in inputs: + tensor = core.LoDTensor() + tensor.set(inputs[name], place) + feed_map[name] = tensor + return feed_map + + def calc_normal_pass(self): + """ + Given dY, get dX for the following pattern: + (1) X1 -> BN1 -> ReLU -> Conv -> Y + (2) with fuse_dual = True: + X1 -> BN1 -> Add -> ReLU -> Conv -> Y + X2 -> BN2 ---/ + (3) with fuse_shortcut = True: + X1 -> BN1 -> Add -> ReLU -> Conv -> Y + X2 ----------/ + (4) with fuse_add = True: + /-------> Y2 + X1 -> BN1 -> ReLU -> Conv -> Y + fuse_add is also compatible with case (2) and (3) + """ + # inputs + x1_tensor = paddle.to_tensor(self.X1, stop_gradient=False) + x2_tensor = paddle.to_tensor(self.X2, stop_gradient=False) + dy1_tensor = paddle.to_tensor(self.dY1, stop_gradient=False) + dy2_tensor = paddle.to_tensor(self.dY2, stop_gradient=False) + + if self.fuse_dual: + before_relu = self.bn1(x1_tensor) + self.bn2(x2_tensor) + elif self.fuse_shortcut: + before_relu = self.bn1(x1_tensor) + x2_tensor + else: + before_relu = self.bn1(x1_tensor) + + after_relu = self.relu(before_relu) + y1_tensor = self.conv(after_relu) + y2_tensor = after_relu * 1 + + if self.fuse_add: + paddle.autograd.backward( + [y1_tensor, y2_tensor], [dy1_tensor, dy2_tensor], True + ) + else: + paddle.autograd.backward([y1_tensor], [dy1_tensor], True) + + self.conv_x = after_relu.numpy() + # ['dW', 'dX1', "BN1_dGamma", "BN1_dBeta"] + outputs = [ + self.conv.weight.grad.numpy(), + x1_tensor.grad.numpy(), + self.bn1.weight.grad.numpy(), + self.bn1.bias.grad.numpy(), + ] + if self.fuse_dual or self.fuse_shortcut: + # ['dX2'] + outputs.append(x2_tensor.grad.numpy()) + if self.fuse_dual: + # ['BN2_dGamma', 'BN1_dBeta'] + outputs += [ + self.bn2.weight.grad.numpy(), + self.bn2.bias.grad.numpy(), + ] + return outputs + + def _calc_mean_invstd( + self, + input, + bn_scale_np, + bn_bias_np, + ): + input = input.astype(self.math_type).reshape((-1, input.shape[-1])) + sample_mean = input.mean(axis=0) + sample_var = input.var(axis=0) + sample_invstd = 1 / np.sqrt(sample_var + self.epsilon) + sample_eqscale = bn_scale_np * sample_invstd + sample_eqbias = -bn_scale_np * sample_invstd * sample_mean + bn_bias_np + return ( + sample_mean, + sample_invstd, + sample_eqscale.astype(self.dtype), + sample_eqbias.astype(self.dtype), + ) + + def calc_mean_invstd(self, place): + ( + self.bn1_saved_mean, + self.bn1_saved_invstd, + self.bn1_eqscale, + self.bn1_eqbias, + ) = self._calc_mean_invstd( + self.X1, + self.bn1_scale_input, + self.bn1_bias_input, + ) + + ( + self.bn2_saved_mean, + self.bn2_saved_invstd, + _, + _, + ) = self._calc_mean_invstd( + self.X2, + self.bn2_scale_input, + self.bn2_bias_input, + ) + + def calc_fused_pass(self, place): + self.calc_mean_invstd(place) + + paddle.enable_static() + program = framework.Program() + block = program.global_block() + bn_size = [self.input_size[-1]] + + dY1 = block.create_var( + name="dY1", shape=self.output_size, dtype='float16' + ) + dY2 = block.create_var( + name="dY2", shape=self.input_size, dtype='float16' + ) + W = block.create_var(name="W", shape=self.filter_size, dtype='float16') + dW = block.create_var( + name="dW", shape=self.filter_size, dtype='float16' + ) + X1 = block.create_var(name="X1", shape=self.input_size, dtype='float16') + X2 = block.create_var(name="X2", shape=self.input_size, dtype='float16') + Conv_X = block.create_var( + name="Conv_X", shape=self.input_size, dtype='float16' + ) + BN1_mean = block.create_var( + name="BN1_mean", shape=bn_size, dtype='float32' + ) + BN1_inv_std = block.create_var( + name="BN1_inv_std", shape=bn_size, dtype='float32' + ) + BN1_scale = block.create_var( + name="BN1_scale", shape=bn_size, dtype='float32' + ) + BN1_bias = block.create_var( + name="BN1_bias", shape=bn_size, dtype='float32' + ) + BN1_eqscale = block.create_var( + name="BN1_eqscale", shape=bn_size, dtype='float16' + ) + BN1_eqbias = block.create_var( + name="BN1_eqbias", shape=bn_size, dtype='float16' + ) + BN2_mean = block.create_var( + name="BN2_mean", shape=bn_size, dtype='float32' + ) + BN2_inv_std = block.create_var( + name="BN2_inv_std", shape=bn_size, dtype='float32' + ) + BN2_scale = block.create_var( + name="BN2_scale", shape=bn_size, dtype='float32' + ) + BN2_bias = block.create_var( + name="BN2_bias", shape=bn_size, dtype='float32' + ) + # outputs + BN1_dGamma = block.create_var( + name="BN1_dGamma", shape=bn_size, dtype='float32' + ) + BN1_dBeta = block.create_var( + name="BN1_dBeta", shape=bn_size, dtype='float32' + ) + BN2_dGamma = block.create_var( + name="BN2_dGamma", shape=bn_size, dtype='float32' + ) + BN2_dBeta = block.create_var( + name="BN2_dBeta", shape=bn_size, dtype='float32' + ) + dX1 = block.create_var( + name="dX1", shape=self.input_size, dtype='float16' + ) + dX2 = block.create_var( + name="dX2", shape=self.input_size, dtype='float16' + ) + + op_attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'dilations': self.dilations, + 'padding_algorithm': self.padding_algorithm, + 'groups': self.groups, + 'data_format': self.data_format, + 'fuse_shortcut': self.fuse_shortcut, + 'fuse_dual': self.fuse_dual, + 'fuse_add': self.fuse_add, + 'exhaustive_search': self.exhaustive_search, + } + + op_inputs = { + 'grad_output': dY1, + 'weight': W, + 'bn1_mean': BN1_mean, + 'bn1_inv_std': BN1_inv_std, + 'bn1_gamma': BN1_scale, + 'bn1_beta': BN1_bias, + 'bn1_input': X1, + } + + op_outputs = { + 'grad_bn1_input': dX1, + 'grad_bn1_gamma': BN1_dGamma, + 'grad_bn1_beta': BN1_dBeta, + 'grad_weight': dW, + } + + if self.fuse_add: + op_inputs['grad_output_add'] = dY2 + + if self.fuse_shortcut: + op_inputs['residual_input'] = X2 + op_outputs['grad_bn2_input'] = dX2 + + if self.fuse_dual: + extra_inputs = { + 'bn2_mean': BN2_mean, + 'bn2_inv_std': BN2_inv_std, + 'bn2_gamma': BN2_scale, + 'bn2_beta': BN2_bias, + 'bn2_input': X2, + } + op_inputs.update(extra_inputs) + + extra_outputs = { + 'grad_bn2_input': dX2, + 'grad_bn2_gamma': BN2_dGamma, + 'grad_bn2_beta': BN2_dBeta, + } + op_outputs.update(extra_outputs) + + if self.fuse_shortcut or self.fuse_dual: + op_inputs['conv_input'] = Conv_X + else: + op_inputs['bn1_eqscale'] = BN1_eqscale + op_inputs['bn1_eqbias'] = BN1_eqbias + + op = block.append_op( + type=self.__class__.op_type, + inputs=op_inputs, + outputs=op_outputs, + attrs=op_attrs, + ) + + # execute program + graph_inputs = { + 'dY1': self.dY1, + 'dY2': self.dY2, + 'W': self.w_input, + 'X1': self.X1, + 'X2': self.X2, + 'BN1_mean': self.bn1_saved_mean, + 'BN1_inv_std': self.bn1_saved_invstd, + 'BN1_scale': self.bn1_scale_input, + 'BN1_bias': self.bn1_bias_input, + 'Conv_X': self.conv_x, + 'BN1_eqscale': self.bn1_eqscale, + 'BN1_eqbias': self.bn1_eqbias, + 'BN2_mean': self.bn2_saved_mean, + 'BN2_inv_std': self.bn2_saved_invstd, + 'BN2_scale': self.bn2_scale_input, + 'BN2_bias': self.bn2_bias_input, + } + + feed_map = self.get_feed_map(graph_inputs, place) + fetch_list = ['dW', 'dX1', "BN1_dGamma", "BN1_dBeta"] + if self.fuse_dual or self.fuse_shortcut: + fetch_list += ['dX2'] + if self.fuse_dual: + fetch_list += ['BN2_dGamma', 'BN1_dBeta'] + + executor = Executor(place) + outs = executor.run( + program, feed=feed_map, fetch_list=fetch_list, return_numpy=True + ) + return outs, fetch_list + + def test_check_output(self): + if self.has_cuda(): + place = core.CUDAPlace(0) + outputs_expected = self.calc_normal_pass() + outputs_actual, _ = self.calc_fused_pass(place) + + assert len(outputs_expected) == len(outputs_actual) + for expected, actual in zip(outputs_expected, outputs_actual): + np.testing.assert_allclose( + expected, + actual, + rtol=self.rtol, + atol=self.atol, + ) + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 5, 5, 16] # NHWC + self.output_size = [2, 5, 5, 32] + assert np.mod(self.input_size[-1], self.groups) == 0 + f_c = self.input_size[-1] // self.groups + self.filter_size = [32, f_c, 1, 1] + self.momentum = 0.9 + self.epsilon = 1e-5 + self.accumulation_count = ( + self.input_size[0] * self.input_size[1] * self.input_size[2] + ) + + def init_dilation(self): + self.dilations = [1, 1] + + def init_paddings(self): + self.pad = [0, 0] + self.padding_algorithm = "EXPLICIT" + + def init_attr(self): + self.fuse_add = False + self.fuse_shortcut = False + self.fuse_dual = False + self.exhaustive_search = False + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(skip_unit_test(), skip_msg) +class TestFusedDconvDreluDbnOpShortcut(TestFusedDconvDreluDbnOp): + def init_attr(self): + self.fuse_add = False + self.fuse_shortcut = True + self.fuse_dual = False + self.exhaustive_search = False + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(skip_unit_test(), skip_msg) +class TestFusedDconvDreluDbnOpDual(TestFusedDconvDreluDbnOp): + def init_attr(self): + self.fuse_add = False + self.fuse_shortcut = False + self.fuse_dual = True + self.exhaustive_search = False + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(skip_unit_test(), skip_msg) +class TestFusedDconvDreluDbnOpShortcutAdd(TestFusedDconvDreluDbnOp): + def init_attr(self): + self.fuse_add = True + self.fuse_shortcut = True + self.fuse_dual = False + self.exhaustive_search = False + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(skip_unit_test(), skip_msg) +class TestFusedDconvDreluDbnOpDualAdd(TestFusedDconvDreluDbnOp): + def init_attr(self): + self.fuse_add = True + self.fuse_shortcut = False + self.fuse_dual = True + self.exhaustive_search = False + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(skip_unit_test(), skip_msg) +class TestFusedDconvDreluDbnOpExhaustive(TestFusedDconvDreluDbnOp): + def init_attr(self): + self.fuse_add = False + self.fuse_shortcut = False + self.fuse_dual = False + self.exhaustive_search = True + + +if __name__ == '__main__': + np.random.seed(0) + unittest.main() From 2226cf97a73c1706a59a06ebdd66e6bada513aaa Mon Sep 17 00:00:00 2001 From: "Tian Zheng (Engrg-Hardware 1)" Date: Tue, 14 Nov 2023 08:13:33 +0000 Subject: [PATCH 3/4] Add to CI test --- tools/gpups_test.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/gpups_test.sh b/tools/gpups_test.sh index 142812064928e..f3ea012f69c5e 100644 --- a/tools/gpups_test.sh +++ b/tools/gpups_test.sh @@ -70,6 +70,7 @@ parallel_list="^init_phi_test$|\ ^test_fused_attention_op_api$|\ ^test_fused_attention_op_api_static_build$|\ ^test_fused_attention_op_static_build$|\ +^test_fused_dconv_drelu_dbn_op$|\ ^test_fused_bias_dropout_residual_layer_norm_op$|\ ^test_fused_bias_dropout_residual_layer_norm_op_api$|\ ^test_fused_comm_buffer$|\ @@ -94,8 +95,8 @@ parallel_list="^init_phi_test$|\ ^test_fused_multi_transformer_int8_op$|\ ^test_fused_residual_dropout_bias$|\ ^test_fused_rotary_position_embedding$|\ -^test_fused_scale_bias_relu_conv_bn_op$|\ ^test_fused_scale_bias_add_relu_op$|\ +^test_fused_scale_bias_relu_conv_bn_op$|\ ^test_fused_token_prune_op$|\ ^test_fused_transformer_encoder_layer$|\ ^test_fused_transformer_with_amp_decorator$|\ From 58a38a0d2c1066615c1908facae38db50e982fbf Mon Sep 17 00:00:00 2001 From: "Tian Zheng (Engrg-Hardware 1)" Date: Thu, 30 Nov 2023 03:36:41 +0000 Subject: [PATCH 4/4] Review changes --- paddle/phi/infermeta/fusion.cc | 21 ++++++++++++------- .../gpu/fused_dconv_drelu_dbn_kernel.cu | 2 +- .../test_fused_dconv_drelu_dbn_op.py | 2 +- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 740b5cf24ad3b..cb4de6f93f600 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -2247,18 +2247,25 @@ void FusedDconvDreluDbnInferMeta(const MetaTensor& grad_output, !!bn2_beta, !!bn2_input)); } - grad_weight->set_dims(weight.dims()); - grad_bn1_input->set_dims(bn1_input.dims()); - grad_bn1_gamma->set_dims(bn1_gamma.dims()); - grad_bn1_beta->set_dims(bn1_beta.dims()); + + auto set_unchanged_meta = [](MetaTensor* out, const MetaTensor& input) { + out->set_dims(input.dims()); + out->set_dtype(input.dtype()); + out->set_layout(input.layout()); + }; + + set_unchanged_meta(grad_weight, weight); + set_unchanged_meta(grad_bn1_input, bn1_input); + set_unchanged_meta(grad_bn1_gamma, bn1_gamma); + set_unchanged_meta(grad_bn1_beta, bn1_beta); if (grad_bn2_input) { - grad_bn2_input->set_dims(bn1_input.dims()); + set_unchanged_meta(grad_bn2_input, bn1_input); } if (grad_bn2_gamma) { - grad_bn2_gamma->set_dims(bn1_gamma.dims()); + set_unchanged_meta(grad_bn2_gamma, bn1_gamma); } if (grad_bn2_beta) { - grad_bn2_beta->set_dims(bn1_beta.dims()); + set_unchanged_meta(grad_bn2_beta, bn1_beta); } } diff --git a/paddle/phi/kernels/fusion/gpu/fused_dconv_drelu_dbn_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_dconv_drelu_dbn_kernel.cu index e194ae3f4756b..6b041753c1a38 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_dconv_drelu_dbn_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_dconv_drelu_dbn_kernel.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/test/legacy_test/test_fused_dconv_drelu_dbn_op.py b/test/legacy_test/test_fused_dconv_drelu_dbn_op.py index 8862404b6acb9..d038d8a83caa2 100644 --- a/test/legacy_test/test_fused_dconv_drelu_dbn_op.py +++ b/test/legacy_test/test_fused_dconv_drelu_dbn_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.