Skip to content

Commit

Permalink
feat: expose flash_attn / cache_type_k / cache_type_v
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Jan 11, 2025
1 parent c571221 commit ff47142
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
19 changes: 19 additions & 0 deletions lib/binding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,25 @@ export type LlamaModelOptions = {
n_ubatch?: number
n_threads?: number
n_gpu_layers?: number
flash_attn?: boolean
cache_type_k?:
| 'f16'
| 'f32'
| 'q8_0'
| 'q4_0'
| 'q4_1'
| 'iq4_nl'
| 'q5_0'
| 'q5_1'
cache_type_v?:
| 'f16'
| 'f32'
| 'q8_0'
| 'q4_0'
| 'q4_1'
| 'iq4_nl'
| 'q5_0'
| 'q5_1'
use_mlock?: boolean
use_mmap?: boolean
vocab_only?: boolean
Expand Down
26 changes: 26 additions & 0 deletions src/LlamaContext.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "ggml.h"
#include "LlamaContext.h"
#include "DetokenizeWorker.h"
#include "DisposeWorker.h"
Expand Down Expand Up @@ -60,6 +61,27 @@ void LlamaContext::Init(Napi::Env env, Napi::Object &exports) {
exports.Set("LlamaContext", func);
}

const std::vector<ggml_type> kv_cache_types = {
GGML_TYPE_F32,
GGML_TYPE_F16,
GGML_TYPE_BF16,
GGML_TYPE_Q8_0,
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1,
GGML_TYPE_IQ4_NL,
GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1,
};

static ggml_type kv_cache_type_from_str(const std::string & s) {
for (const auto & type : kv_cache_types) {
if (ggml_type_name(type) == s) {
return type;
}
}
throw std::runtime_error("Unsupported cache type: " + s);
}

// construct({ model, embedding, n_ctx, n_batch, n_threads, n_gpu_layers,
// use_mlock, use_mmap }): LlamaContext throws error
LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
Expand Down Expand Up @@ -96,6 +118,10 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
params.cpuparams.n_threads =
get_option<int32_t>(options, "n_threads", cpu_get_num_math() / 2);
params.n_gpu_layers = get_option<int32_t>(options, "n_gpu_layers", -1);
params.flash_attn = get_option<bool>(options, "flash_attn", false);
params.cache_type_k = kv_cache_type_from_str(get_option<std::string>(options, "cache_type_k", "f16").c_str());
params.cache_type_v = kv_cache_type_from_str(get_option<std::string>(options, "cache_type_v", "f16").c_str());

params.use_mlock = get_option<bool>(options, "use_mlock", false);
params.use_mmap = get_option<bool>(options, "use_mmap", true);
params.numa =
Expand Down

0 comments on commit ff47142

Please sign in to comment.