diff --git a/qrb_inference_manager/include/qnn_inference/qnn_inference_impl.hpp b/qrb_inference_manager/include/qnn_inference/qnn_inference_impl.hpp index aa48f16..e30ac33 100644 --- a/qrb_inference_manager/include/qnn_inference/qnn_inference_impl.hpp +++ b/qrb_inference_manager/include/qnn_inference/qnn_inference_impl.hpp @@ -95,9 +95,11 @@ class QnnInterface QNN_SYSTEM_INTERFACE_VER_TYPE qnn_system_interface; private: - void * dlopen_helper(const char * file_name, const int flags); template - inline T resolve_symbol(void * lib_handle, const char * sym); + inline T get_function_from_lib(const std::string & lib_name, + const int open_flag, + void ** lib_handle, + const char * func_name); bool init_qnn_backend_interface(const std::string & backend_option, void ** backend_handle); bool init_qnn_graph_interface(const std::string & model_path, void ** model_handle); bool init_qnn_system_interface(const std::string & qnn_syslib_path, void ** sys_lib_handle); diff --git a/qrb_inference_manager/src/qnn_inference/qnn_inference_impl.cpp b/qrb_inference_manager/src/qnn_inference/qnn_inference_impl.cpp index 7f9a2fb..df782c9 100644 --- a/qrb_inference_manager/src/qnn_inference/qnn_inference_impl.cpp +++ b/qrb_inference_manager/src/qnn_inference/qnn_inference_impl.cpp @@ -34,134 +34,77 @@ QnnInterface::QnnInterface(const std::string & backend_option, } } -void * QnnInterface::dlopen_helper(const char * file_name, const int flags) -{ - int real_flags = 0; - - if (flags & 0x0001) { // resolve undefined symbols before return. MUST be specified - real_flags |= RTLD_NOW; - } - - if (flags & 0x0002) { // optional, but the default specified - real_flags |= RTLD_LOCAL; - } - - if (flags & 0x0004) { // optional, resolve symbol globally - real_flags |= RTLD_GLOBAL; - } - - return ::dlopen(file_name, real_flags); -} - template -inline T QnnInterface::resolve_symbol(void * lib_handle, const char * sym) -{ - void * ptr = nullptr; - if (lib_handle == (void *)(0x4)) { // specify this address to distingiush from NULL pointer - ptr = ::dlsym(RTLD_DEFAULT, sym); - } else { - ptr = ::dlsym(lib_handle, sym); +inline T QnnInterface::get_function_from_lib(const std::string & lib_name, + const int open_flag, + void ** lib_handle, + const char * func_name) +{ + if (nullptr == *lib_handle) { + *lib_handle = ::dlopen(lib_name.c_str(), open_flag); + if (nullptr == *lib_handle) { + QRB_ERROR("Unable to open lib!"); + return nullptr; + } } - if (ptr == nullptr) { - QRB_ERROR("Unable to access symbol", sym); + auto temp_lib_handle = *lib_handle; + auto func_ptr = ::dlsym(temp_lib_handle, func_name); + if (nullptr == func_ptr) { + QRB_ERROR("Unable to get function: ", func_name); + return nullptr; } - return reinterpret_cast(ptr); + + return reinterpret_cast(func_ptr); } -/// @brief get the qnn interface from backend file +/// @brief get the qnn interface from backend library /// @param backend_option file path of libQnnBackend.so /// @param backend_handle pointer point to qnn interface /// @return true of false bool QnnInterface::init_qnn_backend_interface(const std::string & backend_option, void ** backend_handle) { - void * lib_backend_handle = dlopen_helper(backend_option.c_str(), 0x0001 | 0x0004); - - if (nullptr == lib_backend_handle) { - QRB_ERROR("Unable to load backend!"); - return false; - } - - if (nullptr != backend_handle) { - *backend_handle = lib_backend_handle; - } - - qnn_interface_providers_func get_interface_providers = nullptr; - get_interface_providers = - resolve_symbol(lib_backend_handle, "QnnInterface_getProviders"); - - if (nullptr == get_interface_providers) { - return false; - } + auto get_interface_providers = get_function_from_lib( + backend_option, RTLD_NOW | RTLD_GLOBAL, backend_handle, "QnnInterface_getProviders"); QnnInterface_t ** interface_providers = nullptr; uint32_t num_providers = 0; + if (QNN_SUCCESS != get_interface_providers((const QnnInterface_t ***)&interface_providers, &num_providers)) { QRB_ERROR("Failed to get interface providers!"); return false; } - if (nullptr == interface_providers) { - QRB_ERROR("Failed to get interface providers: null interface providers received!"); - return false; - } - - if (0 == num_providers) { - QRB_ERROR("Failed to get interface providers: 0 interface providers!"); + if (interface_providers == nullptr || num_providers == 0) { + QRB_ERROR("Invalid interface providers retrieved!"); return false; } - bool found_valid_interface{ false }; for (size_t i = 0; i < num_providers; i++) { if (QNN_API_VERSION_MAJOR == interface_providers[i]->apiVersion.coreApiVersion.major && QNN_API_VERSION_MINOR <= interface_providers[i]->apiVersion.coreApiVersion.minor) { - found_valid_interface = true; this->interface = interface_providers[i]->QNN_INTERFACE_VER_NAME; - break; + return true; } } - if (!found_valid_interface) { - QRB_ERROR("Unable to find a valid interface!"); - backend_handle = nullptr; - return false; - } - - return true; + QRB_ERROR("Unable to find a valid interface!"); + return false; } -/// @brief get the qnn interface from model file +/// @brief get the qnn interface from model lib /// @param model_path /// @param model_handle pointer point to qnn interface /// @return true or false bool QnnInterface::init_qnn_graph_interface(const std::string & model_path, void ** model_handle) { - void * lib_model_handle = dlopen_helper(model_path.c_str(), 0x0001 | 0x0002); - - if (nullptr == lib_model_handle) { - QRB_ERROR("Unable to load model!"); - return false; - } + this->compose_graphs = get_function_from_lib( + model_path, RTLD_NOW | RTLD_LOCAL, model_handle, "QnnModel_composeGraphs"); - if (nullptr != model_handle) { - *model_handle = lib_model_handle; - } - - std::string model_prepare_func = "QnnModel_composeGraphs"; - this->compose_graphs = - resolve_symbol(lib_model_handle, model_prepare_func.c_str()); - if (nullptr == this->compose_graphs) { - return false; - } - - std::string model_free_func = "QnnModel_freeGraphsInfo"; - this->free_graph_info = - resolve_symbol(lib_model_handle, model_free_func.c_str()); - if (nullptr == this->free_graph_info) { - return false; - } + this->free_graph_info = get_function_from_lib( + model_path, RTLD_NOW | RTLD_LOCAL, model_handle, "QnnModel_freeGraphsInfo"); return true; } @@ -173,52 +116,34 @@ bool QnnInterface::init_qnn_graph_interface(const std::string & model_path, void bool QnnInterface::init_qnn_system_interface(const std::string & qnn_syslib_path, void ** sys_lib_handle) { - void * lib_sys_handle = dlopen_helper(qnn_syslib_path.c_str(), 0x0001 | 0x0002); - if (nullptr == lib_sys_handle) { - QRB_ERROR("Unable to load system library!"); - return false; - } - - if (nullptr != sys_lib_handle) { - *sys_lib_handle = lib_sys_handle; - } - - qnn_sys_interface_providers_func get_sys_interface_providers = nullptr; - get_sys_interface_providers = resolve_symbol( - lib_sys_handle, "QnnSystemInterface_getProviders"); - if (nullptr == get_sys_interface_providers) { - return false; - } + auto get_sys_interface_providers = get_function_from_lib( + qnn_syslib_path, RTLD_NOW | RTLD_LOCAL, sys_lib_handle, "QnnSystemInterface_getProviders"); QnnSystemInterface_t ** sys_interface_providers = nullptr; uint32_t number_of_providers = 0; + if (QNN_SUCCESS != get_sys_interface_providers( (const QnnSystemInterface_t ***)&sys_interface_providers, &number_of_providers)) { QRB_ERROR("Failed to get system interface providers."); return false; } + if (nullptr == sys_interface_providers || number_of_providers == 0) { QRB_ERROR("Failed to get system interface providers: null interface providers received."); return false; } - bool found_valid_sys_interface = false; for (size_t i = 0; i < number_of_providers; i++) { if (QNN_SYSTEM_API_VERSION_MAJOR == sys_interface_providers[i]->systemApiVersion.major && QNN_SYSTEM_API_VERSION_MINOR <= sys_interface_providers[i]->systemApiVersion.minor) { - found_valid_sys_interface = true; qnn_system_interface = sys_interface_providers[i]->QNN_SYSTEM_INTERFACE_VER_NAME; - break; + return true; } } - if (false == found_valid_sys_interface) { - QRB_ERROR("Unable to find a valid system interface."); - return false; - } - - return true; + QRB_ERROR("Unable to find a valid system interface."); + return false; } QnnTensor::QnnTensor(uint32_t num_of_input_tensors, uint32_t num_of_output_tensors)