diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index ab96d6e69d14..e67dee3c37c4 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -686,10 +686,10 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { return phi; } -llvm::BasicBlock* CodeGenCPU::MakeCallPacked(const Array& args, llvm::Value** rvalue, - llvm::Value** ret_tcode, const DataType& r_type, - const int64_t begin, const int64_t end) { - using llvm::BasicBlock; +CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& args, + const DataType& r_type, + const int64_t begin, const int64_t end) { + PackedCall pc; std::string func_name = args[0].as()->value; llvm::Value* handle = GetPackedFuncHandle(func_name); // call the function @@ -702,66 +702,69 @@ llvm::BasicBlock* CodeGenCPU::MakeCallPacked(const Array& args, llvm:: llvm::Value* arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); llvm::Value* ret_value = builder_->CreateInBoundsGEP( builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end)); - *ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); + llvm::Value* ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); + #if TVM_LLVM_VERSION >= 90 auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); #else auto call_callee = RuntimeTVMFuncCall(); #endif - BasicBlock* end_block = CheckCallSuccess(builder_->CreateCall( - call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, *ret_tcode})); + llvm::Value* call = builder_->CreateCall( + call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, ret_tcode}); + llvm::BasicBlock* end_block = CheckCallSuccess(call); + + // Load the return value and cast it to the designated type (r_type). DataType r_api_type = tir::APIType(r_type); llvm::Value* load_ptr = builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()); #if TVM_LLVM_VERSION >= 110 - *rvalue = builder_->CreateAlignedLoad(load_ptr, llvm::Align(8)); + llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, llvm::Align(8)); #else - *rvalue = builder_->CreateAlignedLoad(load_ptr, 8); + llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8); #endif - *rvalue = CreateCast(r_api_type, r_type, *rvalue); - return end_block; + pc.ret_value = CreateCast(r_api_type, r_type, rvalue); + + // Load the return type code. +#if TVM_LLVM_VERSION >= 110 + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8)); +#else + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, 8); +#endif + + pc.end_block = end_block; + return pc; } llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op) { ICHECK_EQ(op->args.size(), 5U); - llvm::Value* rvalue = nullptr; - llvm::Value* ret_tcode = nullptr; - MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); - return rvalue; + PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, + op->args[4].as()->value); + return pc.ret_value; } llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { - using llvm::BasicBlock; ICHECK_EQ(op->args.size(), 6U); - llvm::Value* rvalue = nullptr; - llvm::Value* ret_tcode = nullptr; - BasicBlock* end_block = - MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); + PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, + op->args[4].as()->value); // Get traced value. llvm::Value* traced_value = MakeValue(op->args[5]); // The update_block handles case when we need to update the return value. - BasicBlock* update_block = BasicBlock::Create(*ctx_, "update_block", function_); + llvm::BasicBlock* update_block = llvm::BasicBlock::Create(*ctx_, "update_block", function_); // The continue_block handles case when we need to return original // traced value. - BasicBlock* continue_block = BasicBlock::Create(*ctx_, "continue_block", function_); -#if TVM_LLVM_VERSION >= 110 - llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8)); -#else - llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, 8); -#endif + llvm::BasicBlock* continue_block = llvm::BasicBlock::Create(*ctx_, "continue_block", function_); + // Check the ret_type_code and create cmp instruction. llvm::Value* cmp = - builder_->CreateICmpNE(ret_tcode_value, llvm::ConstantInt::get(t_int_, kTVMNullptr)); + builder_->CreateICmpNE(pc.ret_tcode, llvm::ConstantInt::get(t_int_, kTVMNullptr)); builder_->CreateCondBr(cmp, update_block, continue_block); builder_->SetInsertPoint(update_block); builder_->CreateBr(continue_block); builder_->SetInsertPoint(continue_block); // The return value depends on from what bb we come from. llvm::PHINode* phi_rvalue = builder_->CreatePHI(traced_value->getType(), 2); - phi_rvalue->addIncoming(rvalue, update_block); - phi_rvalue->addIncoming(traced_value, end_block); + phi_rvalue->addIncoming(pc.ret_value, update_block); + phi_rvalue->addIncoming(traced_value, pc.end_block); return phi_rvalue; } diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index d08bd639e131..30e61ea63f12 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -110,8 +110,12 @@ class CodeGenCPU : public CodeGenLLVM { void UnpackClosureData(llvm::Value* cdata, const Array& fields, std::unordered_map* vmap); // Make packed call. - llvm::BasicBlock* MakeCallPacked(const Array& args, llvm::Value** rvalue, - llvm::Value** ret_tcode, const DataType& r_type, + struct PackedCall { + llvm::Value* ret_value; + llvm::Value* ret_tcode; + llvm::BasicBlock* end_block; + }; + PackedCall MakeCallPackedLowered(const Array& args, const DataType& r_type, const int64_t begin, const int64_t end); // create call into tvm packed function. llvm::Value* CreateCallPacked(const CallNode* op); diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index d9d0d1f3d6a4..d8a64102f9cd 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -97,8 +97,12 @@ class CodeGenHexagon final : public CodeGenLLVM { std::unordered_map func_handle_map_; // Make packed call. - llvm::BasicBlock* MakeCallPacked(const Array& args, llvm::Value** rvalue, - llvm::Value** ret_tcode, const DataType& r_type, + struct PackedCall { + llvm::Value* ret_value; + llvm::Value* ret_tcode; + llvm::BasicBlock* end_block; + }; + PackedCall MakeCallPackedLowered(const Array& args, const DataType& r_type, const int64_t begin, const int64_t end); // create call into tvm packed function. llvm::Value* CreateCallPacked(const CallNode* op); @@ -296,11 +300,11 @@ llvm::Value* CodeGenHexagon::RuntimeTVMAPISetLastError() { return GetContextPtr(gv_tvm_api_set_last_error_); } -llvm::BasicBlock* CodeGenHexagon::MakeCallPacked(const Array& args, llvm::Value** rvalue, - llvm::Value** ret_tcode, const DataType& r_type, - const int64_t begin, const int64_t end) { - using llvm::BasicBlock; - // using namespace tir; +CodeGenHexagon::PackedCall CodeGenHexagon::MakeCallPackedLowered(const Array& args, + const DataType& r_type, + const int64_t begin, + const int64_t end) { + PackedCall pc; std::string func_name = args[0].as()->value; llvm::Value* handle = GetPackedFuncHandle(func_name); // call the function @@ -313,25 +317,37 @@ llvm::BasicBlock* CodeGenHexagon::MakeCallPacked(const Array& args, ll llvm::Value* arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); llvm::Value* ret_value = builder_->CreateInBoundsGEP( builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end)); - *ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); + llvm::Value* ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); + #if TVM_LLVM_VERSION >= 90 auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); #else auto call_callee = RuntimeTVMFuncCall(); #endif - BasicBlock* end_block = CheckCallSuccess(builder_->CreateCall( - call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, *ret_tcode})); + llvm::Value* call = builder_->CreateCall( + call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, ret_tcode}); + llvm::BasicBlock* end_block = CheckCallSuccess(call); + + // Load the return value and cast it to the designated type (r_type). DataType r_api_type = tir::APIType(r_type); + llvm::Value* load_ptr = + builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()); #if TVM_LLVM_VERSION >= 110 - *rvalue = builder_->CreateAlignedLoad( - builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()), - llvm::Align(8)); + llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, llvm::Align(8)); #else - *rvalue = builder_->CreateAlignedLoad( - builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()), 8); + llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8); #endif - *rvalue = CreateCast(r_api_type, r_type, *rvalue); - return end_block; + pc.ret_value = CreateCast(r_api_type, r_type, rvalue); + + // Load the return type code. +#if TVM_LLVM_VERSION >= 110 + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8)); +#else + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, 8); +#endif + + pc.end_block = end_block; + return pc; } llvm::Value* CodeGenHexagon::GetPackedFuncHandle(const std::string& fname) { @@ -417,44 +433,34 @@ llvm::Value* CodeGenHexagon::CreateCallPacked(const CallNode* op) { } ICHECK_EQ(op->args.size(), 5U); - llvm::Value* rvalue = nullptr; - llvm::Value* ret_tcode = nullptr; - MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); - return rvalue; + PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, + op->args[4].as()->value); + return pc.ret_value; } llvm::Value* CodeGenHexagon::CreateCallTracePacked(const CallNode* op) { - using llvm::BasicBlock; ICHECK_EQ(op->args.size(), 6U); - llvm::Value* rvalue = nullptr; - llvm::Value* ret_tcode = nullptr; - BasicBlock* end_block = - MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); + PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, + op->args[4].as()->value); // Get traced value. llvm::Value* traced_value = MakeValue(op->args[5]); // The update_block handles case when we need to update the return value. - BasicBlock* update_block = BasicBlock::Create(*ctx_, "update_block", function_); + llvm::BasicBlock* update_block = llvm::BasicBlock::Create(*ctx_, "update_block", function_); // The continue_block handles case when we need to return original // traced value. - BasicBlock* continue_block = BasicBlock::Create(*ctx_, "continue_block", function_); -#if TVM_LLVM_VERSION >= 110 - llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8)); -#else - llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, 8); -#endif + llvm::BasicBlock* continue_block = llvm::BasicBlock::Create(*ctx_, "continue_block", function_); + // Check the ret_type_code and create cmp instruction. llvm::Value* cmp = - builder_->CreateICmpNE(ret_tcode_value, llvm::ConstantInt::get(t_int_, kTVMNullptr)); + builder_->CreateICmpNE(pc.ret_tcode, llvm::ConstantInt::get(t_int_, kTVMNullptr)); builder_->CreateCondBr(cmp, update_block, continue_block); builder_->SetInsertPoint(update_block); builder_->CreateBr(continue_block); builder_->SetInsertPoint(continue_block); // The return value depends on from what bb we come from. llvm::PHINode* phi_rvalue = builder_->CreatePHI(traced_value->getType(), 2); - phi_rvalue->addIncoming(rvalue, update_block); - phi_rvalue->addIncoming(traced_value, end_block); + phi_rvalue->addIncoming(pc.ret_value, update_block); + phi_rvalue->addIncoming(traced_value, pc.end_block); return phi_rvalue; }