Skip to content

Commit

Permalink
[TIR][Analysis][Arith] Implement basic data-flow analysis (#13130)
Browse files Browse the repository at this point in the history
An optional utility to track known buffer values through a TIR PrimFunc, allowing simplifications based on known values.

* Updated documentation following review comments

* Unit tests for rewrites, including negative numerators for div/mod

* Fix linting error

* Added brief description on what a control graph is

* Updates based on review comments

* Updated T.assume(expr) to T.evaluate(T.assume(expr))
  • Loading branch information
Lunderberg authored Nov 16, 2022
1 parent b4d4b82 commit a80cdc2
Show file tree
Hide file tree
Showing 14 changed files with 3,403 additions and 42 deletions.
31 changes: 31 additions & 0 deletions include/tvm/tir/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/packed_func.h>

#include <ostream>

namespace tvm {
namespace tir {
/*!
Expand Down Expand Up @@ -92,6 +94,35 @@ enum class CallEffectKind : int {
kControlJump = 6,
};

inline std::ostream& operator<<(std::ostream& os, CallEffectKind side_effect) {
switch (side_effect) {
case CallEffectKind::kExprAnnotation:
return os << "kExprAnnotation";

case CallEffectKind::kPure:
return os << "kPure";

case CallEffectKind::kReadState:
return os << "kReadState";

case CallEffectKind::kUpdateState:
return os << "kUpdateState";

case CallEffectKind::kSpecialCallArg:
return os << "kSpecialCallArg";

case CallEffectKind::kEmbedInfo:
return os << "kEmbedInfo";

case CallEffectKind::kControlJump:
return os << "kControlJump";

default:
LOG(FATAL) << "Unknown CallEffectKind: " << static_cast<int>(side_effect);
return os;
}
}

/*! \brief Use integer to record the kind. */
using TCallEffectKind = Integer;

Expand Down
26 changes: 20 additions & 6 deletions src/arith/conjunctive_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,14 @@ void AndOfOrs::TrySimplifyOr(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) {
Key& a = *a_ptr;
Key& b = *b_ptr;
PrimExpr joint = GetExpr(a) || GetExpr(b);
PrimExpr simplified = analyzer->Simplify(joint);
PrimExpr simplified = analyzer->rewrite_simplify(joint);
if (!ExprDeepEqual()(simplified, joint)) {
if (auto* simplified_or = simplified.as<OrNode>()) {
a = GetKey(simplified_or->a);
b = GetKey(simplified_or->b);
} else {
a = GetKey(simplified);
b = key_false_;
a = key_false_;
b = GetKey(simplified);
}
}
}
Expand All @@ -264,14 +264,14 @@ void AndOfOrs::TrySimplifyAnd(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) {
Key& a = *a_ptr;
Key& b = *b_ptr;
PrimExpr joint = GetExpr(a) && GetExpr(b);
PrimExpr simplified = analyzer->Simplify(joint);
PrimExpr simplified = analyzer->rewrite_simplify(joint);
if (!ExprDeepEqual()(simplified, joint)) {
if (auto* simplified_and = simplified.as<AndNode>()) {
a = GetKey(simplified_and->a);
b = GetKey(simplified_and->b);
} else {
a = GetKey(simplified);
b = key_true_;
a = key_true_;
b = GetKey(simplified);
}
}
}
Expand Down Expand Up @@ -362,6 +362,20 @@ void AndOfOrs::SimplifyAcrossChunks(Analyzer* analyzer) {
// (A or B) and (A or C) => A or (B and C)
auto& key_i = i_chunk[i_distinct_index.value()];
auto& key_j = j_chunk[j_distinct_index.value()];

// When attempting to simplify (B and C), the analyzer may
// assume that A is false.
PrimExpr known = [&]() {
PrimExpr known = Bool(true);
for (const auto& key : i_chunk) {
if (&key != &key_i) {
known = known && analyzer->Simplify(!GetExpr(key));
}
}
return known;
}();

With<ConstraintContext> context(analyzer, known);
TrySimplifyAnd(&key_i, &key_j, analyzer);
}
}
Expand Down
39 changes: 29 additions & 10 deletions src/arith/constraint_extract.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,42 @@
namespace tvm {
namespace arith {

void CollectConstraints(const PrimExpr& expr, Analyzer* analyzer, std::vector<PrimExpr>* collect) {
collect->push_back(expr);
template <typename F>
void CollectConstraints(PrimExpr expr, F callback, bool keep_composite_constraints) {
if (keep_composite_constraints) {
callback(expr);
}

PVar<PrimExpr> x, y;
if ((x && y).Match(expr)) {
CollectConstraints(x.Eval(), analyzer, collect);
CollectConstraints(y.Eval(), analyzer, collect);
} else if ((!(x || y)).Match(expr)) {
CollectConstraints(analyzer->rewrite_simplify(tir::Not(x.Eval())), analyzer, collect);
CollectConstraints(analyzer->rewrite_simplify(tir::Not(y.Eval())), analyzer, collect);
CollectConstraints(x.Eval(), callback, keep_composite_constraints);
CollectConstraints(y.Eval(), callback, keep_composite_constraints);
} else if (!keep_composite_constraints) {
callback(expr);
}
}

std::vector<PrimExpr> ExtractConstraints(const PrimExpr& expr, bool keep_composite_constraints) {
std::vector<PrimExpr> out;
CollectConstraints(
expr, [&](const PrimExpr& part) { out.push_back(part); }, keep_composite_constraints);
return out;
}

template <typename F>
void CollectComponents(PrimExpr expr, F callback) {
PVar<PrimExpr> x, y;
if ((x || y).Match(expr)) {
CollectComponents(x.Eval(), callback);
CollectComponents(y.Eval(), callback);
} else {
callback(expr);
}
}

std::vector<PrimExpr> ExtractConstraints(const PrimExpr& expr) {
std::vector<PrimExpr> ExtractComponents(const PrimExpr& expr) {
std::vector<PrimExpr> out;
Analyzer analyzer;
CollectConstraints(expr, &analyzer, &out);
CollectComponents(expr, [&](const PrimExpr& part) { out.push_back(part); });
return out;
}

Expand Down
31 changes: 30 additions & 1 deletion src/arith/constraint_extract.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,35 @@ namespace arith {
* Example: `i==5 || j==3` => `[i==5 || j==3]`
* Example: `!(i>5 || j==3)` => `[!(i==5 || j==3), i<=5, j!=3]`
*
* If `keep_composite_constraints` is true (default), a constraint
* that can be decomposed will be included in the output. If false,
* they will be excluded.
*
* Example, removing composite: `!(i>5 || j==3)` => `[i<=5, j!=3]`
*
* Intended for use in bounds analysis or simplification within a
* conditional, or identifying independent conditionals that may be
* hoisted.
*
* \param expr The expression to be analyzers
*
* \param keep_composite_constraints Whether to include composite
* constraints in the output.
*
* \returns A vector of independent constraints
*/
std::vector<PrimExpr> ExtractConstraints(const PrimExpr& expr,
bool keep_composite_constraints = true);

/* \brief Returns components that are false if the expression is false.
*
* Utility to break up a boolean expression into independent
* components.
*
* Example: `i==5 || j==3` => `[i==5, j==3]`
* Example: `i==5 && j==3` => `[i==5 && j==3]`
* Example: `!(i>5 && j==3)` => `[i<=5, j!=3]`
*
* Intended for use in bounds analysis or simplification within a
* conditional, or identifying independent conditionals that may be
* hoisted.
Expand All @@ -50,7 +79,7 @@ namespace arith {
*
* \returns A vector of independent constraints
*/
std::vector<PrimExpr> ExtractConstraints(const PrimExpr& expr);
std::vector<PrimExpr> ExtractComponents(const PrimExpr& expr);

} // namespace arith
} // namespace tvm
Expand Down
6 changes: 5 additions & 1 deletion src/arith/ir_visitor_with_analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ class IRVisitorWithAnalyzer : public tir::StmtExprVisitor {
/*! \brief internal analyzer field. */
arith::Analyzer analyzer_;

private:
/*! \brief Extract a constraint from a conditional statement
*
* Intended for preparing argument for use in
* `With<ConstraintContext>`.
*/
PrimExpr ExtractRealCondition(PrimExpr condition) const;
};

Expand Down
53 changes: 51 additions & 2 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c
// we will compare the already simplified result with the constraint,
// so simplify the constraint as well
PrimExpr new_constraint = operator()(constraint);
for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint)) {
for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint, false)) {
if (SideEffect(subconstraint) <= CallEffectKind::kPure) {
literal_constraints_.push_back(subconstraint);
PrimExpr negation;
Expand Down Expand Up @@ -1734,7 +1734,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVar<IntImm> c1, c2, c3;
PVar<int> lanes;

if (op->dtype.lanes() != 1) {
Expand All @@ -1761,6 +1761,55 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {

TVM_TRY_REWRITE(x == c1 && x != c2, x == c1 && c1 != c2);
TVM_TRY_REWRITE(x != c2 && x == c1, x == c1 && c1 != c2);

TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && floormod(x, c2) == c3, x == c1 * c2 + c3);
TVM_TRY_RECURSIVE_REWRITE(floormod(x, c2) == c3 && floordiv(x, c2) == c1, x == c1 * c2 + c3);

TVM_TRY_RECURSIVE_REWRITE_IF(0 <= x - y * c1 &&
x - y * c1<c1, y == floordiv(x, c1), c1.Eval()->value> 0);
TVM_TRY_RECURSIVE_REWRITE_IF(x - y * c1 < c1 && 0 <= x - y * c1, y == floordiv(x, c1),
c1.Eval()->value > 0);

TVM_TRY_RECURSIVE_REWRITE(c1 < x - y * c1 && x - y * c1 <= 0, y == floordiv(x, c1));
TVM_TRY_RECURSIVE_REWRITE(x - y * c1 < c1 && 0 <= x - y * c1, y == floordiv(x, c1));
TVM_TRY_RECURSIVE_REWRITE_IF(0 <= x + y * c2 && x + y * c2 < c1, y == floordiv(x, c1),
c2.Eval()->value == -c1.Eval()->value);
TVM_TRY_RECURSIVE_REWRITE_IF(x + y * c2 < c1 && 0 <= x + y * c2, y == floordiv(x, c1),
c2.Eval()->value == -c1.Eval()->value);

TVM_TRY_RECURSIVE_REWRITE_IF(x < c1 && floormod(x, c2) < c3,
x < c1 - c2 + c3 && floormod(x, c2) < c3,
c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_RECURSIVE_REWRITE_IF(
x < c1 && floormod(x, c2) < c3, x < c1 - floormod(c1, c2) + c3 && floormod(x, c2) < c3,
(c1.Eval()->value % c2.Eval()->value + c2.Eval()->value) % c2.Eval()->value >
c3.Eval()->value);

TVM_TRY_RECURSIVE_REWRITE_IF(x <= c1 && floormod(x, c2) < c3,
x < c1 + 1 - c2 + c3 && floormod(x, c2) < c3,
(c1.Eval()->value + 1) % c2.Eval()->value == 0);
TVM_TRY_RECURSIVE_REWRITE_IF(
x <= c1 && floormod(x, c2) < c3, x < c1 + 1 - floormod(c1, c2) + c3 && floormod(x, c2) < c3,
(((c1.Eval()->value + 1) % c2.Eval()->value) + c2.Eval()->value) % c2.Eval()->value >
c3.Eval()->value);

TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && floormod(x, c2) < c3,
c1 * c2 <= x && x < c1 * c2 + c3);
TVM_TRY_RECURSIVE_REWRITE(floormod(x, c2) < c3 && floordiv(x, c2) == c1,
c1 * c2 <= x && x < c1 * c2 + c3);
TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && floormod(x, c2) <= c3,
c1 * c2 <= x && x <= c1 * c2 + c3);
TVM_TRY_RECURSIVE_REWRITE(floormod(x, c2) <= c3 && floordiv(x, c2) == c1,
c1 * c2 <= x && x <= c1 * c2 + c3);

TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && c3 <= floormod(x, c2),
c1 * c2 + c3 <= x && x < (c1 + 1) * c2);
TVM_TRY_RECURSIVE_REWRITE(c3 <= floormod(x, c2) && floordiv(x, c2) == c1,
c1 * c2 + c3 <= x && x < (c1 + 1) * c2);
TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && c3 < floormod(x, c2),
c1 * c2 + c3 < x && x < (c1 + 1) * c2);
TVM_TRY_RECURSIVE_REWRITE(c3 < floormod(x, c2) && floordiv(x, c2) == c1,
c1 * c2 + c3 < x && x < (c1 + 1) * c2);
return ret;
}

Expand Down
2 changes: 1 addition & 1 deletion src/arith/transitive_comparison_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimEx

void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr,
std::vector<Comparison>* vec) {
for (const auto& subexpr : ExtractConstraints(expr)) {
for (const auto& subexpr : ExtractConstraints(expr, false)) {
if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) {
if (auto cmp = FromExpr(subexpr)) {
vec->push_back(cmp.value());
Expand Down
90 changes: 90 additions & 0 deletions src/arith/unwrap_vector_expr.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file unwrap_vector_expr.cc
* \brief Utility for tracking currently active constraints
*/

#include "unwrap_vector_expr.h"

#include <tvm/arith/analyzer.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/op.h>

#include <unordered_map>

namespace tvm {
namespace arith {

using namespace tir;

class Scalarizer : public ExprMutator {
public:
explicit Scalarizer(PrimExpr lane) : lane_(lane) {}

PrimExpr VisitExpr_(const RampNode* op) final { return op->base + lane_ * op->stride; }

PrimExpr VisitExpr_(const BroadcastNode* op) final { return op->value; }

PrimExpr VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);

auto it = let_var_remap_.find(op);
if (it != let_var_remap_.end()) {
return it->second;
} else {
return ExprMutator::VisitExpr_(op);
}
}
PrimExpr VisitExpr_(const LetNode* op) final {
if (op->value.dtype().lanes() == 1) {
return ExprMutator::VisitExpr_(op);
}

auto it = let_var_remap_.find(op->var.get());
ICHECK(it == let_var_remap_.end()) << "Duplicate binding of variable " << op->var;

Var new_var(op->var->name_hint + "_scalar", op->var.dtype().element_of());
let_var_remap_[op->var.get()] = new_var;

PrimExpr value = this->VisitExpr(op->value);
PrimExpr body = this->VisitExpr(op->body);

let_var_remap_.erase(op->var.get());
return Let(op->var, value, body);
}

private:
// The lane to extract
PrimExpr lane_;

// Let binding
std::unordered_map<const VarNode*, Var> let_var_remap_;
};

PrimExpr UnwrapVectorExpr(const PrimExpr& vector_expr, const PrimExpr& lane) {
return Scalarizer(lane)(vector_expr);
}

} // namespace arith
} // namespace tvm
Loading

0 comments on commit a80cdc2

Please sign in to comment.