Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Keep trivial LetStmt in tir.Simplify when used in buffer decl #14951

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 79 additions & 9 deletions src/tir/transforms/simplify.cc
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@

#include "../../arith/ir_mutator_with_analyzer.h"
#include "../../tir/analysis/control_flow_graph.h"
#include "../../tir/analysis/var_use_def_analysis.h"

namespace tvm {
namespace arith {
@@ -91,6 +92,46 @@ struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
}
};

/* \brief Utility function to collect vars that should be retained */
std::unordered_set<const VarNode*> CollectVarsUsedInBufferDefinition(const Stmt& stmt) {
struct Visitor : StmtExprVisitor {
using StmtExprVisitor::VisitExpr_;
using StmtExprVisitor::VisitStmt_;

void VisitExpr_(const BufferLoadNode* op) override {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode* op) override {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitStmt_(op);
}

void VisitBuffer(const Buffer& buf) {
// Collect variables that should remain defined
VarUseDefAnalyzer usage(Array<Var>{});
usage(buf->data);
for (const auto& dim : buf->shape) {
usage(dim);
}
for (const auto& dim : buf->strides) {
usage(dim);
}
usage(buf->elem_offset);

// Track for use in LetStmtNode mutator
for (const auto& var : usage.undefined_) {
used_in_buffer_def_.insert(var.get());
}
}
std::unordered_set<const VarNode*> used_in_buffer_def_;
};

Visitor visitor;
visitor(stmt);
return visitor.used_in_buffer_def_;
}

class SimplifyConfig : public Attrs {
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs, SimplifyConfigNode);
@@ -110,16 +151,24 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
config->propagate_knowns_to_simplify_expressions) {
touch_pattern = ControlFlowGraph(stmt);
}
StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern));

std::unordered_set<const VarNode*> used_in_buffer_def = CollectVarsUsedInBufferDefinition(stmt);
StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern),
std::move(used_in_buffer_def));
return simplifier(std::move(stmt));
}

private:
explicit StmtSimplifier(Analyzer* analyzer, SimplifyConfig config,
std::optional<ControlFlowGraph> touch_pattern)
: IRMutatorWithAnalyzer(analyzer), config_(config), touch_pattern_(touch_pattern) {}
std::optional<ControlFlowGraph> touch_pattern,
std::unordered_set<const VarNode*> used_in_buffer_def)
: IRMutatorWithAnalyzer(analyzer),
config_(config),
touch_pattern_(touch_pattern),
used_in_buffer_def_(used_in_buffer_def) {}

using Parent = IRMutatorWithAnalyzer;
using Parent::VisitExpr_;
using Parent::VisitStmt;
using Parent::VisitStmt_;

@@ -159,18 +208,36 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {

Stmt VisitStmt_(const LetStmtNode* op) override {
PrimExpr value = this->VisitExpr(op->value);
if (CanInlineLetStmt(op)) {
// it is fine to discard the let binding
// because the call to simplify will always inline the var.
bool can_inline = CanInlineLetStmt(op);
if (can_inline) {
// It is usually fine to discard the let binding because the
// call to simplify will always inline the var.
//
// The exception is when the variable is used in a Buffer's
// definition, as these are not updated by the simplification.
// After DeclBuffer is required prior to use of a buffer,
// simplifying can update the buffer definition as well. The
// buffer can only be updated at its point of definition,
// because the points of use may occur within contexts that
// allow for additional simplifications (e.g. a buffer of shape
// [i,j] whose first use occurs within "if i==1" should not have
// its shape simplified to [1,j]).
analyzer_->Bind(op->var, value);
return this->VisitStmt(op->body);
} else if (SideEffect(op->value) <= CallEffectKind::kPure) {
// Even if we aren't replacing all occurrences, they may be
// necessary for proving conditional statements.
non_inlined_bindings_.Set(op->var, value);
}
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {

// TODO(Lunderberg): Update the Buffer object as part of
// DeclBuffer updates, which will first require
// https://github.com/apache/tvm/pull/14778.
bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get());

if (can_inline && !used_in_buffer_def) {
return body;
} else if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = this->CopyOnWrite(op);
@@ -207,8 +274,10 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return Parent::VisitExpr_(op);
}

PrimExpr VisitExpr_(const BufferLoadNode* op) override { return Parent::VisitExpr_(op); }

// eliminate useless stores
Stmt VisitStmt_(const BufferStoreNode* op) final {
Stmt VisitStmt_(const BufferStoreNode* op) override {
BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
if (const BufferLoadNode* load = store->value.as<BufferLoadNode>()) {
if (load->buffer->data.same_as(store->buffer->data) &&
@@ -260,6 +329,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {

Map<Var, PrimExpr> non_inlined_bindings_;
Optional<Stmt> current_stmt_{NullOpt};
std::unordered_set<const VarNode*> used_in_buffer_def_;
};

} // namespace arith
44 changes: 44 additions & 0 deletions tests/python/unittest/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
@@ -1689,5 +1689,49 @@ def expected(A: T.Buffer(1, "int32")):
A[0] = 12


class TestSimplifyTrivialLetBufferVar(BaseBeforeAfter):
"""A LetStmt used in a buffer definition should be retained"""

def before(A_ptr: T.handle("float32")):
A_ptr_redef: T.handle("float32") = A_ptr
A = T.decl_buffer(1, "float32", data=A_ptr_redef)
A[0] = 42.0

expected = before


class TestSimplifyTrivialLetElemOffset(BaseBeforeAfter):
"""A LetStmt used in a buffer definition should be retained"""

def before(A_ptr: T.handle("float32"), A_offset: T.int32):
A_offset_redef = A_offset
A = T.decl_buffer(1, "float32", elem_offset=A_offset_redef, data=A_ptr)
A[0] = 42.0

expected = before


class TestSimplifyTrivialLetShape(BaseBeforeAfter):
"""A LetStmt used in a buffer definition should be retained"""

def before(A_ptr: T.handle("float32"), A_size: T.int32):
A_size_redef = A_size
A = T.decl_buffer([A_size_redef], "float32", data=A_ptr)
A[0] = 42.0

expected = before


class TestSimplifyTrivialLetStride(BaseBeforeAfter):
"""A LetStmt used in a buffer definition should be retained"""

def before(A_ptr: T.handle("float32"), A_stride: T.int32):
A_stride_redef = A_stride
A = T.decl_buffer(1, "float32", strides=[A_stride_redef], data=A_ptr)
A[0] = 42.0

expected = before


if __name__ == "__main__":
tvm.testing.main()