From d615928406f6172ad8f85deedef16e6fd41189e8 Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Wed, 19 Jun 2019 11:55:21 +0300 Subject: [PATCH] Squashed #3368 (WIP): [ARITH] Migrate simplifier to new infra --- CMakeLists.txt | 6 ++++- include/tvm/arithmetic.h | 8 ++++++ include/tvm/ir_pass.h | 1 - src/arithmetic/analyzer.cc | 11 ++++++++ src/arithmetic/bound_deducer.cc | 4 ++- src/arithmetic/rewrite_simplify.cc | 3 +++ src/arithmetic/stmt_simplify.cc | 40 ++++++------------------------ src/lang/buffer.cc | 5 ++-- src/op/scan_op.cc | 6 ++--- src/pass/loop_partition.cc | 16 ++++++------ src/pass/narrow_channel_access.cc | 7 +++--- src/pass/storage_rewrite.cc | 8 +++--- src/pass/vectorize_loop.cc | 9 ++++--- src/schedule/message_passing.cc | 20 ++++++++------- tests/cpp/ir_simplify_test.cc | 10 ++------ 15 files changed, 79 insertions(+), 75 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6500ba013e28..c23d403bcb6a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -154,7 +154,11 @@ file(GLOB_RECURSE NNVM_COMPILER_SRCS file(GLOB TOPI_SRCS topi/src/*.cc ) -file(GLOB_RECURSE HALIDEIR_SRCS 3rdparty/HalideIR/src/*.cpp) +file(GLOB_RECURSE HALIDEIR_SRCS + 3rdparty/HalideIR/src/base/*.cpp + 3rdparty/HalideIR/src/ir/*.cpp + 3rdparty/HalideIR/src/tvm/*.cpp +) list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS}) file(GLOB RUNTIME_SRCS src/runtime/*.cc diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index c506268cb14b..6f8155f9de59 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -516,6 +516,14 @@ class Analyzer { * \note Analyzer will call into sub-analyzers to get the result. */ bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound); + /*! + * \brief Whether can we prove condition. + * + * \param cond The expression to be proved. + * + * \note Analyzer will call into sub-analyzers to get the result. + */ + bool CanProve(const Expr& cond); }; //----------------------------------------------- diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index e1c92e50e6ad..98dbf6bb6290 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -27,7 +27,6 @@ #ifndef TVM_IR_PASS_H_ #define TVM_IR_PASS_H_ -#include #include #include #include diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index 10a1c7f041c3..fa94fc7e9dae 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -82,5 +82,16 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) { return false; } +bool Analyzer::CanProve(const Expr& expr) { + if (const auto* ptr = expr.as()) { + return ptr->value != 0; + } + auto res = this->rewrite_simplify(expr); + if (const auto* ptr = res.as()) { + return ptr->value != 0; + } + return false; +} + } // namespace arith } // namespace tvm diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 395a371f43af..e85c71057e6c 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -147,7 +147,7 @@ class BoundDeducer: public IRVisitor { } // always use relax bound - bool divided = can_prove(result % operand == 0); + bool divided = analyzer_.CanProve(result % operand == 0); result = result / operand; // since system will round down when not divided // eg. 2/4 -> 0; -2/4 -> -1 @@ -180,6 +180,8 @@ class BoundDeducer: public IRVisitor { ExprIntSetMap expr_map_; std::vector path_; size_t iter_{0}; + // internal analzyer + Analyzer analyzer_; }; class BoundDeduceInputChecker: public IRVisitor { diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 4700be63e608..5af84bbe2b9d 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -690,6 +690,9 @@ Mutate_(const Mod* op, const Expr& self) { if (mod->coeff % c1val == 0 && CanProveGreaterEqual(x.Eval(), 0)) { return (mod->base % c1).Eval(); + } else if (mod->coeff % c1val == 0 && + mod->base % c1val == 0) { + return make_zero(ret.type()); } } } diff --git a/src/arithmetic/stmt_simplify.cc b/src/arithmetic/stmt_simplify.cc index 403187eb39fd..89298ed6d101 100644 --- a/src/arithmetic/stmt_simplify.cc +++ b/src/arithmetic/stmt_simplify.cc @@ -28,7 +28,6 @@ #include #include #include -#include "arithmetic/Simplify.h" namespace tvm { namespace arith { @@ -145,42 +144,17 @@ Expr CanonicalSimplify(Expr expr, Map vrange) { return analyzer.canonical_simplify(expr); } -template -T Simplify_(T a, Map vrange) { - using namespace HalideIR::Internal; - Scope rscope; +Expr Simplify(Expr expr, Map vrange) { + arith::Analyzer analyzer; for (auto kv : vrange) { - Range r = kv.second; - rscope.push( - kv.first.get(), - Interval(r->min, - simplify(r->min + r->extent - make_const(r->min.type(), 1)))); - } - return HalideIR::Internal::simplify(a, true, rscope); -} - - -Expr Simplify(Expr a, Map vrange) { - // Simplify top level reduce. - if (const Reduce* r = a.as()) { - Array new_source; - for (auto& e : r->source) { - new_source.push_back(Simplify_(e, vrange)); - } - Expr new_condition = Simplify_(r->condition, vrange); - if (r->source.same_as(new_source) && - r->condition.same_as(new_condition)) { - return a; - } else { - return Reduce::make( - r->combiner, new_source, r->axis, new_condition, r->value_index); - } + analyzer.Bind(kv.first, kv.second); } - return Simplify_(a, vrange); + return analyzer.canonical_simplify(expr); } -Stmt Simplify(Stmt a, Map vrange) { - return Simplify_(a, vrange); +Stmt Simplify(Stmt stmt, Map vrange) { + return arith::CanonicalStmtSimplifier().CanonicalSimplify( + stmt, vrange); } } // namespace ir } // namespace tvm diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 8c584c50b3c6..3e0615162a8f 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -26,6 +26,7 @@ #include #include #include +#include #include "../arithmetic/compute_expr.h" namespace tvm { diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc index 42b1331e3736..78f8c82d97db 100644 --- a/src/op/scan_op.cc +++ b/src/op/scan_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -80,7 +80,7 @@ Operation ScanOpNode::make(std::string name, for (size_t i = 0; i < init.size(); ++i) { CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype); CHECK_EQ(init[i]->dtype, update[i]->dtype); - CHECK(can_prove(init[i]->shape[0] == axis->dom->min)) + CHECK(prove_equal(init[i]->shape[0], axis->dom->min)) << "init.shape[0] need to match scan_axis.dom.min"; CHECK(prove_equal( state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent)) diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 0a5b7410f3cf..b79016ecbc16 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -504,9 +504,9 @@ Stmt LoopPartitioner::TryPartition(const Node* node, bool pre_stmt_recurse = true; if (middle_interval_i->HasLowerBound()) { body_begin = ir::Simplify(middle_interval.min()); - if (!can_prove(body_begin == min)) { + if (!analyzer_.CanProve(body_begin == min)) { Expr cond = (body_begin - min >= 0); - if (!can_prove(cond)) { + if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond << ", when generating the pre doubt loop"; body_begin = Max::make(body_begin, min); @@ -529,10 +529,10 @@ Stmt LoopPartitioner::TryPartition(const Node* node, bool post_stmt_recurse = true; if (middle_interval_i->HasUpperBound()) { post_doubt_begin = ir::Simplify(middle_interval.max() + 1); - if (!can_prove(middle_interval.max() == max)) { + if (!analyzer_.CanProve(middle_interval.max() == max)) { // require the extent to be non-negative Expr cond = (max - post_doubt_begin + 1 >= 0); - if (!can_prove(cond)) { + if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond << ", when generating the post doubt loop"; post_doubt_begin = Min::make(post_doubt_begin, max); @@ -554,7 +554,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, // Generating code for middle subrange if (!partition_thread_scope) { Stmt mid_stmt; - if (!can_prove(body_begin >= post_doubt_begin)) { + if (!analyzer_.CanProve(body_begin >= post_doubt_begin)) { // [body_begin, post_doubt_begin) Stmt simplified_body = ConditionEliminator(cond_set, cond_value).Mutate(body); Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}}); @@ -576,8 +576,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node, s = AppendStmts(s, post_stmt); } else { Expr cond = const_true(); - if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin); - if (!can_prove(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin); + if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin); + if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin); s = ThreadPartitionInserter(cond_set, cond).Mutate(stmt); } s = ConvertSSA(s); @@ -587,7 +587,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) { const For *for_node = static_cast(node); CHECK(for_node); - if (can_prove(extent == make_const(Int(32), 1))) { + if (analyzer_.CanProve(extent == make_const(Int(32), 1))) { // If the loop extent is 1, do not create the loop anymore return Substitute(body, {{Var{for_node->loop_var}, make_const(Int(32), 0)}}); } else { diff --git a/src/pass/narrow_channel_access.cc b/src/pass/narrow_channel_access.cc index 731064edb012..57f3baf20e10 100644 --- a/src/pass/narrow_channel_access.cc +++ b/src/pass/narrow_channel_access.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -200,7 +200,7 @@ class ChannelAccessRewriter : public IRMutator { Expr base = linear_eq[1]; if (!is_zero(base)) return body; Expr left = ir::Simplify(adv_op->value - coeff * for_op->extent); - if (!can_prove(left >= 0)) return body; + if (!analyzer_.CanProve(left >= 0)) return body; // rewrite access index. ChannelAccessIndexRewriter rw( ch->handle_var.get(), var * coeff, read_access); @@ -233,6 +233,7 @@ class ChannelAccessRewriter : public IRMutator { return body; } + arith::Analyzer analyzer_; std::vector tasks_; }; diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 806a80ad4dc9..eba1cee8b7c7 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -606,7 +606,7 @@ class StoragePlanRewriter : public IRMutator { } // transform to alloc bytes auto type_bits = alloc_type.bits() * alloc_type.lanes(); - bool divided = can_prove(combo_size % type_bits == 0); + bool divided = analyzer_.CanProve(combo_size % type_bits == 0); combo_size = combo_size / type_bits; // round up for can not divided if (!divided) { @@ -920,6 +920,8 @@ class StoragePlanRewriter : public IRMutator { std::unordered_map alloc_map_; // The allocations std::vector > alloc_vec_; + // analyzer + arith::Analyzer analyzer_; }; // Turn alloc into vector alloc diff --git a/src/pass/vectorize_loop.cc b/src/pass/vectorize_loop.cc index 8c3d383c1529..a48e8b4d7e83 100644 --- a/src/pass/vectorize_loop.cc +++ b/src/pass/vectorize_loop.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -132,11 +133,11 @@ class Vectorizer : public IRMutator { if (lanes != 1) { const Ramp* b_ramp = b.as(); const Ramp* a_ramp = a.as(); - if (a_ramp && b.type().lanes() == 1 && can_prove(b > 0)) { + if (a_ramp && b.type().lanes() == 1 && analyzer_.CanProve(b > 0)) { return Ramp::make( a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes); } - if (b_ramp && a.type().lanes() == 1 && can_prove(a > 0)) { + if (b_ramp && a.type().lanes() == 1 && analyzer_.CanProve(a > 0)) { return Ramp::make( b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes); } @@ -186,7 +187,7 @@ class Vectorizer : public IRMutator { Expr stride = this->Mutate(op->stride); if (base.type().lanes() > 1 && stride.type().lanes() == 1) { const Ramp* base_ramp = base.as(); - if (can_prove(base_ramp->stride == stride * make_const(stride.type(), op->lanes))) { + if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.type(), op->lanes))) { return Ramp::make(base_ramp->base, stride, op->lanes * base_ramp->lanes); } } @@ -423,6 +424,8 @@ class Vectorizer : public IRMutator { } private: + // analyzer + arith::Analyzer analyzer_; // variable to be replaced Var var_; // the lanes. diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc index ab91944d05a2..fba2b890a98f 100644 --- a/src/schedule/message_passing.cc +++ b/src/schedule/message_passing.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -433,9 +433,9 @@ void PassDownBitMaskOr(const Stage& stage, */ void PassUpBoundCheck(const Stage& s, const Map& dom_map, - std::unordered_map* p_state) { + std::unordered_map* p_state, + arith::Analyzer* analyzer) { auto& state = *p_state; - using HalideIR::Internal::can_prove; for (size_t i = s->relations.size(); i != 0; --i) { IterVarRelation rel = s->relations[i - 1]; if (const SplitNode* s = rel.as()) { @@ -448,7 +448,7 @@ void PassUpBoundCheck(const Stage& s, if (outer || inner) { state[s->parent] = true; } else { - if (can_prove(dom_map.at(s->parent)->extent == factor * step)) { + if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step)) { state[s->parent] = false; } else { state[s->parent] = true; @@ -477,11 +477,13 @@ std::vector MakeBoundCheck( const std::unordered_map& value_map, bool skip_ivar_domain, const std::unordered_set& skip_iter) { + Analyzer analyzer; + std::unordered_map bound_state; for (IterVar iv : stage->leaf_iter_vars) { bound_state[iv] = false; } - PassUpBoundCheck(stage, dom_map, &bound_state); + PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer); std::vector preds; std::unordered_map iset_dmap; @@ -497,7 +499,7 @@ std::vector MakeBoundCheck( Range dom = dom_map.at(iv); Expr value = ComputeExpr(value_map.at(iv), dom->min); Expr vmax = EvalSet(value, iset_dmap).max(); - if (vmax.type() != value.type() || !can_prove(vmax < dom->extent)) { + if (vmax.type() != value.type() || !analyzer.CanProve(vmax < dom->extent)) { preds.emplace_back(value < dom->extent); } } @@ -512,10 +514,10 @@ std::vector MakeBoundCheck( Expr vmin = s.min(); Expr vmax = s.max(); // The range of `value` resides in [vmin, vmax] - if (vmin.type() != value.type() || !can_prove(vmin >= 0)) { + if (vmin.type() != value.type() || !analyzer.CanProve(vmin >= 0)) { preds.emplace_back(value >= 0); } - if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) { + if (vmax.type() != value.type() || !analyzer.CanProve(vmax < iv->dom->extent)) { preds.emplace_back(value < iv->dom->extent); } } diff --git a/tests/cpp/ir_simplify_test.cc b/tests/cpp/ir_simplify_test.cc index 35968f8524de..5a5dc03f0165 100644 --- a/tests/cpp/ir_simplify_test.cc +++ b/tests/cpp/ir_simplify_test.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,12 +21,6 @@ #include #include #include -#include - -TEST(IRSIMPLIFY, Basic) { - using namespace HalideIR::Internal; - simplify_test(); -} TEST(IRSIMPLIFY, MinMax) { auto x = tvm::var("x");