From 8bf9dd8bf1495592dde967a4cf95a4f8bc0f88f9 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Sat, 16 Nov 2024 14:27:53 +0800 Subject: [PATCH] feat: add static method for read model info from gguf (#87) * feat: add static method for read model info from gguf * feat(ts): rename method & update example * chore: docgen * fix(ios): add missing free context * fix(example): revert unnecessary change * feat(android): implement modelInfo method --- .../main/java/com/rnllama/LlamaContext.java | 4 ++ .../src/main/java/com/rnllama/RNLlama.java | 29 ++++++++ android/src/main/jni.cpp | 69 +++++++++++++++++- .../java/com/rnllama/RNLlamaModule.java | 5 ++ .../java/com/rnllama/RNLlamaModule.java | 5 ++ cpp/rn-llama.hpp | 56 +++++++++++++++ docs/API/README.md | 37 +++++++--- docs/API/classes/LlamaContext.md | 30 ++++---- docs/API/classes/SchemaGrammarConverter.md | 32 ++++----- example/src/App.tsx | 70 ++++++++++++------- ios/RNLlama.mm | 8 +++ ios/RNLlamaContext.h | 3 + ios/RNLlamaContext.mm | 39 +++++++++++ src/NativeRNLlama.ts | 2 + src/index.ts | 12 ++++ 15 files changed, 335 insertions(+), 66 deletions(-) diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index 337ed04..36b1a1c 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -358,6 +358,10 @@ private static String getCpuFeatures() { } } + protected static native WritableMap modelInfo( + String model, + String[] skip + ); protected static native long initContext( String model, boolean embedding, diff --git a/android/src/main/java/com/rnllama/RNLlama.java b/android/src/main/java/com/rnllama/RNLlama.java index eb02755..1f02f2d 100644 --- a/android/src/main/java/com/rnllama/RNLlama.java +++ b/android/src/main/java/com/rnllama/RNLlama.java @@ -42,6 +42,35 @@ public void setContextLimit(double limit, Promise promise) { promise.resolve(null); } + public void modelInfo(final String model, final ReadableArray skip, final Promise promise) { + new AsyncTask() { + private Exception exception; + + @Override + protected WritableMap doInBackground(Void... voids) { + try { + String[] skipArray = new String[skip.size()]; + for (int i = 0; i < skip.size(); i++) { + skipArray[i] = skip.getString(i); + } + return LlamaContext.modelInfo(model, skipArray); + } catch (Exception e) { + exception = e; + } + return null; + } + + @Override + protected void onPostExecute(WritableMap result) { + if (exception != null) { + promise.reject(exception); + return; + } + promise.resolve(result); + } + }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR); + } + public void initContext(double id, final ReadableMap params, final Promise promise) { final int contextId = (int) id; AsyncTask task = new AsyncTask() { diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 128aa3f..3beb203 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -9,8 +9,9 @@ #include #include #include "llama.h" -#include "rn-llama.hpp" +#include "llama-impl.h" #include "ggml.h" +#include "rn-llama.hpp" #define UNUSED(x) (void)(x) #define TAG "RNLLAMA_ANDROID_JNI" @@ -132,6 +133,72 @@ static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject v env->CallVoidMethod(map, putArrayMethod, jKey, value); } +JNIEXPORT jobject JNICALL +Java_com_rnllama_LlamaContext_modelInfo( + JNIEnv *env, + jobject thiz, + jstring model_path_str, + jobjectArray skip +) { + UNUSED(thiz); + + const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr); + + std::vector skip_vec; + int skip_len = env->GetArrayLength(skip); + for (int i = 0; i < skip_len; i++) { + jstring skip_str = (jstring) env->GetObjectArrayElement(skip, i); + const char *skip_chars = env->GetStringUTFChars(skip_str, nullptr); + skip_vec.push_back(skip_chars); + env->ReleaseStringUTFChars(skip_str, skip_chars); + } + + struct lm_gguf_init_params params = { + /*.no_alloc = */ false, + /*.ctx = */ NULL, + }; + struct lm_gguf_context * ctx = lm_gguf_init_from_file(model_path_chars, params); + + if (!ctx) { + LOGI("%s: failed to load '%s'\n", __func__, model_path_chars); + return nullptr; + } + + auto info = createWriteableMap(env); + putInt(env, info, "version", lm_gguf_get_version(ctx)); + putInt(env, info, "alignment", lm_gguf_get_alignment(ctx)); + putInt(env, info, "data_offset", lm_gguf_get_data_offset(ctx)); + { + const int n_kv = lm_gguf_get_n_kv(ctx); + + for (int i = 0; i < n_kv; ++i) { + const char * key = lm_gguf_get_key(ctx, i); + + bool skipped = false; + if (skip_len > 0) { + for (int j = 0; j < skip_len; j++) { + if (skip_vec[j] == key) { + skipped = true; + break; + } + } + } + + if (skipped) { + continue; + } + + const std::string value = rnllama::lm_gguf_kv_to_str(ctx, i); + putString(env, info, key, value.c_str()); + } + } + + env->ReleaseStringUTFChars(model_path_str, model_path_chars); + lm_gguf_free(ctx); + + return reinterpret_cast(info); +} + struct callback_context { JNIEnv *env; rnllama::llama_rn_context *llama; diff --git a/android/src/newarch/java/com/rnllama/RNLlamaModule.java b/android/src/newarch/java/com/rnllama/RNLlamaModule.java index 5bab9b1..a41aa05 100644 --- a/android/src/newarch/java/com/rnllama/RNLlamaModule.java +++ b/android/src/newarch/java/com/rnllama/RNLlamaModule.java @@ -37,6 +37,11 @@ public void setContextLimit(double limit, Promise promise) { rnllama.setContextLimit(limit, promise); } + @ReactMethod + public void modelInfo(final String model, final ReadableArray skip, final Promise promise) { + rnllama.modelInfo(model, skip, promise); + } + @ReactMethod public void initContext(double id, final ReadableMap params, final Promise promise) { rnllama.initContext(id, params, promise); diff --git a/android/src/oldarch/java/com/rnllama/RNLlamaModule.java b/android/src/oldarch/java/com/rnllama/RNLlamaModule.java index 2719515..4f01542 100644 --- a/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +++ b/android/src/oldarch/java/com/rnllama/RNLlamaModule.java @@ -38,6 +38,11 @@ public void setContextLimit(double limit, Promise promise) { rnllama.setContextLimit(limit, promise); } + @ReactMethod + public void modelInfo(final String model, final ReadableArray skip, final Promise promise) { + rnllama.modelInfo(model, skip, promise); + } + @ReactMethod public void initContext(double id, final ReadableMap params, final Promise promise) { rnllama.initContext(id, params, promise); diff --git a/cpp/rn-llama.hpp b/cpp/rn-llama.hpp index 612f671..1da49e6 100644 --- a/cpp/rn-llama.hpp +++ b/cpp/rn-llama.hpp @@ -4,11 +4,67 @@ #include #include #include "common.h" +#include "ggml.h" #include "llama.h" +#include "llama-impl.h" #include "sampling.h" namespace rnllama { +static std::string lm_gguf_data_to_str(enum lm_gguf_type type, const void * data, int i) { + switch (type) { + case LM_GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]); + case LM_GGUF_TYPE_INT8: return std::to_string(((const int8_t *)data)[i]); + case LM_GGUF_TYPE_UINT16: return std::to_string(((const uint16_t *)data)[i]); + case LM_GGUF_TYPE_INT16: return std::to_string(((const int16_t *)data)[i]); + case LM_GGUF_TYPE_UINT32: return std::to_string(((const uint32_t *)data)[i]); + case LM_GGUF_TYPE_INT32: return std::to_string(((const int32_t *)data)[i]); + case LM_GGUF_TYPE_UINT64: return std::to_string(((const uint64_t *)data)[i]); + case LM_GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]); + case LM_GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]); + case LM_GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]); + case LM_GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false"; + default: return "unknown type: {}"; // TODO + } +} + +static std::string lm_gguf_kv_to_str(const struct lm_gguf_context * ctx_gguf, int i) { + const enum lm_gguf_type type = lm_gguf_get_kv_type(ctx_gguf, i); + + switch (type) { + case LM_GGUF_TYPE_STRING: + return lm_gguf_get_val_str(ctx_gguf, i); + case LM_GGUF_TYPE_ARRAY: + { + const enum lm_gguf_type arr_type = lm_gguf_get_arr_type(ctx_gguf, i); + int arr_n = lm_gguf_get_arr_n(ctx_gguf, i); + const void * data = lm_gguf_get_arr_data(ctx_gguf, i); + std::stringstream ss; + ss << "["; + for (int j = 0; j < arr_n; j++) { + if (arr_type == LM_GGUF_TYPE_STRING) { + std::string val = lm_gguf_get_arr_str(ctx_gguf, i, j); + // escape quotes + replace_all(val, "\\", "\\\\"); + replace_all(val, "\"", "\\\""); + ss << '"' << val << '"'; + } else if (arr_type == LM_GGUF_TYPE_ARRAY) { + ss << "???"; + } else { + ss << lm_gguf_data_to_str(arr_type, data, j); + } + if (j < arr_n - 1) { + ss << ", "; + } + } + ss << "]"; + return ss.str(); + } + default: + return lm_gguf_data_to_str(type, lm_gguf_get_val_data(ctx_gguf, i), 0); + } +} + static void llama_batch_clear(llama_batch *batch) { batch->n_tokens = 0; } diff --git a/docs/API/README.md b/docs/API/README.md index 8dde999..2dc0cee 100644 --- a/docs/API/README.md +++ b/docs/API/README.md @@ -20,6 +20,7 @@ llama.rn - [convertJsonSchemaToGrammar](README.md#convertjsonschematogrammar) - [initLlama](README.md#initllama) +- [loadLlamaModelInfo](README.md#loadllamamodelinfo) - [releaseAllLlama](README.md#releaseallllama) - [setContextLimit](README.md#setcontextlimit) @@ -43,7 +44,7 @@ llama.rn #### Defined in -[index.ts:52](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L52) +[index.ts:52](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L52) ___ @@ -53,7 +54,7 @@ ___ #### Defined in -[index.ts:44](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L44) +[index.ts:44](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L44) ___ @@ -63,7 +64,7 @@ ___ #### Defined in -[index.ts:42](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L42) +[index.ts:42](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L42) ___ @@ -80,7 +81,7 @@ ___ #### Defined in -[index.ts:32](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L32) +[index.ts:32](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L32) ## Functions @@ -104,7 +105,7 @@ ___ #### Defined in -[grammar.ts:824](https://github.com/mybigday/llama.rn/blob/41b779f/src/grammar.ts#L824) +[grammar.ts:824](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L824) ___ @@ -125,7 +126,27 @@ ___ #### Defined in -[index.ts:196](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L196) +[index.ts:208](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L208) + +___ + +### loadLlamaModelInfo + +▸ **loadLlamaModelInfo**(`model`): `Promise`<`Object`\> + +#### Parameters + +| Name | Type | +| :------ | :------ | +| `model` | `string` | + +#### Returns + +`Promise`<`Object`\> + +#### Defined in + +[index.ts:202](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L202) ___ @@ -139,7 +160,7 @@ ___ #### Defined in -[index.ts:233](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L233) +[index.ts:245](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L245) ___ @@ -159,4 +180,4 @@ ___ #### Defined in -[index.ts:188](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L188) +[index.ts:188](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L188) diff --git a/docs/API/classes/LlamaContext.md b/docs/API/classes/LlamaContext.md index d5ffb30..e6dd8a3 100644 --- a/docs/API/classes/LlamaContext.md +++ b/docs/API/classes/LlamaContext.md @@ -42,7 +42,7 @@ #### Defined in -[index.ts:73](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L73) +[index.ts:73](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L73) ## Properties @@ -52,7 +52,7 @@ #### Defined in -[index.ts:65](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L65) +[index.ts:65](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L65) ___ @@ -62,7 +62,7 @@ ___ #### Defined in -[index.ts:63](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L63) +[index.ts:63](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L63) ___ @@ -78,7 +78,7 @@ ___ #### Defined in -[index.ts:69](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L69) +[index.ts:69](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L69) ___ @@ -88,7 +88,7 @@ ___ #### Defined in -[index.ts:67](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L67) +[index.ts:67](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L67) ## Methods @@ -111,7 +111,7 @@ ___ #### Defined in -[index.ts:163](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L163) +[index.ts:163](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L163) ___ @@ -132,7 +132,7 @@ ___ #### Defined in -[index.ts:110](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L110) +[index.ts:110](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L110) ___ @@ -152,7 +152,7 @@ ___ #### Defined in -[index.ts:155](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L155) +[index.ts:155](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L155) ___ @@ -172,7 +172,7 @@ ___ #### Defined in -[index.ts:159](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L159) +[index.ts:159](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L159) ___ @@ -192,7 +192,7 @@ ___ #### Defined in -[index.ts:99](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L99) +[index.ts:99](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L99) ___ @@ -214,7 +214,7 @@ Load cached prompt & completion state from a file. #### Defined in -[index.ts:83](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L83) +[index.ts:83](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L83) ___ @@ -228,7 +228,7 @@ ___ #### Defined in -[index.ts:183](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L183) +[index.ts:183](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L183) ___ @@ -252,7 +252,7 @@ Save current cached prompt & completion state to a file. #### Defined in -[index.ts:92](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L92) +[index.ts:92](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L92) ___ @@ -266,7 +266,7 @@ ___ #### Defined in -[index.ts:147](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L147) +[index.ts:147](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L147) ___ @@ -286,4 +286,4 @@ ___ #### Defined in -[index.ts:151](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L151) +[index.ts:151](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L151) diff --git a/docs/API/classes/SchemaGrammarConverter.md b/docs/API/classes/SchemaGrammarConverter.md index 8a16036..15778b4 100644 --- a/docs/API/classes/SchemaGrammarConverter.md +++ b/docs/API/classes/SchemaGrammarConverter.md @@ -46,7 +46,7 @@ #### Defined in -[grammar.ts:211](https://github.com/mybigday/llama.rn/blob/41b779f/src/grammar.ts#L211) +[grammar.ts:211](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L211) ## Properties @@ -56,7 +56,7 @@ #### Defined in -[grammar.ts:201](https://github.com/mybigday/llama.rn/blob/41b779f/src/grammar.ts#L201) +[grammar.ts:201](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L201) ___ @@ -66,7 +66,7 @@ ___ #### Defined in -[grammar.ts:203](https://github.com/mybigday/llama.rn/blob/41b779f/src/grammar.ts#L203) +[grammar.ts:203](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L203) ___ @@ -76,7 +76,7 @@ ___ #### Defined in -[grammar.ts:199](https://github.com/mybigday/llama.rn/blob/41b779f/src/grammar.ts#L199) +[grammar.ts:199](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L199) ___ @@ -90,7 +90,7 @@ ___ #### Defined in -[grammar.ts:207](https://github.com/mybigday/llama.rn/blob/41b779f/src/grammar.ts#L207) +[grammar.ts:207](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L207) ___ @@ -100,7 +100,7 @@ ___ #### Defined in -[grammar.ts:209](https://github.com/mybigday/llama.rn/blob/41b779f/src/grammar.ts#L209) +[grammar.ts:209](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L209) ___ @@ -114,7 +114,7 @@ ___ #### Defined in -[grammar.ts:205](https://github.com/mybigday/llama.rn/blob/41b779f/src/grammar.ts#L205) +[grammar.ts:205](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L205) ## Methods @@ -135,7 +135,7 @@ ___ #### Defined in -[grammar.ts:693](https://github.com/mybigday/llama.rn/blob/41b779f/src/grammar.ts#L693) +[grammar.ts:693](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L693) ___ @@ -156,7 +156,7 @@ ___ #### Defined in -[grammar.ts:224](https://github.com/mybigday/llama.rn/blob/41b779f/src/grammar.ts#L224) +[grammar.ts:224](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L224) ___ @@ -179,7 +179,7 @@ ___ #### Defined in -[grammar.ts:710](https://github.com/mybigday/llama.rn/blob/41b779f/src/grammar.ts#L710) +[grammar.ts:710](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L710) ___ @@ -200,7 +200,7 @@ ___ #### Defined in -[grammar.ts:312](https://github.com/mybigday/llama.rn/blob/41b779f/src/grammar.ts#L312) +[grammar.ts:312](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L312) ___ @@ -220,7 +220,7 @@ ___ #### Defined in -[grammar.ts:518](https://github.com/mybigday/llama.rn/blob/41b779f/src/grammar.ts#L518) +[grammar.ts:518](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L518) ___ @@ -241,7 +241,7 @@ ___ #### Defined in -[grammar.ts:323](https://github.com/mybigday/llama.rn/blob/41b779f/src/grammar.ts#L323) +[grammar.ts:323](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L323) ___ @@ -255,7 +255,7 @@ ___ #### Defined in -[grammar.ts:813](https://github.com/mybigday/llama.rn/blob/41b779f/src/grammar.ts#L813) +[grammar.ts:813](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L813) ___ @@ -276,7 +276,7 @@ ___ #### Defined in -[grammar.ts:247](https://github.com/mybigday/llama.rn/blob/41b779f/src/grammar.ts#L247) +[grammar.ts:247](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L247) ___ @@ -297,4 +297,4 @@ ___ #### Defined in -[grammar.ts:529](https://github.com/mybigday/llama.rn/blob/41b779f/src/grammar.ts#L529) +[grammar.ts:529](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L529) diff --git a/example/src/App.tsx b/example/src/App.tsx index 4a8e9d2..dfe84e6 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -8,8 +8,13 @@ import { Chat, darkTheme } from '@flyerhq/react-native-chat-ui' import type { MessageType } from '@flyerhq/react-native-chat-ui' import json5 from 'json5' import ReactNativeBlobUtil from 'react-native-blob-util' -// eslint-disable-next-line import/no-unresolved -import { initLlama, LlamaContext, convertJsonSchemaToGrammar } from 'llama.rn' +import type { LlamaContext } from 'llama.rn' +import { + initLlama, + loadLlamaModelInfo, + convertJsonSchemaToGrammar, + // eslint-disable-next-line import/no-unresolved +} from 'llama.rn' import { Bubble } from './Bubble' const { dirs } = ReactNativeBlobUtil.fs @@ -81,37 +86,50 @@ export default function App() { }) } + // Example: Get model info without initializing context + const getModelInfo = async (model: string) => { + const t0 = Date.now() + const info = await loadLlamaModelInfo(model) + console.log(`Model info (took ${Date.now() - t0}ms): `, info) + } + const handleInitContext = async (file: DocumentPickerResponse) => { await handleReleaseContext() + await getModelInfo(file.uri) const msgId = addSystemMessage('Initializing context...') - initLlama({ - model: file.uri, - use_mlock: true, - n_gpu_layers: Platform.OS === 'ios' ? 0 : 0, // > 0: enable GPU - // embedding: true, - }, (progress) => { - setMessages((msgs) => { - const index = msgs.findIndex((msg) => msg.id === msgId) - if (index >= 0) { - return msgs.map((msg, i) => { - if (msg.type == 'text' && i === index) { - return { - ...msg, - text: `Initializing context... ${progress}%`, + const t0 = Date.now() + initLlama( + { + model: file.uri, + use_mlock: true, + n_gpu_layers: Platform.OS === 'ios' ? 0 : 0, // > 0: enable GPU + // embedding: true, + }, + (progress) => { + setMessages((msgs) => { + const index = msgs.findIndex((msg) => msg.id === msgId) + if (index >= 0) { + return msgs.map((msg, i) => { + if (msg.type == 'text' && i === index) { + return { + ...msg, + text: `Initializing context... ${progress}%`, + } } - } - return msg - }) - } - return msgs - }) - }) + return msg + }) + } + return msgs + }) + }, + ) .then((ctx) => { + const t1 = Date.now() setContext(ctx) addSystemMessage( - `Context initialized! \n\nGPU: ${ctx.gpu ? 'YES' : 'NO'} (${ - ctx.reasonNoGPU - })\nChat Template: ${ + `Context initialized!\n\nLoad time: ${t1 - t0}ms\nGPU: ${ + ctx.gpu ? 'YES' : 'NO' + } (${ctx.reasonNoGPU})\nChat Template: ${ ctx.model.isChatTemplateSupported ? 'YES' : 'NO' }\n\n` + 'You can use the following commands:\n\n' + diff --git a/ios/RNLlama.mm b/ios/RNLlama.mm index cef6630..d36dff7 100644 --- a/ios/RNLlama.mm +++ b/ios/RNLlama.mm @@ -21,6 +21,14 @@ @implementation RNLlama resolve(nil); } +RCT_EXPORT_METHOD(modelInfo:(NSString *)path + withSkip:(NSArray *)skip + withResolver:(RCTPromiseResolveBlock)resolve + withRejecter:(RCTPromiseRejectBlock)reject) +{ + resolve([RNLlamaContext modelInfo:path skip:skip]); +} + RCT_EXPORT_METHOD(initContext:(double)contextId withContextParams:(NSDictionary *)contextParams withResolver:(RCTPromiseResolveBlock)resolve diff --git a/ios/RNLlamaContext.h b/ios/RNLlamaContext.h index 3d70a54..922f8d1 100644 --- a/ios/RNLlamaContext.h +++ b/ios/RNLlamaContext.h @@ -1,5 +1,7 @@ #ifdef __cplusplus #import "llama.h" +#import "llama-impl.h" +#import "ggml.h" #import "rn-llama.hpp" #endif @@ -14,6 +16,7 @@ rnllama::llama_rn_context * llama; } ++ (NSDictionary *)modelInfo:(NSString *)path skip:(NSArray *)skip; + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsigned int progress))onProgress; - (void)interruptLoad; - (bool)isMetalEnabled; diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index 1e4ad35..36249ef 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -3,6 +3,45 @@ @implementation RNLlamaContext ++ (NSDictionary *)modelInfo:(NSString *)path skip:(NSArray *)skip { + struct lm_gguf_init_params params = { + /*.no_alloc = */ false, + /*.ctx = */ NULL, + }; + + struct lm_gguf_context * ctx = lm_gguf_init_from_file([path UTF8String], params); + + if (!ctx) { + NSLog(@"%s: failed to load '%s'\n", __func__, [path UTF8String]); + return @{}; + } + + NSMutableDictionary *info = [[NSMutableDictionary alloc] init]; + + info[@"version"] = @(lm_gguf_get_version(ctx)); + info[@"alignment"] = @(lm_gguf_get_alignment(ctx)); + info[@"data_offset"] = @(lm_gguf_get_data_offset(ctx)); + + // kv + { + const int n_kv = lm_gguf_get_n_kv(ctx); + + for (int i = 0; i < n_kv; ++i) { + const char * key = lm_gguf_get_key(ctx, i); + + if (skip && [skip containsObject:[NSString stringWithUTF8String:key]]) { + continue; + } + const std::string value = rnllama::lm_gguf_kv_to_str(ctx, i); + info[[NSString stringWithUTF8String:key]] = [NSString stringWithUTF8String:value.c_str()]; + } + } + + lm_gguf_free(ctx); + + return info; +} + + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsigned int progress))onProgress { // llama_backend_init(false); common_params defaultParams; diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts index c0eda5d..5427d9c 100644 --- a/src/NativeRNLlama.ts +++ b/src/NativeRNLlama.ts @@ -120,6 +120,8 @@ export type NativeLlamaChatMessage = { export interface Spec extends TurboModule { setContextLimit(limit: number): Promise + + modelInfo(path: string, skip?: string[]): Promise initContext(contextId: number, params: NativeContextParams): Promise getFormattedChat( diff --git a/src/index.ts b/src/index.ts index ccce03b..a03c2b8 100644 --- a/src/index.ts +++ b/src/index.ts @@ -193,6 +193,18 @@ let contextIdCounter = 0 const contextIdRandom = () => process.env.NODE_ENV === 'test' ? 0 : Math.floor(Math.random() * 100000) +const modelInfoSkip = [ + // Large fields + 'tokenizer.ggml.tokens', + 'tokenizer.ggml.token_type', + 'tokenizer.ggml.merges' +] +export async function loadLlamaModelInfo(model: string): Promise { + let path = model + if (path.startsWith('file://')) path = path.slice(7) + return RNLlama.modelInfo(path, modelInfoSkip) +} + export async function initLlama( { model, is_model_asset: isModelAsset, ...rest }: ContextParams, onProgress?: (progress: number) => void,