Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Arithmetic analysis #10403

Merged
merged 2 commits into from
Feb 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 100 additions & 75 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,9 @@ class IterMapRewriter : public ExprMutator {
return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
}

IterSumExpr RewriteIterConstraint(const PrimExpr& expr, const PrimExpr& predicate_induced_min,
const PrimExpr& predicate_induced_max) {
IterSumExpr RewriteIterConstraint(const PrimExpr& expr,
const Optional<PrimExpr>& predicate_induced_min,
const Optional<PrimExpr>& predicate_induced_max) {
return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_min,
predicate_induced_max);
}
Expand Down Expand Up @@ -494,16 +495,17 @@ class IterMapRewriter : public ExprMutator {
* \param predicate_induced_max Open upper bound from iter constraint, maybe undefined.
* \return The Normalized expression.
*/
IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, PrimExpr predicate_induced_min,
PrimExpr predicate_induced_max) {
IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, Optional<PrimExpr> predicate_induced_min,
Optional<PrimExpr> predicate_induced_max) {
// normalize to zero base
PrimExpr base = expr->base;
if (!is_zero(base)) {
expr.CopyOnWrite()->base = 0;
if (predicate_induced_min.defined()) predicate_induced_min = predicate_induced_min - base;
if (predicate_induced_max.defined()) predicate_induced_max = predicate_induced_max - base;
if (predicate_induced_min.defined())
predicate_induced_min = predicate_induced_min.value() - base;
if (predicate_induced_max.defined())
predicate_induced_max = predicate_induced_max.value() - base;
}
if (expr->args.size() < 1) return expr;
Optional<IterSumExpr> opt = TryFuseIters(expr);
ICHECK(!opt.defined() || opt.value()->args.size() == 1);
// scale should be 1
Expand All @@ -522,10 +524,10 @@ class IterMapRewriter : public ExprMutator {
PrimExpr iter_min = mark_offset;
PrimExpr iter_max = iter_min + mark->extent;
if (predicate_induced_min.defined()) {
iter_min = max(predicate_induced_min, iter_min);
iter_min = max(predicate_induced_min.value(), iter_min);
}
if (predicate_induced_max.defined()) {
iter_max = min(predicate_induced_max, iter_max);
iter_max = min(predicate_induced_max.value(), iter_max);
}
if (!is_zero(iter_min)) {
// structured form's offset should be updated
Expand All @@ -536,7 +538,6 @@ class IterMapRewriter : public ExprMutator {
}
mark.CopyOnWrite()->extent = iter_max - iter_min;
sum_fuse_map_[flattened_form] = {mark, iter_min};

// we need to note down the flattened form of constrained iterators
// to check the validity of constraints, see also CheckConstraints()
constrained_iters_flattened_.push_back(flattened_form);
Expand Down Expand Up @@ -771,14 +772,15 @@ class IterMapRewriter : public ExprMutator {
struct IterConstraint {
// The expr of the iter
PrimExpr iter;
// The expr of the lower_bound
PrimExpr lower_bound;
// The expr of the upper_bound
PrimExpr upper_bound;
// The expr of the lower_bound, maybe undefined
Optional<PrimExpr> lower_bound;
// The expr of the upper_bound, maybe undefined
Optional<PrimExpr> upper_bound;
// The size of the iter, which is the number of nodes
size_t expr_size = 0;

IterConstraint(PrimExpr iter, PrimExpr lower_bound, PrimExpr upper_bound, size_t size)
IterConstraint(PrimExpr iter, Optional<PrimExpr> lower_bound, Optional<PrimExpr> upper_bound,
size_t size)
: iter(std::move(iter)),
lower_bound(std::move(lower_bound)),
upper_bound(std::move(upper_bound)),
Expand All @@ -788,11 +790,12 @@ struct IterConstraint {
/*!
* \brief Split the predicate into `(a < b) && (c < d) && ...`
* \param pred The predicate to be split.
* \param input_iters The input iterators.
* \param result The result of predicate split.
* \return A list of IterConstraint, empty if the split failed.
*/
std::vector<IterConstraint> MatchBoundConstraints(PrimExpr pred,
const Map<Var, Range>& input_iters) {
std::vector<IterConstraint> result;
bool MatchBoundConstraints(PrimExpr pred, Map<Var, Range>* input_iters,
std::vector<IterConstraint>* result) {
arith::PVar<PrimExpr> lhs, rhs, rest;
for (;;) {
// try extract comparisions
Expand Down Expand Up @@ -821,78 +824,94 @@ std::vector<IterConstraint> MatchBoundConstraints(PrimExpr pred,
is_equal = true;
is_finish = true;
} else {
return std::vector<IterConstraint>();
return false;
}
PrimExpr lhs_expr = lhs.Eval();
PrimExpr rhs_expr = rhs.Eval();
// we only accept predicate of integers
if (!((lhs_expr->dtype.is_int() || lhs_expr->dtype.is_uint()) &&
(rhs_expr->dtype.is_int() || rhs_expr->dtype.is_uint()))) {
return std::vector<IterConstraint>();
return false;
}
// determine iter and bound, if we can not distinguish them simply,
// try divide (lhs - rhs) into itervar aware and itervar free parts
auto f_use_itervar = [&input_iters](const VarNode* v) {
return input_iters.count(GetRef<Var>(v));
return input_iters->count(GetRef<Var>(v));
};
bool bound_at_left;
if (is_const_int(lhs_expr) || !UsesVar(lhs_expr, f_use_itervar)) {
bound_at_left = true;
} else if (is_const_int(rhs_expr) || !UsesVar(rhs_expr, f_use_itervar)) {
bound_at_left = false;
} else {
bound_at_left = false; // accumulate bound to rhs
PrimExpr sum_parts = lhs_expr - rhs_expr;
lhs_expr = 0;
rhs_expr = 0;
std::function<void(const PrimExpr&, bool)> f_extract =
[&lhs_expr, &rhs_expr, f_use_itervar, &f_extract](const PrimExpr& part, bool sign) {
if (const AddNode* add = part.as<AddNode>()) {
f_extract(add->a, sign);
f_extract(add->b, sign);
} else if (const SubNode* sub = part.as<SubNode>()) {
f_extract(sub->a, sign);
f_extract(sub->b, !sign);
} else if (UsesVar(part, f_use_itervar)) {
lhs_expr = sign ? lhs_expr + part : lhs_expr - part;
} else {
rhs_expr = sign ? rhs_expr - part : rhs_expr + part;
}
};
f_extract(sum_parts, true);
arith::Analyzer analyzer;
lhs_expr = analyzer.Simplify(lhs_expr);
rhs_expr = analyzer.Simplify(rhs_expr);
}
PrimExpr lower_bound, upper_bound, iter;
if (is_greater) {
if (bound_at_left) {
// bound > iter
upper_bound = is_equal ? lhs_expr + 1 : lhs_expr;
iter = rhs_expr;
if (UsesVar(lhs_expr, f_use_itervar) || UsesVar(rhs_expr, f_use_itervar)) {
// At least it uses one input iter
if (is_const_int(lhs_expr) || !UsesVar(lhs_expr, f_use_itervar)) {
bound_at_left = true;
} else if (is_const_int(rhs_expr) || !UsesVar(rhs_expr, f_use_itervar)) {
bound_at_left = false;
} else {
// iter > bound
lower_bound = is_equal ? rhs_expr : rhs_expr + 1;
iter = lhs_expr;
bound_at_left = false; // accumulate bound to rhs
PrimExpr sum_parts = lhs_expr - rhs_expr;
lhs_expr = 0;
rhs_expr = 0;
std::function<void(const PrimExpr&, bool)> f_extract =
[&lhs_expr, &rhs_expr, f_use_itervar, &f_extract](const PrimExpr& part, bool sign) {
if (const AddNode* add = part.as<AddNode>()) {
f_extract(add->a, sign);
f_extract(add->b, sign);
} else if (const SubNode* sub = part.as<SubNode>()) {
f_extract(sub->a, sign);
f_extract(sub->b, !sign);
} else if (UsesVar(part, f_use_itervar)) {
lhs_expr = sign ? lhs_expr + part : lhs_expr - part;
} else {
rhs_expr = sign ? rhs_expr - part : rhs_expr + part;
}
};
f_extract(sum_parts, true);
arith::Analyzer analyzer;
lhs_expr = analyzer.Simplify(lhs_expr);
rhs_expr = analyzer.Simplify(rhs_expr);
}
} else {
if (bound_at_left) {
// bound < iter
lower_bound = is_equal ? lhs_expr : lhs_expr + 1;
iter = rhs_expr;
Optional<PrimExpr> lower_bound = NullOpt, upper_bound = NullOpt;
PrimExpr iter;
if (is_greater) {
if (bound_at_left) {
// bound > iter / bound >= iter
upper_bound = is_equal ? lhs_expr + 1 : lhs_expr;
iter = rhs_expr;
} else {
// iter > bound / iter >= bound
lower_bound = is_equal ? rhs_expr : rhs_expr + 1;
iter = lhs_expr;
}
} else {
// iter < bound
upper_bound = is_equal ? rhs_expr + 1 : rhs_expr;
iter = lhs_expr;
if (bound_at_left) {
// bound < iter / bound <= iter
lower_bound = is_equal ? lhs_expr : lhs_expr + 1;
iter = rhs_expr;
} else {
// iter < bound / iter <= bound
upper_bound = is_equal ? rhs_expr + 1 : rhs_expr;
iter = lhs_expr;
}
}
// If it is a predicate for a single input iter
if (const auto* var_ptr = iter.as<VarNode>()) {
auto it = input_iters->find(GetRef<Var>(var_ptr));
if (it != input_iters->end()) {
PrimExpr iter_min = (*it).second->min;
PrimExpr iter_max = (*it).second->min + (*it).second->extent;
if (lower_bound.defined()) iter_min = max(iter_min, lower_bound.value());
if (upper_bound.defined()) iter_max = min(iter_max, upper_bound.value());
input_iters->Set(GetRef<Var>(var_ptr), Range(iter_min, iter_max));
}
} else {
result->emplace_back(iter, lower_bound, upper_bound, 0);
}
}
result.emplace_back(iter, lower_bound, upper_bound, 0);
if (is_finish) {
break;
}
pred = rest.Eval();
}
return result;
return true;
}

bool IterRangeSanityCheck(const Map<Var, Range>& iter_ranges) {
Expand All @@ -912,13 +931,14 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
// - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
// - Step1: IterIndependenceChecker checks if the iterator are independent.
if (!IterRangeSanityCheck(input_iters)) return Array<IterSumExpr>();
std::vector<IterConstraint> constraints = MatchBoundConstraints(predicate, input_iters);
if (!is_one(predicate) && constraints.empty()) {
Map<Var, Range> constrained_input_iters = input_iters;
std::vector<IterConstraint> constraints;
if (!is_one(predicate) &&
!MatchBoundConstraints(predicate, &constrained_input_iters, &constraints)) {
diag_ctx.Emit(Diagnostic::Error(predicate->span)
<< "Fail to collect constraints from iteration predicate: " << predicate);
return Array<IterSumExpr>();
}

// We have to make sure when we visit an iterator, all the constraints related with its successors
// in the iter var graph has been visited, where the expression of this iterator will contain the
// expression of its successor, so we sort them by their sizes.
Expand All @@ -930,10 +950,11 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
constraints.begin(), constraints.end(),
[](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; });

IterMapRewriter rewriter(analyzer, input_iters, diag_ctx);
IterMapRewriter rewriter(analyzer, constrained_input_iters, diag_ctx);
// Step0.0: rewrite constraints in the order from size-small ones to size-big ones
for (const IterConstraint& constraint : constraints) {
rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound, constraint.upper_bound);
auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound,
constraint.upper_bound);
if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
}
if (!rewriter.CheckConstraints()) {
Expand All @@ -945,7 +966,10 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
Array<IterSumExpr> results;
for (PrimExpr value : indices) {
results.push_back(rewriter.Rewrite(value));
if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
if (rewriter.unresolved_count() != 0) {
diag_ctx.Emit(Diagnostic::Error(predicate->span) << "Affine mapping detection failed");
return Array<IterSumExpr>();
}
}
// Step1: IterIndependenceChecker checks if the iterator are independent.
if (!rewriter.CheckMapping(results, require_bijective)) {
Expand Down Expand Up @@ -1306,7 +1330,8 @@ class IterMapToExprNormalizer : public ExprMutator {
} else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor * expr->extent)) {
return floordiv(source, expr->lower_factor) * expr->scale;
} else {
return floormod(floordiv(source, expr->lower_factor), expr->extent) * expr->scale;
return floordiv(floormod(source, expr->lower_factor * expr->extent), expr->lower_factor) *
expr->scale;
}
}

Expand Down
12 changes: 12 additions & 0 deletions src/arith/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,18 @@ class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(co
return Everything();
}

Entry VisitExpr_(const FloorModNode* op) final {
Entry b = VisitExpr(op->b);
if (b.is_const()) {
int64_t c2 = b.base;
ICHECK(c2 != 0) << "MathError: the divisor is 0";
Entry a = VisitExpr(op->a);
int64_t coeff = ZeroAwareGCD(a.coeff, c2);
return Entry(coeff, a.base % c2);
}
return Everything();
}

Entry VisitExpr_(const MinNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Expand Down
12 changes: 12 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
TVM_TRY_REWRITE(truncdiv(x, c1) * c1 + truncmod(x, c1), x);
// floor div
TVM_TRY_REWRITE(floordiv(x, c1) * c1 + floormod(x, c1), x);
TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2),
c2.Eval()->value > 0);

// canonicalization rule
// will try rewrite again after canonicalization.
Expand Down Expand Up @@ -785,6 +787,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), x * floordiv(c1, c2) + floordiv(y, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), floordiv(x, floordiv(c2, c1)),
c1.Eval()->value > 0 && c2.Eval()->value > 0 &&
c2.Eval()->value % c1.Eval()->value == 0 &&
CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0));

TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

Expand All @@ -794,6 +801,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(y, c2) + x * floordiv(c1, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(x, floordiv(c2, c1)),
c1.Eval()->value > 0 && c2.Eval()->value > 0 &&
c2.Eval()->value % c1.Eval()->value == 0 &&
CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0));

TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), min(floordiv(y, c2), x * floordiv(c1, c2)),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

Expand Down
16 changes: 10 additions & 6 deletions src/tir/schedule/primitive/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
for (int i = 0; i < n; i++) {
const PrimExpr& factor = factors[i];
Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i));
substitute_value = substitute_value * factor + var;
if (!is_one(factor)) substitute_value = substitute_value * factor + var;
analyzer.Bind(var, Range::FromMinExtent(0, factor));
new_loop_vars.emplace_back(std::move(var));
}
Expand Down Expand Up @@ -505,11 +505,14 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix);
Array<PrimExpr> substitute_value;
substitute_value.resize(loops.size());
PrimExpr tot = fused_var;
for (int i = static_cast<int>(loops.size()) - 1; i >= 0; i--) {
substitute_value.Set(i, floormod(tot, loops[i]->extent));
tot = floordiv(tot, loops[i]->extent);
}
PrimExpr lower = 1;
for (int i = static_cast<int>(loops.size()) - 1; i > 0; i--) {
substitute_value.Set(i, is_one(loops[i]->extent)
? 0
: floordiv(floormod(fused_var, lower * loops[i]->extent), lower));
lower = lower * loops[i]->extent;
}
substitute_value.Set(0, is_one(loops[0]->extent) ? 0 : floordiv(fused_var, lower));
Stmt new_stmt = loops.back()->body;
Map<Block, Block> opaque_block_reuse;
auto f_substitute = [&](const Var& v) -> Optional<PrimExpr> {
Expand All @@ -534,6 +537,7 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse);
return self->stmt2ref.at(new_stmt.get());
}

/*!
* \brief Collect an array of loop srefs into a set
* \param self The schedule state
Expand Down
5 changes: 4 additions & 1 deletion tests/python/unittest/test_arith_intset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ def test_mod():
ck.verify(
flm(y, 8),
{y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)},
(x * 4 - 8 * fld(x * 4, 8), x * 4 - 8 * fld(x * 4, 8) + 3),
(
z * 8 + x * 4 - 8 * fld(z * 8 + x * 4, 8),
z * 8 + x * 4 + 3 - 8 * fld(z * 8 + x * 4, 8),
),
)
ck1 = IntSetChecker()
ck1.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 2))
Expand Down
Loading