Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][Analysis][Arith] Implement basic data-flow analysis #13130

Merged
merged 8 commits into from
Nov 16, 2022
31 changes: 31 additions & 0 deletions include/tvm/tir/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/packed_func.h>

#include <ostream>

namespace tvm {
namespace tir {
/*!
Expand Down Expand Up @@ -92,6 +94,35 @@ enum class CallEffectKind : int {
kControlJump = 6,
};

inline std::ostream& operator<<(std::ostream& os, CallEffectKind side_effect) {
switch (side_effect) {
case CallEffectKind::kExprAnnotation:
return os << "kExprAnnotation";

case CallEffectKind::kPure:
return os << "kPure";

case CallEffectKind::kReadState:
return os << "kReadState";

case CallEffectKind::kUpdateState:
return os << "kUpdateState";

case CallEffectKind::kSpecialCallArg:
return os << "kSpecialCallArg";

case CallEffectKind::kEmbedInfo:
return os << "kEmbedInfo";

case CallEffectKind::kControlJump:
return os << "kControlJump";

default:
LOG(FATAL) << "Unknown CallEffectKind: " << static_cast<int>(side_effect);
return os;
}
}

/*! \brief Use integer to record the kind. */
using TCallEffectKind = Integer;

Expand Down
26 changes: 20 additions & 6 deletions src/arith/conjunctive_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,14 @@ void AndOfOrs::TrySimplifyOr(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) {
Key& a = *a_ptr;
Key& b = *b_ptr;
PrimExpr joint = GetExpr(a) || GetExpr(b);
PrimExpr simplified = analyzer->Simplify(joint);
PrimExpr simplified = analyzer->rewrite_simplify(joint);
if (!ExprDeepEqual()(simplified, joint)) {
if (auto* simplified_or = simplified.as<OrNode>()) {
a = GetKey(simplified_or->a);
b = GetKey(simplified_or->b);
} else {
a = GetKey(simplified);
b = key_false_;
a = key_false_;
b = GetKey(simplified);
Comment on lines +257 to +258
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a bug fix? Do we have a test to reproduce if so?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not so much as bug as a performance improvement and consistency with some steps in the data-flow portions.

The performance improvement comes from this calling scope, which iterates across pairs of expressions, where b is always later in the loop. By placing the simplified expression in b instead of a, this allows later calls to TrySimplifyOr to further simplify using the simplified expression as an input. As a result, repeated passes across all pairs are avoided.

The consistency comes from the ControlFlowBlock's use of ExprDeepEqual to check for convergence. Because the rewrite simplifier doesn't produce a canonical form (e.g. The expressions i > 0 && j > 0 and j>0 && i>0 are equivalent, but neither will be simplified to the other), there were a few cases where I wanted to preserve a specific order when simplifying.

}
}
}
Expand All @@ -264,14 +264,14 @@ void AndOfOrs::TrySimplifyAnd(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) {
Key& a = *a_ptr;
Key& b = *b_ptr;
PrimExpr joint = GetExpr(a) && GetExpr(b);
PrimExpr simplified = analyzer->Simplify(joint);
PrimExpr simplified = analyzer->rewrite_simplify(joint);
if (!ExprDeepEqual()(simplified, joint)) {
if (auto* simplified_and = simplified.as<AndNode>()) {
a = GetKey(simplified_and->a);
b = GetKey(simplified_and->b);
} else {
a = GetKey(simplified);
b = key_true_;
a = key_true_;
b = GetKey(simplified);
csullivan marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Expand Down Expand Up @@ -362,6 +362,20 @@ void AndOfOrs::SimplifyAcrossChunks(Analyzer* analyzer) {
// (A or B) and (A or C) => A or (B and C)
auto& key_i = i_chunk[i_distinct_index.value()];
auto& key_j = j_chunk[j_distinct_index.value()];

// When attempting to simplify (B and C), the analyzer may
// assume that A is false.
PrimExpr known = [&]() {
PrimExpr known = Bool(true);
for (const auto& key : i_chunk) {
if (&key != &key_i) {
known = known && analyzer->Simplify(!GetExpr(key));
}
}
return known;
}();

With<ConstraintContext> context(analyzer, known);
TrySimplifyAnd(&key_i, &key_j, analyzer);
}
}
Expand Down
39 changes: 29 additions & 10 deletions src/arith/constraint_extract.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,42 @@
namespace tvm {
namespace arith {

void CollectConstraints(const PrimExpr& expr, Analyzer* analyzer, std::vector<PrimExpr>* collect) {
collect->push_back(expr);
template <typename F>
void CollectConstraints(PrimExpr expr, F callback, bool keep_composite_constraints) {
if (keep_composite_constraints) {
callback(expr);
}

PVar<PrimExpr> x, y;
if ((x && y).Match(expr)) {
CollectConstraints(x.Eval(), analyzer, collect);
CollectConstraints(y.Eval(), analyzer, collect);
} else if ((!(x || y)).Match(expr)) {
CollectConstraints(analyzer->rewrite_simplify(tir::Not(x.Eval())), analyzer, collect);
CollectConstraints(analyzer->rewrite_simplify(tir::Not(y.Eval())), analyzer, collect);
CollectConstraints(x.Eval(), callback, keep_composite_constraints);
CollectConstraints(y.Eval(), callback, keep_composite_constraints);
} else if (!keep_composite_constraints) {
callback(expr);
}
}

std::vector<PrimExpr> ExtractConstraints(const PrimExpr& expr, bool keep_composite_constraints) {
std::vector<PrimExpr> out;
CollectConstraints(
expr, [&](const PrimExpr& part) { out.push_back(part); }, keep_composite_constraints);
return out;
}

template <typename F>
void CollectComponents(PrimExpr expr, F callback) {
PVar<PrimExpr> x, y;
if ((x || y).Match(expr)) {
CollectComponents(x.Eval(), callback);
CollectComponents(y.Eval(), callback);
} else {
callback(expr);
}
}

std::vector<PrimExpr> ExtractConstraints(const PrimExpr& expr) {
std::vector<PrimExpr> ExtractComponents(const PrimExpr& expr) {
std::vector<PrimExpr> out;
Analyzer analyzer;
CollectConstraints(expr, &analyzer, &out);
CollectComponents(expr, [&](const PrimExpr& part) { out.push_back(part); });
return out;
}

Expand Down
31 changes: 30 additions & 1 deletion src/arith/constraint_extract.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,35 @@ namespace arith {
* Example: `i==5 || j==3` => `[i==5 || j==3]`
* Example: `!(i>5 || j==3)` => `[!(i==5 || j==3), i<=5, j!=3]`
*
* If `keep_composite_constraints` is true (default), a constraint
* that can be decomposed will be included in the output. If false,
* they will be excluded.
*
* Example, removing composite: `!(i>5 || j==3)` => `[i<=5, j!=3]`
*
* Intended for use in bounds analysis or simplification within a
* conditional, or identifying independent conditionals that may be
* hoisted.
*
* \param expr The expression to be analyzers
*
* \param keep_composite_constraints Whether to include composite
* constraints in the output.
*
* \returns A vector of independent constraints
*/
std::vector<PrimExpr> ExtractConstraints(const PrimExpr& expr,
bool keep_composite_constraints = true);

/* \brief Returns components that are false if the expression is false.
*
* Utility to break up a boolean expression into independent
* components.
*
* Example: `i==5 || j==3` => `[i==5, j==3]`
* Example: `i==5 && j==3` => `[i==5 && j==3]`
* Example: `!(i>5 && j==3)` => `[i<=5, j!=3]`
*
* Intended for use in bounds analysis or simplification within a
* conditional, or identifying independent conditionals that may be
* hoisted.
Expand All @@ -50,7 +79,7 @@ namespace arith {
*
* \returns A vector of independent constraints
*/
std::vector<PrimExpr> ExtractConstraints(const PrimExpr& expr);
std::vector<PrimExpr> ExtractComponents(const PrimExpr& expr);

} // namespace arith
} // namespace tvm
Expand Down
6 changes: 5 additions & 1 deletion src/arith/ir_visitor_with_analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ class IRVisitorWithAnalyzer : public tir::StmtExprVisitor {
/*! \brief internal analyzer field. */
arith::Analyzer analyzer_;

private:
/*! \brief Extract a constraint from a conditional statement
*
* Intended for preparing argument for use in
* `With<ConstraintContext>`.
*/
PrimExpr ExtractRealCondition(PrimExpr condition) const;
};

Expand Down
53 changes: 51 additions & 2 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c
// we will compare the already simplified result with the constraint,
// so simplify the constraint as well
PrimExpr new_constraint = operator()(constraint);
for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint)) {
for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint, false)) {
if (SideEffect(subconstraint) <= CallEffectKind::kPure) {
literal_constraints_.push_back(subconstraint);
PrimExpr negation;
Expand Down Expand Up @@ -1734,7 +1734,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVar<IntImm> c1, c2, c3;
PVar<int> lanes;

if (op->dtype.lanes() != 1) {
Expand All @@ -1761,6 +1761,55 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {

TVM_TRY_REWRITE(x == c1 && x != c2, x == c1 && c1 != c2);
TVM_TRY_REWRITE(x != c2 && x == c1, x == c1 && c1 != c2);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add tests for these. Tests using negative numbers too would be great.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, and updated. There were a couple of the rewrite rules that had incorrect behavior for negative numerators, so those are now fixed.

TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && floormod(x, c2) == c3, x == c1 * c2 + c3);
TVM_TRY_RECURSIVE_REWRITE(floormod(x, c2) == c3 && floordiv(x, c2) == c1, x == c1 * c2 + c3);

TVM_TRY_RECURSIVE_REWRITE_IF(0 <= x - y * c1 &&
x - y * c1<c1, y == floordiv(x, c1), c1.Eval()->value> 0);
TVM_TRY_RECURSIVE_REWRITE_IF(x - y * c1 < c1 && 0 <= x - y * c1, y == floordiv(x, c1),
c1.Eval()->value > 0);

TVM_TRY_RECURSIVE_REWRITE(c1 < x - y * c1 && x - y * c1 <= 0, y == floordiv(x, c1));
TVM_TRY_RECURSIVE_REWRITE(x - y * c1 < c1 && 0 <= x - y * c1, y == floordiv(x, c1));
TVM_TRY_RECURSIVE_REWRITE_IF(0 <= x + y * c2 && x + y * c2 < c1, y == floordiv(x, c1),
c2.Eval()->value == -c1.Eval()->value);
TVM_TRY_RECURSIVE_REWRITE_IF(x + y * c2 < c1 && 0 <= x + y * c2, y == floordiv(x, c1),
c2.Eval()->value == -c1.Eval()->value);

TVM_TRY_RECURSIVE_REWRITE_IF(x < c1 && floormod(x, c2) < c3,
x < c1 - c2 + c3 && floormod(x, c2) < c3,
c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_RECURSIVE_REWRITE_IF(
x < c1 && floormod(x, c2) < c3, x < c1 - floormod(c1, c2) + c3 && floormod(x, c2) < c3,
(c1.Eval()->value % c2.Eval()->value + c2.Eval()->value) % c2.Eval()->value >
c3.Eval()->value);

TVM_TRY_RECURSIVE_REWRITE_IF(x <= c1 && floormod(x, c2) < c3,
x < c1 + 1 - c2 + c3 && floormod(x, c2) < c3,
(c1.Eval()->value + 1) % c2.Eval()->value == 0);
TVM_TRY_RECURSIVE_REWRITE_IF(
x <= c1 && floormod(x, c2) < c3, x < c1 + 1 - floormod(c1, c2) + c3 && floormod(x, c2) < c3,
(((c1.Eval()->value + 1) % c2.Eval()->value) + c2.Eval()->value) % c2.Eval()->value >
c3.Eval()->value);

TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && floormod(x, c2) < c3,
c1 * c2 <= x && x < c1 * c2 + c3);
TVM_TRY_RECURSIVE_REWRITE(floormod(x, c2) < c3 && floordiv(x, c2) == c1,
c1 * c2 <= x && x < c1 * c2 + c3);
TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && floormod(x, c2) <= c3,
c1 * c2 <= x && x <= c1 * c2 + c3);
TVM_TRY_RECURSIVE_REWRITE(floormod(x, c2) <= c3 && floordiv(x, c2) == c1,
c1 * c2 <= x && x <= c1 * c2 + c3);

TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && c3 <= floormod(x, c2),
c1 * c2 + c3 <= x && x < (c1 + 1) * c2);
TVM_TRY_RECURSIVE_REWRITE(c3 <= floormod(x, c2) && floordiv(x, c2) == c1,
c1 * c2 + c3 <= x && x < (c1 + 1) * c2);
TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && c3 < floormod(x, c2),
c1 * c2 + c3 < x && x < (c1 + 1) * c2);
TVM_TRY_RECURSIVE_REWRITE(c3 < floormod(x, c2) && floordiv(x, c2) == c1,
c1 * c2 + c3 < x && x < (c1 + 1) * c2);
return ret;
}

Expand Down
2 changes: 1 addition & 1 deletion src/arith/transitive_comparison_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimEx

void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr,
std::vector<Comparison>* vec) {
for (const auto& subexpr : ExtractConstraints(expr)) {
for (const auto& subexpr : ExtractConstraints(expr, false)) {
if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) {
if (auto cmp = FromExpr(subexpr)) {
vec->push_back(cmp.value());
Expand Down
90 changes: 90 additions & 0 deletions src/arith/unwrap_vector_expr.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file unwrap_vector_expr.cc
* \brief Utility for tracking currently active constraints
*/

#include "unwrap_vector_expr.h"

#include <tvm/arith/analyzer.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/op.h>

#include <unordered_map>

namespace tvm {
namespace arith {

using namespace tir;

class Scalarizer : public ExprMutator {
public:
explicit Scalarizer(PrimExpr lane) : lane_(lane) {}

PrimExpr VisitExpr_(const RampNode* op) final { return op->base + lane_ * op->stride; }

PrimExpr VisitExpr_(const BroadcastNode* op) final { return op->value; }

PrimExpr VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);

auto it = let_var_remap_.find(op);
if (it != let_var_remap_.end()) {
return it->second;
} else {
return ExprMutator::VisitExpr_(op);
}
}
PrimExpr VisitExpr_(const LetNode* op) final {
if (op->value.dtype().lanes() == 1) {
return ExprMutator::VisitExpr_(op);
}

auto it = let_var_remap_.find(op->var.get());
ICHECK(it == let_var_remap_.end()) << "Duplicate binding of variable " << op->var;

Var new_var(op->var->name_hint + "_scalar", op->var.dtype().element_of());
let_var_remap_[op->var.get()] = new_var;

PrimExpr value = this->VisitExpr(op->value);
PrimExpr body = this->VisitExpr(op->body);

let_var_remap_.erase(op->var.get());
return Let(op->var, value, body);
}

private:
// The lane to extract
PrimExpr lane_;

// Let binding
std::unordered_map<const VarNode*, Var> let_var_remap_;
};

PrimExpr UnwrapVectorExpr(const PrimExpr& vector_expr, const PrimExpr& lane) {
return Scalarizer(lane)(vector_expr);
}

} // namespace arith
} // namespace tvm
Loading