diff --git a/apps/hexagon_api/CMakeLists.txt b/apps/hexagon_api/CMakeLists.txt index 82c4b5b66d4c..5234be3c1a15 100644 --- a/apps/hexagon_api/CMakeLists.txt +++ b/apps/hexagon_api/CMakeLists.txt @@ -131,6 +131,7 @@ ExternalProject_Add(hexagon_tvm_runtime_rpc "-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}" "-DUSE_ALTERNATIVE_LINKER=OFF" "-DUSE_CUSTOM_LOGGING=ON" + "-DUSE_HEXAGON_QHL=ON" "${GTEST_FLAG}" INSTALL_COMMAND "" BUILD_ALWAYS ON diff --git a/cmake/config.cmake b/cmake/config.cmake index b9a3aaef7d7e..4cd10f104a83 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -322,6 +322,9 @@ set(USE_HEXAGON_RPC OFF) # Valid values are v65, v66, v68, v69. set(USE_HEXAGON_ARCH "v66") +# Whether to use QHL library +set(USE_HEXAGON_QHL OFF) + # Whether to use ONNX codegen set(USE_TARGET_ONNX OFF) diff --git a/cmake/modules/Hexagon.cmake b/cmake/modules/Hexagon.cmake index 6e9b7dc70cbf..c08ea5eb1df1 100644 --- a/cmake/modules/Hexagon.cmake +++ b/cmake/modules/Hexagon.cmake @@ -89,6 +89,10 @@ endif() # From here on, USE_HEXAGON is assumed to be TRUE. +if(BUILD_FOR_HOST AND USE_HEXAGON_QHL) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_QHL") +endif() + function(add_android_paths) get_hexagon_sdk_property("${USE_HEXAGON_SDK}" "${USE_HEXAGON_ARCH}" SDK_INCLUDE SDK_INCLUDE_DIRS @@ -148,8 +152,27 @@ if(BUILD_FOR_HEXAGON) include_directories(SYSTEM ${SDK_INCLUDE_DIRS} ${QURT_INCLUDE_DIRS}) set(USE_CUSTOM_LOGGING ON) # To use a custom logger -endif() +# QHL support. + if(USE_HEXAGON_QHL) + file_glob_append(TVM_QHL_WRAPPER_SRCS + "${TVMRT_SOURCE_DIR}/hexagon/qhl/*.cc" + ) + + include_directories( + "${USE_HEXAGON_SDK}/libs/qhl_hvx/inc/qhmath_hvx" + "${USE_HEXAGON_SDK}/libs/qhl_hvx/inc/internal/" + + "${USE_HEXAGON_SDK}/libs/qhl/inc/qhmath" + "${USE_HEXAGON_SDK}/libs/qhl/src/internal/" + ) + set_property(SOURCE ${TVM_QHL_WRAPPER_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-narrowing -mhvx -mhvx-length=128B") + + list(APPEND TVM_RUNTIME_LINKER_LIBS -Wl,--whole-archive ${USE_HEXAGON_SDK}/libs/qhl_hvx/prebuilt/hexagon_toolv84_v68/libqhmath_hvx.a -Wl,--no-whole-archive) + list(APPEND TVM_RUNTIME_LINKER_LIBS -Wl,--whole-archive ${USE_HEXAGON_SDK}/libs/qhl/prebuilt/hexagon_toolv84_v68/libqhmath.a -Wl,--no-whole-archive) + + endif() +endif() if(USE_HEXAGON_RPC) function(build_rpc_idl) @@ -238,5 +261,4 @@ if(USE_HEXAGON_RPC) endif() endif() # USE_HEXAGON_RPC - -list(APPEND RUNTIME_SRCS ${RUNTIME_HEXAGON_SRCS}) +list(APPEND RUNTIME_SRCS ${RUNTIME_HEXAGON_SRCS} ${TVM_QHL_WRAPPER_SRCS}) diff --git a/src/runtime/hexagon/qhl/qhl_wrapper.cc b/src/runtime/hexagon/qhl/qhl_wrapper.cc new file mode 100644 index 000000000000..df188c8907e5 --- /dev/null +++ b/src/runtime/hexagon/qhl/qhl_wrapper.cc @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#if defined(__hexagon__) +#include +#include +#include + +#define restrict __restrict__ +#define LOG2VLEN 7 + +// QHL functions with 1 input arg +#define TVM_QHL_WRAPPER_DECL_1IP(NAME) HVX_Vector tvm_vect_##NAME(HVX_Vector input); + +// QHL functions with 2 input args +#define TVM_QHL_WRAPPER_DECL_2IP(NAME) HVX_Vector tvm_vect_##NAME(HVX_Vector ip1, HVX_Vector ip2); + +#define TVM_QHL_WRAPPER_AHF_1IP(NAME) \ + HVX_Vector tvm_vect_##NAME(HVX_Vector input) { return wrapper_api<__fp16>(input, NAME, #NAME); } + +#define TVM_QHL_WRAPPER_AHF_2IP(NAME) \ + HVX_Vector tvm_vect_##NAME(HVX_Vector ip1, HVX_Vector ip2) { \ + return wrapper_api<__fp16>(ip1, ip2, NAME, #NAME); \ + } + +extern "C" { +#include "hvx_internal.h" +#include "qhmath_hvx.h" +#include "qhmath_hvx_vector.h" +using qhlFptr = int (*)(__fp16*, __fp16*, uint32_t); +using qhlFptr2 = int (*)(__fp16*, __fp16*, __fp16*, uint32_t); +TVM_QHL_WRAPPER_DECL_1IP(qhmath_hvx_ceil_ahf) +TVM_QHL_WRAPPER_DECL_1IP(qhmath_hvx_cos_ahf) +TVM_QHL_WRAPPER_DECL_1IP(qhmath_hvx_exp_ahf) +TVM_QHL_WRAPPER_DECL_1IP(qhmath_hvx_floor_ahf) +TVM_QHL_WRAPPER_DECL_1IP(qhmath_hvx_sin_ahf) +TVM_QHL_WRAPPER_DECL_1IP(qhmath_hvx_sigmoid_ahf) +TVM_QHL_WRAPPER_DECL_1IP(qhmath_hvx_sqrt_ahf) +TVM_QHL_WRAPPER_DECL_1IP(qhmath_hvx_tan_ahf) +TVM_QHL_WRAPPER_DECL_1IP(qhmath_hvx_tanh_ahf) + +// QHL functions with 2 input args +TVM_QHL_WRAPPER_DECL_2IP(qhmath_hvx_pow_ahf) +} +template +HVX_Vector wrapper_api(HVX_Vector input, qhlFptr qhl_api, const char* qhl_api_name) { + HVX_Vector output; + int32_t res = qhl_api(reinterpret_cast(&input), reinterpret_cast(&output), 64); + if (res != 0) LOG(FATAL) << "Error. Failed execution of " << qhl_api_name << " Error=" << res; + return output; +} + +template +HVX_Vector wrapper_api(HVX_Vector ip1, HVX_Vector ip2, qhlFptr2 qhl_api, const char* qhl_api_name) { + HVX_Vector output; + int32_t res = qhl_api(reinterpret_cast(&ip1), reinterpret_cast(&ip2), + reinterpret_cast(&output), 64); + if (res != 0) LOG(FATAL) << "Error. Failed execution of " << qhl_api_name << "Error=" << res; + return output; +} + +TVM_QHL_WRAPPER_AHF_1IP(qhmath_hvx_ceil_ahf); +TVM_QHL_WRAPPER_AHF_1IP(qhmath_hvx_cos_ahf); +TVM_QHL_WRAPPER_AHF_1IP(qhmath_hvx_exp_ahf); +TVM_QHL_WRAPPER_AHF_1IP(qhmath_hvx_floor_ahf); +TVM_QHL_WRAPPER_AHF_1IP(qhmath_hvx_sin_ahf); +TVM_QHL_WRAPPER_AHF_1IP(qhmath_hvx_sigmoid_ahf); +TVM_QHL_WRAPPER_AHF_1IP(qhmath_hvx_sqrt_ahf); +TVM_QHL_WRAPPER_AHF_1IP(qhmath_hvx_tan_ahf); +TVM_QHL_WRAPPER_AHF_1IP(qhmath_hvx_tanh_ahf); + +TVM_QHL_WRAPPER_AHF_2IP(qhmath_hvx_pow_ahf); + +#endif diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 7b0081869a27..cab77697164d 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -79,6 +79,12 @@ class CodeGenHexagon final : public CodeGenCPU { llvm::Module* GetModulePtr() const { return module_.get(); } + llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, const Array& args, + bool skip_first_arg) override; + + llvm::Value* CreateCallExternQHL(Type ret_type, String global_symbol, const Array& args, + bool skip_first_arg); + uint64_t GetTypeSizeInBits(llvm::Type* type) const { #if TVM_LLVM_VERSION >= 100 return data_layout_->getTypeSizeInBits(type).getFixedSize(); @@ -98,6 +104,15 @@ class CodeGenHexagon final : public CodeGenCPU { llvm::GlobalVariable* InitContextPtr(llvm::Type* type, std::string name); llvm::Value* GetContextPtr(llvm::GlobalVariable* gv); + bool IsQHLFunction(const std::string& func); + + std::vector fqhl_list_ = { + "tvm_vect_qhmath_hvx_cos_ahf", "tvm_vect_qhmath_hvx_tanh_ahf", + "tvm_vect_qhmath_hvx_sigmoid_ahf", "tvm_vect_qhmath_hvx_sin_ahf", + "tvm_vect_qhmath_hvx_sqrt_ahf", "tvm_vect_qhmath_hvx_exp_ahf", + "tvm_vect_qhmath_hvx_tan_ahf", "tvm_vect_qhmath_hvx_floor_ahf", + "tvm_vect_qhmath_hvx_ceil_ahf", "tvm_vect_qhmath_hvx_pow_ahf"}; + llvm::Value* VectorLookupLoad(Buffer buffer, DataType buffer_type, Array index); llvm::Value* Intrinsic(llvm::Intrinsic::ID, llvm::ArrayRef args); }; @@ -127,6 +142,50 @@ void CodeGenHexagon::InitTarget(llvm::TargetMachine* tm) { CodeGenLLVM::InitTarget(tm); } +llvm::Value* CodeGenHexagon::CreateCallExternQHL(Type ret_type, String global_symbol, + const Array& args, bool skip_first_arg) { + int num_lanes = args[1].dtype().lanes(); + int vector_length = native_vector_bits_ / args[1].dtype().bits(); + num_lanes = ((num_lanes + vector_length - 1) / vector_length) * vector_length; + std::vector vect_split; + for (int i = 0; i < num_lanes / vector_length; ++i) { + std::vector sub_vect_val; + std::vector arg_types; + for (size_t k = skip_first_arg; k < args.size(); ++k) + sub_vect_val.push_back( + CodeGenLLVM::CreateVecSlice(MakeValue(args[k]), i * vector_length, vector_length)); + for (llvm::Value* v : sub_vect_val) { + arg_types.push_back(v->getType()); + } + llvm::FunctionType* ftype = llvm::FunctionType::get(arg_types[0], arg_types, false); + llvm::Function* f = module_->getFunction(MakeStringRef(global_symbol)); + if (f == nullptr) { + f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, + MakeStringRef(global_symbol), module_.get()); + } +#if TVM_LLVM_VERSION >= 90 + auto ext_callee = llvm::FunctionCallee(f); +#else + auto ext_callee = f; +#endif + vect_split.push_back(builder_->CreateCall(ext_callee, sub_vect_val)); + } + return CodeGenLLVM::CreateVecConcat(vect_split); +} + +bool CodeGenHexagon::IsQHLFunction(const std::string& func) { + return std::find(fqhl_list_.begin(), fqhl_list_.end(), func) != fqhl_list_.end(); +} + +llvm::Value* CodeGenHexagon::CreateCallExtern(Type ret_type, String global_symbol, + const Array& args, bool skip_first_arg) { + int num_lanes = args[1].dtype().lanes(); + int vector_length = native_vector_bits_ / args[1].dtype().bits(); + if (IsQHLFunction(global_symbol) && (num_lanes > vector_length)) + return CreateCallExternQHL(ret_type, global_symbol, args, skip_first_arg); + return CodeGenCPU::CreateCallExtern(ret_type, global_symbol, args, skip_first_arg); +} + llvm::GlobalVariable* CodeGenHexagon::InitContextPtr(llvm::Type* p_type, std::string name) { llvm::GlobalVariable* gv = new llvm::GlobalVariable( *module_, p_type, false, llvm::GlobalValue::LinkOnceAnyLinkage, nullptr, name); diff --git a/src/target/llvm/intrin_rule_hexagon.cc b/src/target/llvm/intrin_rule_hexagon.cc index a6f5eae4a561..c96245e1399c 100644 --- a/src/target/llvm/intrin_rule_hexagon.cc +++ b/src/target/llvm/intrin_rule_hexagon.cc @@ -20,17 +20,61 @@ #ifdef TVM_LLVM_VERSION #include +#include #include #include "intrin_rule_llvm.h" +#define TVM_REGISTER_QHL_OP_FP16(INTRIN_FUNC, WRAPPER_FUNC, NUM_SIGN) \ + std::string tvm_qhl_ahf_##INTRIN_FUNC = WRAPPER_FUNC; \ + TVM_REGISTER_OP("tir." #INTRIN_FUNC) \ + .set_attr( \ + "hexagon.FLowerIntrinsic", \ + DispatchTVMQHLWrapperFp16); + namespace tvm { namespace codegen { namespace llvm { using tir::FLowerIntrinsic; -TVM_REGISTER_OP("tir.exp").set_attr( - "hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); +inline PrimExpr TVMExternCall(const tir::CallNode* call, const std::string& fname) { + Array new_args = {tir::StringImm(fname)}; + for (PrimExpr arg : call->args) { + new_args.push_back(arg); + } + return tir::Call(call->dtype, tir::builtin::call_pure_extern(), new_args); +} + +template +inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) { + using namespace tir; + const CallNode* call = e.as(); + ICHECK(call != nullptr); + Array new_args; +#if ENABLE_QHL + // Check target for qfloat enablement + const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent"); + ICHECK(f != nullptr); + const auto ret = (*f)(true); + const Target t = ret.AsObjectRef(); + bool useqhl = true; + if (t.defined()) { + const std::string tstring = t->str(); + useqhl = tstring.find("+hvx-qfloat") != std::string::npos; + } + + // Enable QHL library for FP16 data type + const PrimExpr& x = call->args[0]; + if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { + return TVMExternCall(call, tvm_wrapper); + } +#endif + new_args.push_back(IntImm(DataType::UInt(32), id)); + new_args.push_back(IntImm(DataType::UInt(32), num_sign)); + new_args.insert(new_args.end(), call->args.begin(), call->args.end()); + return tir::Call(call->dtype, tir::builtin::call_llvm_pure_intrin(), new_args); +} TVM_REGISTER_OP("tir.fma").set_attr( "hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); @@ -38,18 +82,6 @@ TVM_REGISTER_OP("tir.fma").set_attr( TVM_REGISTER_OP("tir.log").set_attr( "hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); -TVM_REGISTER_OP("tir.sqrt") - .set_attr("hexagon.FLowerIntrinsic", - DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); - -TVM_REGISTER_OP("tir.floor") - .set_attr("hexagon.FLowerIntrinsic", - DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); - -TVM_REGISTER_OP("tir.ceil") - .set_attr("hexagon.FLowerIntrinsic", - DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); - TVM_REGISTER_OP("tir.trunc") .set_attr("hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); @@ -62,12 +94,117 @@ TVM_REGISTER_OP("tir.round") .set_attr("hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); -TVM_REGISTER_OP("tir.pow").set_attr( - "hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 1>); - TVM_REGISTER_OP("tir.ctpop") .set_attr("hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); +TVM_REGISTER_OP("tir.tanh") + .set_attr("hexagon.FLowerIntrinsic", [](const PrimExpr& e) { + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + +#if ENABLE_QHL + // Check target for qfloat enablement + const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent"); + ICHECK(f != nullptr); + const auto ret = (*f)(true); + const Target t = ret.AsObjectRef(); + bool useqhl = true; + if (t.defined()) { + const std::string tstring = t->str(); + useqhl = tstring.find("+hvx-qfloat") != std::string::npos; + } + + // Enable QHL library for FP16 data type + if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { + std::string tvm_wrapper("tvm_vect_qhmath_hvx_tanh_ahf"); + return TVMExternCall(call, tvm_wrapper); + } +#endif + PrimExpr one = tir::make_const(x.dtype(), 1); + PrimExpr two = tir::make_const(x.dtype(), 2); + PrimExpr neg_two = tir::make_const(x.dtype(), -2); + + PrimExpr exp_neg2x = exp(neg_two * x); + PrimExpr exp_pos2x = exp(two * x); + + PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); + PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); + PrimExpr tanh_x = tir::Select(x >= tir::make_zero(x.dtype()), tanh_pos, tanh_neg); + return tanh_x; + }); + +TVM_REGISTER_OP("tir.tan").set_attr( + "hexagon.FLowerIntrinsic", [](const PrimExpr& e) { + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr); + const PrimExpr& x = call->args[0]; +#if ENABLE_QHL + // Check target for qfloat enablement + const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent"); + ICHECK(f != nullptr); + const auto ret = (*f)(true); + const Target t = ret.AsObjectRef(); + bool useqhl = true; + if (t.defined()) { + const std::string tstring = t->str(); + useqhl = tstring.find("+hvx-qfloat") != std::string::npos; + } + + // Enable QHL library for FP16 data type + if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { + std::string tvm_wrapper("tvm_vect_qhmath_hvx_tan_ahf"); + return TVMExternCall(call, tvm_wrapper); + } +#endif + PrimExpr tan_x = sin(x) / cos(x); + return tan_x; + }); + +TVM_REGISTER_OP("tir.nearbyint") + .set_attr("hexagon.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); + +TVM_REGISTER_OP("tir.sigmoid") + .set_attr("hexagon.FLowerIntrinsic", [](const PrimExpr& e) { + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr); + const PrimExpr& x = call->args[0]; +#if ENABLE_QHL + // Check target for qfloat enablement + const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent"); + ICHECK(f != nullptr); + const auto ret = (*f)(true); + const Target t = ret.AsObjectRef(); + bool useqhl = true; + if (t.defined()) { + const std::string tstring = t->str(); + useqhl = tstring.find("+hvx-qfloat") != std::string::npos; + } + + // Enable QHL library for FP16 data type + if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { + std::string tvm_wrapper("tvm_vect_qhmath_hvx_sigmoid_ahf"); + return TVMExternCall(call, tvm_wrapper); + } +#endif + PrimExpr one = tir::make_const(x.dtype(), 1); + return one / (one + exp(-x)); + }); + +TVM_REGISTER_QHL_OP_FP16(ceil, "tvm_vect_qhmath_hvx_ceil_ahf", 1) + +TVM_REGISTER_QHL_OP_FP16(cos, "tvm_vect_qhmath_hvx_cos_ahf", 1) + +TVM_REGISTER_QHL_OP_FP16(exp, "tvm_vect_qhmath_hvx_exp_ahf", 1) + +TVM_REGISTER_QHL_OP_FP16(floor, "tvm_vect_qhmath_hvx_floor_ahf", 1) + +TVM_REGISTER_QHL_OP_FP16(sin, "tvm_vect_qhmath_hvx_sin_ahf", 1) + +TVM_REGISTER_QHL_OP_FP16(pow, "tvm_vect_qhmath_hvx_pow_ahf", 2) + +TVM_REGISTER_QHL_OP_FP16(sqrt, "tvm_vect_qhmath_hvx_sqrt_ahf", 1) } // namespace llvm } // namespace codegen diff --git a/tests/scripts/task_config_build_hexagon.sh b/tests/scripts/task_config_build_hexagon.sh index a943d72e3635..2f84bed23a30 100755 --- a/tests/scripts/task_config_build_hexagon.sh +++ b/tests/scripts/task_config_build_hexagon.sh @@ -33,3 +33,4 @@ echo set\(USE_HEXAGON "ON"\) >> config.cmake echo set\(USE_HEXAGON_SDK "${HEXAGON_SDK_ROOT}"\) >> config.cmake echo set\(USE_CCACHE OFF\) >> config.cmake echo set\(SUMMARIZE ON\) >> config.cmake +echo set\(USE_HEXAGON_QHL ON\) >> config.cmake