Skip to content

Commit

Permalink
[LLVM] Refactor MakeCallPacked, NFC (apache#9118)
Browse files Browse the repository at this point in the history
Change the interface for `MakeCallPacked` in `CodeGenCPU` and in
`CodeGenHexagon` to encapsulate the multiple returned values into
a single structure. This should help readability, but also it will
make the upcoming adoption of opaque pointers a bit easier.
  • Loading branch information
Krzysztof Parzyszek authored and ylc committed Jan 13, 2022
1 parent 4e4c06d commit d73da13
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 72 deletions.
67 changes: 35 additions & 32 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -686,10 +686,10 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) {
return phi;
}

llvm::BasicBlock* CodeGenCPU::MakeCallPacked(const Array<PrimExpr>& args, llvm::Value** rvalue,
llvm::Value** ret_tcode, const DataType& r_type,
const int64_t begin, const int64_t end) {
using llvm::BasicBlock;
CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>& args,
const DataType& r_type,
const int64_t begin, const int64_t end) {
PackedCall pc;
std::string func_name = args[0].as<StringImmNode>()->value;
llvm::Value* handle = GetPackedFuncHandle(func_name);
// call the function
Expand All @@ -702,66 +702,69 @@ llvm::BasicBlock* CodeGenCPU::MakeCallPacked(const Array<PrimExpr>& args, llvm::
llvm::Value* arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin));
llvm::Value* ret_value = builder_->CreateInBoundsGEP(
builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end));
*ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end));
llvm::Value* ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end));

#if TVM_LLVM_VERSION >= 90
auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall());
#else
auto call_callee = RuntimeTVMFuncCall();
#endif
BasicBlock* end_block = CheckCallSuccess(builder_->CreateCall(
call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, *ret_tcode}));
llvm::Value* call = builder_->CreateCall(
call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, ret_tcode});
llvm::BasicBlock* end_block = CheckCallSuccess(call);

// Load the return value and cast it to the designated type (r_type).
DataType r_api_type = tir::APIType(r_type);
llvm::Value* load_ptr =
builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo());
#if TVM_LLVM_VERSION >= 110
*rvalue = builder_->CreateAlignedLoad(load_ptr, llvm::Align(8));
llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, llvm::Align(8));
#else
*rvalue = builder_->CreateAlignedLoad(load_ptr, 8);
llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8);
#endif
*rvalue = CreateCast(r_api_type, r_type, *rvalue);
return end_block;
pc.ret_value = CreateCast(r_api_type, r_type, rvalue);

// Load the return type code.
#if TVM_LLVM_VERSION >= 110
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8));
#else
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, 8);
#endif

pc.end_block = end_block;
return pc;
}

llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op) {
ICHECK_EQ(op->args.size(), 5U);
llvm::Value* rvalue = nullptr;
llvm::Value* ret_tcode = nullptr;
MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as<IntImmNode>()->value,
op->args[4].as<IntImmNode>()->value);
return rvalue;
PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as<IntImmNode>()->value,
op->args[4].as<IntImmNode>()->value);
return pc.ret_value;
}

llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) {
using llvm::BasicBlock;
ICHECK_EQ(op->args.size(), 6U);
llvm::Value* rvalue = nullptr;
llvm::Value* ret_tcode = nullptr;
BasicBlock* end_block =
MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as<IntImmNode>()->value,
op->args[4].as<IntImmNode>()->value);
PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as<IntImmNode>()->value,
op->args[4].as<IntImmNode>()->value);
// Get traced value.
llvm::Value* traced_value = MakeValue(op->args[5]);
// The update_block handles case when we need to update the return value.
BasicBlock* update_block = BasicBlock::Create(*ctx_, "update_block", function_);
llvm::BasicBlock* update_block = llvm::BasicBlock::Create(*ctx_, "update_block", function_);
// The continue_block handles case when we need to return original
// traced value.
BasicBlock* continue_block = BasicBlock::Create(*ctx_, "continue_block", function_);
#if TVM_LLVM_VERSION >= 110
llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8));
#else
llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, 8);
#endif
llvm::BasicBlock* continue_block = llvm::BasicBlock::Create(*ctx_, "continue_block", function_);

// Check the ret_type_code and create cmp instruction.
llvm::Value* cmp =
builder_->CreateICmpNE(ret_tcode_value, llvm::ConstantInt::get(t_int_, kTVMNullptr));
builder_->CreateICmpNE(pc.ret_tcode, llvm::ConstantInt::get(t_int_, kTVMNullptr));
builder_->CreateCondBr(cmp, update_block, continue_block);
builder_->SetInsertPoint(update_block);
builder_->CreateBr(continue_block);
builder_->SetInsertPoint(continue_block);
// The return value depends on from what bb we come from.
llvm::PHINode* phi_rvalue = builder_->CreatePHI(traced_value->getType(), 2);
phi_rvalue->addIncoming(rvalue, update_block);
phi_rvalue->addIncoming(traced_value, end_block);
phi_rvalue->addIncoming(pc.ret_value, update_block);
phi_rvalue->addIncoming(traced_value, pc.end_block);
return phi_rvalue;
}

Expand Down
8 changes: 6 additions & 2 deletions src/target/llvm/codegen_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,12 @@ class CodeGenCPU : public CodeGenLLVM {
void UnpackClosureData(llvm::Value* cdata, const Array<Var>& fields,
std::unordered_map<const VarNode*, llvm::Value*>* vmap);
// Make packed call.
llvm::BasicBlock* MakeCallPacked(const Array<PrimExpr>& args, llvm::Value** rvalue,
llvm::Value** ret_tcode, const DataType& r_type,
struct PackedCall {
llvm::Value* ret_value;
llvm::Value* ret_tcode;
llvm::BasicBlock* end_block;
};
PackedCall MakeCallPackedLowered(const Array<PrimExpr>& args, const DataType& r_type,
const int64_t begin, const int64_t end);
// create call into tvm packed function.
llvm::Value* CreateCallPacked(const CallNode* op);
Expand Down
82 changes: 44 additions & 38 deletions src/target/llvm/codegen_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,12 @@ class CodeGenHexagon final : public CodeGenLLVM {
std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_;

// Make packed call.
llvm::BasicBlock* MakeCallPacked(const Array<PrimExpr>& args, llvm::Value** rvalue,
llvm::Value** ret_tcode, const DataType& r_type,
struct PackedCall {
llvm::Value* ret_value;
llvm::Value* ret_tcode;
llvm::BasicBlock* end_block;
};
PackedCall MakeCallPackedLowered(const Array<PrimExpr>& args, const DataType& r_type,
const int64_t begin, const int64_t end);
// create call into tvm packed function.
llvm::Value* CreateCallPacked(const CallNode* op);
Expand Down Expand Up @@ -296,11 +300,11 @@ llvm::Value* CodeGenHexagon::RuntimeTVMAPISetLastError() {
return GetContextPtr(gv_tvm_api_set_last_error_);
}

llvm::BasicBlock* CodeGenHexagon::MakeCallPacked(const Array<PrimExpr>& args, llvm::Value** rvalue,
llvm::Value** ret_tcode, const DataType& r_type,
const int64_t begin, const int64_t end) {
using llvm::BasicBlock;
// using namespace tir;
CodeGenHexagon::PackedCall CodeGenHexagon::MakeCallPackedLowered(const Array<PrimExpr>& args,
const DataType& r_type,
const int64_t begin,
const int64_t end) {
PackedCall pc;
std::string func_name = args[0].as<StringImmNode>()->value;
llvm::Value* handle = GetPackedFuncHandle(func_name);
// call the function
Expand All @@ -313,25 +317,37 @@ llvm::BasicBlock* CodeGenHexagon::MakeCallPacked(const Array<PrimExpr>& args, ll
llvm::Value* arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin));
llvm::Value* ret_value = builder_->CreateInBoundsGEP(
builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end));
*ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end));
llvm::Value* ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end));

#if TVM_LLVM_VERSION >= 90
auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall());
#else
auto call_callee = RuntimeTVMFuncCall();
#endif
BasicBlock* end_block = CheckCallSuccess(builder_->CreateCall(
call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, *ret_tcode}));
llvm::Value* call = builder_->CreateCall(
call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, ret_tcode});
llvm::BasicBlock* end_block = CheckCallSuccess(call);

// Load the return value and cast it to the designated type (r_type).
DataType r_api_type = tir::APIType(r_type);
llvm::Value* load_ptr =
builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo());
#if TVM_LLVM_VERSION >= 110
*rvalue = builder_->CreateAlignedLoad(
builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()),
llvm::Align(8));
llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, llvm::Align(8));
#else
*rvalue = builder_->CreateAlignedLoad(
builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()), 8);
llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8);
#endif
*rvalue = CreateCast(r_api_type, r_type, *rvalue);
return end_block;
pc.ret_value = CreateCast(r_api_type, r_type, rvalue);

// Load the return type code.
#if TVM_LLVM_VERSION >= 110
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8));
#else
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, 8);
#endif

pc.end_block = end_block;
return pc;
}

llvm::Value* CodeGenHexagon::GetPackedFuncHandle(const std::string& fname) {
Expand Down Expand Up @@ -417,44 +433,34 @@ llvm::Value* CodeGenHexagon::CreateCallPacked(const CallNode* op) {
}

ICHECK_EQ(op->args.size(), 5U);
llvm::Value* rvalue = nullptr;
llvm::Value* ret_tcode = nullptr;
MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as<IntImmNode>()->value,
op->args[4].as<IntImmNode>()->value);
return rvalue;
PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as<IntImmNode>()->value,
op->args[4].as<IntImmNode>()->value);
return pc.ret_value;
}

llvm::Value* CodeGenHexagon::CreateCallTracePacked(const CallNode* op) {
using llvm::BasicBlock;
ICHECK_EQ(op->args.size(), 6U);
llvm::Value* rvalue = nullptr;
llvm::Value* ret_tcode = nullptr;
BasicBlock* end_block =
MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as<IntImmNode>()->value,
op->args[4].as<IntImmNode>()->value);
PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as<IntImmNode>()->value,
op->args[4].as<IntImmNode>()->value);
// Get traced value.
llvm::Value* traced_value = MakeValue(op->args[5]);
// The update_block handles case when we need to update the return value.
BasicBlock* update_block = BasicBlock::Create(*ctx_, "update_block", function_);
llvm::BasicBlock* update_block = llvm::BasicBlock::Create(*ctx_, "update_block", function_);
// The continue_block handles case when we need to return original
// traced value.
BasicBlock* continue_block = BasicBlock::Create(*ctx_, "continue_block", function_);
#if TVM_LLVM_VERSION >= 110
llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8));
#else
llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, 8);
#endif
llvm::BasicBlock* continue_block = llvm::BasicBlock::Create(*ctx_, "continue_block", function_);

// Check the ret_type_code and create cmp instruction.
llvm::Value* cmp =
builder_->CreateICmpNE(ret_tcode_value, llvm::ConstantInt::get(t_int_, kTVMNullptr));
builder_->CreateICmpNE(pc.ret_tcode, llvm::ConstantInt::get(t_int_, kTVMNullptr));
builder_->CreateCondBr(cmp, update_block, continue_block);
builder_->SetInsertPoint(update_block);
builder_->CreateBr(continue_block);
builder_->SetInsertPoint(continue_block);
// The return value depends on from what bb we come from.
llvm::PHINode* phi_rvalue = builder_->CreatePHI(traced_value->getType(), 2);
phi_rvalue->addIncoming(rvalue, update_block);
phi_rvalue->addIncoming(traced_value, end_block);
phi_rvalue->addIncoming(pc.ret_value, update_block);
phi_rvalue->addIncoming(traced_value, pc.end_block);
return phi_rvalue;
}

Expand Down

0 comments on commit d73da13

Please sign in to comment.