Skip to content

Commit

Permalink
Avoid Var-to-Var Let binding in AOTExecutorCodegen
Browse files Browse the repository at this point in the history
Prior to #14951, these can have
erroneous simplifications when used in buffer definitions.
  • Loading branch information
Lunderberg committed Jun 2, 2023
1 parent e7bfa2d commit af1628f
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include <vector>

#include "../../target/source/codegen_source_base.h"
#include "../../tir/transforms/ir_utils.h"
#include "../op/annotation/annotation.h"
#include "../op/call/call.h"
#include "../op/memory/device_copy.h"
Expand Down Expand Up @@ -505,18 +506,34 @@ class AOTExecutorCodegen : public MixedModeVisitor {
* copy-on-write fashion.
*/
void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) {
std::vector<tir::Stmt> let_nest;

// Define intermediate DLTensor to load/store the data
tir::Buffer tmp_read =
tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_read");
tir::Buffer tmp_write =
tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_write");
te::Var loop_idx("i", DataType::Int(32));
auto retval_i = tir::BufferLoad(tmp_read, {loop_idx});

// Re-use in/out as the buffer var, if possible
if (auto opt = out.as<tir::Var>()) {
tmp_write.CopyOnWrite()->data = opt.value();
} else {
let_nest.push_back(tir::LetStmt(tmp_write->data, out, tir::Evaluate(0)));
}
if (auto opt = in.as<tir::Var>()) {
tmp_read.CopyOnWrite()->data = opt.value();
} else {
let_nest.push_back(tir::LetStmt(tmp_read->data, in, tir::Evaluate(0)));
}

// Copy the variable from the input to the output
tir::Stmt copy = tir::For(
loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), tir::ForKind::kSerial,
tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx}));
stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy));
te::Var loop_idx("i", DataType::Int(32));
tir::Stmt copy = tir::BufferStore(tmp_write, tir::BufferLoad(tmp_read, {loop_idx}), {loop_idx});
copy = tir::For(loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()),
tir::ForKind::kSerial, copy);
copy = tir::MergeNest(let_nest, copy);

stmts_.push_back(copy);
}

/*
Expand Down

0 comments on commit af1628f

Please sign in to comment.