Skip to content

Commit

Permalink
[Arith] Implement statistics counters for RewriteSimplifier (#14532)
Browse files Browse the repository at this point in the history
* [Arith] Implement statistics counters for RewriteSimplifier

Previously, so long as `RewriteSimplifier` produces the same output,
unit tests of its behavior would pass.  This could have severe
performance regressions, such as the one resolved in
#14528, which caused the runtime of
two test to increase from ~1.5 seconds to ~10 minutes each.

This commit implements statistics counts in RewriteSimplifier, which
are exposed through both the C++ and Python APIs, and uses these to
guard against the known performance regression from
#14528.

* lint fixes

* Updates based on review comments

* Consistent int64_t with kMaxRecurDepth

* Removed unused is_currently_visiting_

* Add missing \brief for RewriteSimplifierStatsNode

* Use int64_t in ControlFlowGraph for max simplification steps
  • Loading branch information
Lunderberg authored May 5, 2023
1 parent aa7d2bf commit 1294926
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 8 deletions.
21 changes: 21 additions & 0 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,27 @@ class RewriteSimplifier {
/*! \brief Return the currently enabled extensions */
TVM_DLL Extension GetEnabledExtensions() const;

/*! \brief Return the statistics counters */
TVM_DLL ObjectRef GetStatsCounters() const;

/*! \brief Reset the statistics counters */
TVM_DLL void ResetStatsCounters();

/*! \brief Set the maximum allowed number of rewrite steps
*
* By default, the simplifier may perform as many steps as are
* required. If a positive limit is set, then the simplifier will
* throw an exception when exceeding that number of rewrite steps.
* This allows tests to guard against performance regressions.
*
* Note: To maintain accurate usage counters, `Analyzer` instances
* should be re-used wherever possible. For example, TIR
* transformations should declare a single `Analyzer` that is used
* throughout the pass, and utility functions should receive an
* `Analyzer*` from their calling scope.
*/
TVM_DLL void SetMaximumRewriteSteps(int64_t maximum);

private:
friend class Analyzer;
friend class ConstraintContext;
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/arith/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def __init__(self):
self._modular_set = _mod("modular_set")
self._simplify = _mod("Simplify")
self._rewrite_simplify = _mod("rewrite_simplify")
self._get_rewrite_simplify_stats = _mod("get_rewrite_simplify_stats")
self._reset_rewrite_simplify_stats = _mod("reset_rewrite_simplify_stats")
self._canonical_simplify = _mod("canonical_simplify")
self._int_set = _mod("int_set")
self._enter_constraint_context = _mod("enter_constraint_context")
Expand Down Expand Up @@ -167,6 +169,13 @@ def rewrite_simplify(self, expr):
"""
return self._rewrite_simplify(expr)

@property
def rewrite_simplify_stats(self):
return self._get_rewrite_simplify_stats()

def reset_rewrite_simplify_stats(self):
self._reset_rewrite_simplify_stats()

def canonical_simplify(self, expr):
"""Simplify expression via canonicalization.
Expand Down
7 changes: 7 additions & 0 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,13 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu
} else if (name == "rewrite_simplify") {
return PackedFunc(
[self](TVMArgs args, TVMRetValue* ret) { *ret = self->rewrite_simplify(args[0]); });
} else if (name == "get_rewrite_simplify_stats") {
return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
*ret = self->rewrite_simplify.GetStatsCounters();
});
} else if (name == "reset_rewrite_simplify_stats") {
return PackedFunc(
[self](TVMArgs args, TVMRetValue* ret) { self->rewrite_simplify.ResetStatsCounters(); });
} else if (name == "canonical_simplify") {
return PackedFunc(
[self](TVMArgs args, TVMRetValue* ret) { *ret = self->canonical_simplify(args[0]); });
Expand Down
35 changes: 35 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,25 +58,33 @@ using namespace tir;

// macro for doing simple rewrite
#define TVM_TRY_REWRITE(SrcExpr, ResExpr) \
RecordAttemptedRewrite(); \
if ((SrcExpr).Match(ret)) { \
RecordRewrite(); \
return (ResExpr).Eval(); \
}

// macro for rewrite + recursively rewrite ResExpr
#define TVM_TRY_RECURSIVE_REWRITE(SrcExpr, ResExpr) \
RecordAttemptedRewrite(); \
if ((SrcExpr).Match(ret)) { \
RecordRewrite(); \
return RecursiveRewrite((ResExpr).Eval()); \
}

// macro rewrite only if CondExor is true after match.
#define TVM_TRY_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \
RecordAttemptedRewrite(); \
if ((SrcExpr).Match(ret, [&]() { return (CondExpr); })) { \
RecordRewrite(); \
return (ResExpr).Eval(); \
}

// macro rewrite + recursive_rewrite only if CondExor is true after match.
#define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \
RecordAttemptedRewrite(); \
if ((SrcExpr).Match(ret, [&]() { return (CondExpr); })) { \
RecordRewrite(); \
return RecursiveRewrite((ResExpr).Eval()); \
}

Expand Down Expand Up @@ -211,6 +219,11 @@ CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val
return CompareResult::kUnknown;
}

PrimExpr RewriteSimplifier::Impl::VisitExpr(const PrimExpr& e) {
stats_.nodes_visited++;
return IRMutatorWithAnalyzer::VisitExpr(e);
}

void RewriteSimplifier::Impl::Update(const Var& var, const PrimExpr& info, bool can_override) {
if (!can_override) {
auto it = var_map_.find(var);
Expand Down Expand Up @@ -359,6 +372,7 @@ std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c
literal_constraints_.push_back(Not(negation));
}
}
stats_.constraints_entered++;
size_t new_literal_size = literal_constraints_.size();
auto frecover = [old_literal_size, new_literal_size, this]() {
ICHECK_EQ(literal_constraints_.size(), new_literal_size);
Expand Down Expand Up @@ -2150,9 +2164,30 @@ RewriteSimplifier::Extension RewriteSimplifier::GetEnabledExtensions() const {
return impl_->GetEnabledExtensions();
}

ObjectRef RewriteSimplifier::GetStatsCounters() const { return impl_->GetStatsCounters(); }

void RewriteSimplifier::ResetStatsCounters() { impl_->ResetStatsCounters(); }

void RewriteSimplifier::SetMaximumRewriteSteps(int64_t maximum) {
impl_->SetMaximumRewriteSteps(maximum);
}

RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {}

RewriteSimplifier::~RewriteSimplifier() { delete impl_; }

TVM_REGISTER_NODE_TYPE(RewriteSimplifierStatsNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RewriteSimplifierStatsNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* ptr = node.as<RewriteSimplifierStatsNode>();
p->stream << "RewriteSimplifierStats(nodes_visited = " << ptr->nodes_visited
<< ", constraints_entered = " << ptr->constraints_entered
<< ", rewrites_attempted = " << ptr->rewrites_attempted
<< ", rewrites_performed = " << ptr->rewrites_performed
<< ", max_recursive_depth = " << ptr->max_recursive_depth
<< ", num_recursive_rewrites = " << ptr->num_recursive_rewrites << ")";
});

} // namespace arith
} // namespace tvm
62 changes: 60 additions & 2 deletions src/arith/rewrite_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>

#include <algorithm>
#include <unordered_map>
#include <vector>

Expand All @@ -39,6 +40,41 @@ namespace arith {

using namespace tir;

/* \brief Usage counters for RewriteSimplifier
*
* These are intended for debug and testing purposes, to ensure that
* PrimExpr simplifications and TIR passes do not require an excessive
*/
struct RewriteSimplifierStatsNode : Object {
int64_t nodes_visited{0};
int64_t constraints_entered{0};
int64_t rewrites_attempted{0};
int64_t rewrites_performed{0};
int64_t max_recursive_depth{0};
int64_t num_recursive_rewrites{0};

void VisitAttrs(AttrVisitor* v) {
v->Visit("nodes_visited", &nodes_visited);
v->Visit("constraints_entered", &constraints_entered);
v->Visit("rewrites_attempted", &rewrites_attempted);
v->Visit("rewrites_performed", &rewrites_performed);
v->Visit("max_recursive_depth", &max_recursive_depth);
v->Visit("num_recursive_rewrites", &num_recursive_rewrites);
}

static constexpr const char* _type_key = "arith.RewriteSimplifierStats";
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteSimplifierStatsNode, Object);
};

struct RewriteSimplifierStats : ObjectRef {
explicit RewriteSimplifierStats(RewriteSimplifierStatsNode data) {
data_ = make_object<RewriteSimplifierStatsNode>(data);
}

TVM_DEFINE_OBJECT_REF_METHODS(RewriteSimplifierStats, ObjectRef, RewriteSimplifierStatsNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(RewriteSimplifierStatsNode);
};

/*!
* \brief Rewrite-based simplifier.
*
Expand All @@ -50,6 +86,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {

explicit Impl(Analyzer* parent) : IRMutatorWithAnalyzer(parent) {}

PrimExpr VisitExpr(const PrimExpr& e) override;

void Update(const Var& var, const PrimExpr& info, bool override_info);
PrimExpr VisitExpr_(const AddNode* op) override;
PrimExpr VisitExpr_(const SubNode* op) override;
Expand Down Expand Up @@ -87,9 +125,27 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
/*! \brief Return the currently enabled extensions */
Extension GetEnabledExtensions() const;

RewriteSimplifierStats GetStatsCounters() const { return RewriteSimplifierStats(stats_); }

void ResetStatsCounters() { stats_ = {}; }

void SetMaximumRewriteSteps(int64_t maximum) { maximum_rewrite_steps_ = maximum; }

protected:
int64_t maximum_rewrite_steps_{0};
RewriteSimplifierStatsNode stats_;

void RecordAttemptedRewrite() { stats_.rewrites_attempted++; }
void RecordRewrite() {
stats_.rewrites_performed++;

ICHECK(maximum_rewrite_steps_ <= 0 || stats_.rewrites_performed <= maximum_rewrite_steps_)
<< "RewriteSimplifier exceeded maximum number of rewrites allowed ("
<< maximum_rewrite_steps_ << ")";
}

// counter to record recursive rewrite depth.
int recur_depth_{0};
int64_t recur_depth_{0};
// internal variable map
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;

Expand All @@ -103,7 +159,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
bool recursively_visiting_boolean_{false};

// maximum number of recursion allowed during a single pass.
static const constexpr int kMaxRecurDepth = 5;
static const constexpr int64_t kMaxRecurDepth = 5;
/*!
* \brief try to compare x against val.
* \param x The expression to be evaluated.
Expand Down Expand Up @@ -177,8 +233,10 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
// we limit maximum depth of recursive rewrite allowed to
// avoid infinite loop
PrimExpr RecursiveRewrite(const PrimExpr& x) {
stats_.num_recursive_rewrites++;
if (recur_depth_ >= kMaxRecurDepth) return x;
++recur_depth_;
stats_.max_recursive_depth = std::max(recur_depth_, stats_.max_recursive_depth);
PrimExpr res = this->VisitExpr(x);
--recur_depth_;
return res;
Expand Down
7 changes: 5 additions & 2 deletions src/tir/analysis/control_flow_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -820,8 +820,9 @@ BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph
return buffer_touch;
}

ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits)
: max_revisits_(max_revisits) {
ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, int64_t max_simplification_steps,
size_t max_revisits)
: max_revisits_(max_revisits), max_simplification_steps_(max_simplification_steps) {
ControlFlowGraphBuilder::Build(this, stmt);
ForwardPropagateKnownValues();
BackwardPropagateUnusedValues();
Expand Down Expand Up @@ -1377,6 +1378,7 @@ void ControlFlowGraph::ForwardPropagateKnownValues(std::optional<size_t> flow_fr
std::unordered_map<size_t, size_t> visit_count_lookup;

Analyzer analyzer;
analyzer.rewrite_simplify.SetMaximumRewriteSteps(max_simplification_steps_);
analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
arith::RewriteSimplifier::kTransitivelyProveInequalities |
arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
Expand Down Expand Up @@ -1510,6 +1512,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional<size_t> flow_
std::unordered_map<size_t, size_t> visit_count_lookup;

Analyzer analyzer;
analyzer.rewrite_simplify.SetMaximumRewriteSteps(max_simplification_steps_);
analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
arith::RewriteSimplifier::kTransitivelyProveInequalities |
arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
Expand Down
6 changes: 5 additions & 1 deletion src/tir/analysis/control_flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,8 @@ class ControlFlowGraph {
public:
/* \brief Extract the touch pattern from a TIR statement
*/
explicit ControlFlowGraph(const Stmt& stmt, size_t max_revisits = 5);
explicit ControlFlowGraph(const Stmt& stmt, int64_t max_simplification_steps = 0,
size_t max_revisits = 5);

/* \brief Check if a write is overwritten without impacting final results
*
Expand Down Expand Up @@ -655,6 +656,9 @@ class ControlFlowGraph {

/*! \brief The maximum number of revisits while flowing constraints */
size_t max_revisits_;

/*! \brief The maximum number of revisits while flowing constraints */
int64_t max_simplification_steps_;
};

} // namespace tir
Expand Down
18 changes: 15 additions & 3 deletions src/tir/transforms/remove_no_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,20 @@ namespace tir {

struct RemoveNoOpConfigNode : public tvm::AttrsNode<RemoveNoOpConfigNode> {
bool use_dataflow_analysis;
int64_t max_simplification_steps;

TVM_DECLARE_ATTRS(RemoveNoOpConfigNode, "tir.transform.RemoveNoOpConfig") {
TVM_ATTR_FIELD(use_dataflow_analysis)
.describe(
"If true, known buffer values are propagated and used "
"to statically prove statements as no-ops.")
.set_default(false);
TVM_ATTR_FIELD(max_simplification_steps)
.describe(
"If non-zero, RewriteSimplifier will throw an error "
"after the number of steps specified. "
"For use in debug and testing purposes.")
.set_default(0);
}
};

Expand Down Expand Up @@ -291,14 +298,19 @@ Pass RemoveNoOp() {

RemoveNoOpConfig config = ctx->GetConfig<RemoveNoOpConfig>("tir.RemoveNoOp")
.value_or(AttrsWithDefaultValues<RemoveNoOpConfig>());

if (config->use_dataflow_analysis) {
touch_pattern.emplace(f->body);
touch_pattern.emplace(f->body, config->max_simplification_steps);
}

arith::Analyzer analyzer;
analyzer.rewrite_simplify.SetMaximumRewriteSteps(config->max_simplification_steps);

auto* n = f.CopyOnWrite();
n->body = NoOpRemover::Apply(std::move(n->body), &analyzer, std::move(touch_pattern), nullptr);
{
auto* write_ptr = f.CopyOnWrite();
write_ptr->body = NoOpRemover::Apply(std::move(write_ptr->body), &analyzer,
std::move(touch_pattern), nullptr);
}
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {});
Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_tir_transform_remove_no_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,14 @@ def main(A: T.Buffer((16), "int32"), B: T.Buffer((16), "int32")) -> None:

class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
use_dataflow_analysis = False
max_simplification_steps = 0

def transform(self):
def inner(mod):
config = {
"tir.RemoveNoOp": {
"use_dataflow_analysis": self.use_dataflow_analysis,
"max_simplification_steps": self.max_simplification_steps,
}
}
with tvm.transform.PassContext(config=config):
Expand Down Expand Up @@ -319,9 +321,16 @@ class TestRemoveOverwrittenPredicatedLoopWithIdenticalCondition(BaseBeforeAfter)
Similar to TestKeepPartiallyOverwrittenLoop, except the first loop
has the same predicate as the second, and can therefore be
removed.
In the past, this test has had performance regressions in which
the runtime increased from a few seconds to nearly ten minutes.
The "max_simplification_steps" parameter is set at twice the
current number of steps required, in order to prevent similar
performance regression.
"""

use_dataflow_analysis = True
max_simplification_steps = 200000

def before(A: T.Buffer(16, "int32")):
for i in T.serial(16):
Expand All @@ -347,9 +356,16 @@ class TestRemoveOverwrittenPredicatedLoopWithProvableCondition(BaseBeforeAfter):
loop's predicate. So long as the regions written in the first
loop are a subset of those written in the second loop, they can be
removed.
In the past, this test has had performance regressions in which
the runtime increased from a few seconds to nearly ten minutes.
The "max_simplification_steps" parameter is set at twice the
current number of steps required, in order to prevent similar
performance regression.
"""

use_dataflow_analysis = True
max_simplification_steps = 200000

def before(A: T.Buffer(16, "int32")):
for i in T.serial(16):
Expand Down

0 comments on commit 1294926

Please sign in to comment.