From 798a2bdfcdddfecdebaacecc01a2e065d73aac0f Mon Sep 17 00:00:00 2001 From: winter-wang <1030748926@qq.com> Date: Sun, 17 Sep 2023 02:20:31 +0000 Subject: [PATCH] [PIR] normalize the use of value. 4-3 --- .../pir/dialect/op_generator/op_build_gen.py | 2 +- .../fluid/pir/dialect/operator/ir/manual_op.cc | 16 ++++++++-------- paddle/pir/core/builder.cc | 4 ++-- paddle/pir/core/builder.h | 6 +++--- paddle/pir/core/builtin_op.cc | 4 ++-- paddle/pir/core/op_operand.cc | 9 --------- paddle/pir/core/op_operand.h | 4 +--- paddle/pir/core/operation_utils.h | 14 -------------- paddle/pir/core/value.cc | 4 ---- paddle/pir/core/value.h | 2 +- paddle/pir/dialect/control_flow/ir/cf_ops.cc | 2 +- test/cpp/pir/core/ir_op_test.cc | 3 +-- test/cpp/pir/core/ir_program_test.cc | 12 ++++++------ .../ir_kernel_dialect_pass_test.cc | 2 +- test/cpp/pir/pass/pass_manager_test.cc | 4 ++-- .../pir/pattern_rewrite/pattern_rewrite_test.cc | 2 +- test/cpp/pir/tools/test_op.cc | 2 +- 17 files changed, 31 insertions(+), 61 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py index df487f01e99307..3c6004c163d248 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py @@ -184,7 +184,7 @@ def GenBuildInserFullForMutableAttribute( def GenBuildInputs(op_input_name_list, op_mutable_attribute_name_list): BUILD_INPUT_TEMPLATE = """ std::vector argument_inputs = {{{inputs_args}}}; - argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); + argument.AddInputs(argument_inputs.begin(), argument_inputs.end()); """ build_input_str = ' VLOG(4) << "Builder construction inputs";\n' input_name_list = op_input_name_list + op_mutable_attribute_name_list diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 3ee3bec97cd89e..6d880ad2226262 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -103,7 +103,7 @@ void AddNOp::Build(pir::Builder &builder, // NOLINT pir::OpResult inputs) { VLOG(4) << "Builder construction inputs"; std::vector argument_inputs = {inputs}; - argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); + argument.AddInputs(argument_inputs.begin(), argument_inputs.end()); VLOG(4) << "Builder construction attributes"; @@ -179,7 +179,7 @@ void AddN_Op::Build(pir::Builder &builder, pir::OpResult inputs_) { VLOG(4) << "Builder construction inputs"; std::vector argument_inputs = {inputs_}; - argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); + argument.AddInputs(argument_inputs.begin(), argument_inputs.end()); VLOG(4) << "Builder construction attributes"; @@ -307,7 +307,7 @@ void AddNWithKernelOp::Build(pir::Builder &builder, pir::OpResult inputs_) { VLOG(4) << "Builder construction inputs"; std::vector argument_inputs = {inputs_}; - argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); + argument.AddInputs(argument_inputs.begin(), argument_inputs.end()); VLOG(4) << "Builder construction attributes"; @@ -477,7 +477,7 @@ void FusedGemmEpilogueOp::Build(pir::Builder &builder, VLOG(4) << "Builder construction inputs"; std::vector argument_inputs = {x_, y_, bias_}; - argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); + argument.AddInputs(argument_inputs.begin(), argument_inputs.end()); VLOG(4) << "Builder construction attributes"; pir::Attribute attr_trans_x = @@ -732,7 +732,7 @@ void FusedGemmEpilogueGradOp::Build(pir::Builder &builder, VLOG(4) << "Builder construction inputs"; std::vector argument_inputs = { x_, y_, reserve_space_, out_grad_}; - argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); + argument.AddInputs(argument_inputs.begin(), argument_inputs.end()); VLOG(4) << "Builder construction attributes"; pir::Attribute attr_trans_x = @@ -916,7 +916,7 @@ void SplitGradOp::Build(pir::Builder &builder, VLOG(4) << "Builder construction inputs"; std::vector argument_inputs = {out_grad_, axis_}; - argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); + argument.AddInputs(argument_inputs.begin(), argument_inputs.end()); VLOG(4) << "Builder construction attributes"; @@ -974,7 +974,7 @@ void SplitGradOp::Build(pir::Builder &builder, pir::OpResult axis_) { VLOG(4) << "Builder construction inputs"; std::vector argument_inputs = {out_grad_, axis_}; - argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); + argument.AddInputs(argument_inputs.begin(), argument_inputs.end()); VLOG(4) << "Builder construction attributes"; @@ -1095,7 +1095,7 @@ void IfOp::Build(pir::Builder &builder, // NOLINT pir::OpResult cond, std::vector &&output_types) { argument.num_regions = 2; - argument.AddOperand(cond); + argument.AddInput(cond); argument.output_types.swap(output_types); } pir::Block *IfOp::true_block() { diff --git a/paddle/pir/core/builder.cc b/paddle/pir/core/builder.cc index a91428ba99080e..3c12508b6d1550 100644 --- a/paddle/pir/core/builder.cc +++ b/paddle/pir/core/builder.cc @@ -20,8 +20,8 @@ namespace pir { /// Create an operation given the fields represented as an OperationState. -Operation *Builder::Build(OperationArgument &&argument) { - return Insert(Operation::Create(std::move(argument))); +Operation *Builder::Build(const OperationArgument &argument) { + return Insert(Operation::Create(argument)); } /// Creates an operation with the given fields. diff --git a/paddle/pir/core/builder.h b/paddle/pir/core/builder.h index 81e25a0d365f0c..d39df37a14554e 100644 --- a/paddle/pir/core/builder.h +++ b/paddle/pir/core/builder.h @@ -94,7 +94,7 @@ class Builder { Block *block() const { return block_; } /// Creates an operation given the fields represented as an OperationState. - IR_API Operation *Build(OperationArgument &&argument); + IR_API Operation *Build(const OperationArgument &argument); /// Creates an operation with the given fields. IR_API Operation *Build(const std::vector &inputs, @@ -107,8 +107,8 @@ class Builder { OpTy Build(Args &&...args) { OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name())); OpTy::Build(*this, argument, std::forward(args)...); - Operation *op = Build(std::move(argument)); - return op->dyn_cast(); + Operation *op = Build(argument); + return OpTy(op); } IR_API UInt8Type uint8_type(); diff --git a/paddle/pir/core/builtin_op.cc b/paddle/pir/core/builtin_op.cc index c7e8675bab75de..27cbc920f414a6 100644 --- a/paddle/pir/core/builtin_op.cc +++ b/paddle/pir/core/builtin_op.cc @@ -70,7 +70,7 @@ ModuleOp ModuleOp::Create(IrContext *context, Program *pointer) { OperationArgument argument(info); argument.num_regions = 1; argument.AddAttribute("program", PointerAttribute::get(context, pointer)); - Operation *op = Operation::Create(std::move(argument)); + Operation *op = Operation::Create(argument); op->region(0).emplace_back(); return ModuleOp(op); } @@ -140,7 +140,7 @@ void SetParameterOp::Build(Builder &builder, // NOLINT OperationArgument &argument, // NOLINT OpResult parameter, const std::string &name) { - argument.AddOperand(parameter); + argument.AddInput(parameter); argument.AddAttribute(attributes_name[0], pir::StrAttribute::get(builder.ir_context(), name)); } diff --git a/paddle/pir/core/op_operand.cc b/paddle/pir/core/op_operand.cc index b27f02ac23d4ce..c728180f48fbfb 100644 --- a/paddle/pir/core/op_operand.cc +++ b/paddle/pir/core/op_operand.cc @@ -24,20 +24,11 @@ CHECK_NULL_IMPL(OpOpernad, func_name) namespace pir { - -OpOperand::OpOperand(const detail::OpOperandImpl *impl) - : impl_(const_cast(impl)) {} - OpOperand &OpOperand::operator=(const OpOperand &rhs) { impl_ = rhs.impl_; return *this; } -OpOperand &OpOperand::operator=(const detail::OpOperandImpl *impl) { - if (this->impl_ == impl) return *this; - impl_ = const_cast(impl); - return *this; -} OpOperand::operator bool() const { return impl_ && impl_->source(); } OpOperand OpOperand::next_use() const { diff --git a/paddle/pir/core/op_operand.h b/paddle/pir/core/op_operand.h index 96b355b861ffa1..91636ea9ed8ba8 100644 --- a/paddle/pir/core/op_operand.h +++ b/paddle/pir/core/op_operand.h @@ -35,12 +35,10 @@ class IR_API OpOperand { OpOperand(const OpOperand &other) = default; - OpOperand(const detail::OpOperandImpl *impl); // NOLINT + OpOperand(detail::OpOperandImpl *impl) : impl_(impl) {} // NOLINT OpOperand &operator=(const OpOperand &rhs); - OpOperand &operator=(const detail::OpOperandImpl *impl); - bool operator==(const OpOperand &other) const { return impl_ == other.impl_; } bool operator!=(const OpOperand &other) const { return !operator==(other); } diff --git a/paddle/pir/core/operation_utils.h b/paddle/pir/core/operation_utils.h index fde59c31c012b4..e07c048c6fe141 100644 --- a/paddle/pir/core/operation_utils.h +++ b/paddle/pir/core/operation_utils.h @@ -57,17 +57,10 @@ struct OperationArgument { num_regions(num_regions), successors(successors) {} - // Will be deleted in the next pr. - void AddOperand(OpResult operand) { inputs.emplace_back(operand); } - void AddInput(Value input) { inputs.emplace_back(input.dyn_cast()); } - // Will be deleted in the next pr. - template - void AddOperands(InputIt first, InputIt last); - template void AddInputs(InputIt first, InputIt last); @@ -99,13 +92,6 @@ struct OperationArgument { void AddSuccessor(Block* successor) { successors.emplace_back(successor); } }; -template -void OperationArgument::AddOperands(InputIt first, InputIt last) { - while (first != last) { - inputs.emplace_back(*first++); - } -} - template void OperationArgument::AddInputs(InputIt first, InputIt last) { while (first != last) { diff --git a/paddle/pir/core/value.cc b/paddle/pir/core/value.cc index 8aaab15b29c422..1c7eb83f01036f 100644 --- a/paddle/pir/core/value.cc +++ b/paddle/pir/core/value.cc @@ -30,10 +30,6 @@ #define CHECK_VALUE_NULL_IMPL(func_name) CHECK_NULL_IMPL(Value, func_name) namespace pir { - -Value::Value(const detail::ValueImpl *impl) - : impl_(const_cast(impl)) {} - bool Value::operator==(const Value &other) const { return impl_ == other.impl_; } diff --git a/paddle/pir/core/value.h b/paddle/pir/core/value.h index ee908f2355b085..39a3d052b5078b 100644 --- a/paddle/pir/core/value.h +++ b/paddle/pir/core/value.h @@ -33,7 +33,7 @@ class IR_API Value { public: Value() = default; - Value(const detail::ValueImpl *impl); // NOLINT + Value(detail::ValueImpl *impl) : impl_(impl) {} // NOLINT Value(const Value &other) = default; diff --git a/paddle/pir/dialect/control_flow/ir/cf_ops.cc b/paddle/pir/dialect/control_flow/ir/cf_ops.cc index 7dd72ea12551ed..6147ab8ee40e47 100644 --- a/paddle/pir/dialect/control_flow/ir/cf_ops.cc +++ b/paddle/pir/dialect/control_flow/ir/cf_ops.cc @@ -19,7 +19,7 @@ namespace pir { void YieldOp::Build(Builder &builder, OperationArgument &argument, std::vector &&inputs) { - argument.AddOperands(inputs.begin(), inputs.end()); + argument.AddInputs(inputs.begin(), inputs.end()); } } // namespace pir diff --git a/test/cpp/pir/core/ir_op_test.cc b/test/cpp/pir/core/ir_op_test.cc index 8b5e5173bd78cb..bfc03e66944e9c 100644 --- a/test/cpp/pir/core/ir_op_test.cc +++ b/test/cpp/pir/core/ir_op_test.cc @@ -237,8 +237,7 @@ TEST(op_test, region_test) { argument.output_types = {pir::Float32Type::get(ctx)}; argument.num_regions = 1; - pir::Operation *op3 = pir::Operation::Create(std::move(argument)); - // argument.regions.emplace_back(std::make_unique()); + pir::Operation *op3 = pir::Operation::Create(argument); pir::Region ®ion = op3->region(0); EXPECT_EQ(region.empty(), true); diff --git a/test/cpp/pir/core/ir_program_test.cc b/test/cpp/pir/core/ir_program_test.cc index cafca1c97bdb22..85f608aa117a28 100644 --- a/test/cpp/pir/core/ir_program_test.cc +++ b/test/cpp/pir/core/ir_program_test.cc @@ -44,8 +44,8 @@ class AddOp : public pir::Op { void Verify(); static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT - pir::OpResult l_operand, - pir::OpResult r_operand, + pir::Value l_operand, + pir::Value r_operand, pir::Type sum_type); }; void AddOp::Verify() { @@ -58,11 +58,11 @@ void AddOp::Verify() { } void AddOp::Build(pir::Builder &, pir::OperationArgument &argument, - pir::OpResult l_operand, - pir::OpResult r_operand, + pir::Value l_operand, + pir::Value r_operand, pir::Type sum_type) { - argument.AddOperand(l_operand); - argument.AddOperand(r_operand); + argument.AddInput(l_operand); + argument.AddInput(r_operand); argument.AddOutput(sum_type); } IR_DECLARE_EXPLICIT_TYPE_ID(AddOp) diff --git a/test/cpp/pir/kernel_dialect/ir_kernel_dialect_pass_test.cc b/test/cpp/pir/kernel_dialect/ir_kernel_dialect_pass_test.cc index 9efb3f2329e887..52773cc96e9289 100644 --- a/test/cpp/pir/kernel_dialect/ir_kernel_dialect_pass_test.cc +++ b/test/cpp/pir/kernel_dialect/ir_kernel_dialect_pass_test.cc @@ -164,7 +164,7 @@ TEST(kernel_dialect, legacy_op_test) { "kernel_key", kernel_key); - pir::Operation* op = pir::Operation::Create(std::move(argument)); + pir::Operation* op = pir::Operation::Create(argument); EXPECT_EQ("pd_op.kernel_op", op->dyn_cast().op_name()); EXPECT_EQ("kernel_op", diff --git a/test/cpp/pir/pass/pass_manager_test.cc b/test/cpp/pir/pass/pass_manager_test.cc index ac1b8a6c6d9f38..8095f9e00acbe4 100644 --- a/test/cpp/pir/pass/pass_manager_test.cc +++ b/test/cpp/pir/pass/pass_manager_test.cc @@ -89,8 +89,8 @@ void AddOp::Build(pir::Builder &, pir::OpResult l_operand, pir::OpResult r_operand, pir::Type sum_type) { - argument.AddOperand(l_operand); - argument.AddOperand(r_operand); + argument.AddInput(l_operand); + argument.AddInput(r_operand); argument.AddOutput(sum_type); } IR_DECLARE_EXPLICIT_TYPE_ID(AddOp) diff --git a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc index 1bb9dc0cafae73..343a75f8229421 100644 --- a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc @@ -583,7 +583,7 @@ void Conv2dFusionOpTest::Build(pir::Builder &builder, VLOG(4) << "Builder construction inputs"; std::vector argument_inputs = { input_, filter_, bias_, residual_}; - argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); + argument.AddInputs(argument_inputs.begin(), argument_inputs.end()); VLOG(4) << "Builder construction attributes"; std::vector vec_strides; diff --git a/test/cpp/pir/tools/test_op.cc b/test/cpp/pir/tools/test_op.cc index 9802f8827cf6f5..9adce7ea402e98 100644 --- a/test/cpp/pir/tools/test_op.cc +++ b/test/cpp/pir/tools/test_op.cc @@ -28,7 +28,7 @@ void BranchOp::Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, const std::vector &target_operands, pir::Block *target) { - argument.AddOperands(target_operands.begin(), target_operands.end()); + argument.AddInputs(target_operands.begin(), target_operands.end()); argument.AddSuccessor(target); }