Skip to content

Commit

Permalink
feat: add getModelInfo method (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 authored Dec 31, 2024
1 parent 55b3f9f commit 845a9f0
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 0 deletions.
1 change: 1 addition & 0 deletions lib/binding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ export type EmbeddingResult = {
export interface LlamaContext {
new (options: LlamaModelOptions): LlamaContext
getSystemInfo(): string
getModelInfo(): object
getFormattedChat(messages: ChatMessage[]): string
completion(options: LlamaCompletionOptions, callback?: (token: LlamaCompletionToken) => void): Promise<LlamaCompletionResult>
stopCompletion(): void
Expand Down
41 changes: 41 additions & 0 deletions src/LlamaContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ void LlamaContext::Init(Napi::Env env, Napi::Object &exports) {
{InstanceMethod<&LlamaContext::GetSystemInfo>(
"getSystemInfo",
static_cast<napi_property_attributes>(napi_enumerable)),
InstanceMethod<&LlamaContext::GetModelInfo>(
"getModelInfo",
static_cast<napi_property_attributes>(napi_enumerable)),
InstanceMethod<&LlamaContext::GetFormattedChat>(
"getFormattedChat",
static_cast<napi_property_attributes>(napi_enumerable)),
Expand Down Expand Up @@ -102,6 +105,44 @@ Napi::Value LlamaContext::GetSystemInfo(const Napi::CallbackInfo &info) {
return Napi::String::New(info.Env(), _info);
}

bool validateModelChatTemplate(const struct llama_model * model) {
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
std::string template_key = "tokenizer.chat_template";
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
if (res >= 0) {
llama_chat_message chat[] = {{"user", "test"}};
std::string tmpl = std::string(model_template.data(), model_template.size());
int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
return chat_res > 0;
}
return res > 0;
}

// getModelInfo(): object
Napi::Value LlamaContext::GetModelInfo(const Napi::CallbackInfo &info) {
char desc[1024];
auto model = _sess->model();
llama_model_desc(model, desc, sizeof(desc));

int count = llama_model_meta_count(model);
Napi::Object metadata = Napi::Object::New(info.Env());
for (int i = 0; i < count; i++) {
char key[256];
llama_model_meta_key_by_index(model, i, key, sizeof(key));
char val[2048];
llama_model_meta_val_str_by_index(model, i, val, sizeof(val));

metadata.Set(key, val);
}
Napi::Object details = Napi::Object::New(info.Env());
details.Set("desc", desc);
details.Set("nParams", llama_model_n_params(model));
details.Set("size", llama_model_size(model));
details.Set("isChatTemplateSupported", validateModelChatTemplate(model));
details.Set("metadata", metadata);
return details;
}

// getFormattedChat(messages: [{ role: string, content: string }]): string
Napi::Value LlamaContext::GetFormattedChat(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
Expand Down
2 changes: 2 additions & 0 deletions src/LlamaContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class LlamaContext : public Napi::ObjectWrap<LlamaContext> {

private:
Napi::Value GetSystemInfo(const Napi::CallbackInfo &info);
Napi::Value GetModelInfo(const Napi::CallbackInfo &info);
Napi::Value GetFormattedChat(const Napi::CallbackInfo &info);
Napi::Value Completion(const Napi::CallbackInfo &info);
void StopCompletion(const Napi::CallbackInfo &info);
Expand All @@ -20,6 +21,7 @@ class LlamaContext : public Napi::ObjectWrap<LlamaContext> {
Napi::Value Release(const Napi::CallbackInfo &info);

std::string _info;
Napi::Object _meta;
LlamaSessionPtr _sess = nullptr;
LlamaCompletionWorker *_wip = nullptr;
};
26 changes: 26 additions & 0 deletions test/__snapshots__/index.test.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -421,3 +421,29 @@ exports[`work fine 1`] = `
"truncated": false,
}
`;

exports[`work fine: model info 1`] = `
{
"desc": "llama ?B F16",
"isChatTemplateSupported": false,
"metadata": {
"general.architecture": "llama",
"general.file_type": "1",
"general.name": "LLaMA v2",
"llama.attention.head_count": "2",
"llama.attention.head_count_kv": "2",
"llama.attention.layer_norm_rms_epsilon": "0.000010",
"llama.block_count": "1",
"llama.context_length": "4096",
"llama.embedding_length": "8",
"llama.feed_forward_length": "32",
"llama.rope.dimension_count": "4",
"tokenizer.ggml.bos_token_id": "1",
"tokenizer.ggml.eos_token_id": "2",
"tokenizer.ggml.model": "llama",
"tokenizer.ggml.unknown_token_id": "0",
},
"nParams": 513048,
"size": 1026144,
}
`;
2 changes: 2 additions & 0 deletions test/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import { loadModel } from '../lib'
it('work fine', async () => {
let tokens = ''
const model = await loadModel({ model: path.resolve(__dirname, './tiny-random-llama.gguf') })
const info = model.getModelInfo()
expect(info).toMatchSnapshot('model info')
const result = await model.completion({
prompt: 'My name is Merve and my favorite',
n_samples: 1,
Expand Down

0 comments on commit 845a9f0

Please sign in to comment.