Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax][VM] Re-implementation of callback functions #16573

Merged
merged 1 commit into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 3 additions & 25 deletions include/tvm/runtime/relax_vm/bytecode.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ enum class Opcode {
Ret = 2U,
Goto = 3U,
If = 4U,
CallFromRegister = 5U,
};

/*! \brief A single virtual machine instruction.
Expand Down Expand Up @@ -184,15 +183,10 @@ struct Instruction {
/*! \brief The instruction opcode. */
Opcode op;
union {
struct /* Call, CallFromRegister */ {
struct /* Call */ {
/*! \brief The destination register. */
RegName dst;
/*! \brief The index of the function.
*
* For `OpCode::Call`, this is an index into the table of static
* functions. For `OpCode::CallFromRegister`, this is an index
* of a register.
*/
/*! \brief The index into the packed function table. */
Index func_idx;
/*! \brief The number of arguments to the packed function. */
Index num_args;
Expand All @@ -214,43 +208,27 @@ struct Instruction {
Index false_offset;
};
};

/*!
* \brief Construct a Call instruction.
* \param func_idx The index of the function to call within the
* static function table
* \param func_idx The index of the function to call.
* \param num_args The number of arguments.
* \param args The input arguments.
* \param dst The destination register.
* \return The call instruction.
*/
static Instruction Call(Index func_idx, Index num_args, Arg* args, RegName dst);

/*!
* \brief Construct a Call instruction.
* \param func_idx The index of the function to call within the
* current stack frame's registers.
* \param num_args The number of arguments.
* \param args The input arguments.
* \param dst The destination register.
* \return The call instruction.
*/
static Instruction CallFromRegister(Index func_idx, Index num_args, Arg* args, RegName dst);

/*!
* \brief Construct a return instruction.
* \param result The register containing the return value.
* \return The return instruction.
*/
static Instruction Ret(RegName result);

/*!
* \brief Construct a goto instruction.
* \param pc_offset The register containing the jump offset.
* \return The goto instruction.
*/
static Instruction Goto(RegName pc_offset);

/*!
* \brief Construct an If instruction.
* \param cond The register containing the cond value.
Expand Down
10 changes: 9 additions & 1 deletion src/relax/backend/vm/codegen_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,15 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
void EmitNormalCall(const Call& call_node, RegName dst_reg) {
Instruction::Arg func = VisitExpr(call_node->op);
std::vector<Instruction::Arg> args = VisitArray(call_node->args);
builder_->EmitCall(func, args, dst_reg);

if (func.kind() == vm::Instruction::ArgKind::kFuncIdx) {
builder_->EmitCall(func, args, dst_reg);
} else {
std::vector<Instruction::Arg> closure_args = {
Instruction::Arg::Register(Instruction::kVMRegister), func};
std::copy(args.begin(), args.end(), std::back_inserter(closure_args));
builder_->EmitCall("vm.builtin.invoke_closure", closure_args, dst_reg);
}
}

// Emits call to packed function `name` with arguments copied over from `call_node` args
Expand Down
20 changes: 4 additions & 16 deletions src/relax/backend/vm/exec_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,20 +138,10 @@ void ExecBuilderNode::EndFunction(const std::string& func_name) {

void ExecBuilderNode::EmitCall(vm::Instruction::Arg func, std::vector<vm::Instruction::Arg> args,
vm::RegName dst) {
Opcode op_code;
if (func.kind() == vm::Instruction::ArgKind::kFuncIdx) {
op_code = Opcode::Call;
} else if (func.kind() == vm::Instruction::ArgKind::kRegister) {
op_code = Opcode::CallFromRegister;
} else {
LOG(FATAL) << "VM instruction for a function must be either "
<< "kFuncIdx (static function ) "
<< "or kRegister (function passed as parameter), "
<< "but instead found " << func.kind();
}
ICHECK(func.kind() == vm::Instruction::ArgKind::kFuncIdx);
// store instruction
exec_->instr_offset.push_back(exec_->instr_data.size());
exec_->instr_data.push_back(static_cast<ExecWord>(op_code));
exec_->instr_data.push_back(static_cast<ExecWord>(Opcode::Call));
exec_->instr_data.push_back(dst);
exec_->instr_data.push_back(func.value());
exec_->instr_data.push_back(args.size());
Expand Down Expand Up @@ -238,8 +228,7 @@ void ExecBuilderNode::CheckExecutable() {
for (size_t idx = start_instr; idx < end_instr; ++idx) {
Instruction instr = exec_->GetInstruction(idx);
switch (instr.op) {
case Opcode::Call:
case Opcode::CallFromRegister: {
case Opcode::Call: {
check_func_defined(Instruction::Arg::FuncIdx(instr.func_idx));
for (int i = 0; i < instr.num_args; ++i) {
check_reg_defined(instr.args[i]);
Expand Down Expand Up @@ -291,8 +280,7 @@ void ExecBuilderNode::Formalize() {
for (size_t idx = start_instr; idx < end_instr; ++idx) {
Instruction instr = this->exec_->GetInstruction(idx);
switch (instr.op) {
case Opcode::Call:
case Opcode::CallFromRegister: {
case Opcode::Call: {
// rewrite args
for (int i = 0; i < instr.num_args; ++i) {
if (instr.args[i].kind() == Instruction::ArgKind::kRegister &&
Expand Down
11 changes: 0 additions & 11 deletions src/runtime/relax_vm/bytecode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,6 @@ Instruction Instruction::Call(Index func_idx, Index num_args, Instruction::Arg*
return instr;
}

Instruction Instruction::CallFromRegister(Index func_idx, Index num_args, Instruction::Arg* args,
RegName dst) {
Instruction instr;
instr.op = Opcode::CallFromRegister;
instr.dst = dst;
instr.func_idx = func_idx;
instr.num_args = num_args;
instr.args = args;
return instr;
}

Instruction Instruction::Ret(RegName result) {
Instruction instr;
instr.op = Opcode::Ret;
Expand Down
8 changes: 0 additions & 8 deletions src/runtime/relax_vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,6 @@ Instruction Executable::GetInstruction(Index i) const {
ExecWord* args = const_cast<ExecWord*>(&instr_data[offset + 4]);
return Instruction::Call(func_idx, num_args, reinterpret_cast<Instruction::Arg*>(args), dst);
}
case Opcode::CallFromRegister: {
RegName dst = instr_data[offset + 1];
Index func_idx = instr_data[offset + 2];
Index num_args = instr_data[offset + 3];
ExecWord* args = const_cast<ExecWord*>(&instr_data[offset + 4]);
return Instruction::CallFromRegister(func_idx, num_args,
reinterpret_cast<Instruction::Arg*>(args), dst);
}
case Opcode::Ret: {
RegName result = instr_data[offset + 1];
return Instruction::Ret(result);
Expand Down
39 changes: 12 additions & 27 deletions src/runtime/relax_vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,9 @@ class VirtualMachineImpl : public VirtualMachine {
/*!
* \brief Run call instruction.
* \param curr_frame The current frame.
* \param callable The callable object, either PackedFunc or closure
* \param inst The call instruction.
*/
virtual void RunInstrCall(VMFrame* curr_frame, const ObjectRef& callable, Instruction inst);
virtual void RunInstrCall(VMFrame* curr_frame, Instruction inst);

/*! \brief Run VM dispatch loop. */
void RunLoop();
Expand Down Expand Up @@ -507,18 +506,14 @@ void VirtualMachineImpl::SetInput(std::string func_name, bool with_param_module,
//------------------------------------------
void VirtualMachineImpl::InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs args,
TVMRetValue* rv) {
ICHECK(closure_or_packedfunc.defined())
<< "InvokeClosurePacked requires the callable object to be defined";

// run packed call if it is a packed func.
if (auto* packed = closure_or_packedfunc.as<PackedFunc::ContainerType>()) {
packed->CallPacked(args, rv);
return;
}
// run closure call.
auto* clo = closure_or_packedfunc.as<VMClosureObj>();
ICHECK(clo != nullptr) << "Function expects a closure or PackedFunc, "
<< "but received " << closure_or_packedfunc->GetTypeKey();
ICHECK(clo != nullptr) << "Function expects a closure or PackedFunc ";

std::vector<TVMValue> values(args.size() + 1);
std::vector<int> tcodes(args.size() + 1);
Expand Down Expand Up @@ -600,8 +595,6 @@ Optional<VMClosure> VirtualMachineImpl::GetClosureInternal(const String& func_na
auto impl = PackedFunc([gf_idx](TVMArgs args, TVMRetValue* rv) {
// Per convention, ctx ptr is a VirtualMachine*
VirtualMachine* ctx_ptr = static_cast<VirtualMachine*>(args[0].operator void*());
ICHECK(ctx_ptr) << "Context pointer for relax VM closure should be a VirtualMachine*, "
<< "but was NULL";

std::vector<RegType> inputs(args.size() - 1);
for (size_t i = 0; i < inputs.size(); ++i) {
Expand Down Expand Up @@ -651,7 +644,7 @@ RegType VirtualMachineImpl::InvokeBytecode(Index gf_idx, const std::vector<RegTy
auto guard = PushFrame(this->pc_, gfunc);
// Get new frame and set the caller info.
VMFrame* curr_frame = frames_.back().get();
if (curr_instr.op == Opcode::Call || curr_instr.op == Opcode::CallFromRegister) {
if (curr_instr.op == Opcode::Call) {
curr_frame->caller_return_register = curr_instr.dst;
}

Expand Down Expand Up @@ -695,12 +688,8 @@ void VirtualMachineImpl::InitFuncPool() {
}
}

void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, const ObjectRef& callable,
Instruction instr) {
ICHECK(callable.defined()) << "RunInstrCall requires the callable object to be defined";
auto func_name = instr.op == Opcode::Call ? GetFuncName(instr.func_idx) : "<dynamic>";

DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << func_name;
void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) {
DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << GetFuncName(instr.func_idx);
int args_begin_offset = instrument_ != nullptr ? 4 : 0;
// Use the call arg stack from the current frame to increase reuse
// and avoid re-allocation
Expand Down Expand Up @@ -746,11 +735,11 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, const ObjectRef& call
ICHECK_LT(static_cast<size_t>(instr.func_idx), this->func_pool_.size());

if (instrument_ == nullptr) {
this->InvokeClosurePacked(callable, args, &ret);
this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret);
} else {
// insert light-weight instrument callback
setter(0, callable);
setter(1, func_name);
setter(0, func_pool_[instr.func_idx]);
setter(1, GetFuncName(instr.func_idx));
setter(2, true);
setter(3, nullptr);
TVMRetValue rv;
Expand All @@ -769,7 +758,7 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, const ObjectRef& call
ret_kind = rv;
}
if (ret_kind != static_cast<int>(VMInstrumentReturnKind::kSkipRun)) {
this->InvokeClosurePacked(callable, args, &ret);
this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret);
setter(2, false);
setter(3, ret);
instrument_.CallPacked(TVMArgs(values.data(), tcodes.data(), values.size()), &rv);
Expand All @@ -793,11 +782,7 @@ void VirtualMachineImpl::RunLoop() {
Instruction instr = exec_->GetInstruction(pc_);
switch (instr.op) {
case Opcode::Call: {
this->RunInstrCall(curr_frame, func_pool_[instr.func_idx], instr);
break;
}
case Opcode::CallFromRegister: {
this->RunInstrCall(curr_frame, ReadRegister(curr_frame, instr.func_idx), instr);
this->RunInstrCall(curr_frame, instr);
break;
}
case Opcode::Ret: {
Expand Down Expand Up @@ -1015,7 +1000,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl {
}

protected:
void RunInstrCall(VMFrame* curr_frame, const ObjectRef& callable, Instruction inst) override {
void RunInstrCall(VMFrame* curr_frame, Instruction inst) override {
bool profiling = false;
if (prof_ && prof_->IsRunning()) {
auto f_name = GetFuncName(inst.func_idx);
Expand Down Expand Up @@ -1051,7 +1036,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl {
}
}

VirtualMachineImpl::RunInstrCall(curr_frame, callable, inst);
VirtualMachineImpl::RunInstrCall(curr_frame, inst);

if (profiling) {
prof_->StopCall();
Expand Down
Loading