Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HEXAGON] QCOM hexagon library (qhl) #12149

Merged
merged 4 commits into from
Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions apps/hexagon_api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ ExternalProject_Add(hexagon_tvm_runtime_rpc
"-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}"
"-DUSE_ALTERNATIVE_LINKER=OFF"
"-DUSE_CUSTOM_LOGGING=ON"
"-DUSE_HEXAGON_QHL=ON"
"${GTEST_FLAG}"
INSTALL_COMMAND ""
BUILD_ALWAYS ON
Expand Down
3 changes: 3 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,9 @@ set(USE_HEXAGON_RPC OFF)
# Valid values are v65, v66, v68, v69.
set(USE_HEXAGON_ARCH "v66")

# Whether to use QHL library
set(USE_HEXAGON_QHL OFF)

# Whether to use ONNX codegen
set(USE_TARGET_ONNX OFF)

Expand Down
28 changes: 25 additions & 3 deletions cmake/modules/Hexagon.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ endif()

# From here on, USE_HEXAGON is assumed to be TRUE.

if(BUILD_FOR_HOST AND USE_HEXAGON_QHL)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_QHL")
endif()

function(add_android_paths)
get_hexagon_sdk_property("${USE_HEXAGON_SDK}" "${USE_HEXAGON_ARCH}"
SDK_INCLUDE SDK_INCLUDE_DIRS
Expand Down Expand Up @@ -148,8 +152,27 @@ if(BUILD_FOR_HEXAGON)
include_directories(SYSTEM ${SDK_INCLUDE_DIRS} ${QURT_INCLUDE_DIRS})

set(USE_CUSTOM_LOGGING ON) # To use a custom logger
endif()

# QHL support.
if(USE_HEXAGON_QHL)
file_glob_append(TVM_QHL_WRAPPER_SRCS
"${TVMRT_SOURCE_DIR}/hexagon/qhl/*.cc"
)

include_directories(
"${USE_HEXAGON_SDK}/libs/qhl_hvx/inc/qhmath_hvx"
"${USE_HEXAGON_SDK}/libs/qhl_hvx/inc/internal/"

"${USE_HEXAGON_SDK}/libs/qhl/inc/qhmath"
"${USE_HEXAGON_SDK}/libs/qhl/src/internal/"
)
set_property(SOURCE ${TVM_QHL_WRAPPER_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-narrowing -mhvx -mhvx-length=128B")

list(APPEND TVM_RUNTIME_LINKER_LIBS -Wl,--whole-archive ${USE_HEXAGON_SDK}/libs/qhl_hvx/prebuilt/hexagon_toolv84_v68/libqhmath_hvx.a -Wl,--no-whole-archive)
list(APPEND TVM_RUNTIME_LINKER_LIBS -Wl,--whole-archive ${USE_HEXAGON_SDK}/libs/qhl/prebuilt/hexagon_toolv84_v68/libqhmath.a -Wl,--no-whole-archive)

endif()
endif()

if(USE_HEXAGON_RPC)
function(build_rpc_idl)
Expand Down Expand Up @@ -238,5 +261,4 @@ if(USE_HEXAGON_RPC)
endif()
endif() # USE_HEXAGON_RPC


list(APPEND RUNTIME_SRCS ${RUNTIME_HEXAGON_SRCS})
list(APPEND RUNTIME_SRCS ${RUNTIME_HEXAGON_SRCS} ${TVM_QHL_WRAPPER_SRCS})
89 changes: 89 additions & 0 deletions src/runtime/hexagon/qhl/qhl_wrapper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#if defined(__hexagon__)
#include <hexagon_types.h>
#include <stdio.h>
#include <tvm/runtime/logging.h>

#define restrict __restrict__
#define LOG2VLEN 7

// QHL functions with 1 input arg
#define TVM_QHL_WRAPPER_DECL_1IP(NAME) HVX_Vector tvm_vect_##NAME(HVX_Vector input);

// QHL functions with 2 input args
#define TVM_QHL_WRAPPER_DECL_2IP(NAME) HVX_Vector tvm_vect_##NAME(HVX_Vector ip1, HVX_Vector ip2);

#define TVM_QHL_WRAPPER_AHF_1IP(NAME) \
HVX_Vector tvm_vect_##NAME(HVX_Vector input) { return wrapper_api<__fp16>(input, NAME, #NAME); }

#define TVM_QHL_WRAPPER_AHF_2IP(NAME) \
HVX_Vector tvm_vect_##NAME(HVX_Vector ip1, HVX_Vector ip2) { \
return wrapper_api<__fp16>(ip1, ip2, NAME, #NAME); \
}

extern "C" {
#include "hvx_internal.h"
#include "qhmath_hvx.h"
#include "qhmath_hvx_vector.h"
using qhlFptr = int (*)(__fp16*, __fp16*, uint32_t);
using qhlFptr2 = int (*)(__fp16*, __fp16*, __fp16*, uint32_t);
TVM_QHL_WRAPPER_DECL_1IP(qhmath_hvx_ceil_ahf)
TVM_QHL_WRAPPER_DECL_1IP(qhmath_hvx_cos_ahf)
TVM_QHL_WRAPPER_DECL_1IP(qhmath_hvx_exp_ahf)
TVM_QHL_WRAPPER_DECL_1IP(qhmath_hvx_floor_ahf)
TVM_QHL_WRAPPER_DECL_1IP(qhmath_hvx_sin_ahf)
TVM_QHL_WRAPPER_DECL_1IP(qhmath_hvx_sigmoid_ahf)
TVM_QHL_WRAPPER_DECL_1IP(qhmath_hvx_sqrt_ahf)
TVM_QHL_WRAPPER_DECL_1IP(qhmath_hvx_tan_ahf)
TVM_QHL_WRAPPER_DECL_1IP(qhmath_hvx_tanh_ahf)

// QHL functions with 2 input args
TVM_QHL_WRAPPER_DECL_2IP(qhmath_hvx_pow_ahf)
}
template <typename T>
HVX_Vector wrapper_api(HVX_Vector input, qhlFptr qhl_api, const char* qhl_api_name) {
HVX_Vector output;
int32_t res = qhl_api(reinterpret_cast<T*>(&input), reinterpret_cast<T*>(&output), 64);
if (res != 0) LOG(FATAL) << "Error. Failed execution of " << qhl_api_name << " Error=" << res;
return output;
}

template <typename T>
HVX_Vector wrapper_api(HVX_Vector ip1, HVX_Vector ip2, qhlFptr2 qhl_api, const char* qhl_api_name) {
HVX_Vector output;
int32_t res = qhl_api(reinterpret_cast<T*>(&ip1), reinterpret_cast<T*>(&ip2),
reinterpret_cast<T*>(&output), 64);
if (res != 0) LOG(FATAL) << "Error. Failed execution of " << qhl_api_name << "Error=" << res;
return output;
}

TVM_QHL_WRAPPER_AHF_1IP(qhmath_hvx_ceil_ahf);
TVM_QHL_WRAPPER_AHF_1IP(qhmath_hvx_cos_ahf);
TVM_QHL_WRAPPER_AHF_1IP(qhmath_hvx_exp_ahf);
TVM_QHL_WRAPPER_AHF_1IP(qhmath_hvx_floor_ahf);
TVM_QHL_WRAPPER_AHF_1IP(qhmath_hvx_sin_ahf);
TVM_QHL_WRAPPER_AHF_1IP(qhmath_hvx_sigmoid_ahf);
TVM_QHL_WRAPPER_AHF_1IP(qhmath_hvx_sqrt_ahf);
TVM_QHL_WRAPPER_AHF_1IP(qhmath_hvx_tan_ahf);
TVM_QHL_WRAPPER_AHF_1IP(qhmath_hvx_tanh_ahf);

TVM_QHL_WRAPPER_AHF_2IP(qhmath_hvx_pow_ahf);

#endif
59 changes: 59 additions & 0 deletions src/target/llvm/codegen_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ class CodeGenHexagon final : public CodeGenCPU {

llvm::Module* GetModulePtr() const { return module_.get(); }

llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
bool skip_first_arg) override;

llvm::Value* CreateCallExternQHL(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
bool skip_first_arg);

uint64_t GetTypeSizeInBits(llvm::Type* type) const {
#if TVM_LLVM_VERSION >= 100
return data_layout_->getTypeSizeInBits(type).getFixedSize();
Expand All @@ -98,6 +104,15 @@ class CodeGenHexagon final : public CodeGenCPU {
llvm::GlobalVariable* InitContextPtr(llvm::Type* type, std::string name);
llvm::Value* GetContextPtr(llvm::GlobalVariable* gv);

bool IsQHLFunction(const std::string& func);

std::vector<std::string> fqhl_list_ = {
"tvm_vect_qhmath_hvx_cos_ahf", "tvm_vect_qhmath_hvx_tanh_ahf",
"tvm_vect_qhmath_hvx_sigmoid_ahf", "tvm_vect_qhmath_hvx_sin_ahf",
"tvm_vect_qhmath_hvx_sqrt_ahf", "tvm_vect_qhmath_hvx_exp_ahf",
"tvm_vect_qhmath_hvx_tan_ahf", "tvm_vect_qhmath_hvx_floor_ahf",
"tvm_vect_qhmath_hvx_ceil_ahf", "tvm_vect_qhmath_hvx_pow_ahf"};

llvm::Value* VectorLookupLoad(Buffer buffer, DataType buffer_type, Array<PrimExpr> index);
llvm::Value* Intrinsic(llvm::Intrinsic::ID, llvm::ArrayRef<llvm::Value*> args);
};
Expand Down Expand Up @@ -127,6 +142,50 @@ void CodeGenHexagon::InitTarget(llvm::TargetMachine* tm) {
CodeGenLLVM::InitTarget(tm);
}

llvm::Value* CodeGenHexagon::CreateCallExternQHL(Type ret_type, String global_symbol,
const Array<PrimExpr>& args, bool skip_first_arg) {
int num_lanes = args[1].dtype().lanes();
int vector_length = native_vector_bits_ / args[1].dtype().bits();
num_lanes = ((num_lanes + vector_length - 1) / vector_length) * vector_length;
std::vector<llvm::Value*> vect_split;
for (int i = 0; i < num_lanes / vector_length; ++i) {
std::vector<llvm::Value*> sub_vect_val;
std::vector<llvm::Type*> arg_types;
for (size_t k = skip_first_arg; k < args.size(); ++k)
sub_vect_val.push_back(
CodeGenLLVM::CreateVecSlice(MakeValue(args[k]), i * vector_length, vector_length));
for (llvm::Value* v : sub_vect_val) {
arg_types.push_back(v->getType());
}
llvm::FunctionType* ftype = llvm::FunctionType::get(arg_types[0], arg_types, false);
llvm::Function* f = module_->getFunction(MakeStringRef(global_symbol));
if (f == nullptr) {
f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
MakeStringRef(global_symbol), module_.get());
}
#if TVM_LLVM_VERSION >= 90
auto ext_callee = llvm::FunctionCallee(f);
#else
auto ext_callee = f;
#endif
vect_split.push_back(builder_->CreateCall(ext_callee, sub_vect_val));
}
return CodeGenLLVM::CreateVecConcat(vect_split);
}

bool CodeGenHexagon::IsQHLFunction(const std::string& func) {
return std::find(fqhl_list_.begin(), fqhl_list_.end(), func) != fqhl_list_.end();
}

llvm::Value* CodeGenHexagon::CreateCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr>& args, bool skip_first_arg) {
int num_lanes = args[1].dtype().lanes();
int vector_length = native_vector_bits_ / args[1].dtype().bits();
if (IsQHLFunction(global_symbol) && (num_lanes > vector_length))
return CreateCallExternQHL(ret_type, global_symbol, args, skip_first_arg);
return CodeGenCPU::CreateCallExtern(ret_type, global_symbol, args, skip_first_arg);
}

llvm::GlobalVariable* CodeGenHexagon::InitContextPtr(llvm::Type* p_type, std::string name) {
llvm::GlobalVariable* gv = new llvm::GlobalVariable(
*module_, p_type, false, llvm::GlobalValue::LinkOnceAnyLinkage, nullptr, name);
Expand Down
Loading