Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【CINN】add IRCompare Interface for ir_equal_visitor and move it from namespace optim to ir_utils #57531

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions paddle/cinn/auto_schedule/search_space/search_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,10 @@ bool SearchStateEqual::operator()(const SearchState& lhs,
// compare exprs size firstly
if (lhs_exprs.size() != rhs_exprs.size()) return false;

// compare every expr one by one with ir::IrEqualVisitor
// compare every expr one by one with ir::ir_utils::IrEqualVisitor
for (int i = 0; i < lhs_exprs.size(); ++i) {
ir::IrEqualVisitor compartor(
/*allow_name_suffix_diff=*/true); // ignore suffix difference in name
if (!compartor.Compare(lhs_exprs[i], rhs_exprs[i])) return false;
if (!ir::ir_utils::IRCompare(lhs_exprs[i], rhs_exprs[i], true))
return false;
}
return true;
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/auto_schedule/search_space/search_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ struct SearchStateHash {
size_t operator()(const SearchState& s) const;
};

// SearchStateHash equal functor, use ir::IrEqualVisitor to compare their AST
// struct and fields
// SearchStateHash equal functor, use ir::ir_utils::IrEqualVisitor to compare
// their AST struct and fields
struct SearchStateEqual {
bool operator()(const SearchState& lhs, const SearchState& rhs) const;
};
Expand Down
18 changes: 7 additions & 11 deletions paddle/cinn/ir/test/ir_compare_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

namespace cinn {
namespace ir {

namespace ir_utils {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个test不用加namespace吧?

TEST(TestIrCompare, SingleFunction) {
Target target = common::DefaultHostTarget();

Expand Down Expand Up @@ -128,20 +128,16 @@ TEST(TestIrCompare, SingleFunction) {
ASSERT_EQ(func2_str, utils::GetStreamCnt(funcs_2.front()));
ASSERT_EQ(func3_str, utils::GetStreamCnt(funcs_3.front()));

IrEqualVisitor compartor;
// they are different at the name of root ScheduleBlock
ASSERT_TRUE(compartor.Compare(funcs_1.front(), funcs_2.front()));
ASSERT_TRUE(IRCompare(funcs_1.front(), funcs_2.front()));
// compare with itself
ASSERT_TRUE(compartor.Compare(funcs_1.front(), funcs_1.front()));
IrEqualVisitor compartor_allow_suffix_diff(true);
ASSERT_TRUE(IRCompare(funcs_1.front(), funcs_1.front()));
// they are euqal if allowing suffix of name different
ASSERT_TRUE(
compartor_allow_suffix_diff.Compare(funcs_1.front(), funcs_2.front()));
ASSERT_TRUE(IRCompare(funcs_1.front(), funcs_2.front(), true));

ASSERT_FALSE(compartor.Compare(funcs_1.front(), funcs_3.front()));
ASSERT_FALSE(
compartor_allow_suffix_diff.Compare(funcs_1.front(), funcs_3.front()));
ASSERT_FALSE(IRCompare(funcs_1.front(), funcs_3.front()));
ASSERT_FALSE(IRCompare(funcs_1.front(), funcs_3.front(), true));
}

} // namespace ir_utils
} // namespace ir
} // namespace cinn
8 changes: 8 additions & 0 deletions paddle/cinn/ir/utils/ir_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
namespace cinn {
namespace ir {

namespace ir_utils {

bool IrEqualVisitor::Compare(const Expr& lhs, const Expr& rhs) {
if (lhs.get() == rhs.get()) { // the same object, including both are null
return true;
Expand Down Expand Up @@ -358,5 +360,11 @@ bool IrEqualVisitor::Visit(const ScheduleBlockRealize* lhs, const Expr* other) {
Compare(lhs->schedule_block, rhs->schedule_block);
}

bool IRCompare(const Expr& lhs, const Expr& rhs, bool allow_name_suffix_diff) {
IrEqualVisitor ir_equal_visitor(allow_name_suffix_diff);
return ir_equal_visitor.Compare(lhs, rhs);
}

} // namespace ir_utils
} // namespace ir
} // namespace cinn
6 changes: 6 additions & 0 deletions paddle/cinn/ir/utils/ir_compare.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

namespace cinn {
namespace ir {
namespace ir_utils {

// Determine whether two ir AST trees are euqal by comparing their struct and
// fields of each node through dfs visitor
Expand Down Expand Up @@ -47,5 +48,10 @@ class IrEqualVisitor : public IRVisitorRequireReImpl<bool, const Expr*> {
bool allow_name_suffix_diff_ = false;
};

bool IRCompare(const Expr& lhs,
const Expr& rhs,
bool allow_name_suffix_diff = false);

} // namespace ir_utils
} // namespace ir
} // namespace cinn
3 changes: 1 addition & 2 deletions paddle/cinn/ir/utils/ir_visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ namespace ir {

bool operator==(Expr a, Expr b) {
if (a.get() == b.get()) return true;
IrEqualVisitor cmp;
return cmp.Compare(a, b);
return ir_utils::IRCompare(a, b);
}

bool operator!=(Expr a, Expr b) { return !(a == b); }
Expand Down