Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NUMA-aware KV cache buffer type (experimental) #11580

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ggml/include/ggml-backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ extern "C" {
// CPU buffer types are always available
GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
GGML_API ggml_backend_buffer_type_t ggml_backend_numa_buffer_type(void);

#ifdef __cplusplus
}
Expand Down
81 changes: 81 additions & 0 deletions ggml/src/ggml-backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2000,3 +2000,84 @@ ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size)
GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned");
return ggml_backend_buffer_init(ggml_backend_cpu_buffer_from_ptr_type(), ggml_backend_cpu_buffer_from_ptr_i, ptr, size);
}

// NUMA buffer interface - similar to CPU, but with pages allocated accordingly to a NUMA first-touch policy

#include <sys/mman.h>

static void ggml_backend_numa_buffer_free_buffer(ggml_backend_buffer_t buffer) {
if (munmap((char *) buffer->context, buffer->size)) {
GGML_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
}
}

static void ggml_backend_numa_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
if (posix_madvise(buffer->context, buffer->size, POSIX_MADV_DONTNEED)) {
GGML_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_DONTNEED) failed: %s\n",
strerror(errno));
}
}

static const struct ggml_backend_buffer_i ggml_backend_numa_buffer_i = {
/* .free_buffer = */ ggml_backend_numa_buffer_free_buffer,
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
/* .init_tensor = */ NULL, // no initialization required
/* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
/* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
/* .clear = */ ggml_backend_numa_buffer_clear,
/* .reset = */ NULL,
};

// NUMA buffer type - similar to CPU, but with pages allocated accordingly to a NUMA first-touch policy

static const char * ggml_backend_numa_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
return "NUMA";

GGML_UNUSED(buft);
}

static ggml_backend_buffer_t ggml_backend_numa_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
int flags = MAP_SHARED | MAP_ANONYMOUS;
void * data = mmap(NULL, size, PROT_READ|PROT_WRITE, flags, -1, 0);
if (data == MAP_FAILED) {
GGML_LOG_ERROR("%s: failed to allocate buffer of size %zu\n", __func__, size);
return NULL;
}
if (posix_madvise(data, size, POSIX_MADV_RANDOM)) {
GGML_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n",
strerror(errno));
}

return ggml_backend_buffer_init(buft, ggml_backend_numa_buffer_i, data, size);
}

static size_t ggml_backend_numa_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
return TENSOR_ALIGNMENT;

GGML_UNUSED(buft);
}

static bool ggml_backend_numa_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
return true;

GGML_UNUSED(buft);
}

ggml_backend_buffer_type_t ggml_backend_numa_buffer_type(void) {
static struct ggml_backend_buffer_type ggml_backend_numa_buffer_type = {
/* .iface = */ {
/* .get_name = */ ggml_backend_numa_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_numa_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_numa_buffer_type_get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .is_host = */ ggml_backend_numa_buffer_type_is_host,
},
/* .device = */ NULL, // FIXME ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
/* .context = */ NULL,
};

return &ggml_backend_numa_buffer_type;
}
13 changes: 12 additions & 1 deletion src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ bool llama_kv_cache_init(
cache.k_l.reserve(n_layer);
cache.v_l.reserve(n_layer);

auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
bool is_numa = is_numa_fn();
if (!offload && is_numa) {
LLAMA_LOG_INFO("%s: NUMA usage detected, using NUMA-aware buffer for KV cache\n", __func__);
}

for (int i = 0; i < n_layer; i++) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
Expand All @@ -82,7 +89,11 @@ bool llama_kv_cache_init(
auto * dev = model.dev_layer(i);
buft = ggml_backend_dev_buffer_type(dev);
} else {
buft = ggml_backend_cpu_buffer_type();
if (is_numa) {
buft = ggml_backend_numa_buffer_type();
} else {
buft = ggml_backend_cpu_buffer_type();
}
}
ggml_context * ctx = ctx_for_buft(buft);

Expand Down
Loading