Skip to content

Commit

Permalink
feat: add static method for read model info from gguf (#87)
Browse files Browse the repository at this point in the history
* feat: add static method for read model info from gguf

* feat(ts): rename method & update example

* chore: docgen

* fix(ios): add missing free context

* fix(example): revert unnecessary change

* feat(android): implement modelInfo method
  • Loading branch information
jhen0409 authored Nov 16, 2024
1 parent c1d15a3 commit 8bf9dd8
Show file tree
Hide file tree
Showing 15 changed files with 335 additions and 66 deletions.
4 changes: 4 additions & 0 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ private static String getCpuFeatures() {
}
}

protected static native WritableMap modelInfo(
String model,
String[] skip
);
protected static native long initContext(
String model,
boolean embedding,
Expand Down
29 changes: 29 additions & 0 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,35 @@ public void setContextLimit(double limit, Promise promise) {
promise.resolve(null);
}

public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
new AsyncTask<Void, Void, WritableMap>() {
private Exception exception;

@Override
protected WritableMap doInBackground(Void... voids) {
try {
String[] skipArray = new String[skip.size()];
for (int i = 0; i < skip.size(); i++) {
skipArray[i] = skip.getString(i);
}
return LlamaContext.modelInfo(model, skipArray);
} catch (Exception e) {
exception = e;
}
return null;
}

@Override
protected void onPostExecute(WritableMap result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(result);
}
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
}

public void initContext(double id, final ReadableMap params, final Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
Expand Down
69 changes: 68 additions & 1 deletion android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
#include <thread>
#include <unordered_map>
#include "llama.h"
#include "rn-llama.hpp"
#include "llama-impl.h"
#include "ggml.h"
#include "rn-llama.hpp"

#define UNUSED(x) (void)(x)
#define TAG "RNLLAMA_ANDROID_JNI"
Expand Down Expand Up @@ -132,6 +133,72 @@ static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject v
env->CallVoidMethod(map, putArrayMethod, jKey, value);
}

JNIEXPORT jobject JNICALL
Java_com_rnllama_LlamaContext_modelInfo(
JNIEnv *env,
jobject thiz,
jstring model_path_str,
jobjectArray skip
) {
UNUSED(thiz);

const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);

std::vector<std::string> skip_vec;
int skip_len = env->GetArrayLength(skip);
for (int i = 0; i < skip_len; i++) {
jstring skip_str = (jstring) env->GetObjectArrayElement(skip, i);
const char *skip_chars = env->GetStringUTFChars(skip_str, nullptr);
skip_vec.push_back(skip_chars);
env->ReleaseStringUTFChars(skip_str, skip_chars);
}

struct lm_gguf_init_params params = {
/*.no_alloc = */ false,
/*.ctx = */ NULL,
};
struct lm_gguf_context * ctx = lm_gguf_init_from_file(model_path_chars, params);

if (!ctx) {
LOGI("%s: failed to load '%s'\n", __func__, model_path_chars);
return nullptr;
}

auto info = createWriteableMap(env);
putInt(env, info, "version", lm_gguf_get_version(ctx));
putInt(env, info, "alignment", lm_gguf_get_alignment(ctx));
putInt(env, info, "data_offset", lm_gguf_get_data_offset(ctx));
{
const int n_kv = lm_gguf_get_n_kv(ctx);

for (int i = 0; i < n_kv; ++i) {
const char * key = lm_gguf_get_key(ctx, i);

bool skipped = false;
if (skip_len > 0) {
for (int j = 0; j < skip_len; j++) {
if (skip_vec[j] == key) {
skipped = true;
break;
}
}
}

if (skipped) {
continue;
}

const std::string value = rnllama::lm_gguf_kv_to_str(ctx, i);
putString(env, info, key, value.c_str());
}
}

env->ReleaseStringUTFChars(model_path_str, model_path_chars);
lm_gguf_free(ctx);

return reinterpret_cast<jobject>(info);
}

struct callback_context {
JNIEnv *env;
rnllama::llama_rn_context *llama;
Expand Down
5 changes: 5 additions & 0 deletions android/src/newarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ public void setContextLimit(double limit, Promise promise) {
rnllama.setContextLimit(limit, promise);
}

@ReactMethod
public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
rnllama.modelInfo(model, skip, promise);
}

@ReactMethod
public void initContext(double id, final ReadableMap params, final Promise promise) {
rnllama.initContext(id, params, promise);
Expand Down
5 changes: 5 additions & 0 deletions android/src/oldarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ public void setContextLimit(double limit, Promise promise) {
rnllama.setContextLimit(limit, promise);
}

@ReactMethod
public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
rnllama.modelInfo(model, skip, promise);
}

@ReactMethod
public void initContext(double id, final ReadableMap params, final Promise promise) {
rnllama.initContext(id, params, promise);
Expand Down
56 changes: 56 additions & 0 deletions cpp/rn-llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,67 @@
#include <sstream>
#include <iostream>
#include "common.h"
#include "ggml.h"
#include "llama.h"
#include "llama-impl.h"
#include "sampling.h"

namespace rnllama {

static std::string lm_gguf_data_to_str(enum lm_gguf_type type, const void * data, int i) {
switch (type) {
case LM_GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]);
case LM_GGUF_TYPE_INT8: return std::to_string(((const int8_t *)data)[i]);
case LM_GGUF_TYPE_UINT16: return std::to_string(((const uint16_t *)data)[i]);
case LM_GGUF_TYPE_INT16: return std::to_string(((const int16_t *)data)[i]);
case LM_GGUF_TYPE_UINT32: return std::to_string(((const uint32_t *)data)[i]);
case LM_GGUF_TYPE_INT32: return std::to_string(((const int32_t *)data)[i]);
case LM_GGUF_TYPE_UINT64: return std::to_string(((const uint64_t *)data)[i]);
case LM_GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]);
case LM_GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]);
case LM_GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]);
case LM_GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false";
default: return "unknown type: {}"; // TODO
}
}

static std::string lm_gguf_kv_to_str(const struct lm_gguf_context * ctx_gguf, int i) {
const enum lm_gguf_type type = lm_gguf_get_kv_type(ctx_gguf, i);

switch (type) {
case LM_GGUF_TYPE_STRING:
return lm_gguf_get_val_str(ctx_gguf, i);
case LM_GGUF_TYPE_ARRAY:
{
const enum lm_gguf_type arr_type = lm_gguf_get_arr_type(ctx_gguf, i);
int arr_n = lm_gguf_get_arr_n(ctx_gguf, i);
const void * data = lm_gguf_get_arr_data(ctx_gguf, i);
std::stringstream ss;
ss << "[";
for (int j = 0; j < arr_n; j++) {
if (arr_type == LM_GGUF_TYPE_STRING) {
std::string val = lm_gguf_get_arr_str(ctx_gguf, i, j);
// escape quotes
replace_all(val, "\\", "\\\\");
replace_all(val, "\"", "\\\"");
ss << '"' << val << '"';
} else if (arr_type == LM_GGUF_TYPE_ARRAY) {
ss << "???";
} else {
ss << lm_gguf_data_to_str(arr_type, data, j);
}
if (j < arr_n - 1) {
ss << ", ";
}
}
ss << "]";
return ss.str();
}
default:
return lm_gguf_data_to_str(type, lm_gguf_get_val_data(ctx_gguf, i), 0);
}
}

static void llama_batch_clear(llama_batch *batch) {
batch->n_tokens = 0;
}
Expand Down
37 changes: 29 additions & 8 deletions docs/API/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ llama.rn

- [convertJsonSchemaToGrammar](README.md#convertjsonschematogrammar)
- [initLlama](README.md#initllama)
- [loadLlamaModelInfo](README.md#loadllamamodelinfo)
- [releaseAllLlama](README.md#releaseallllama)
- [setContextLimit](README.md#setcontextlimit)

Expand All @@ -43,7 +44,7 @@ llama.rn

#### Defined in

[index.ts:52](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L52)
[index.ts:52](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L52)

___

Expand All @@ -53,7 +54,7 @@ ___

#### Defined in

[index.ts:44](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L44)
[index.ts:44](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L44)

___

Expand All @@ -63,7 +64,7 @@ ___

#### Defined in

[index.ts:42](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L42)
[index.ts:42](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L42)

___

Expand All @@ -80,7 +81,7 @@ ___

#### Defined in

[index.ts:32](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L32)
[index.ts:32](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L32)

## Functions

Expand All @@ -104,7 +105,7 @@ ___

#### Defined in

[grammar.ts:824](https://github.com/mybigday/llama.rn/blob/41b779f/src/grammar.ts#L824)
[grammar.ts:824](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L824)

___

Expand All @@ -125,7 +126,27 @@ ___

#### Defined in

[index.ts:196](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L196)
[index.ts:208](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L208)

___

### loadLlamaModelInfo

**loadLlamaModelInfo**(`model`): `Promise`<`Object`\>

#### Parameters

| Name | Type |
| :------ | :------ |
| `model` | `string` |

#### Returns

`Promise`<`Object`\>

#### Defined in

[index.ts:202](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L202)

___

Expand All @@ -139,7 +160,7 @@ ___

#### Defined in

[index.ts:233](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L233)
[index.ts:245](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L245)

___

Expand All @@ -159,4 +180,4 @@ ___

#### Defined in

[index.ts:188](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L188)
[index.ts:188](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L188)
Loading

0 comments on commit 8bf9dd8

Please sign in to comment.