diff --git a/CMakeLists.txt b/CMakeLists.txt index 6500ba013e28..c23d403bcb6a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -154,7 +154,11 @@ file(GLOB_RECURSE NNVM_COMPILER_SRCS file(GLOB TOPI_SRCS topi/src/*.cc ) -file(GLOB_RECURSE HALIDEIR_SRCS 3rdparty/HalideIR/src/*.cpp) +file(GLOB_RECURSE HALIDEIR_SRCS + 3rdparty/HalideIR/src/base/*.cpp + 3rdparty/HalideIR/src/ir/*.cpp + 3rdparty/HalideIR/src/tvm/*.cpp +) list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS}) file(GLOB RUNTIME_SRCS src/runtime/*.cc diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 92f7399a89a5..446c4c0c19a9 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -623,12 +623,15 @@ IntSet Intersect(const Array& sets); * give the domain of each variables. Return undefined IntSet to * represent failure. * + * \note The returned set may be smaller than set that + * contains all possible values of v that satisfies the bound. + * * \param v The target variable to be deduced. * \param cond The conditional expression. * \param hint_map The domain of variable, used to help deduce. * \param relax_map The domain of each variable, used to relax the domain, - * The deduce bound mush implies e for all value in relax_map - * \return An integer set that can cover all the possible values. + * The deduce bound must implies e for all value in relax_map + * \return An integer set that always satisfies the condition. */ IntSet DeduceBound(Expr v, Expr cond, const Map& hint_map, @@ -641,7 +644,7 @@ IntSet DeduceBound(Expr v, Expr cond, * \param hint_map The domain of variable, used to help deduce. * \param relax_map The domain of each variable, used to relax the domain, * The deduce bound mush implies e for all value in relax_map - * \return An integer set that can cover all the possible values. + * \return An integer set that always satisfies the condition. */ IntSet DeduceBound(Expr v, Expr cond, const std::unordered_map& hint_map, diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index e1c92e50e6ad..98dbf6bb6290 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -27,7 +27,6 @@ #ifndef TVM_IR_PASS_H_ #define TVM_IR_PASS_H_ -#include #include #include #include diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index 2198aee93478..626fc18c57df 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -106,6 +106,7 @@ bool Analyzer::CanProve(const Expr& expr) { Expr Analyzer::Simplify(const Expr& expr) { if (is_const(expr)) return expr; auto res = this->rewrite_simplify(expr); + if (is_const(res)) return res; res = this->canonical_simplify(res); return res; } diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 395a371f43af..003ba8def761 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -84,11 +84,11 @@ class BoundDeducer: public IRVisitor { void Deduce(); void Visit(const NodeRef& e) final { - if (!success) return; + if (!success_) return; if (e.get() == path_[iter_++]) { IRVisitor::Visit(e); } else { - success = false; + success_ = false; return; } } @@ -111,18 +111,18 @@ class BoundDeducer: public IRVisitor { void Visit_(const Add* op) final { bool left = op->a.get() == path_[iter_]; - result -= left ? op->b : op->a; + result_ -= left ? op->b : op->a; Visit(left ? op->a : op->b); } void Visit_(const Sub* op) final { bool left = op->a.get() == path_[iter_]; if (left) { - result += op->b; + result_ += op->b; } else { - result -= op->a; - result = - result; - is_greater = !is_greater; + result_ -= op->a; + result_ = - result_; + is_greater_ = !is_greater_; } Visit(left ? op->a : op->b); } @@ -130,43 +130,65 @@ class BoundDeducer: public IRVisitor { void Visit_(const Mul* op) final { bool left = op->a.get() == path_[iter_]; Expr operand = left ? op->b : op->a; + Expr target_var = left ? op->a : op->b; - SignType sign; + SignType sign_operand; if (operand.type().is_uint()) { - sign = kPositive; + sign_operand = kPositive; } else { - sign = expr_map_[operand].sign_type(); + sign_operand = expr_map_[operand].sign_type(); } - if (sign == SignType::kNegative) { - is_greater = !is_greater; - } else if (sign == SignType::kUnknown) { + if (sign_operand == SignType::kNegative) { + is_greater_ = !is_greater_; + } else if (sign_operand == SignType::kUnknown) { // unable to get the sign of operand - success = false; + success_ = false; return; } - // always use relax bound - bool divided = can_prove(result % operand == 0); - result = result / operand; - // since system will round down when not divided - // eg. 2/4 -> 0; -2/4 -> -1 - // no need fix for !is_greater: - // eg. a <= 2/4 -> a <= 0 - // eg. a <= 0/4 -> a <= 0 - // so just fix for not divided and is_greater - // eg. a >= 2/4 -> a >= 0 + 1 - // eg. a >= 0/4 -> a >= 0 - if (is_greater && !divided) { - result += 1; + bool divided = analyzer_.CanProve(result_ % operand == 0); + + result_ = result_ / operand; + + if (!divided) { + // Handle non-divisible case + // NOTE: this accounts for truc div behavior. + bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative(); + + if (is_greater_) { + result_ += 1; + } else { + // NOTE: this is a bit sutble hack. + // + // condition: + // - x * operand <= result + // - operand > 0 + // - x >= 0 + // + // Then it is fine to deduce that x <= result / operand. + // - if result > 0, this division round down + // - if result < 0, (result / operand) rounds up and may violate the constraint + // however, given that x is always non-negative, + // it is fine to have this relaxed bound, given that the user of deduce bound + // will respect the bound of x + // + // TODO(tvm-team): think about a better API to incorporate constraint of x. + // e.g. specify an interval of x and return a bound + // that is in the interval and satisfies the condition. + if (target_is_non_neg && sign_operand == kPositive) { + // do nothing + } else { + result_ -= 1; + } + } } - Visit(left ? op->a : op->b); } - Expr result; - bool is_greater{true}; - bool success{true}; + Expr result_; + bool is_greater_{true}; + bool success_{true}; private: void Init(); @@ -180,6 +202,8 @@ class BoundDeducer: public IRVisitor { ExprIntSetMap expr_map_; std::vector path_; size_t iter_{0}; + // internal analzyer + Analyzer analyzer_; }; class BoundDeduceInputChecker: public IRVisitor { @@ -202,7 +226,7 @@ class BoundDeduceInputChecker: public IRVisitor { void BoundDeducer::Init() { BoundDeduceInputChecker checker; - if (!checker.Check(this)) success = false; + if (!checker.Check(this)) success_ = false; Transform(); } @@ -211,66 +235,65 @@ void BoundDeducer::Transform() { if (const LT* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a < b -> b >= a + 1 - is_greater = true; + is_greater_ = true; expr_ = op->b; - result = op->a + 1; + result_ = op->a + 1; } else { // a < b -> a <= b - 1 - is_greater = false; + is_greater_ = false; expr_ = op->a; - result = op->b - 1; + result_ = op->b - 1; } } else if (const LE* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a <= b -> b >= a - is_greater = true; + is_greater_ = true; expr_ = op->b; - result = op->a; + result_ = op->a; } else { - is_greater = false; + is_greater_ = false; expr_ = op->a; - result = op->b; + result_ = op->b; } } else if (const GT* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a > b -> b <= a - 1 - is_greater = false; + is_greater_ = false; expr_ = op->b; - result = op->a - 1; + result_ = op->a - 1; } else { // a > b -> a >= b + 1 - is_greater = true; + is_greater_ = true; expr_ = op->a; - result = op->b + 1; + result_ = op->b + 1; } } else if (const GE* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a >= b -> b <= a - is_greater = false; + is_greater_ = false; expr_ = op->b; - result = op->a; + result_ = op->a; } else { - is_greater = true; + is_greater_ = true; expr_ = op->a; - result = op->b; + result_ = op->b; } } else { - success = false; + success_ = false; } } void BoundDeducer::Deduce() { Init(); - if (!success) return; + if (!success_) return; Relax(); - if (!success) return; + if (!success_) return; // get the path path_ = GetPath(target_, expr_); if (!path_.size()) { - success = false; + success_ = false; return; } - expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_); Visit(expr_); @@ -278,13 +301,13 @@ void BoundDeducer::Deduce() { void BoundDeducer::Relax() { IntSet a = EvalSet(expr_, relax_map_); - IntSet b = EvalSet(result, relax_map_); + IntSet b = EvalSet(result_, relax_map_); if (a.is_everything() || b.is_everything()) { - success = false; + success_ = false; return; } - expr_ = is_greater ? a.min() : a.max(); - result = is_greater ? b.max() : b.min(); + expr_ = is_greater_ ? a.min() : a.max(); + result_ = is_greater_ ? b.max() : b.min(); } IntSet DeduceBound(Expr v, Expr e, @@ -292,12 +315,12 @@ IntSet DeduceBound(Expr v, Expr e, const std::unordered_map& relax_map) { BoundDeducer d(v, e, hint_map, relax_map); d.Deduce(); - if (!d.success) return IntSet::nothing(); + if (!d.success_) return IntSet::nothing(); Expr min = neg_inf(), max = pos_inf(); - if (d.is_greater) { - min = d.result; + if (d.is_greater_) { + min = d.result_; } else { - max = d.result; + max = d.result_; } return IntSet::interval(min, max); } diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index ec50aef5c51e..dc6b80a31c7b 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -155,9 +155,10 @@ template<> inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const Type& rtype = a.type(); - // due to division and mod can have different modes - // only constant fold positive number where rule is fixed. - if (pa && pb && pa->value >= 0 && pb->value > 0) { + if (pa && pb) { + // due to division and mod can have different modes + // NOTE: this will assumes truc div. + CHECK_NE(pb->value, 0) << "Divide by zero"; return IntImm::make(rtype, pa->value / pb->value); } if (pa) { diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index bc8666e893b4..6cc829d07e88 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -155,7 +155,6 @@ Mutate_(const Add* op, const Expr& self) { TVM_TRY_REWRITE(max(x, y - z) + z, max(x + z, y)); TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z)); - TVM_TRY_REWRITE_IF(min(x, y + z * c1) + z * c2, min(x + z * c2, y), c1.Eval()->value == -c2.Eval()->value); TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y), diff --git a/src/arithmetic/stmt_simplify.cc b/src/arithmetic/stmt_simplify.cc index 01cb96ee1323..162cb1e5fd16 100644 --- a/src/arithmetic/stmt_simplify.cc +++ b/src/arithmetic/stmt_simplify.cc @@ -28,7 +28,6 @@ #include #include #include -#include "arithmetic/Simplify.h" namespace tvm { namespace arith { @@ -158,42 +157,18 @@ Expr CanonicalSimplify(Expr expr, Map vrange) { return analyzer.canonical_simplify(expr); } -template -T Simplify_(T a, Map vrange) { - using namespace HalideIR::Internal; - Scope rscope; +Expr Simplify(Expr expr, Map vrange) { + arith::Analyzer analyzer; for (auto kv : vrange) { - Range r = kv.second; - rscope.push( - kv.first.get(), - Interval(r->min, - simplify(r->min + r->extent - make_const(r->min.type(), 1)))); - } - return HalideIR::Internal::simplify(a, true, rscope); -} - - -Expr Simplify(Expr a, Map vrange) { - // Simplify top level reduce. - if (const Reduce* r = a.as()) { - Array new_source; - for (auto& e : r->source) { - new_source.push_back(Simplify_(e, vrange)); - } - Expr new_condition = Simplify_(r->condition, vrange); - if (r->source.same_as(new_source) && - r->condition.same_as(new_condition)) { - return a; - } else { - return Reduce::make( - r->combiner, new_source, r->axis, new_condition, r->value_index); - } + analyzer.Bind(kv.first, kv.second); } - return Simplify_(a, vrange); + expr = analyzer.Simplify(expr); + return expr; } -Stmt Simplify(Stmt a, Map vrange) { - return Simplify_(a, vrange); +Stmt Simplify(Stmt stmt, Map vrange) { + return arith::CanonicalStmtSimplifier().CanonicalSimplify( + stmt, vrange); } } // namespace ir } // namespace tvm diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 8c584c50b3c6..3e0615162a8f 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -26,6 +26,7 @@ #include #include #include +#include #include "../arithmetic/compute_expr.h" namespace tvm { diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc index 42b1331e3736..78f8c82d97db 100644 --- a/src/op/scan_op.cc +++ b/src/op/scan_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -80,7 +80,7 @@ Operation ScanOpNode::make(std::string name, for (size_t i = 0; i < init.size(); ++i) { CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype); CHECK_EQ(init[i]->dtype, update[i]->dtype); - CHECK(can_prove(init[i]->shape[0] == axis->dom->min)) + CHECK(prove_equal(init[i]->shape[0], axis->dom->min)) << "init.shape[0] need to match scan_axis.dom.min"; CHECK(prove_equal( state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent)) diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 0a5b7410f3cf..33dbaed83b69 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -466,8 +466,13 @@ Stmt LoopPartitioner::TryPartition(const Node* node, Stmt body, bool partition_thread_scope) { using namespace arith; + // include hint of var. + hint_map_.insert({var.get(), IntSet::interval(min, max)}); + PartitionFinder finder(var, hint_map_, relax_map_); finder.Visit(body); + + hint_map_.erase(var.get()); if (finder.partitions.empty()) return Stmt(); arith::IntervalSet for_interval(min, max); @@ -504,9 +509,9 @@ Stmt LoopPartitioner::TryPartition(const Node* node, bool pre_stmt_recurse = true; if (middle_interval_i->HasLowerBound()) { body_begin = ir::Simplify(middle_interval.min()); - if (!can_prove(body_begin == min)) { + if (!analyzer_.CanProve(body_begin == min)) { Expr cond = (body_begin - min >= 0); - if (!can_prove(cond)) { + if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond << ", when generating the pre doubt loop"; body_begin = Max::make(body_begin, min); @@ -529,10 +534,10 @@ Stmt LoopPartitioner::TryPartition(const Node* node, bool post_stmt_recurse = true; if (middle_interval_i->HasUpperBound()) { post_doubt_begin = ir::Simplify(middle_interval.max() + 1); - if (!can_prove(middle_interval.max() == max)) { + if (!analyzer_.CanProve(middle_interval.max() == max)) { // require the extent to be non-negative Expr cond = (max - post_doubt_begin + 1 >= 0); - if (!can_prove(cond)) { + if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond << ", when generating the post doubt loop"; post_doubt_begin = Min::make(post_doubt_begin, max); @@ -554,7 +559,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, // Generating code for middle subrange if (!partition_thread_scope) { Stmt mid_stmt; - if (!can_prove(body_begin >= post_doubt_begin)) { + if (!analyzer_.CanProve(body_begin >= post_doubt_begin)) { // [body_begin, post_doubt_begin) Stmt simplified_body = ConditionEliminator(cond_set, cond_value).Mutate(body); Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}}); @@ -576,8 +581,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node, s = AppendStmts(s, post_stmt); } else { Expr cond = const_true(); - if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin); - if (!can_prove(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin); + if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin); + if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin); s = ThreadPartitionInserter(cond_set, cond).Mutate(stmt); } s = ConvertSSA(s); @@ -587,7 +592,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) { const For *for_node = static_cast(node); CHECK(for_node); - if (can_prove(extent == make_const(Int(32), 1))) { + if (analyzer_.CanProve(extent == make_const(Int(32), 1))) { // If the loop extent is 1, do not create the loop anymore return Substitute(body, {{Var{for_node->loop_var}, make_const(Int(32), 0)}}); } else { diff --git a/src/pass/narrow_channel_access.cc b/src/pass/narrow_channel_access.cc index 731064edb012..57f3baf20e10 100644 --- a/src/pass/narrow_channel_access.cc +++ b/src/pass/narrow_channel_access.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -200,7 +200,7 @@ class ChannelAccessRewriter : public IRMutator { Expr base = linear_eq[1]; if (!is_zero(base)) return body; Expr left = ir::Simplify(adv_op->value - coeff * for_op->extent); - if (!can_prove(left >= 0)) return body; + if (!analyzer_.CanProve(left >= 0)) return body; // rewrite access index. ChannelAccessIndexRewriter rw( ch->handle_var.get(), var * coeff, read_access); @@ -233,6 +233,7 @@ class ChannelAccessRewriter : public IRMutator { return body; } + arith::Analyzer analyzer_; std::vector tasks_; }; diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 806a80ad4dc9..eba1cee8b7c7 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -606,7 +606,7 @@ class StoragePlanRewriter : public IRMutator { } // transform to alloc bytes auto type_bits = alloc_type.bits() * alloc_type.lanes(); - bool divided = can_prove(combo_size % type_bits == 0); + bool divided = analyzer_.CanProve(combo_size % type_bits == 0); combo_size = combo_size / type_bits; // round up for can not divided if (!divided) { @@ -920,6 +920,8 @@ class StoragePlanRewriter : public IRMutator { std::unordered_map alloc_map_; // The allocations std::vector > alloc_vec_; + // analyzer + arith::Analyzer analyzer_; }; // Turn alloc into vector alloc diff --git a/src/pass/vectorize_loop.cc b/src/pass/vectorize_loop.cc index 8c3d383c1529..a48e8b4d7e83 100644 --- a/src/pass/vectorize_loop.cc +++ b/src/pass/vectorize_loop.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -132,11 +133,11 @@ class Vectorizer : public IRMutator { if (lanes != 1) { const Ramp* b_ramp = b.as(); const Ramp* a_ramp = a.as(); - if (a_ramp && b.type().lanes() == 1 && can_prove(b > 0)) { + if (a_ramp && b.type().lanes() == 1 && analyzer_.CanProve(b > 0)) { return Ramp::make( a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes); } - if (b_ramp && a.type().lanes() == 1 && can_prove(a > 0)) { + if (b_ramp && a.type().lanes() == 1 && analyzer_.CanProve(a > 0)) { return Ramp::make( b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes); } @@ -186,7 +187,7 @@ class Vectorizer : public IRMutator { Expr stride = this->Mutate(op->stride); if (base.type().lanes() > 1 && stride.type().lanes() == 1) { const Ramp* base_ramp = base.as(); - if (can_prove(base_ramp->stride == stride * make_const(stride.type(), op->lanes))) { + if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.type(), op->lanes))) { return Ramp::make(base_ramp->base, stride, op->lanes * base_ramp->lanes); } } @@ -423,6 +424,8 @@ class Vectorizer : public IRMutator { } private: + // analyzer + arith::Analyzer analyzer_; // variable to be replaced Var var_; // the lanes. diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc index a7f974613aa1..0dc82abd9a8f 100644 --- a/src/schedule/message_passing.cc +++ b/src/schedule/message_passing.cc @@ -432,9 +432,9 @@ void PassDownBitMaskOr(const Stage& stage, */ void PassUpBoundCheck(const Stage& s, const Map& dom_map, - std::unordered_map* p_state) { + std::unordered_map* p_state, + arith::Analyzer* analyzer) { auto& state = *p_state; - using HalideIR::Internal::can_prove; for (size_t i = s->relations.size(); i != 0; --i) { IterVarRelation rel = s->relations[i - 1]; if (const SplitNode* s = rel.as()) { @@ -447,7 +447,7 @@ void PassUpBoundCheck(const Stage& s, if (outer || inner) { state[s->parent] = true; } else { - if (can_prove(dom_map.at(s->parent)->extent == factor * step)) { + if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step)) { state[s->parent] = false; } else { state[s->parent] = true; @@ -476,11 +476,13 @@ std::vector MakeBoundCheck( const std::unordered_map& value_map, bool skip_ivar_domain, const std::unordered_set& skip_iter) { + Analyzer analyzer; + std::unordered_map bound_state; for (IterVar iv : stage->leaf_iter_vars) { bound_state[iv] = false; } - PassUpBoundCheck(stage, dom_map, &bound_state); + PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer); std::vector preds; std::unordered_map iset_dmap; @@ -496,7 +498,7 @@ std::vector MakeBoundCheck( Range dom = dom_map.at(iv); Expr value = ComputeExpr(value_map.at(iv), dom->min); Expr vmax = EvalSet(value, iset_dmap).max(); - if (vmax.type() != value.type() || !can_prove(vmax < dom->extent)) { + if (vmax.type() != value.type() || !analyzer.CanProve(vmax < dom->extent)) { preds.emplace_back(value < dom->extent); } } @@ -511,10 +513,10 @@ std::vector MakeBoundCheck( Expr vmin = s.min(); Expr vmax = s.max(); // The range of `value` resides in [vmin, vmax] - if (vmin.type() != value.type() || !can_prove(vmin >= 0)) { + if (vmin.type() != value.type() || !analyzer.CanProve(vmin >= 0)) { preds.emplace_back(value >= 0); } - if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) { + if (vmax.type() != value.type() || !analyzer.CanProve(vmax < iv->dom->extent)) { preds.emplace_back(value < iv->dom->extent); } } diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index c5f1b1656dd5..760ed0f233f7 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -740,7 +740,7 @@ Array Schedule::rfactor(const Tensor& tensor, const Reduce* reduce = compute_op->body[idx].as(); CHECK(reduce) << "Can only rfactor non-inline reductions"; predicates.push_back(reduce->condition); - Expr predicate = likely(simplify(arith::ComputeReduce(predicates, Expr()))); + Expr predicate = likely(arith::ComputeReduce(predicates, Expr())); std::unordered_map vsub; diff --git a/tests/cpp/ir_simplify_test.cc b/tests/cpp/ir_simplify_test.cc index 35968f8524de..5a5dc03f0165 100644 --- a/tests/cpp/ir_simplify_test.cc +++ b/tests/cpp/ir_simplify_test.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,12 +21,6 @@ #include #include #include -#include - -TEST(IRSIMPLIFY, Basic) { - using namespace HalideIR::Internal; - simplify_test(); -} TEST(IRSIMPLIFY, MinMax) { auto x = tvm::var("x"); diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py index 7fe6f56edea7..d26b508ff262 100644 --- a/tests/python/unittest/test_arith_deduce_bound.py +++ b/tests/python/unittest/test_arith_deduce_bound.py @@ -16,6 +16,14 @@ # under the License. import tvm + +def assert_expr_equal(a, b): + res = tvm.ir_pass.Simplify(a - b) + equal = isinstance(res, tvm.expr.IntImm) and res.value == 0 + if not equal: + raise ValueError("{} and {} are not equal".format(a, b)) + + def test_deduce(): a = tvm.var('a') b = tvm.var('b') @@ -29,31 +37,34 @@ def test_deduce(): e0 = (-b)*a+c-d res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) - ans0 = ((d - c) /(b*-1)) - assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) + ans0 = ((d - c) /(b*-1) + (-1)) + assert_expr_equal(res0.max_value, ans0) # expression containing variable a is on rhs res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) - assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) + assert_expr_equal(res0.max_value, ans0) e0 = d*a+c-d res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) - ans0 = ((0-c)/d + 1) - assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) + ans0 = ((d-c)/d - 1) + assert_expr_equal(res0.max_value, ans0) # expression containing variable a is on rhs res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) - assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) + assert_expr_equal(res0.max_value, ans0) + e1 = (a*4+b < c) res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) - ans1 = (((c - b) + -1)/4) - assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1) + ans1 = (((c - b) + -1)/4 -1) + assert_expr_equal(res1.max_value, ans1) + # expression containing variable a is on rhs e1 = (c > a*4+b) res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) - assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1) + assert_expr_equal(res1.max_value, ans1) + e2 = (tvm.max(5, a * 4) < 0) res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) @@ -66,7 +77,6 @@ def test_deduce(): assert str(res2.max_value) == "neg_inf" assert str(res2.min_value) == "pos_inf" - e3 = (-b)+a*c-d res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) ans3 = 2/c+1 @@ -75,6 +85,7 @@ def test_deduce(): res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3) + def test_check(): a = tvm.var('a') b = tvm.var('b') diff --git a/tests/python/unittest/test_pass_basic.py b/tests/python/unittest/test_pass_basic.py index fc76c306731c..b05d75ab2d1e 100644 --- a/tests/python/unittest/test_pass_basic.py +++ b/tests/python/unittest/test_pass_basic.py @@ -24,9 +24,6 @@ def test_simplify(): assert(tvm.ir_pass.Equal(e2, x * 8)) e3 = tvm.ir_pass.Simplify(x - x / 3 * 3) assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3))) - let = tvm.make.Let(x, 1, x + 3) - e4 = tvm.ir_pass.Simplify(let) - assert(tvm.ir_pass.Equal(e4, 4)) def test_verify_ssa():