Skip to content

Commit

Permalink
tensor: add bind buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
chraac committed Sep 18, 2024
1 parent c1bd94c commit 74d5016
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 23 deletions.
4 changes: 2 additions & 2 deletions ggml/src/ggml-qnn/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
namespace qnn {
class ggml_qnn_rpc_buffer {
public:
ggml_qnn_rpc_buffer(std::shared_ptr<qnn_instance> qnn_instance, size_t size, uint32_t rank, uint32_t *dimensions,
Qnn_DataType_t data_type) :
ggml_qnn_rpc_buffer(std::shared_ptr<qnn_instance> qnn_instance, const size_t size, const uint32_t rank,
uint32_t *dimensions, Qnn_DataType_t data_type) :
_qnn_instance(qnn_instance), _size(size) {

_qnn_rpc_buffer = static_cast<uint8_t *>(qnn_instance->alloc_rpcmem(size, alignof(void *)));
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-qnn/op-config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,13 @@ bool ggml_qnn_op_config_base::bind_output_tensors(const ggml_tensor_array_t &ten

void ggml_qnn_op_config_base::unbind_input_tensors() {
for (auto &tensor : _tensor_inputs) {
tensor->unbind_ggml_tensor();
tensor->unbind();
}
}

void ggml_qnn_op_config_base::unbind_output_tensors() {
for (auto &tensor : _tensor_outputs) {
tensor->unbind_ggml_tensor();
tensor->unbind();
}
}

Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-qnn/qnn-lib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ class qnn_instance {
return mem_fd;
}

Qnn_MemHandle_t register_rpcmem(void *p_data, uint32_t rank, uint32_t *dimensions, Qnn_DataType_t data_type) {
Qnn_MemHandle_t register_rpcmem(void *p_data, const uint32_t rank, uint32_t *dimensions, Qnn_DataType_t data_type) {
if (!p_data) {
QNN_LOG_WARN("invalid param\n");
return nullptr;
Expand Down
49 changes: 31 additions & 18 deletions ggml/src/ggml-qnn/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#pragma once

#include <atomic>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
Expand Down Expand Up @@ -57,15 +58,14 @@ class ggml_qnn_tensor {
return true;
}

bool bind_ggml_tensor(ggml_tensor *tensor) {
if (_tensor) {
if (_tensor != tensor) {
QNN_LOG_WARN("tensor %s has been bound to another ggml tensor %s", _tensor_name.c_str(),
ggml_get_name(_tensor));
bool bind_buffer(uint8_t *buffer, const size_t buffer_size) {
if (_buffer) {
if (_buffer != buffer) {
QNN_LOG_WARN("tensor %s has been bound to another buffer %p", _tensor_name.c_str(), _buffer);
return false;
}
QNN_LOG_INFO("tensor %s already bound to same ggml tensor %s", _tensor_name.c_str(),
ggml_get_name(_tensor));

QNN_LOG_INFO("tensor %s already bound to same ggml tensor %p", _tensor_name.c_str(), _buffer);
return true;
}

Expand All @@ -78,7 +78,7 @@ class ggml_qnn_tensor {
if (should_use_mem_handle()) {
if (!_qnn_rpc_buffer) {
auto qnn_rpc_buffer = std::make_unique<ggml_qnn_rpc_buffer>(
_qnn_instance, ggml_nbytes(tensor), QNN_TENSOR_GET_RANK(_qnn_tensor),
_qnn_instance, buffer_size, QNN_TENSOR_GET_RANK(_qnn_tensor),
QNN_TENSOR_GET_DIMENSIONS(_qnn_tensor), QNN_TENSOR_GET_DATA_TYPE(_qnn_tensor));
if (!qnn_rpc_buffer->is_valid()) {
QNN_LOG_WARN("alloc rpc mem failed, tensor %s", _tensor_name.c_str());
Expand All @@ -93,30 +93,41 @@ class ggml_qnn_tensor {
QNN_LOG_DEBUG("tensor %s, use mem handle %p", _tensor_name.c_str(), QNN_TENSOR_GET_MEM_HANDLE(_qnn_tensor));
} else {
QNN_TENSOR_SET_MEM_TYPE(_qnn_tensor, QNN_TENSORMEMTYPE_RAW);
Qnn_ClientBuffer_t client_buf = { tensor->data, get_ggml_tensor_data_size(tensor) };
Qnn_ClientBuffer_t client_buf = { buffer, (uint32_t)buffer_size };
QNN_TENSOR_SET_CLIENT_BUF(_qnn_tensor, client_buf);
QNN_LOG_DEBUG("tensor %s, use client buffer %p size %d", _tensor_name.c_str(), client_buf.data,
(int)client_buf.dataSize);
}

_tensor = tensor;
_buffer = buffer;
_buffer_size = buffer_size;

if (!write_to_qnn_tensor()) {
QNN_LOG_WARN("write to qnn tensor failed, tensor %s", _tensor_name.c_str());
return false;
}

QNN_LOG_DEBUG("bind tensor %s to ggml tensor %s", _tensor_name.c_str(), ggml_get_name(tensor));
QNN_LOG_DEBUG("bind tensor %s to buffer: %p, size: %d", _tensor_name.c_str(), buffer, (int)buffer_size);
return true;
}

bool bind_ggml_tensor(ggml_tensor *tensor) {
if (!bind_buffer(reinterpret_cast<uint8_t *>(tensor->data), ggml_nbytes(tensor))) {
QNN_LOG_WARN("Failed to bind tensor: %s to ggml tensor: %s", _tensor_name.c_str(), ggml_get_name(tensor));
return false;
}

QNN_LOG_DEBUG("Bind tensor %s to ggml tensor %s", _tensor_name.c_str(), ggml_get_name(tensor));
return true;
}

bool unbind_ggml_tensor() {
bool unbind() {
if (!_graph_handle) {
QNN_LOG_WARN("tensor %s not bound to any graph", _tensor_name.c_str());
return false;
}

if (!_tensor) {
if (!_buffer) {
QNN_LOG_DEBUG("tensor %s not bound to ggml tensor", _tensor_name.c_str());
return true;
}
Expand All @@ -133,8 +144,9 @@ class ggml_qnn_tensor {
QNN_LOG_DEBUG("tensor %s, clear client buffer", _tensor_name.c_str());
}

QNN_LOG_DEBUG("unbind tensor: %s from ggml tensor: %s", _tensor_name.c_str(), ggml_get_name(_tensor));
_tensor = nullptr;
QNN_LOG_DEBUG("unbind tensor: %s from buffer: %p, size: %d", _tensor_name.c_str(), _buffer, (int)_buffer_size);
_buffer = nullptr;
_buffer_size = 0;
return true;
}

Expand All @@ -150,7 +162,7 @@ class ggml_qnn_tensor {

if (should_use_mem_handle()) {
if (_qnn_rpc_buffer) {
memcpy(_qnn_rpc_buffer->get_buffer(), _tensor->data, ggml_nbytes(_tensor));
memcpy(_qnn_rpc_buffer->get_buffer(), _buffer, _buffer_size);
} else {
QNN_LOG_WARN("tensor %s: can't find rpcmem from qnn mem handle\n", _tensor_name.c_str());
return false;
Expand All @@ -171,7 +183,7 @@ class ggml_qnn_tensor {

if (should_use_mem_handle()) {
if (_qnn_rpc_buffer) {
memcpy(_tensor->data, _qnn_rpc_buffer->get_buffer(), ggml_nbytes(_tensor));
memcpy(_buffer, _qnn_rpc_buffer->get_buffer(), _buffer_size);
} else {
QNN_LOG_WARN("can't find rpcmem from qnn mem handle\n");
return false;
Expand Down Expand Up @@ -217,7 +229,8 @@ class ggml_qnn_tensor {
bool should_use_mem_handle() const { return _device == QNN_BACKEND_NPU; }

std::string _tensor_name;
const ggml_tensor *_tensor;
uint8_t *_buffer = nullptr;
size_t _buffer_size = 0;
QNNBackend _device;
std::shared_ptr<qnn_instance> _qnn_instance;
Qnn_Tensor_t _qnn_tensor = qnn_tensor_init(kDefaultQnnTensorVersion);
Expand Down

0 comments on commit 74d5016

Please sign in to comment.