-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NewIR] support elementwise operations with axis!=-1 #55699
Merged
kangguangli
merged 5 commits into
PaddlePaddle:develop
from
kangguangli:fix_elementwise_op
Jul 31, 2023
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
25e4d81
support elementwise with axis!=-1
kangguangli 88ac97a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
kangguangli 598e124
fix coverage ci
kangguangli a5cf728
fix bug
kangguangli 618f369
remove print
kangguangli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -287,7 +287,8 @@ struct OpTranscriber { | |||||
const OpAttributeInfoList& op_attr_infos, | ||||||
const OpDesc& op_desc); | ||||||
|
||||||
virtual void RecordOpResultMapping(TranslationContext* param_map, | ||||||
virtual void RecordOpResultMapping(ir::IrContext* ctx, | ||||||
TranslationContext* param_map, | ||||||
const OpDesc& op_desc, | ||||||
ir::Operation* operation, | ||||||
const OpOutputMapping& arg_to_idx); | ||||||
|
@@ -597,15 +598,16 @@ ir::AttributeMap OpTranscriber::TranslateOpAttribute( | |||||
return attribute_map; | ||||||
} | ||||||
|
||||||
void OpTranscriber::RecordOpResultMapping(TranslationContext* param_map, | ||||||
void OpTranscriber::RecordOpResultMapping(ir::IrContext* ctx, | ||||||
TranslationContext* param_map, | ||||||
const OpDesc& op_desc, | ||||||
ir::Operation* operation, | ||||||
const OpOutputMapping& arg_to_idx) { | ||||||
for (const auto& n : op_desc.Outputs()) { | ||||||
auto& name = n.first; | ||||||
VLOG(10) << "[output recording]" | ||||||
<< "[" << op_desc.Type() << "]" << name; | ||||||
auto& args = n.second; | ||||||
const auto& args = n.second; | ||||||
size_t idx_in_vector = 0; | ||||||
for (const auto& arg_name : args) { | ||||||
if (arg_name == kEmptyVarName) { | ||||||
|
@@ -674,7 +676,7 @@ ir::Operation* OpTranscriber::operator()(ir::IrContext* ctx, | |||||
program->block()->push_back(operation); | ||||||
|
||||||
VLOG(4) << "[general op][" << op_desc.Type() << "] opearation insertion end."; | ||||||
this->RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); | ||||||
this->RecordOpResultMapping(ctx, param_map, op_desc, operation, arg_to_idx); | ||||||
|
||||||
return operation; | ||||||
} | ||||||
|
@@ -843,7 +845,7 @@ struct AssignValueOpTranscriber : public OpTranscriber { | |||||
ir::Operation* operation = ir::Operation::Create( | ||||||
op_inputs, attribute_map, op_output_types, op_info); | ||||||
program->block()->push_back(operation); | ||||||
RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); | ||||||
RecordOpResultMapping(ctx, param_map, op_desc, operation, arg_to_idx); | ||||||
|
||||||
VLOG(10) << "[op assign_value] translation finished"; | ||||||
|
||||||
|
@@ -1260,6 +1262,192 @@ struct ReduceOpTranscriber : public OpTranscriber { | |||||
} | ||||||
}; | ||||||
|
||||||
struct ElementwiseTranscriber : public OpTranscriber { | ||||||
std::vector<ir::OpResult> GenerateOperationInput( | ||||||
ir::IrContext* ctx, | ||||||
TranslationContext* param_map, | ||||||
const OpDesc& op_desc, | ||||||
const std::string& normalized_op_name, | ||||||
const OpInputInfoList& input_infos, | ||||||
ir::Program* program) override { | ||||||
int axis = paddle::get<int>(op_desc.GetAttr("axis")); | ||||||
|
||||||
if (axis == -1) { | ||||||
return OpTranscriber::GenerateOperationInput( | ||||||
ctx, param_map, op_desc, normalized_op_name, input_infos, program); | ||||||
} | ||||||
|
||||||
auto x_names = op_desc.Input("X", true); | ||||||
IR_ENFORCE(x_names.size() == 1, | ||||||
"Expected op[%s]'s input X has only 1 variable, but got %d", | ||||||
op_desc.Type(), | ||||||
x_names.size()); | ||||||
auto x_name = x_names[0]; | ||||||
IR_ENFORCE(param_map->count(x_name) > 0, | ||||||
"Expected op[%s]'s input %s has been parsed", | ||||||
op_desc.Type(), | ||||||
x_name); | ||||||
auto x_defining_info = param_map->at(x_name); | ||||||
if (x_defining_info.generated_by_vector) { | ||||||
InsertSliceOperationForTarget( | ||||||
ctx, param_map, program, x_defining_info, x_name); | ||||||
x_defining_info = param_map->at(x_name); | ||||||
} | ||||||
ir::OpResult x_value = x_defining_info.value; | ||||||
IR_ENFORCE(x_value, | ||||||
"Expected op[%s]'s input %s is not null", | ||||||
op_desc.Type(), | ||||||
x_name); | ||||||
ir::Type x_type = x_value.type(); | ||||||
IR_ENFORCE(x_type.isa<dialect::DenseTensorType>(), | ||||||
"Expected op[%s]'s input %s is DenseTensor but got %s", | ||||||
op_desc.Type(), | ||||||
x_name, | ||||||
x_type); | ||||||
dialect::DenseTensorType x_tensor_type = | ||||||
x_type.dyn_cast<dialect::DenseTensorType>(); | ||||||
std::vector<int64_t> x_shape = phi::vectorize(x_tensor_type.dims()); | ||||||
|
||||||
auto y_names = op_desc.Input("Y", true); | ||||||
IR_ENFORCE(y_names.size() == 1, | ||||||
"Expected op[%s]'s input Y has only 1 variable, but got %d", | ||||||
op_desc.Type(), | ||||||
y_names.size()); | ||||||
auto y_name = y_names[0]; | ||||||
IR_ENFORCE(param_map->count(y_name) > 0, | ||||||
"Expected op[%s]'s input %s has been parsed", | ||||||
op_desc.Type(), | ||||||
y_name); | ||||||
auto y_defining_info = param_map->at(y_name); | ||||||
if (y_defining_info.generated_by_vector) { | ||||||
InsertSliceOperationForTarget( | ||||||
ctx, param_map, program, y_defining_info, y_name); | ||||||
y_defining_info = param_map->at(y_name); | ||||||
} | ||||||
ir::OpResult y_value = y_defining_info.value; | ||||||
IR_ENFORCE(y_value, | ||||||
"Expected op[%s]'s input %s is not null", | ||||||
op_desc.Type(), | ||||||
y_name); | ||||||
ir::Type y_type = y_value.type(); | ||||||
IR_ENFORCE(y_type.isa<dialect::DenseTensorType>(), | ||||||
"Expected op[%s]'s input %s is DenseTensor but got %s", | ||||||
op_desc.Type(), | ||||||
y_name, | ||||||
y_type); | ||||||
dialect::DenseTensorType y_tensor_type = | ||||||
y_type.dyn_cast<dialect::DenseTensorType>(); | ||||||
std::vector<int64_t> y_shape = phi::vectorize(y_tensor_type.dims()); | ||||||
|
||||||
if (axis < 0) { | ||||||
axis += x_shape.size(); | ||||||
} | ||||||
|
||||||
int append_size = x_shape.size() - axis - 1 - y_shape.size(); | ||||||
if (append_size < 0) { // which means x.rank <= y.rank, mostly | ||||||
// x.rank=y.rank | ||||||
return {x_value, y_value}; | ||||||
} | ||||||
IR_ENFORCE(append_size >= 0, | ||||||
"Expected op[%s] have append size >= 0 with axis=%d but got %d", | ||||||
op_desc.Type(), | ||||||
axis, | ||||||
append_size); | ||||||
|
||||||
ir::Builder builder(ctx, program->block()); | ||||||
ir::OpResult y_new; | ||||||
if (std::find(y_shape.begin(), y_shape.end(), -1) == y_shape.end()) { | ||||||
std::vector<int64_t> y_new_shape(y_shape); | ||||||
for (int i = 0; i <= append_size; i++) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not Important,但++i好像更高效些,可以看下二者的区别
Suggested change
|
||||||
y_new_shape.push_back(1); | ||||||
} | ||||||
dialect::Reshape_Op reshape_op = | ||||||
builder.Build<dialect::Reshape_Op>(y_value, y_new_shape); | ||||||
y_new = reshape_op.out(); | ||||||
VLOG(6) << "[" << op_desc.Type() << "] y_shape change from " | ||||||
<< y_tensor_type.dims() << " to " << phi::make_ddim(y_new_shape); | ||||||
} else { | ||||||
auto shape_op = builder.Build<dialect::ShapeOp>(y_value); | ||||||
auto append_shape_op = builder.Build<dialect::FullIntArrayOp>( | ||||||
std::vector<int64_t>(append_size, 1), | ||||||
phi::DataType::INT64, | ||||||
phi::CPUPlace()); | ||||||
auto y_true_shape_op = builder.Build<ir::CombineOp>( | ||||||
std::vector<ir::OpResult>{shape_op.out(), append_shape_op.out()}); | ||||||
auto concat_op = | ||||||
builder.Build<dialect::ConcatOp>(y_true_shape_op.out(), 0); | ||||||
auto y_new_shape = concat_op.out(); | ||||||
auto reshape_op = | ||||||
builder.Build<dialect::Reshape_Op>(y_value, y_new_shape); | ||||||
y_new = reshape_op.out(); | ||||||
} | ||||||
return {x_value, y_new}; | ||||||
} | ||||||
}; | ||||||
|
||||||
struct ElementwiseGradTranscriber : public OpTranscriber { | ||||||
void RecordOpResultMapping(ir::IrContext* ctx, | ||||||
TranslationContext* param_map, | ||||||
const OpDesc& op_desc, | ||||||
ir::Operation* operation, | ||||||
const OpOutputMapping& arg_to_idx) override { | ||||||
OpTranscriber::RecordOpResultMapping( | ||||||
ctx, param_map, op_desc, operation, arg_to_idx); | ||||||
|
||||||
int axis = paddle::get<int>(op_desc.GetAttr("axis")); | ||||||
if (axis == -1) { | ||||||
return; | ||||||
} | ||||||
|
||||||
const auto& y_grad_output = op_desc.Output("Y@GRAD"); | ||||||
if (y_grad_output.size() < 1) { | ||||||
return; | ||||||
} | ||||||
IR_ENFORCE( | ||||||
y_grad_output.size() == 1, | ||||||
"Expected op[%s]'s output Y@GRAD has only 1 variable, but got %d", | ||||||
op_desc.Type(), | ||||||
y_grad_output.size()); | ||||||
const auto& y_grad_var_name = y_grad_output[0]; | ||||||
|
||||||
auto idx_iter = arg_to_idx.find(y_grad_var_name); | ||||||
if (idx_iter == arg_to_idx.end()) { | ||||||
IR_THROW("op[%s] should have got its y_grad", op_desc.Type()); | ||||||
} | ||||||
auto idx = idx_iter->second; | ||||||
VLOG(10) << "[output recording]" | ||||||
<< "[" << op_desc.Type() << "]" << y_grad_var_name << " " << idx; | ||||||
|
||||||
auto y_names = op_desc.Input("Y", true); | ||||||
auto y_name = y_names[0]; | ||||||
IR_ENFORCE(param_map->count(y_name) > 0, | ||||||
"Expected op[%s]'s input %s has been parsed", | ||||||
op_desc.Type(), | ||||||
y_name); | ||||||
auto y_defining_info = param_map->at(y_name); | ||||||
ir::OpResult y_value = y_defining_info.value; | ||||||
IR_ENFORCE(y_value, | ||||||
"Expected op[%s]'s input %s is not null", | ||||||
op_desc.Type(), | ||||||
y_name); | ||||||
ir::Type y_type = y_value.type(); | ||||||
IR_ENFORCE(y_type.isa<dialect::DenseTensorType>(), | ||||||
"Expected op[%s]'s input %s is DenseTensor but got %s", | ||||||
op_desc.Type(), | ||||||
y_name, | ||||||
y_type); | ||||||
dialect::DenseTensorType y_tensor_type = | ||||||
y_type.dyn_cast<dialect::DenseTensorType>(); | ||||||
std::vector<int64_t> y_shape = phi::vectorize(y_tensor_type.dims()); | ||||||
|
||||||
ir::OpResult value = operation->result(idx); | ||||||
ir::Builder builder(ctx, operation->GetParent()); | ||||||
auto reshape_op = builder.Build<dialect::Reshape_Op>(value, y_shape); | ||||||
(*param_map)[y_grad_var_name] = | ||||||
VariableDefiningInfo(reshape_op.out(), false, -1); | ||||||
} | ||||||
}; | ||||||
|
||||||
OpTranslator::OpTranslator() { | ||||||
general_handler = OpTranscriber(); | ||||||
special_handlers["add_n"] = AddNOpTranscriber(); | ||||||
|
@@ -1278,6 +1466,25 @@ OpTranslator::OpTranslator() { | |||||
special_handlers["shaddow_output"] = ShaddowOutputOpTranscriber(); | ||||||
special_handlers["split"] = SplitOpTranscriber(); | ||||||
special_handlers["sum"] = AddNOpTranscriber(); | ||||||
|
||||||
// special handler for elementwise ops with axis != -1 | ||||||
// note(lyk): maybe we should do this by a pass, which seems more reasonable | ||||||
special_handlers["elementwise_add"] = ElementwiseTranscriber(); | ||||||
special_handlers["elementwise_sub"] = ElementwiseTranscriber(); | ||||||
special_handlers["elementwise_mul"] = ElementwiseTranscriber(); | ||||||
special_handlers["elementwise_div"] = ElementwiseTranscriber(); | ||||||
special_handlers["elementwise_max"] = ElementwiseTranscriber(); | ||||||
special_handlers["elementwise_min"] = ElementwiseTranscriber(); | ||||||
special_handlers["elementwise_mod"] = ElementwiseTranscriber(); | ||||||
special_handlers["elementwise_floordiv"] = ElementwiseTranscriber(); | ||||||
special_handlers["elementwise_add_grad"] = ElementwiseGradTranscriber(); | ||||||
special_handlers["elementwise_sub_grad"] = ElementwiseGradTranscriber(); | ||||||
special_handlers["elementwise_mul_grad"] = ElementwiseGradTranscriber(); | ||||||
special_handlers["elementwise_div_grad"] = ElementwiseGradTranscriber(); | ||||||
special_handlers["elementwise_max_grad"] = ElementwiseGradTranscriber(); | ||||||
special_handlers["elementwise_min_grad"] = ElementwiseGradTranscriber(); | ||||||
special_handlers["elementwise_mod_grad"] = ElementwiseGradTranscriber(); | ||||||
special_handlers["elementwise_floordiv_grad"] = ElementwiseGradTranscriber(); | ||||||
} | ||||||
|
||||||
} // namespace translator | ||||||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后续可以留意下auto的位置是否可以优先使用auto&,除了有copy的开销,更重要的是有时候可能会触发隐藏的bug(之前跟zhiqiu一起遇到过)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,之后会注意下。