Skip to content

Commit

Permalink
[Relay][VM] Relay VM memory liveness/lifetime analysis (#10026)
Browse files Browse the repository at this point in the history
* WIP VM memory planning

* tuple projection

* support if

* lint

* remove old comment

* WIP check in attempt at CFG analysis

* rewrite CFG analysis in stages, support ADTs

* lint

* fix small bug in alias elimination, try fix VM profiler error

* update DCE tests since allocations can be DCE'd

* optimize worklist to reduce runtime

* add docs, rename pass to ManifestLifetimes

* add tests, more comments, proper VM profiler fix

* lint

* ci please

* address nits

* retry ci again

* retry ci once again :)

* fix sneaky memory leak due to cyclic refs

* fix didn't work but retry ci anyway

* slightly reduce size of large pretty printer test
  • Loading branch information
altanh authored Feb 5, 2022
1 parent 1b71cae commit 34d70de
Show file tree
Hide file tree
Showing 13 changed files with 866 additions and 15 deletions.
8 changes: 8 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,14 @@ TVM_DLL Pass RelayToTIRTargetHook();
*/
TVM_DLL Pass ManifestAlloc(VirtualDevice cpu_virtual_device);

/*!
* \brief A pass for manifesting variable lifetimes by inserting kill operations when variables
* become dead. This pass should be run after ManifestAlloc, and should not be run more than once.
*
* \return The pass.
*/
TVM_DLL Pass ManifestLifetimes();

/*!
* \brief Uses existing "on_device" and "device_copy" CallNodes to infer the \p VirtualDevice on
* which every Relay sub-expression should run and the result stored. Captures the result of that
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/runtime/vm/bytecode.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ enum class Opcode {
ShapeOf = 17U,
ReshapeTensor = 18U,
DeviceCopy = 19U,
KillRegister = 20U,
};

/*! \brief A single virtual machine instruction.
Expand Down Expand Up @@ -386,6 +387,8 @@ struct Instruction {
static Instruction DeviceCopy(RegName src, Index src_device_index, Index dst_device_index,
RegName dst);

static Instruction KillRegister(RegName dst);

Instruction();
Instruction(const Instruction& instr);
Instruction& operator=(const Instruction& instr);
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,14 @@ def PlanDevices(config):
return _ffi_api.PlanDevices(config)


def ManifestLifetimes():
"""
Manifest the lifetimes of variables after allocations have been manifested, by inserting kill
operations once variables become dead.
"""
return _ffi_api.ManifestLifetimes()


def FoldExplicitPadding():
"""
FoldExplicitPadding finds explict padding before an op that can support
Expand Down
10 changes: 8 additions & 2 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
case Opcode::Ret:
case Opcode::Goto:
case Opcode::Fatal:
case Opcode::KillRegister:
break;
}
instructions_.push_back(instr);
Expand Down Expand Up @@ -647,8 +648,10 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
Emit(Instruction::ReshapeTensor(tensor_reg, shape_reg, NewRegister()));
})
.Match("memory.kill",
[](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
LOG(FATAL) << "memory.kill is not yet supported";
[this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
ICHECK_EQ(args.size(), 1u);
this->VisitExpr(args[0]);
Emit(Instruction::KillRegister(this->last_register_));
});
matcher(GetRef<Call>(call_node));
return;
Expand Down Expand Up @@ -993,6 +996,9 @@ transform::Sequential VMCompiler::MemoryOpt(const VirtualDevice& host_virtual_de
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());

// Insert kills to free memory.
pass_seqs.push_back(transform::ManifestLifetimes());

// Lift constants to the top-level of the block to simplify VM code generation.
// TODO(@icemelon9, @jroesch): Remove this pass for now because some
// instructions need to access to constant
Expand Down
Loading

0 comments on commit 34d70de

Please sign in to comment.