Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qrb_inference_manager: #13

Merged
merged 1 commit into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,11 @@ class QnnInterface
QNN_SYSTEM_INTERFACE_VER_TYPE qnn_system_interface;

private:
void * dlopen_helper(const char * file_name, const int flags);
template <class T>
inline T resolve_symbol(void * lib_handle, const char * sym);
inline T get_function_from_lib(const std::string & lib_name,
const int open_flag,
void ** lib_handle,
const char * func_name);
bool init_qnn_backend_interface(const std::string & backend_option, void ** backend_handle);
bool init_qnn_graph_interface(const std::string & model_path, void ** model_handle);
bool init_qnn_system_interface(const std::string & qnn_syslib_path, void ** sys_lib_handle);
Expand Down
153 changes: 39 additions & 114 deletions qrb_inference_manager/src/qnn_inference/qnn_inference_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,134 +34,77 @@ QnnInterface::QnnInterface(const std::string & backend_option,
}
}

void * QnnInterface::dlopen_helper(const char * file_name, const int flags)
{
int real_flags = 0;

if (flags & 0x0001) { // resolve undefined symbols before return. MUST be specified
real_flags |= RTLD_NOW;
}

if (flags & 0x0002) { // optional, but the default specified
real_flags |= RTLD_LOCAL;
}

if (flags & 0x0004) { // optional, resolve symbol globally
real_flags |= RTLD_GLOBAL;
}

return ::dlopen(file_name, real_flags);
}

template <class T>
inline T QnnInterface::resolve_symbol(void * lib_handle, const char * sym)
{
void * ptr = nullptr;
if (lib_handle == (void *)(0x4)) { // specify this address to distingiush from NULL pointer
ptr = ::dlsym(RTLD_DEFAULT, sym);
} else {
ptr = ::dlsym(lib_handle, sym);
inline T QnnInterface::get_function_from_lib(const std::string & lib_name,
const int open_flag,
void ** lib_handle,
const char * func_name)
{
if (nullptr == *lib_handle) {
*lib_handle = ::dlopen(lib_name.c_str(), open_flag);
if (nullptr == *lib_handle) {
QRB_ERROR("Unable to open lib!");
return nullptr;
}
}

if (ptr == nullptr) {
QRB_ERROR("Unable to access symbol", sym);
auto temp_lib_handle = *lib_handle;
auto func_ptr = ::dlsym(temp_lib_handle, func_name);
if (nullptr == func_ptr) {
QRB_ERROR("Unable to get function: ", func_name);
return nullptr;
}
return reinterpret_cast<T>(ptr);

return reinterpret_cast<T>(func_ptr);
}

/// @brief get the qnn interface from backend file
/// @brief get the qnn interface from backend library
/// @param backend_option file path of libQnnBackend.so
/// @param backend_handle pointer point to qnn interface
/// @return true of false
bool QnnInterface::init_qnn_backend_interface(const std::string & backend_option,
void ** backend_handle)
{
void * lib_backend_handle = dlopen_helper(backend_option.c_str(), 0x0001 | 0x0004);

if (nullptr == lib_backend_handle) {
QRB_ERROR("Unable to load backend!");
return false;
}

if (nullptr != backend_handle) {
*backend_handle = lib_backend_handle;
}

qnn_interface_providers_func get_interface_providers = nullptr;
get_interface_providers =
resolve_symbol<qnn_interface_providers_func>(lib_backend_handle, "QnnInterface_getProviders");

if (nullptr == get_interface_providers) {
return false;
}
auto get_interface_providers = get_function_from_lib<qnn_interface_providers_func>(
backend_option, RTLD_NOW | RTLD_GLOBAL, backend_handle, "QnnInterface_getProviders");

QnnInterface_t ** interface_providers = nullptr;
uint32_t num_providers = 0;

if (QNN_SUCCESS !=
get_interface_providers((const QnnInterface_t ***)&interface_providers, &num_providers)) {
QRB_ERROR("Failed to get interface providers!");
return false;
}

if (nullptr == interface_providers) {
QRB_ERROR("Failed to get interface providers: null interface providers received!");
return false;
}

if (0 == num_providers) {
QRB_ERROR("Failed to get interface providers: 0 interface providers!");
if (interface_providers == nullptr || num_providers == 0) {
QRB_ERROR("Invalid interface providers retrieved!");
return false;
}

bool found_valid_interface{ false };
for (size_t i = 0; i < num_providers; i++) {
if (QNN_API_VERSION_MAJOR == interface_providers[i]->apiVersion.coreApiVersion.major &&
QNN_API_VERSION_MINOR <= interface_providers[i]->apiVersion.coreApiVersion.minor) {
found_valid_interface = true;
this->interface = interface_providers[i]->QNN_INTERFACE_VER_NAME;
break;
return true;
}
}

if (!found_valid_interface) {
QRB_ERROR("Unable to find a valid interface!");
backend_handle = nullptr;
return false;
}

return true;
QRB_ERROR("Unable to find a valid interface!");
return false;
}

/// @brief get the qnn interface from model file
/// @brief get the qnn interface from model lib
/// @param model_path
/// @param model_handle pointer point to qnn interface
/// @return true or false
bool QnnInterface::init_qnn_graph_interface(const std::string & model_path, void ** model_handle)
{
void * lib_model_handle = dlopen_helper(model_path.c_str(), 0x0001 | 0x0002);

if (nullptr == lib_model_handle) {
QRB_ERROR("Unable to load model!");
return false;
}
this->compose_graphs = get_function_from_lib<compose_graphs_func>(
model_path, RTLD_NOW | RTLD_LOCAL, model_handle, "QnnModel_composeGraphs");

if (nullptr != model_handle) {
*model_handle = lib_model_handle;
}

std::string model_prepare_func = "QnnModel_composeGraphs";
this->compose_graphs =
resolve_symbol<compose_graphs_func>(lib_model_handle, model_prepare_func.c_str());
if (nullptr == this->compose_graphs) {
return false;
}

std::string model_free_func = "QnnModel_freeGraphsInfo";
this->free_graph_info =
resolve_symbol<free_graph_info_func>(lib_model_handle, model_free_func.c_str());
if (nullptr == this->free_graph_info) {
return false;
}
this->free_graph_info = get_function_from_lib<free_graph_info_func>(
model_path, RTLD_NOW | RTLD_LOCAL, model_handle, "QnnModel_freeGraphsInfo");

return true;
}
Expand All @@ -173,52 +116,34 @@ bool QnnInterface::init_qnn_graph_interface(const std::string & model_path, void
bool QnnInterface::init_qnn_system_interface(const std::string & qnn_syslib_path,
void ** sys_lib_handle)
{
void * lib_sys_handle = dlopen_helper(qnn_syslib_path.c_str(), 0x0001 | 0x0002);
if (nullptr == lib_sys_handle) {
QRB_ERROR("Unable to load system library!");
return false;
}

if (nullptr != sys_lib_handle) {
*sys_lib_handle = lib_sys_handle;
}

qnn_sys_interface_providers_func get_sys_interface_providers = nullptr;
get_sys_interface_providers = resolve_symbol<qnn_sys_interface_providers_func>(
lib_sys_handle, "QnnSystemInterface_getProviders");
if (nullptr == get_sys_interface_providers) {
return false;
}
auto get_sys_interface_providers = get_function_from_lib<qnn_sys_interface_providers_func>(
qnn_syslib_path, RTLD_NOW | RTLD_LOCAL, sys_lib_handle, "QnnSystemInterface_getProviders");

QnnSystemInterface_t ** sys_interface_providers = nullptr;
uint32_t number_of_providers = 0;

if (QNN_SUCCESS !=
get_sys_interface_providers(
(const QnnSystemInterface_t ***)&sys_interface_providers, &number_of_providers)) {
QRB_ERROR("Failed to get system interface providers.");
return false;
}

if (nullptr == sys_interface_providers || number_of_providers == 0) {
QRB_ERROR("Failed to get system interface providers: null interface providers received.");
return false;
}

bool found_valid_sys_interface = false;
for (size_t i = 0; i < number_of_providers; i++) {
if (QNN_SYSTEM_API_VERSION_MAJOR == sys_interface_providers[i]->systemApiVersion.major &&
QNN_SYSTEM_API_VERSION_MINOR <= sys_interface_providers[i]->systemApiVersion.minor) {
found_valid_sys_interface = true;
qnn_system_interface = sys_interface_providers[i]->QNN_SYSTEM_INTERFACE_VER_NAME;
break;
return true;
}
}

if (false == found_valid_sys_interface) {
QRB_ERROR("Unable to find a valid system interface.");
return false;
}

return true;
QRB_ERROR("Unable to find a valid system interface.");
return false;
}

QnnTensor::QnnTensor(uint32_t num_of_input_tensors, uint32_t num_of_output_tensors)
Expand Down
Loading