Skip to content

Commit

Permalink
[ARITH][SCHEDULE] Update schedule to use the new analyzer (#3466)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Jul 1, 2019
1 parent dfc4f97 commit 79e071c
Show file tree
Hide file tree
Showing 14 changed files with 152 additions and 40 deletions.
18 changes: 18 additions & 0 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,24 @@ class Analyzer {
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
/*!
* \brief Whether can we prove condition.
*
* \param cond The expression to be proved.
* \return The result.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProve(const Expr& cond);
/*!
* \brief Simplify expr.
*
* \param expr The expression to be simplified.
* \return The result.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
Expr Simplify(const Expr& expr);
};

//-----------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
Expand Down
28 changes: 28 additions & 0 deletions src/arithmetic/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
#include <tvm/ir.h>
#include <tvm/arithmetic.h>
#include <tvm/expr_operator.h>

namespace tvm {
namespace arith {
Expand All @@ -49,8 +50,13 @@ void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
}

void Analyzer::Bind(const VarExpr& v, const Range& range) {
CHECK(range.defined());
Var var(v.node_);
this->const_int_bound.Bind(var, range);
if (is_one(range->extent)) {
this->rewrite_simplify.Update(var, range->min);
this->canonical_simplify.Update(var, range->min);
}
// skip modular_set
// skip rewrite simplify
}
Expand Down Expand Up @@ -82,5 +88,27 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
return false;
}

bool Analyzer::CanProve(const Expr& expr) {
if (const auto* ptr = expr.as<ir::UIntImm>()) {
return ptr->value != 0;
}
auto res = this->rewrite_simplify(expr);
if (const auto* ptr = res.as<ir::UIntImm>()) {
return ptr->value != 0;
}
res = this->canonical_simplify(expr);
if (const auto* ptr = res.as<ir::UIntImm>()) {
return ptr->value != 0;
}
return false;
}

Expr Analyzer::Simplify(const Expr& expr) {
if (is_const(expr)) return expr;
auto res = this->rewrite_simplify(expr);
res = this->canonical_simplify(res);
return res;
}

} // namespace arith
} // namespace tvm
1 change: 1 addition & 0 deletions src/arithmetic/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ class SumExprNode : public CanonicalExprNode {
rhs.CopyOnWrite()->scale += lhs->scale;
lhs.CopyOnWrite()->scale = 0;
} else if (lhs->lower_factor == rhs->upper_factor &&
rhs->scale != 0 &&
lhs->scale % rhs->scale == 0 &&
lhs->lower_factor == (lhs->scale / rhs->scale) * rhs->lower_factor) {
// Rules used in the proof:
Expand Down
25 changes: 20 additions & 5 deletions src/arithmetic/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,23 @@ ConstIntBound::ConstIntBound(
node_ = std::move(node);
}

inline void PrintBoundValue(std::ostream& os, int64_t val) {
if (val == ConstIntBound::kPosInf) {
os << "pos_inf";
} else if (val == ConstIntBound::kNegInf) {
os << "neg_inf";
} else {
os << val;
}
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ConstIntBoundNode *op, IRPrinter *p) {
p->stream << "ConstIntBound"
<< "[" << op->min_value << ", "
<< op->max_value << ']';
.set_dispatch<ConstIntBoundNode>([](const ConstIntBoundNode* op, IRPrinter* p) {
p->stream << "ConstIntBound[";
PrintBoundValue(p->stream, op->min_value);
p->stream << ',';
PrintBoundValue(p->stream, op->max_value);
p->stream << ']';
});

// internal entry for const int bound
Expand Down Expand Up @@ -95,7 +107,10 @@ class ConstIntBoundAnalyzer::Impl :
auto it = var_map_.find(var);
if (it != var_map_.end()) {
CHECK(it->second == info)
<< "var \'" << var << "\' already updated.";
<< "Trying to update var \'" << var << "\'"
<< " with a different const bound: "
<< "original=" << ConstIntBound(it->second.min_value, it->second.max_value)
<< ", new=" << ConstIntBound(info.min_value, info.max_value);
}
}
var_map_[var] = info;
Expand Down
19 changes: 18 additions & 1 deletion src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,14 @@ TryCompare(const Expr& x, int64_t val) {
void RewriteSimplifier::Impl::
Update(const Var& var, const Expr& info, bool override) {
if (!override) {
CHECK(!var_map_.count(var));
auto it = var_map_.find(var);
if (it != var_map_.end()) {
CHECK(Equal(it->second, info))
<< "Trying to update var \'" << var << "\'"
<< " with a different value: "
<< "original=" << it->second
<< ", new=" << info;
}
}
var_map_[var] = info;
}
Expand Down Expand Up @@ -199,6 +206,9 @@ Mutate_(const Add* op, const Expr& self) {
TVM_TRY_RECURSIVE_REWRITE(x + c1 + y, (x + y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x + (c1 + y), (x + y) + c1);
TVM_TRY_RECURSIVE_REWRITE((y % c1) + x * c1, x * c1 + (y % c1));

TVM_TRY_RECURSIVE_REWRITE(x + max(y, z), max(y, z) + x);
TVM_TRY_RECURSIVE_REWRITE(x + min(y, z), min(y, z) + x);
}

// condition rules.
Expand Down Expand Up @@ -477,6 +487,10 @@ Mutate_(const Div* op, const Expr& self) {
}
}

TVM_TRY_REWRITE(x / x, OneWithTypeLike(x));
TVM_TRY_REWRITE(x * c1 / x, c1);
TVM_TRY_REWRITE(c1 * x / x, c1);

// Rules involving 2-operands.
TVM_TRY_REWRITE_IF((x * c1 + y) / c2, x * (c1 / c2) + y / c2,
c1.Eval()->value >= 0 &&
Expand Down Expand Up @@ -684,6 +698,9 @@ Mutate_(const Mod* op, const Expr& self) {
if (mod->coeff % c1val == 0 &&
CanProveGreaterEqual(x.Eval(), 0)) {
return (mod->base % c1).Eval();
} else if (mod->coeff % c1val == 0 &&
mod->base % c1val == 0) {
return make_zero(ret.type());
}
}
}
Expand Down
9 changes: 7 additions & 2 deletions src/arithmetic/rewrite_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -121,6 +121,11 @@ class RewriteSimplifier::Impl : public IRMutator {
PConstWithTypeLike<TA> ZeroWithTypeLike(const Pattern<TA>& pattern) {
return PConstWithTypeLike<TA>(pattern.derived(), 0);
}

template<typename TA>
PConstWithTypeLike<TA> OneWithTypeLike(const Pattern<TA>& pattern) {
return PConstWithTypeLike<TA>(pattern.derived(), 1);
}
};


Expand Down
22 changes: 17 additions & 5 deletions src/schedule/bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -213,6 +213,8 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
// Prepare context
GraphContext ctx;
Array<Operation> roots;
arith::Analyzer analyzer;

for (Operation op : sch->outputs) {
roots.push_back(sch->stage_map[op]->op);
}
Expand All @@ -233,16 +235,26 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
for (size_t i = sch->stages.size(); i != 0; --i) {
const Stage& stage = sch->stages[i - 1];
InferRootBound(stage, ctx, &ret);

// bind bound of root iter vars.
for (auto iv : stage->op->root_iter_vars()) {
auto it = ret.find(iv);
if (it != ret.end()) {
analyzer.Bind(iv->var, it->second);
}
}

// pass down to get bound of all iter vars.
PassDownDomain(stage, &ret);
PassDownDomain(stage, &ret, &analyzer);
for (IterVar iv : stage->env_threads) {
CHECK(iv->dom.defined());
ret[iv] = iv->dom;
}
}
for (auto& p : ret) {
ret[p.first] = Range::make_by_min_extent(ir::Simplify(p.second->min),
ir::Simplify(p.second->extent));
ret[p.first] = Range::make_by_min_extent(
analyzer.Simplify(p.second->min),
analyzer.Simplify(p.second->extent));
}
return Map<IterVar, Range>(ret.begin(), ret.end());
}
Expand Down
41 changes: 20 additions & 21 deletions src/schedule/message_passing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -34,24 +34,17 @@ namespace schedule {
using namespace ir;
using namespace arith;

// result = ceil((a / b)), both a and b are positive integer
inline Expr DivCeil(Expr a, Expr b) {
return ir::Simplify((a + b - 1) / b);
}

inline bool prove_equal(Expr lhs, Expr rhs) {
return is_zero(ir::Simplify(lhs - rhs));
}

void Update(std::unordered_map<IterVar, Range>* p_state,
const IterVar& iv,
Range r) {
Range r,
Analyzer* analyzer) {
auto it = p_state->find(iv);
if (it == p_state->end()) {
(*p_state)[iv] = r;
analyzer->Bind(iv->var, r);
} else {
bool match = is_zero(it->second->min);
if (!prove_equal(r->extent, it->second->extent)) match = false;
bool match = is_zero(it->second->min) &&
analyzer->CanProve(r->extent - it->second->extent == 0);
CHECK(match)
<< iv
<< " domain already inferred,"
Expand All @@ -62,7 +55,12 @@ void Update(std::unordered_map<IterVar, Range>* p_state,

void PassDownDomain(const Stage& stage,
std::unordered_map<IterVar, Range>* p_state,
arith::Analyzer* actx,
bool allow_missing) {
auto ceil_div = [actx](Expr a, Expr b) {
return actx->Simplify((a + (b - 1)) / b);
};

auto& state = *p_state;
// forwar iteration on relations
for (IterVarRelation rel : stage->relations) {
Expand All @@ -74,15 +72,16 @@ void PassDownDomain(const Stage& stage,
CHECK(!state.count(r->inner));
const Range& range_parent = state.at(r->parent);
if (r->factor.defined()) {
Update(p_state, r->inner, Range::make_by_min_extent(0, r->factor));
Update(p_state, r->inner,
Range::make_by_min_extent(0, r->factor), actx);
Update(p_state, r->outer,
Range::make_by_min_extent(
0, DivCeil(range_parent->extent, r->factor)));
0, ceil_div(range_parent->extent, r->factor)), actx);
} else {
Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts));
Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), actx);
Update(p_state, r->inner,
Range::make_by_min_extent(
0, DivCeil(range_parent->extent, r->nparts)));
0, ceil_div(range_parent->extent, r->nparts)), actx);
}
} else if (const FuseNode* r = rel.as<FuseNode>()) {
if (!state.count(r->outer) || !state.count(r->inner)) {
Expand All @@ -100,9 +99,9 @@ void PassDownDomain(const Stage& stage,
}
Update(p_state, r->rebased,
Range::make_by_min_extent(
0, state.at(r->parent)->extent));
0, state.at(r->parent)->extent), actx);
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
Update(p_state, s->iter, Range::make_by_min_extent(0, 1));
Update(p_state, s->iter, Range::make_by_min_extent(0, 1), actx);
} else {
LOG(FATAL) << "unknown relation type";
}
Expand All @@ -111,7 +110,7 @@ void PassDownDomain(const Stage& stage,
for (auto kv : stage->iter_var_attrs) {
if (kv.second->bind_thread.defined()) {
CHECK(state.count(kv.first));
Update(p_state, kv.second->bind_thread, state.at(kv.first));
Update(p_state, kv.second->bind_thread, state.at(kv.first), actx);
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/schedule/message_passing.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -43,11 +43,13 @@ namespace schedule {
*
* \param stage The stage to operate on.
* \param p_state The state of the message passing.
* \param analyzer Analyzer context, storing information about bounds in p_state.
* \param allow_missing Whether allow missing value.
*/
void PassDownDomain(
const Stage& stage,
std::unordered_map<IterVar, Range>* p_state,
arith::Analyzer* analyzer,
bool allow_missing = false);

/*!
Expand Down
Loading

0 comments on commit 79e071c

Please sign in to comment.