Skip to content

Commit

Permalink
Squashed apache#3368 (WIP): [ARITH] Migrate simplifier to new infra
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrechanik-h committed Jun 19, 2019
1 parent ddfd8d6 commit d615928
Show file tree
Hide file tree
Showing 15 changed files with 79 additions and 75 deletions.
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ file(GLOB_RECURSE NNVM_COMPILER_SRCS
file(GLOB TOPI_SRCS
topi/src/*.cc
)
file(GLOB_RECURSE HALIDEIR_SRCS 3rdparty/HalideIR/src/*.cpp)
file(GLOB_RECURSE HALIDEIR_SRCS
3rdparty/HalideIR/src/base/*.cpp
3rdparty/HalideIR/src/ir/*.cpp
3rdparty/HalideIR/src/tvm/*.cpp
)
list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS})
file(GLOB RUNTIME_SRCS
src/runtime/*.cc
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,14 @@ class Analyzer {
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
/*!
* \brief Whether can we prove condition.
*
* \param cond The expression to be proved.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProve(const Expr& cond);
};

//-----------------------------------------------
Expand Down
1 change: 0 additions & 1 deletion include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_

#include <arithmetic/Simplify.h>
#include <unordered_map>
#include <unordered_set>
#include <vector>
Expand Down
11 changes: 11 additions & 0 deletions src/arithmetic/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,16 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
return false;
}

bool Analyzer::CanProve(const Expr& expr) {
if (const auto* ptr = expr.as<ir::UIntImm>()) {
return ptr->value != 0;
}
auto res = this->rewrite_simplify(expr);
if (const auto* ptr = res.as<ir::UIntImm>()) {
return ptr->value != 0;
}
return false;
}

} // namespace arith
} // namespace tvm
4 changes: 3 additions & 1 deletion src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class BoundDeducer: public IRVisitor {
}

// always use relax bound
bool divided = can_prove(result % operand == 0);
bool divided = analyzer_.CanProve(result % operand == 0);
result = result / operand;
// since system will round down when not divided
// eg. 2/4 -> 0; -2/4 -> -1
Expand Down Expand Up @@ -180,6 +180,8 @@ class BoundDeducer: public IRVisitor {
ExprIntSetMap expr_map_;
std::vector<const Node*> path_;
size_t iter_{0};
// internal analzyer
Analyzer analyzer_;
};

class BoundDeduceInputChecker: public IRVisitor {
Expand Down
3 changes: 3 additions & 0 deletions src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,9 @@ Mutate_(const Mod* op, const Expr& self) {
if (mod->coeff % c1val == 0 &&
CanProveGreaterEqual(x.Eval(), 0)) {
return (mod->base % c1).Eval();
} else if (mod->coeff % c1val == 0 &&
mod->base % c1val == 0) {
return make_zero(ret.type());
}
}
}
Expand Down
40 changes: 7 additions & 33 deletions src/arithmetic/stmt_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
#include <tvm/arithmetic.h>
#include "arithmetic/Simplify.h"

namespace tvm {
namespace arith {
Expand Down Expand Up @@ -145,42 +144,17 @@ Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
return analyzer.canonical_simplify(expr);
}

template<typename T>
T Simplify_(T a, Map<Var, Range> vrange) {
using namespace HalideIR::Internal;
Scope<Interval> rscope;
Expr Simplify(Expr expr, Map<Var, Range> vrange) {
arith::Analyzer analyzer;
for (auto kv : vrange) {
Range r = kv.second;
rscope.push(
kv.first.get(),
Interval(r->min,
simplify(r->min + r->extent - make_const(r->min.type(), 1))));
}
return HalideIR::Internal::simplify(a, true, rscope);
}


Expr Simplify(Expr a, Map<Var, Range> vrange) {
// Simplify top level reduce.
if (const Reduce* r = a.as<Reduce>()) {
Array<Expr> new_source;
for (auto& e : r->source) {
new_source.push_back(Simplify_(e, vrange));
}
Expr new_condition = Simplify_(r->condition, vrange);
if (r->source.same_as(new_source) &&
r->condition.same_as(new_condition)) {
return a;
} else {
return Reduce::make(
r->combiner, new_source, r->axis, new_condition, r->value_index);
}
analyzer.Bind(kv.first, kv.second);
}
return Simplify_(a, vrange);
return analyzer.canonical_simplify(expr);
}

Stmt Simplify(Stmt a, Map<Var, Range> vrange) {
return Simplify_(a, vrange);
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
return arith::CanonicalStmtSimplifier().CanonicalSimplify(
stmt, vrange);
}
} // namespace ir
} // namespace tvm
5 changes: 3 additions & 2 deletions src/lang/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -26,6 +26,7 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <iterator>
#include <stack>
#include "../arithmetic/compute_expr.h"

namespace tvm {
Expand Down
6 changes: 3 additions & 3 deletions src/op/scan_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -80,7 +80,7 @@ Operation ScanOpNode::make(std::string name,
for (size_t i = 0; i < init.size(); ++i) {
CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype);
CHECK_EQ(init[i]->dtype, update[i]->dtype);
CHECK(can_prove(init[i]->shape[0] == axis->dom->min))
CHECK(prove_equal(init[i]->shape[0], axis->dom->min))
<< "init.shape[0] need to match scan_axis.dom.min";
CHECK(prove_equal(
state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
Expand Down
16 changes: 8 additions & 8 deletions src/pass/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -504,9 +504,9 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
bool pre_stmt_recurse = true;
if (middle_interval_i->HasLowerBound()) {
body_begin = ir::Simplify(middle_interval.min());
if (!can_prove(body_begin == min)) {
if (!analyzer_.CanProve(body_begin == min)) {
Expr cond = (body_begin - min >= 0);
if (!can_prove(cond)) {
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the pre doubt loop";
body_begin = Max::make(body_begin, min);
Expand All @@ -529,10 +529,10 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
bool post_stmt_recurse = true;
if (middle_interval_i->HasUpperBound()) {
post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
if (!can_prove(middle_interval.max() == max)) {
if (!analyzer_.CanProve(middle_interval.max() == max)) {
// require the extent to be non-negative
Expr cond = (max - post_doubt_begin + 1 >= 0);
if (!can_prove(cond)) {
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the post doubt loop";
post_doubt_begin = Min::make(post_doubt_begin, max);
Expand All @@ -554,7 +554,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
// Generating code for middle subrange
if (!partition_thread_scope) {
Stmt mid_stmt;
if (!can_prove(body_begin >= post_doubt_begin)) {
if (!analyzer_.CanProve(body_begin >= post_doubt_begin)) {
// [body_begin, post_doubt_begin)
Stmt simplified_body = ConditionEliminator(cond_set, cond_value).Mutate(body);
Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
Expand All @@ -576,8 +576,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
s = AppendStmts(s, post_stmt);
} else {
Expr cond = const_true();
if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin);
if (!can_prove(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin);
if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin);
if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin);
s = ThreadPartitionInserter(cond_set, cond).Mutate(stmt);
}
s = ConvertSSA(s);
Expand All @@ -587,7 +587,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) {
const For *for_node = static_cast<const For*>(node);
CHECK(for_node);
if (can_prove(extent == make_const(Int(32), 1))) {
if (analyzer_.CanProve(extent == make_const(Int(32), 1))) {
// If the loop extent is 1, do not create the loop anymore
return Substitute(body, {{Var{for_node->loop_var}, make_const(Int(32), 0)}});
} else {
Expand Down
7 changes: 4 additions & 3 deletions src/pass/narrow_channel_access.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -200,7 +200,7 @@ class ChannelAccessRewriter : public IRMutator {
Expr base = linear_eq[1];
if (!is_zero(base)) return body;
Expr left = ir::Simplify(adv_op->value - coeff * for_op->extent);
if (!can_prove(left >= 0)) return body;
if (!analyzer_.CanProve(left >= 0)) return body;
// rewrite access index.
ChannelAccessIndexRewriter rw(
ch->handle_var.get(), var * coeff, read_access);
Expand Down Expand Up @@ -233,6 +233,7 @@ class ChannelAccessRewriter : public IRMutator {
return body;
}

arith::Analyzer analyzer_;
std::vector<RewriteEntry> tasks_;
};

Expand Down
8 changes: 5 additions & 3 deletions src/pass/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -606,7 +606,7 @@ class StoragePlanRewriter : public IRMutator {
}
// transform to alloc bytes
auto type_bits = alloc_type.bits() * alloc_type.lanes();
bool divided = can_prove(combo_size % type_bits == 0);
bool divided = analyzer_.CanProve(combo_size % type_bits == 0);
combo_size = combo_size / type_bits;
// round up for can not divided
if (!divided) {
Expand Down Expand Up @@ -920,6 +920,8 @@ class StoragePlanRewriter : public IRMutator {
std::unordered_map<const Variable*, StorageEntry*> alloc_map_;
// The allocations
std::vector<std::unique_ptr<StorageEntry> > alloc_vec_;
// analyzer
arith::Analyzer analyzer_;
};

// Turn alloc into vector alloc
Expand Down
9 changes: 6 additions & 3 deletions src/pass/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/arithmetic.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
Expand Down Expand Up @@ -132,11 +133,11 @@ class Vectorizer : public IRMutator {
if (lanes != 1) {
const Ramp* b_ramp = b.as<Ramp>();
const Ramp* a_ramp = a.as<Ramp>();
if (a_ramp && b.type().lanes() == 1 && can_prove(b > 0)) {
if (a_ramp && b.type().lanes() == 1 && analyzer_.CanProve(b > 0)) {
return Ramp::make(
a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes);
}
if (b_ramp && a.type().lanes() == 1 && can_prove(a > 0)) {
if (b_ramp && a.type().lanes() == 1 && analyzer_.CanProve(a > 0)) {
return Ramp::make(
b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes);
}
Expand Down Expand Up @@ -186,7 +187,7 @@ class Vectorizer : public IRMutator {
Expr stride = this->Mutate(op->stride);
if (base.type().lanes() > 1 && stride.type().lanes() == 1) {
const Ramp* base_ramp = base.as<Ramp>();
if (can_prove(base_ramp->stride == stride * make_const(stride.type(), op->lanes))) {
if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.type(), op->lanes))) {
return Ramp::make(base_ramp->base, stride, op->lanes * base_ramp->lanes);
}
}
Expand Down Expand Up @@ -423,6 +424,8 @@ class Vectorizer : public IRMutator {
}

private:
// analyzer
arith::Analyzer analyzer_;
// variable to be replaced
Var var_;
// the lanes.
Expand Down
Loading

0 comments on commit d615928

Please sign in to comment.