Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[3/4] CUDNNv8 ResNet Fusion: Add fused_donv_drelu_dbn OP #58986

Merged
merged 4 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
'fused_multi_transformer_xpu',
'fused_scale_bias_relu_conv_bn',
'fused_scale_bias_add_relu',
'fused_dconv_drelu_dbn',
'fusion_transpose_flatten_concat',
'skip_layernorm',
'generate_sequence_xpu',
Expand Down Expand Up @@ -118,6 +119,7 @@
'fused_elemwise_add_activation',
'fused_scale_bias_relu_conv_bn',
'fused_scale_bias_add_relu',
'fused_dconv_drelu_dbn',
'recv_v2',
'rnn_',
'seed',
Expand Down
12 changes: 11 additions & 1 deletion paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,16 @@
optional : bias, residual, norm_weight, norm_bias, residual_out
support_dygraph_mode : true

- op : fused_dconv_drelu_dbn
args : (Tensor grad_output, Tensor weight, Tensor grad_output_add, Tensor residual_input, Tensor bn1_eqscale, Tensor bn1_eqbias, Tensor conv_input, Tensor bn1_mean, Tensor bn1_inv_std, Tensor bn1_gamma, Tensor bn1_beta, Tensor bn1_input, Tensor bn2_mean, Tensor bn2_inv_std, Tensor bn2_gamma, Tensor bn2_beta, Tensor bn2_input, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, str data_format, bool fuse_shortcut, bool fuse_dual, bool fuse_add, bool exhaustive_search)
output : Tensor(grad_weight), Tensor(grad_bn1_input), Tensor(grad_bn1_gamma), Tensor(grad_bn1_beta), Tensor(grad_bn2_input), Tensor(grad_bn2_gamma), Tensor(grad_bn2_beta)
optional : grad_output_add, residual_input, bn1_eqscale, bn1_eqbias, conv_input, bn2_mean, bn2_inv_std, bn2_gamma, bn2_beta, bn2_input, grad_bn2_input, grad_bn2_gamma, grad_bn2_beta
infer_meta :
func : FusedDconvDreluDbnInferMeta
kernel :
func : fused_dconv_drelu_dbn
data_type : grad_output

- op : fused_dropout_add
args : (Tensor x, Tensor y, Tensor seed_tensor, Scalar p, bool is_test, str mode, int seed = 0, bool fix_seed = false)
optional : seed_tensor
Expand Down Expand Up @@ -252,7 +262,7 @@
- op : fused_scale_bias_add_relu
args : (Tensor x1, Tensor scale1, Tensor bias1, Tensor x2, Tensor scale2, Tensor bias2, bool fuse_dual, bool exhaustive_search)
optional : scale2, bias2
output : Tensor(y)
output : Tensor(out)
infer_meta :
func : FusedScaleBiasAddReluInferMeta
kernel :
Expand Down
143 changes: 141 additions & 2 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2112,7 +2112,7 @@ void FusedScaleBiasAddReluInferMeta(const MetaTensor& x1,
const MetaTensor& bias2,
bool fuse_dual,
bool exhaustive_search,
MetaTensor* y) {
MetaTensor* out) {
// check optional inputs
if (fuse_dual) {
bool has_scale2 = !!scale2;
Expand All @@ -2127,7 +2127,146 @@ void FusedScaleBiasAddReluInferMeta(const MetaTensor& x1,
fuse_dual));
}
// set output dims
y->set_dims(x1.dims());
out->set_dims(x1.dims());
out->set_dtype(x1.dtype());
out->set_layout(x1.layout());
}

void FusedDconvDreluDbnInferMeta(const MetaTensor& grad_output,
const MetaTensor& weight,
const MetaTensor& grad_output_add,
const MetaTensor& residual_input,
const MetaTensor& bn1_eqscale,
const MetaTensor& bn1_eqbias,
const MetaTensor& conv_input,
const MetaTensor& bn1_mean,
const MetaTensor& bn1_inv_std,
const MetaTensor& bn1_gamma,
const MetaTensor& bn1_beta,
const MetaTensor& bn1_input,
const MetaTensor& bn2_mean,
const MetaTensor& bn2_inv_std,
const MetaTensor& bn2_gamma,
const MetaTensor& bn2_beta,
const MetaTensor& bn2_input,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
const std::string& data_format,
bool fuse_shortcut,
bool fuse_dual,
bool fuse_add,
bool exhaustive_search,
MetaTensor* grad_weight,
MetaTensor* grad_bn1_input,
MetaTensor* grad_bn1_gamma,
MetaTensor* grad_bn1_beta,
MetaTensor* grad_bn2_input,
MetaTensor* grad_bn2_gamma,
MetaTensor* grad_bn2_beta) {
// Check if data format is NHWC
PADDLE_ENFORCE_EQ(
data_format,
"NHWC",
phi::errors::InvalidArgument(
"Operator(FusedScaleBiasReluConvBnstats) only supports data format "
"of "
"channel last (NHWC) now. But recieved: data_format = '%s'.",
data_format));

PADDLE_ENFORCE_EQ(
groups,
1,
phi::errors::InvalidArgument("Expect group to be 1, got %d.", groups));

PADDLE_ENFORCE_EQ(
fuse_shortcut && fuse_dual,
0,
phi::errors::InvalidArgument(
"fuse_shortcut and fuse_dual should not be set at the same time."
"Got fuse_shortcut=%d, fuse_dual=%d.",
fuse_shortcut,
fuse_dual));

if (fuse_add) {
PADDLE_ENFORCE_EQ(
!!grad_output_add,
true,
phi::errors::InvalidArgument(
"grad_output_add must be provided when fuse_add = true."
"Got fuse_add=%d, grad_output_add=%d.",
fuse_add,
!!grad_output_add));
}
if (fuse_shortcut) {
PADDLE_ENFORCE_EQ(
!!residual_input,
true,
phi::errors::InvalidArgument(
"residual_input must be provided when fuse_shortcut = true."
"Got fuse_shortcut =%d, residual_input=%d.",
fuse_shortcut,
!!residual_input));
}
if (fuse_shortcut || fuse_dual) {
PADDLE_ENFORCE_EQ(
!!conv_input,
true,
phi::errors::InvalidArgument(
"conv_input must be provided when either fuse_shortcut "
"or fuse_dual is set to true. Got conv_input=%d, fuse_shortcut=%d, "
"fuse_dual=%d.",
!!conv_input,
fuse_shortcut,
fuse_dual));
} else {
PADDLE_ENFORCE_EQ(
bn1_eqscale && bn1_eqbias,
true,
phi::errors::InvalidArgument(
"bn1_eqscale and bn1_eqbias must be provided when neither "
"fuse_shortcut "
"or fuse_dual is set. Got bn1_eqscale=%d, bn1_eqbias=%d.",
!!bn1_eqscale,
!!bn1_eqbias));
}
if (fuse_dual) {
PADDLE_ENFORCE_EQ(
bn2_mean && bn2_inv_std && bn2_gamma && bn2_beta && bn2_input,
true,
phi::errors::InvalidArgument("bn2_mean, bn2_inv_std, bn2_gamma, "
"bn2_beta, bn2_input must be provided "
"when fuse_dual is set. Got bn2_mean=%d, "
"bn2_inv_std=%d, bn2_gamma=%d, "
"bn2_beta=%d, bn2_input=%d.",
!!bn2_mean,
!!bn2_inv_std,
!!bn2_gamma,
!!bn2_beta,
!!bn2_input));
}

auto set_unchanged_meta = [](MetaTensor* out, const MetaTensor& input) {
out->set_dims(input.dims());
out->set_dtype(input.dtype());
out->set_layout(input.layout());
};

set_unchanged_meta(grad_weight, weight);
set_unchanged_meta(grad_bn1_input, bn1_input);
set_unchanged_meta(grad_bn1_gamma, bn1_gamma);
set_unchanged_meta(grad_bn1_beta, bn1_beta);
if (grad_bn2_input) {
set_unchanged_meta(grad_bn2_input, bn1_input);
}
if (grad_bn2_gamma) {
set_unchanged_meta(grad_bn2_gamma, bn1_gamma);
}
if (grad_bn2_beta) {
set_unchanged_meta(grad_bn2_beta, bn1_beta);
}
}

void SqueezeExcitationInferMeta(const MetaTensor& x,
Expand Down
37 changes: 36 additions & 1 deletion paddle/phi/infermeta/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,42 @@ void FusedScaleBiasAddReluInferMeta(const MetaTensor& x1,
const MetaTensor& bias2,
bool fuse_prologue,
bool exhaustive_search,
MetaTensor* y);
MetaTensor* out);

void FusedDconvDreluDbnInferMeta(const MetaTensor& grad_output,
const MetaTensor& weight,
const MetaTensor& grad_output_add,
const MetaTensor& residual_input,
const MetaTensor& bn1_eqscale,
const MetaTensor& bn1_eqbias,
const MetaTensor& conv_input,
const MetaTensor& bn1_mean,
const MetaTensor& bn1_inv_std,
const MetaTensor& bn1_gamma,
const MetaTensor& bn1_beta,
const MetaTensor& bn1_input,
const MetaTensor& bn2_mean,
const MetaTensor& bn2_inv_std,
const MetaTensor& bn2_gamma,
const MetaTensor& bn2_beta,
const MetaTensor& bn2_input,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
const std::string& data_format,
bool fuse_shortcut,
bool fuse_dual,
bool fuse_add,
bool exhaustive_search,
MetaTensor* grad_weight,
MetaTensor* grad_bn1_input,
MetaTensor* grad_bn1_gamma,
MetaTensor* grad_bn1_beta,
MetaTensor* grad_bn2_input,
MetaTensor* grad_bn2_gamma,
MetaTensor* grad_bn2_beta);

void SqueezeExcitationInferMeta(const MetaTensor& x,
const MetaTensor& filter,
Expand Down
7 changes: 4 additions & 3 deletions paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,10 @@ if(WITH_CUTLASS)
endif()

if(NOT WITH_CUDNN_FRONTEND)
list(REMOVE_ITEM kernel_cu
"fusion/gpu/fused_scale_bias_relu_conv_bn_kernel.cu"
"fusion/gpu/fused_scale_bias_add_relu_kernel.cu")
list(
REMOVE_ITEM kernel_cu "fusion/gpu/fused_scale_bias_relu_conv_bn_kernel.cu"
"fusion/gpu/fused_scale_bias_add_relu_kernel.cu"
"fusion/gpu/fused_dconv_drelu_dbn_kernel.cu")
endif()

set(cc_search_pattern
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/autotune/cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ std::string AlgorithmTypeString(int64_t algo_type) {
} else if (algo_type ==
static_cast<int64_t>(AlgorithmType::kScaleBiasAddRelu)) {
return "scale_bias_add_relu";
} else if (algo_type ==
static_cast<int64_t>(AlgorithmType::kDgradDreluBnBwdWeight)) {
return "dgrad_drelu_bnbwdweight";
} else if (algo_type == static_cast<int64_t>(AlgorithmType::kDbnApply)) {
return "dbn_apply";
} else if (algo_type == static_cast<int64_t>(AlgorithmType::kBnActWgrad)) {
return "bn_act_wgrad";
}
#endif
return std::to_string(algo_type);
Expand Down
5 changes: 4 additions & 1 deletion paddle/phi/kernels/autotune/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ enum class AlgorithmType {
kScaleBiasReluConvBNstats = 13,
kBNFinalize = 14,
kScaleBiasAddRelu = 15,
kAlgorithmCount = 16
kDgradDreluBnBwdWeight = 16,
kDbnApply = 17,
kBnActWgrad = 18,
kAlgorithmCount = 19
#endif
};

Expand Down
Loading