Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR】add batch_norm_grad_grad in pir #59373

Merged
merged 4 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1239,6 +1239,147 @@ struct AddNOpTranscriber : public OpTranscriber {
}
};

struct BatchNormDoubleGradOpTranscriber : public OpTranscriber {
pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx,
const OpDesc& op_desc) override {
std::string target_op_name = "pd_op.batch_norm_double_grad";
const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
kangguangli marked this conversation as resolved.
Show resolved Hide resolved
if (!op_info) {
IR_THROW(
"Op batch_norm_grad_grad should have corresponding OpInfo "
"pd_op.batch_norm_double_grad.");
}

return op_info;
}

std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput(
pir::IrContext* ctx,
const OpDesc& op_desc,
const OpOutputInfoList& output_infos) {
OpOutputMapping arg_to_idx;
OpOutputTypeList op_output_types = {};

auto& type_translator = TypeTranslator::instance();
auto& op_normalizer = OpNameNormalizer::instance();

const BlockDesc* block = op_desc.Block();

for (const auto& info : output_infos) {
size_t cur_output_idx = op_output_types.size();
std::string legacy_output_name =
op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);
if (info.name == "scale_grad") {
legacy_output_name = "DScale";
} else if (info.name == "x_grad") {
legacy_output_name = "DX";
} else if (info.name == "grad_out_grad") {
legacy_output_name = "DDY";
}
kangguangli marked this conversation as resolved.
Show resolved Hide resolved
VLOG(10) << "[op:" << op_desc.Type() << "][output]" << info.name << " "
<< legacy_output_name;

// return empty type if this arg is optional and not shown in OpDesc
if (!op_desc.HasOutput(legacy_output_name)) {
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "] optional " << info.name << " :"
<< info.type_name << " " << legacy_output_name;
IR_ENFORCE(info.optional,
"Op %s arg %s should be optional if it can be empty",
op_desc.Type(),
legacy_output_name);
op_output_types.emplace_back(nullptr);
continue;
}

const auto& legacy_output_vars = op_desc.Output(legacy_output_name);
bool is_vector = (info.type_name.find("VectorType") != std::string::npos);

VLOG(10) << "[op:" << op_desc.Type() << "][output]" << info.name << " "
<< legacy_output_name << " " << legacy_output_vars.size() << " "
<< is_vector;

// Specially process TensorArray, this because we cannot distinguish it
// with Vector<DenseTensor> by other conditions but we cannot support it
// like Vector<DenseTensor>
if (legacy_output_vars.size() == 1) {
VarDesc* var = block->FindVarRecursive(legacy_output_vars[0]);
IR_ENFORCE(var != nullptr,
"[op:%s] Output %s should not be null",
op_desc.Type(),
legacy_output_vars[0]);
if (var->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR_ARRAY) {
pir::Type translated_var_type =
type_translator[var->GetType()](ctx, *var);
op_output_types.push_back(translated_var_type);
arg_to_idx[var->Name()] = {cur_output_idx, 0};
continue;
}
}

// if src type is Tensor
if (!is_vector) {
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << info.name << " :"
<< info.type_name << " " << legacy_output_name << " "
<< legacy_output_vars.size();
if (legacy_output_vars.empty()) {
op_output_types.emplace_back(nullptr);
continue;
}

auto& var_name = legacy_output_vars[0];
VarDesc* var = block->FindVarRecursive(var_name);
IR_ENFORCE(var != nullptr,
"[op:%s] Output %s should not be null",
op_desc.Type(),
var_name);
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << info.name
<< " var: " << var_name << " type: " << var->GetType();

pir::Type translated_var_type =
type_translator[var->GetType()](ctx, *var);

arg_to_idx[var_name] = {cur_output_idx, 0};
op_output_types.push_back(translated_var_type);

// if src type is Vector<Tesnor>
} else {
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << info.name << " :"
<< info.type_name << " var: " << legacy_output_name;
std::vector<pir::Type> types;
for (IdxInVector idx_in_vec = 0; idx_in_vec < legacy_output_vars.size();
idx_in_vec++) {
const auto& var_name = legacy_output_vars[idx_in_vec];
if (var_name == kEmptyVarName) {
types.emplace_back(nullptr);
arg_to_idx[var_name] = {cur_output_idx, idx_in_vec};
continue;
}
VarDesc* var = block->FindVarRecursive(var_name);
IR_ENFORCE(var != nullptr,
"[op:%s] Output %s should not be null",
op_desc.Type(),
var_name);
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << info.name
<< " var: " << var_name << " type: " << var->GetType();
pir::Type translated_var_type =
type_translator[var->GetType()](ctx, *var);
types.push_back(translated_var_type);
arg_to_idx[var_name] = {cur_output_idx, idx_in_vec};
}
pir::Type vec_type = pir::VectorType::get(ctx, types);
op_output_types.push_back(vec_type);
}
}
return {op_output_types, arg_to_idx};
}
};

struct TrilAndTriuOpTranscriber : public OpTranscriber {
pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx,
const OpDesc& op_desc) override {
Expand Down Expand Up @@ -2483,6 +2624,7 @@ OpTranslator::OpTranslator() {
general_handler = OpTranscriber();
special_handlers["add_n"] = AddNOpTranscriber();
special_handlers["assign_value"] = AssignValueOpTranscriber();
special_handlers["batch_norm_grad_grad"] = BatchNormDoubleGradOpTranscriber();
special_handlers["range"] = ArangeOpTranscriber();
special_handlers["cast"] = CastOpTranscriber();
special_handlers["data"] = DataOpTranscriber();
Expand Down
8 changes: 7 additions & 1 deletion paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,19 @@
{auc : AUC, stat_pos_out : StatPosOut, stat_neg_out : StatNegOut}

- op : batch_norm
backward : batch_norm_grad
backward : batch_norm_grad, batch_norm_double_grad(batch_norm_grad_grad)
inputs:
x : X
mean : Mean
variance : Variance
scale : Scale
bias : Bias
out_mean : OutMean
out_variance : OutVariance
grad_x_grad : DDX
grad_scale_grad : DDScale
grad_bias_grad : DDBias
grad_out : DY
outputs :
out : Y
mean_out: MeanOut
Expand Down
1 change: 0 additions & 1 deletion test/dygraph_to_static/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ endif()

set(DISABLE_PIR_PT_MODES
test_error
test_gradname_parse
test_seq2seq
test_pylayer
test_save_inference_model
Expand Down