From 6dce15b5ca549d9e396bd44f7c5f170066beff9c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 17 May 2023 11:54:31 -0500 Subject: [PATCH] [TIR] Keep trivial LetStmt in tir.Simplify when used in buffer decl Prior to this commit, any trivial let binding of `var1 = var2` is inlined. However, buffer definitions are not updated, so this can result in dangling `tir::Var` instances. This commit updates the `tir.Simplify` pass to keep trivial let bindings if they are used as part of a buffer definition. Ideally, the trivial `LetStmt` variable would be inlined into the buffer definition as well as other expressions. However, because a buffer may be implicitly declared, the first usage may be within a constrained context. If that happens, the simplified shape/strides expression cannot be used to update the buffer definition, as that simplification is not valid at all possible usage points of the buffer. ```python for i in range(n): elem_offset = i view = T.Buffer(1, data=buf, elem_offset = elem_offset) if i == 0: # First occurrence in TIR is here, where elem_offset would # simplify to zero. view[0] = 1 else: # But the same buffer is used here, where elem_offset doesn't # simplify to zero. view[0] = 2 ``` This will be resolvable after https://github.com/apache/tvm/pull/14778 lands, requiring all buffers to be declared with `DeclBuffer` prior to usage. ```python for i in range(n): elem_offset = i # All variables used by the DeclBuffer are valid across the entire # body of the DeclBuffer. view = T.decl_buffer(1, data=buf, elem_offset = elem_offset) if i == 0: view[0] = 1 else: view[0] = 2 ``` --- src/tir/transforms/simplify.cc | 88 +++++++++++++++++-- .../unittest/test_tir_transform_simplify.py | 44 ++++++++++ 2 files changed, 123 insertions(+), 9 deletions(-) diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index cc088e8f74c6c..130cbe37c1677 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -33,6 +33,7 @@ #include "../../arith/ir_mutator_with_analyzer.h" #include "../../tir/analysis/control_flow_graph.h" +#include "../../tir/analysis/var_use_def_analysis.h" namespace tvm { namespace arith { @@ -91,6 +92,46 @@ struct SimplifyConfigNode : public tvm::AttrsNode { } }; +/* \brief Utility function to collect vars that should be retained */ +std::unordered_set CollectVarsUsedInBufferDefinition(const Stmt& stmt) { + struct Visitor : StmtExprVisitor { + using StmtExprVisitor::VisitExpr_; + using StmtExprVisitor::VisitStmt_; + + void VisitExpr_(const BufferLoadNode* op) override { + VisitBuffer(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } + void VisitStmt_(const BufferStoreNode* op) override { + VisitBuffer(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + + void VisitBuffer(const Buffer& buf) { + // Collect variables that should remain defined + VarUseDefAnalyzer usage(Array{}); + usage(buf->data); + for (const auto& dim : buf->shape) { + usage(dim); + } + for (const auto& dim : buf->strides) { + usage(dim); + } + usage(buf->elem_offset); + + // Track for use in LetStmtNode mutator + for (const auto& var : usage.undefined_) { + used_in_buffer_def_.insert(var.get()); + } + } + std::unordered_set used_in_buffer_def_; + }; + + Visitor visitor; + visitor(stmt); + return visitor.used_in_buffer_def_; +} + class SimplifyConfig : public Attrs { public: TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs, SimplifyConfigNode); @@ -110,16 +151,24 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { config->propagate_knowns_to_simplify_expressions) { touch_pattern = ControlFlowGraph(stmt); } - StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern)); + + std::unordered_set used_in_buffer_def = CollectVarsUsedInBufferDefinition(stmt); + StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern), + std::move(used_in_buffer_def)); return simplifier(std::move(stmt)); } private: explicit StmtSimplifier(Analyzer* analyzer, SimplifyConfig config, - std::optional touch_pattern) - : IRMutatorWithAnalyzer(analyzer), config_(config), touch_pattern_(touch_pattern) {} + std::optional touch_pattern, + std::unordered_set used_in_buffer_def) + : IRMutatorWithAnalyzer(analyzer), + config_(config), + touch_pattern_(touch_pattern), + used_in_buffer_def_(used_in_buffer_def) {} using Parent = IRMutatorWithAnalyzer; + using Parent::VisitExpr_; using Parent::VisitStmt; using Parent::VisitStmt_; @@ -159,18 +208,36 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { Stmt VisitStmt_(const LetStmtNode* op) override { PrimExpr value = this->VisitExpr(op->value); - if (CanInlineLetStmt(op)) { - // it is fine to discard the let binding - // because the call to simplify will always inline the var. + bool can_inline = CanInlineLetStmt(op); + if (can_inline) { + // It is usually fine to discard the let binding because the + // call to simplify will always inline the var. + // + // The exception is when the variable is used in a Buffer's + // definition, as these are not updated by the simplification. + // After DeclBuffer is required prior to use of a buffer, + // simplifying can update the buffer definition as well. The + // buffer can only be updated at its point of definition, + // because the points of use may occur within contexts that + // allow for additional simplifications (e.g. a buffer of shape + // [i,j] whose first use occurs within "if i==1" should not have + // its shape simplified to [1,j]). analyzer_->Bind(op->var, value); - return this->VisitStmt(op->body); } else if (SideEffect(op->value) <= CallEffectKind::kPure) { // Even if we aren't replacing all occurrences, they may be // necessary for proving conditional statements. non_inlined_bindings_.Set(op->var, value); } Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && body.same_as(op->body)) { + + // TODO(Lunderberg): Update the Buffer object as part of + // DeclBuffer updates, which will first require + // https://github.com/apache/tvm/pull/14778. + bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get()); + + if (can_inline && !used_in_buffer_def) { + return body; + } else if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { auto n = this->CopyOnWrite(op); @@ -207,8 +274,10 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { return Parent::VisitExpr_(op); } + PrimExpr VisitExpr_(const BufferLoadNode* op) override { return Parent::VisitExpr_(op); } + // eliminate useless stores - Stmt VisitStmt_(const BufferStoreNode* op) final { + Stmt VisitStmt_(const BufferStoreNode* op) override { BufferStore store = Downcast(Parent::VisitStmt_(op)); if (const BufferLoadNode* load = store->value.as()) { if (load->buffer->data.same_as(store->buffer->data) && @@ -260,6 +329,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { Map non_inlined_bindings_; Optional current_stmt_{NullOpt}; + std::unordered_set used_in_buffer_def_; }; } // namespace arith diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index b50035aa69d46..1f25405ec9d1b 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -1689,5 +1689,49 @@ def expected(A: T.Buffer(1, "int32")): A[0] = 12 +class TestSimplifyTrivialLetBufferVar(BaseBeforeAfter): + """A LetStmt used in a buffer definition should be retained""" + + def before(A_ptr: T.handle("float32")): + A_ptr_redef: T.handle("float32") = A_ptr + A = T.decl_buffer(1, "float32", data=A_ptr_redef) + A[0] = 42.0 + + expected = before + + +class TestSimplifyTrivialLetElemOffset(BaseBeforeAfter): + """A LetStmt used in a buffer definition should be retained""" + + def before(A_ptr: T.handle("float32"), A_offset: T.int32): + A_offset_redef = A_offset + A = T.decl_buffer(1, "float32", elem_offset=A_offset_redef, data=A_ptr) + A[0] = 42.0 + + expected = before + + +class TestSimplifyTrivialLetShape(BaseBeforeAfter): + """A LetStmt used in a buffer definition should be retained""" + + def before(A_ptr: T.handle("float32"), A_size: T.int32): + A_size_redef = A_size + A = T.decl_buffer([A_size_redef], "float32", data=A_ptr) + A[0] = 42.0 + + expected = before + + +class TestSimplifyTrivialLetStride(BaseBeforeAfter): + """A LetStmt used in a buffer definition should be retained""" + + def before(A_ptr: T.handle("float32"), A_stride: T.int32): + A_stride_redef = A_stride + A = T.decl_buffer(1, "float32", strides=[A_stride_redef], data=A_ptr) + A[0] = 42.0 + + expected = before + + if __name__ == "__main__": tvm.testing.main()