From 845a9f0245454619378a20a23fcb991a044648fb Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Tue, 31 Dec 2024 09:35:55 +0800 Subject: [PATCH] feat: add getModelInfo method (#63) --- lib/binding.ts | 1 + src/LlamaContext.cpp | 41 +++++++++++++++++++++++++++ src/LlamaContext.h | 2 ++ test/__snapshots__/index.test.ts.snap | 26 +++++++++++++++++ test/index.test.ts | 2 ++ 5 files changed, 72 insertions(+) diff --git a/lib/binding.ts b/lib/binding.ts index a07e5f9..271dddc 100644 --- a/lib/binding.ts +++ b/lib/binding.ts @@ -54,6 +54,7 @@ export type EmbeddingResult = { export interface LlamaContext { new (options: LlamaModelOptions): LlamaContext getSystemInfo(): string + getModelInfo(): object getFormattedChat(messages: ChatMessage[]): string completion(options: LlamaCompletionOptions, callback?: (token: LlamaCompletionToken) => void): Promise stopCompletion(): void diff --git a/src/LlamaContext.cpp b/src/LlamaContext.cpp index 9c94384..94ed7f4 100644 --- a/src/LlamaContext.cpp +++ b/src/LlamaContext.cpp @@ -25,6 +25,9 @@ void LlamaContext::Init(Napi::Env env, Napi::Object &exports) { {InstanceMethod<&LlamaContext::GetSystemInfo>( "getSystemInfo", static_cast(napi_enumerable)), + InstanceMethod<&LlamaContext::GetModelInfo>( + "getModelInfo", + static_cast(napi_enumerable)), InstanceMethod<&LlamaContext::GetFormattedChat>( "getFormattedChat", static_cast(napi_enumerable)), @@ -102,6 +105,44 @@ Napi::Value LlamaContext::GetSystemInfo(const Napi::CallbackInfo &info) { return Napi::String::New(info.Env(), _info); } +bool validateModelChatTemplate(const struct llama_model * model) { + std::vector model_template(2048, 0); // longest known template is about 1200 bytes + std::string template_key = "tokenizer.chat_template"; + int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); + if (res >= 0) { + llama_chat_message chat[] = {{"user", "test"}}; + std::string tmpl = std::string(model_template.data(), model_template.size()); + int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0); + return chat_res > 0; + } + return res > 0; +} + +// getModelInfo(): object +Napi::Value LlamaContext::GetModelInfo(const Napi::CallbackInfo &info) { + char desc[1024]; + auto model = _sess->model(); + llama_model_desc(model, desc, sizeof(desc)); + + int count = llama_model_meta_count(model); + Napi::Object metadata = Napi::Object::New(info.Env()); + for (int i = 0; i < count; i++) { + char key[256]; + llama_model_meta_key_by_index(model, i, key, sizeof(key)); + char val[2048]; + llama_model_meta_val_str_by_index(model, i, val, sizeof(val)); + + metadata.Set(key, val); + } + Napi::Object details = Napi::Object::New(info.Env()); + details.Set("desc", desc); + details.Set("nParams", llama_model_n_params(model)); + details.Set("size", llama_model_size(model)); + details.Set("isChatTemplateSupported", validateModelChatTemplate(model)); + details.Set("metadata", metadata); + return details; +} + // getFormattedChat(messages: [{ role: string, content: string }]): string Napi::Value LlamaContext::GetFormattedChat(const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); diff --git a/src/LlamaContext.h b/src/LlamaContext.h index f53f3ed..b0ef374 100644 --- a/src/LlamaContext.h +++ b/src/LlamaContext.h @@ -9,6 +9,7 @@ class LlamaContext : public Napi::ObjectWrap { private: Napi::Value GetSystemInfo(const Napi::CallbackInfo &info); + Napi::Value GetModelInfo(const Napi::CallbackInfo &info); Napi::Value GetFormattedChat(const Napi::CallbackInfo &info); Napi::Value Completion(const Napi::CallbackInfo &info); void StopCompletion(const Napi::CallbackInfo &info); @@ -20,6 +21,7 @@ class LlamaContext : public Napi::ObjectWrap { Napi::Value Release(const Napi::CallbackInfo &info); std::string _info; + Napi::Object _meta; LlamaSessionPtr _sess = nullptr; LlamaCompletionWorker *_wip = nullptr; }; diff --git a/test/__snapshots__/index.test.ts.snap b/test/__snapshots__/index.test.ts.snap index 9bb711f..96f363c 100644 --- a/test/__snapshots__/index.test.ts.snap +++ b/test/__snapshots__/index.test.ts.snap @@ -421,3 +421,29 @@ exports[`work fine 1`] = ` "truncated": false, } `; + +exports[`work fine: model info 1`] = ` +{ + "desc": "llama ?B F16", + "isChatTemplateSupported": false, + "metadata": { + "general.architecture": "llama", + "general.file_type": "1", + "general.name": "LLaMA v2", + "llama.attention.head_count": "2", + "llama.attention.head_count_kv": "2", + "llama.attention.layer_norm_rms_epsilon": "0.000010", + "llama.block_count": "1", + "llama.context_length": "4096", + "llama.embedding_length": "8", + "llama.feed_forward_length": "32", + "llama.rope.dimension_count": "4", + "tokenizer.ggml.bos_token_id": "1", + "tokenizer.ggml.eos_token_id": "2", + "tokenizer.ggml.model": "llama", + "tokenizer.ggml.unknown_token_id": "0", + }, + "nParams": 513048, + "size": 1026144, +} +`; diff --git a/test/index.test.ts b/test/index.test.ts index e0fbeda..b881462 100644 --- a/test/index.test.ts +++ b/test/index.test.ts @@ -5,6 +5,8 @@ import { loadModel } from '../lib' it('work fine', async () => { let tokens = '' const model = await loadModel({ model: path.resolve(__dirname, './tiny-random-llama.gguf') }) + const info = model.getModelInfo() + expect(info).toMatchSnapshot('model info') const result = await model.completion({ prompt: 'My name is Merve and my favorite', n_samples: 1,