From 6ef22f5631f21656ea8fdc7a58a270972ebf2640 Mon Sep 17 00:00:00 2001
From: Eric Lunderberg <Lunderberg@users.noreply.github.com>
Date: Thu, 15 Jun 2023 14:47:03 -0400
Subject: [PATCH] [AOT] Avoid Var-to-Var Let binding in AOTExecutorCodegen
 (#15033)

Prior to https://github.com/apache/tvm/pull/14951, these can have
erroneous simplifications when used in buffer definitions.  While they
no longer cause issues with correct-ness, they are unnecessary in this case.
---
 src/relay/backend/aot_executor_codegen.cc | 29 ++++++++++++++++++-----
 1 file changed, 23 insertions(+), 6 deletions(-)

diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc
index 4001c870ef3f..945290f70265 100644
--- a/src/relay/backend/aot_executor_codegen.cc
+++ b/src/relay/backend/aot_executor_codegen.cc
@@ -45,6 +45,7 @@
 #include <vector>
 
 #include "../../target/source/codegen_source_base.h"
+#include "../../tir/transforms/ir_utils.h"
 #include "../op/annotation/annotation.h"
 #include "../op/call/call.h"
 #include "../op/memory/device_copy.h"
@@ -505,18 +506,34 @@ class AOTExecutorCodegen : public MixedModeVisitor {
    * copy-on-write fashion.
    */
   void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) {
+    std::vector<tir::Stmt> let_nest;
+
     // Define intermediate DLTensor to load/store the data
     tir::Buffer tmp_read =
         tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_read");
     tir::Buffer tmp_write =
         tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_write");
-    te::Var loop_idx("i", DataType::Int(32));
-    auto retval_i = tir::BufferLoad(tmp_read, {loop_idx});
+
+    // Re-use in/out as the buffer var, if possible
+    if (auto opt = out.as<tir::Var>()) {
+      tmp_write.CopyOnWrite()->data = opt.value();
+    } else {
+      let_nest.push_back(tir::LetStmt(tmp_write->data, out, tir::Evaluate(0)));
+    }
+    if (auto opt = in.as<tir::Var>()) {
+      tmp_read.CopyOnWrite()->data = opt.value();
+    } else {
+      let_nest.push_back(tir::LetStmt(tmp_read->data, in, tir::Evaluate(0)));
+    }
+
     // Copy the variable from the input to the output
-    tir::Stmt copy = tir::For(
-        loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), tir::ForKind::kSerial,
-        tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx}));
-    stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy));
+    te::Var loop_idx("i", DataType::Int(32));
+    tir::Stmt copy = tir::BufferStore(tmp_write, tir::BufferLoad(tmp_read, {loop_idx}), {loop_idx});
+    copy = tir::For(loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()),
+                    tir::ForKind::kSerial, copy);
+    copy = tir::MergeNest(let_nest, copy);
+
+    stmts_.push_back(copy);
   }
 
   /*