diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 1261d9971762b..a3c8b65ab4388 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -45,6 +45,7 @@ #include #include "../../target/source/codegen_source_base.h" +#include "../../tir/transforms/ir_utils.h" #include "../op/annotation/annotation.h" #include "../op/call/call.h" #include "../op/memory/device_copy.h" @@ -505,18 +506,34 @@ class AOTExecutorCodegen : public MixedModeVisitor { * copy-on-write fashion. */ void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) { + std::vector let_nest; + // Define intermediate DLTensor to load/store the data tir::Buffer tmp_read = tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_read"); tir::Buffer tmp_write = tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_write"); - te::Var loop_idx("i", DataType::Int(32)); - auto retval_i = tir::BufferLoad(tmp_read, {loop_idx}); + + // Re-use in/out as the buffer var, if possible + if (auto opt = out.as()) { + tmp_write.CopyOnWrite()->data = opt.value(); + } else { + let_nest.push_back(tir::LetStmt(tmp_write->data, out, tir::Evaluate(0))); + } + if (auto opt = in.as()) { + tmp_read.CopyOnWrite()->data = opt.value(); + } else { + let_nest.push_back(tir::LetStmt(tmp_read->data, in, tir::Evaluate(0))); + } + // Copy the variable from the input to the output - tir::Stmt copy = tir::For( - loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), tir::ForKind::kSerial, - tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx})); - stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy)); + te::Var loop_idx("i", DataType::Int(32)); + tir::Stmt copy = tir::BufferStore(tmp_write, tir::BufferLoad(tmp_read, {loop_idx}), {loop_idx}); + copy = tir::For(loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), + tir::ForKind::kSerial, copy); + copy = tir::MergeNest(let_nest, copy); + + stmts_.push_back(copy); } /*