From afe7d6cf271132fa2f977bdd1a823dcc7119b9e7 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Wed, 20 Sep 2023 06:34:40 +0000 Subject: [PATCH] add API for ir_compare and move it from namespace optim to ir_utils --- .../auto_schedule/search_space/search_state.cc | 7 +++---- .../auto_schedule/search_space/search_state.h | 4 ++-- paddle/cinn/ir/test/ir_compare_test.cc | 18 +++++++----------- paddle/cinn/ir/utils/ir_compare.cc | 8 ++++++++ paddle/cinn/ir/utils/ir_compare.h | 6 ++++++ paddle/cinn/ir/utils/ir_visitor.cc | 3 +-- 6 files changed, 27 insertions(+), 19 deletions(-) diff --git a/paddle/cinn/auto_schedule/search_space/search_state.cc b/paddle/cinn/auto_schedule/search_space/search_state.cc index 96ace0f505d7fa..c16bf628402913 100644 --- a/paddle/cinn/auto_schedule/search_space/search_state.cc +++ b/paddle/cinn/auto_schedule/search_space/search_state.cc @@ -133,11 +133,10 @@ bool SearchStateEqual::operator()(const SearchState& lhs, // compare exprs size firstly if (lhs_exprs.size() != rhs_exprs.size()) return false; - // compare every expr one by one with ir::IrEqualVisitor + // compare every expr one by one with ir::ir_utils::IrEqualVisitor for (int i = 0; i < lhs_exprs.size(); ++i) { - ir::IrEqualVisitor compartor( - /*allow_name_suffix_diff=*/true); // ignore suffix difference in name - if (!compartor.Compare(lhs_exprs[i], rhs_exprs[i])) return false; + if (!ir::ir_utils::IRCompare(lhs_exprs[i], rhs_exprs[i], true)) + return false; } return true; } diff --git a/paddle/cinn/auto_schedule/search_space/search_state.h b/paddle/cinn/auto_schedule/search_space/search_state.h index 7991fb9540188e..b3f45c5cd746c7 100644 --- a/paddle/cinn/auto_schedule/search_space/search_state.h +++ b/paddle/cinn/auto_schedule/search_space/search_state.h @@ -70,8 +70,8 @@ struct SearchStateHash { size_t operator()(const SearchState& s) const; }; -// SearchStateHash equal functor, use ir::IrEqualVisitor to compare their AST -// struct and fields +// SearchStateHash equal functor, use ir::ir_utils::IrEqualVisitor to compare +// their AST struct and fields struct SearchStateEqual { bool operator()(const SearchState& lhs, const SearchState& rhs) const; }; diff --git a/paddle/cinn/ir/test/ir_compare_test.cc b/paddle/cinn/ir/test/ir_compare_test.cc index a1bca0cd5373f3..cc9ce438221a2e 100644 --- a/paddle/cinn/ir/test/ir_compare_test.cc +++ b/paddle/cinn/ir/test/ir_compare_test.cc @@ -23,7 +23,7 @@ namespace cinn { namespace ir { - +namespace ir_utils { TEST(TestIrCompare, SingleFunction) { Target target = common::DefaultHostTarget(); @@ -128,20 +128,16 @@ TEST(TestIrCompare, SingleFunction) { ASSERT_EQ(func2_str, utils::GetStreamCnt(funcs_2.front())); ASSERT_EQ(func3_str, utils::GetStreamCnt(funcs_3.front())); - IrEqualVisitor compartor; // they are different at the name of root ScheduleBlock - ASSERT_TRUE(compartor.Compare(funcs_1.front(), funcs_2.front())); + ASSERT_TRUE(IRCompare(funcs_1.front(), funcs_2.front())); // compare with itself - ASSERT_TRUE(compartor.Compare(funcs_1.front(), funcs_1.front())); - IrEqualVisitor compartor_allow_suffix_diff(true); + ASSERT_TRUE(IRCompare(funcs_1.front(), funcs_1.front())); // they are euqal if allowing suffix of name different - ASSERT_TRUE( - compartor_allow_suffix_diff.Compare(funcs_1.front(), funcs_2.front())); + ASSERT_TRUE(IRCompare(funcs_1.front(), funcs_2.front(), true)); - ASSERT_FALSE(compartor.Compare(funcs_1.front(), funcs_3.front())); - ASSERT_FALSE( - compartor_allow_suffix_diff.Compare(funcs_1.front(), funcs_3.front())); + ASSERT_FALSE(IRCompare(funcs_1.front(), funcs_3.front())); + ASSERT_FALSE(IRCompare(funcs_1.front(), funcs_3.front(), true)); } - +} // namespace ir_utils } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/utils/ir_compare.cc b/paddle/cinn/ir/utils/ir_compare.cc index c303262d04fbd1..87324be608048d 100644 --- a/paddle/cinn/ir/utils/ir_compare.cc +++ b/paddle/cinn/ir/utils/ir_compare.cc @@ -22,6 +22,8 @@ namespace cinn { namespace ir { +namespace ir_utils { + bool IrEqualVisitor::Compare(const Expr& lhs, const Expr& rhs) { if (lhs.get() == rhs.get()) { // the same object, including both are null return true; @@ -358,5 +360,11 @@ bool IrEqualVisitor::Visit(const ScheduleBlockRealize* lhs, const Expr* other) { Compare(lhs->schedule_block, rhs->schedule_block); } +bool IRCompare(const Expr& lhs, const Expr& rhs, bool allow_name_suffix_diff) { + IrEqualVisitor ir_equal_visitor(allow_name_suffix_diff); + return ir_equal_visitor.Compare(lhs, rhs); +} + +} // namespace ir_utils } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/utils/ir_compare.h b/paddle/cinn/ir/utils/ir_compare.h index 9e4b335857b985..d41e6db0441a7b 100644 --- a/paddle/cinn/ir/utils/ir_compare.h +++ b/paddle/cinn/ir/utils/ir_compare.h @@ -20,6 +20,7 @@ namespace cinn { namespace ir { +namespace ir_utils { // Determine whether two ir AST trees are euqal by comparing their struct and // fields of each node through dfs visitor @@ -47,5 +48,10 @@ class IrEqualVisitor : public IRVisitorRequireReImpl { bool allow_name_suffix_diff_ = false; }; +bool IRCompare(const Expr& lhs, + const Expr& rhs, + bool allow_name_suffix_diff = false); + +} // namespace ir_utils } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/utils/ir_visitor.cc b/paddle/cinn/ir/utils/ir_visitor.cc index 9ef6a78df1fcd5..f55259be2c6415 100644 --- a/paddle/cinn/ir/utils/ir_visitor.cc +++ b/paddle/cinn/ir/utils/ir_visitor.cc @@ -23,8 +23,7 @@ namespace ir { bool operator==(Expr a, Expr b) { if (a.get() == b.get()) return true; - IrEqualVisitor cmp; - return cmp.Compare(a, b); + return ir_utils::IRCompare(a, b); } bool operator!=(Expr a, Expr b) { return !(a == b); }