Skip to content

Commit

Permalink
[PIR]Polish GroupOp and Interface code (PaddlePaddle#57829)
Browse files Browse the repository at this point in the history
* [PIR]Polish GroupOp and Interface code

* fix comment
  • Loading branch information
Aurelius84 authored Oct 10, 2023
1 parent c8d0fde commit f393845
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 34 deletions.
11 changes: 6 additions & 5 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,16 @@ void GroupOp::Build(pir::Builder &builder,
argument.output_types = output_types;
}

pir::Block *GroupOp::Block() {
pir::Block *GroupOp::block() {
pir::Region &region = (*this)->region(0);
if (region.empty()) region.emplace_back();
return region.front();
}

std::vector<pir::Operation *> GroupOp::Ops() {
auto *block = this->Block();
return std::vector<pir::Operation *>(block->begin(), block->end());
std::vector<pir::Operation *> GroupOp::ops() {
auto *inner_block = this->block();
return std::vector<pir::Operation *>(inner_block->begin(),
inner_block->end());
}

void GroupOp::Verify() {}
Expand All @@ -54,7 +55,7 @@ void GroupOp::Print(pir::IrPrinter &printer) {
os << " -> ";
printer.PrintOpReturnType(op);
os << " {";
for (auto &sub_op : Ops()) {
for (auto &sub_op : ops()) {
os << "\n";
printer.PrintOperation(sub_op);
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class GroupOp : public pir::Op<GroupOp> {
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Type> &output_types);

pir::Block *Block();
std::vector<pir::Operation *> Ops();
pir::Block *block();
std::vector<pir::Operation *> ops();

void Verify();
void Print(pir::IrPrinter &printer); // NOLINT
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/transforms/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ void ReplaceWithGroupOp(pir::Block* block,
}
// step 2: Replace the old op with GroupOp.
auto new_group_op = builder.Build<cinn::dialect::GroupOp>(output_types);
pir::Block* group_block = new_group_op.Block();
pir::Block* group_block = new_group_op.block();
for (auto* op : group_ops) {
op->MoveTo(group_block, group_block->begin());
}
Expand Down
1 change: 1 addition & 0 deletions paddle/pir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ void Operation::SetParent(Block *parent, const Block::Iterator &position) {
}

void Operation::MoveTo(Block *block, Block::Iterator position) {
IR_ENFORCE(parent_, "Operation does not have parent");
Operation *op = parent_->Take(this);
block->insert(position, op);
}
Expand Down
4 changes: 1 addition & 3 deletions paddle/pir/core/value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ bool Value::operator!=(const Value &other) const {

bool Value::operator!() const { return impl_ == nullptr; }

bool Value::operator<(const Value &other) const {
return std::hash<Value>{}(*this) < std::hash<Value>{}(other);
}
bool Value::operator<(const Value &other) const { return impl_ < other.impl_; }

Value::operator bool() const { return impl_; }

Expand Down
22 changes: 3 additions & 19 deletions test/cpp/pir/cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,11 @@ if(WITH_TESTING AND WITH_CINN)
convert_to_dialect)
set_tests_properties(test_jit_instruction PROPERTIES LABELS "RUN_TYPE=CINN")

cc_test_old(
test_group_op
SRCS
group_op_test.cc
DEPS
cinn_op_dialect
pir
phi
gtest
glog)
paddle_test(test_group_op SRCS group_op_test.cc DEPS cinn_op_dialect)
set_tests_properties(test_group_op PROPERTIES LABELS "RUN_TYPE=CINN")

cc_test_old(
test_pir_build_cinn_pass
SRCS
build_cinn_pass_test.cc
DEPS
pd_build_cinn_pass
pir_pass
gtest
glog)
paddle_test(test_pir_build_cinn_pass SRCS build_cinn_pass_test.cc DEPS
pd_build_cinn_pass)
set_tests_properties(test_pir_build_cinn_pass PROPERTIES LABELS
"RUN_TYPE=CINN")
endif()
2 changes: 1 addition & 1 deletion test/cpp/pir/cinn/build_cinn_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
CHECK_EQ(origin_program->block()->size(), 1u);
pir::Operation* group_op = origin_program->block()->front();
pir::Block* group_block =
group_op->dyn_cast<cinn::dialect::GroupOp>().Block();
group_op->dyn_cast<cinn::dialect::GroupOp>().block();
CHECK_EQ(group_block->size(), 6u);

std::vector<std::string> op_names = {
Expand Down
6 changes: 3 additions & 3 deletions test/cpp/pir/cinn/group_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ std::shared_ptr<::pir::Program> BuildGroupProgram() {
const std::vector<int64_t> shape = {64, 128};
auto group_op1 = builder.Build<cinn::dialect::GroupOp>(
CreateDenseTensorTypes(phi::make_ddim(shape)));
pir::Block* block1 = group_op1.Block();
pir::Block* block1 = group_op1.block();
builder.SetInsertionPointToEnd(block1);
auto full_op_x = builder.Build<paddle::dialect::FullOp>(
shape, value_one, phi::DataType::FLOAT32, phi::GPUPlace());
Expand All @@ -60,7 +60,7 @@ std::shared_ptr<::pir::Program> BuildGroupProgram() {
builder.SetInsertionPointToEnd(program->block());
auto group_op2 = builder.Build<cinn::dialect::GroupOp>(
CreateDenseTensorTypes(phi::make_ddim(shape)));
pir::Block* block2 = group_op2.Block();
pir::Block* block2 = group_op2.block();
builder.SetInsertionPointToEnd(block2);

auto tan_op_x = builder.Build<paddle::dialect::TanOp>(group_op1->result(0));
Expand All @@ -84,7 +84,7 @@ TEST(GroupOp, TestBuild) {
int i = 0;
for (auto* sub_op : *(program->block())) {
EXPECT_TRUE(sub_op->isa<cinn::dialect::GroupOp>());
EXPECT_EQ(sub_op->dyn_cast<cinn::dialect::GroupOp>().Ops().size(),
EXPECT_EQ(sub_op->dyn_cast<cinn::dialect::GroupOp>().ops().size(),
op_num[i]);
++i;
}
Expand Down

0 comments on commit f393845

Please sign in to comment.