From 6ef22f5631f21656ea8fdc7a58a270972ebf2640 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg <Lunderberg@users.noreply.github.com> Date: Thu, 15 Jun 2023 14:47:03 -0400 Subject: [PATCH] [AOT] Avoid Var-to-Var Let binding in AOTExecutorCodegen (#15033) Prior to https://github.com/apache/tvm/pull/14951, these can have erroneous simplifications when used in buffer definitions. While they no longer cause issues with correct-ness, they are unnecessary in this case. --- src/relay/backend/aot_executor_codegen.cc | 29 ++++++++++++++++++----- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 4001c870ef3f..945290f70265 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -45,6 +45,7 @@ #include <vector> #include "../../target/source/codegen_source_base.h" +#include "../../tir/transforms/ir_utils.h" #include "../op/annotation/annotation.h" #include "../op/call/call.h" #include "../op/memory/device_copy.h" @@ -505,18 +506,34 @@ class AOTExecutorCodegen : public MixedModeVisitor { * copy-on-write fashion. */ void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) { + std::vector<tir::Stmt> let_nest; + // Define intermediate DLTensor to load/store the data tir::Buffer tmp_read = tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_read"); tir::Buffer tmp_write = tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_write"); - te::Var loop_idx("i", DataType::Int(32)); - auto retval_i = tir::BufferLoad(tmp_read, {loop_idx}); + + // Re-use in/out as the buffer var, if possible + if (auto opt = out.as<tir::Var>()) { + tmp_write.CopyOnWrite()->data = opt.value(); + } else { + let_nest.push_back(tir::LetStmt(tmp_write->data, out, tir::Evaluate(0))); + } + if (auto opt = in.as<tir::Var>()) { + tmp_read.CopyOnWrite()->data = opt.value(); + } else { + let_nest.push_back(tir::LetStmt(tmp_read->data, in, tir::Evaluate(0))); + } + // Copy the variable from the input to the output - tir::Stmt copy = tir::For( - loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), tir::ForKind::kSerial, - tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx})); - stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy)); + te::Var loop_idx("i", DataType::Int(32)); + tir::Stmt copy = tir::BufferStore(tmp_write, tir::BufferLoad(tmp_read, {loop_idx}), {loop_idx}); + copy = tir::For(loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), + tir::ForKind::kSerial, copy); + copy = tir::MergeNest(let_nest, copy); + + stmts_.push_back(copy); } /*