Skip to content

Commit

Permalink
[TIR] Keep trivial LetStmt in tir.Simplify when used in buffer decl
Browse files Browse the repository at this point in the history
Prior to this commit, any trivial let binding of `var1 = var2` is
inlined.  However, buffer definitions are not updated, so this can
result in dangling `tir::Var` instances.  This commit updates the
`tir.Simplify` pass to keep trivial let bindings if they are used as
part of a buffer definition.

Ideally, the trivial `LetStmt` variable would be inlined into the
buffer definition as well as other expressions.  However, because a
buffer may be implicitly declared, the first usage may be within a
constrained context.  If that happens, the simplified shape/strides
expression cannot be used to update the buffer definition, as that
simplification is not valid at all possible usage points of the
buffer.

```python
for i in range(n):
    elem_offset = i
    view = T.Buffer(1, data=buf, elem_offset = elem_offset)
    if i == 0:
        # First occurrence in TIR is here, where elem_offset would
        # simplify to zero.
        view[0] = 1
    else:
        # But the same buffer is used here, where elem_offset doesn't
        # simplify to zero.
        view[0] = 2
```

This will be resolvable after apache#14778
lands, requiring all buffers to be declared with `DeclBuffer` prior to
usage.

```python
for i in range(n):
    elem_offset = i
    # All variables used by the DeclBuffer are valid across the entire
    # body of the DeclBuffer.
    view = T.decl_buffer(1, data=buf, elem_offset = elem_offset)
    if i == 0:
        view[0] = 1
    else:
        view[0] = 2
```
  • Loading branch information
Lunderberg committed May 30, 2023
1 parent 4fdf1d1 commit 6dce15b
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 9 deletions.
88 changes: 79 additions & 9 deletions src/tir/transforms/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

#include "../../arith/ir_mutator_with_analyzer.h"
#include "../../tir/analysis/control_flow_graph.h"
#include "../../tir/analysis/var_use_def_analysis.h"

namespace tvm {
namespace arith {
Expand Down Expand Up @@ -91,6 +92,46 @@ struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
}
};

/* \brief Utility function to collect vars that should be retained */
std::unordered_set<const VarNode*> CollectVarsUsedInBufferDefinition(const Stmt& stmt) {
struct Visitor : StmtExprVisitor {
using StmtExprVisitor::VisitExpr_;
using StmtExprVisitor::VisitStmt_;

void VisitExpr_(const BufferLoadNode* op) override {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode* op) override {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitStmt_(op);
}

void VisitBuffer(const Buffer& buf) {
// Collect variables that should remain defined
VarUseDefAnalyzer usage(Array<Var>{});
usage(buf->data);
for (const auto& dim : buf->shape) {
usage(dim);
}
for (const auto& dim : buf->strides) {
usage(dim);
}
usage(buf->elem_offset);

// Track for use in LetStmtNode mutator
for (const auto& var : usage.undefined_) {
used_in_buffer_def_.insert(var.get());
}
}
std::unordered_set<const VarNode*> used_in_buffer_def_;
};

Visitor visitor;
visitor(stmt);
return visitor.used_in_buffer_def_;
}

class SimplifyConfig : public Attrs {
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs, SimplifyConfigNode);
Expand All @@ -110,16 +151,24 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
config->propagate_knowns_to_simplify_expressions) {
touch_pattern = ControlFlowGraph(stmt);
}
StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern));

std::unordered_set<const VarNode*> used_in_buffer_def = CollectVarsUsedInBufferDefinition(stmt);
StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern),
std::move(used_in_buffer_def));
return simplifier(std::move(stmt));
}

private:
explicit StmtSimplifier(Analyzer* analyzer, SimplifyConfig config,
std::optional<ControlFlowGraph> touch_pattern)
: IRMutatorWithAnalyzer(analyzer), config_(config), touch_pattern_(touch_pattern) {}
std::optional<ControlFlowGraph> touch_pattern,
std::unordered_set<const VarNode*> used_in_buffer_def)
: IRMutatorWithAnalyzer(analyzer),
config_(config),
touch_pattern_(touch_pattern),
used_in_buffer_def_(used_in_buffer_def) {}

using Parent = IRMutatorWithAnalyzer;
using Parent::VisitExpr_;
using Parent::VisitStmt;
using Parent::VisitStmt_;

Expand Down Expand Up @@ -159,18 +208,36 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {

Stmt VisitStmt_(const LetStmtNode* op) override {
PrimExpr value = this->VisitExpr(op->value);
if (CanInlineLetStmt(op)) {
// it is fine to discard the let binding
// because the call to simplify will always inline the var.
bool can_inline = CanInlineLetStmt(op);
if (can_inline) {
// It is usually fine to discard the let binding because the
// call to simplify will always inline the var.
//
// The exception is when the variable is used in a Buffer's
// definition, as these are not updated by the simplification.
// After DeclBuffer is required prior to use of a buffer,
// simplifying can update the buffer definition as well. The
// buffer can only be updated at its point of definition,
// because the points of use may occur within contexts that
// allow for additional simplifications (e.g. a buffer of shape
// [i,j] whose first use occurs within "if i==1" should not have
// its shape simplified to [1,j]).
analyzer_->Bind(op->var, value);
return this->VisitStmt(op->body);
} else if (SideEffect(op->value) <= CallEffectKind::kPure) {
// Even if we aren't replacing all occurrences, they may be
// necessary for proving conditional statements.
non_inlined_bindings_.Set(op->var, value);
}
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {

// TODO(Lunderberg): Update the Buffer object as part of
// DeclBuffer updates, which will first require
// https://github.com/apache/tvm/pull/14778.
bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get());

if (can_inline && !used_in_buffer_def) {
return body;
} else if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = this->CopyOnWrite(op);
Expand Down Expand Up @@ -207,8 +274,10 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return Parent::VisitExpr_(op);
}

PrimExpr VisitExpr_(const BufferLoadNode* op) override { return Parent::VisitExpr_(op); }

// eliminate useless stores
Stmt VisitStmt_(const BufferStoreNode* op) final {
Stmt VisitStmt_(const BufferStoreNode* op) override {
BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
if (const BufferLoadNode* load = store->value.as<BufferLoadNode>()) {
if (load->buffer->data.same_as(store->buffer->data) &&
Expand Down Expand Up @@ -260,6 +329,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {

Map<Var, PrimExpr> non_inlined_bindings_;
Optional<Stmt> current_stmt_{NullOpt};
std::unordered_set<const VarNode*> used_in_buffer_def_;
};

} // namespace arith
Expand Down
44 changes: 44 additions & 0 deletions tests/python/unittest/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,5 +1689,49 @@ def expected(A: T.Buffer(1, "int32")):
A[0] = 12


class TestSimplifyTrivialLetBufferVar(BaseBeforeAfter):
"""A LetStmt used in a buffer definition should be retained"""

def before(A_ptr: T.handle("float32")):
A_ptr_redef: T.handle("float32") = A_ptr
A = T.decl_buffer(1, "float32", data=A_ptr_redef)
A[0] = 42.0

expected = before


class TestSimplifyTrivialLetElemOffset(BaseBeforeAfter):
"""A LetStmt used in a buffer definition should be retained"""

def before(A_ptr: T.handle("float32"), A_offset: T.int32):
A_offset_redef = A_offset
A = T.decl_buffer(1, "float32", elem_offset=A_offset_redef, data=A_ptr)
A[0] = 42.0

expected = before


class TestSimplifyTrivialLetShape(BaseBeforeAfter):
"""A LetStmt used in a buffer definition should be retained"""

def before(A_ptr: T.handle("float32"), A_size: T.int32):
A_size_redef = A_size
A = T.decl_buffer([A_size_redef], "float32", data=A_ptr)
A[0] = 42.0

expected = before


class TestSimplifyTrivialLetStride(BaseBeforeAfter):
"""A LetStmt used in a buffer definition should be retained"""

def before(A_ptr: T.handle("float32"), A_stride: T.int32):
A_stride_redef = A_stride
A = T.decl_buffer(1, "float32", strides=[A_stride_redef], data=A_ptr)
A[0] = 42.0

expected = before


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 6dce15b

Please sign in to comment.