Skip to content

Commit

Permalink
feat: update embedding method (#88)
Browse files Browse the repository at this point in the history
* feat: update embedding method

* docs(api): build

* fix(example): revert unnecessary change
  • Loading branch information
jhen0409 authored Nov 16, 2024
1 parent 8bf9dd8 commit 6190f57
Show file tree
Hide file tree
Showing 14 changed files with 251 additions and 102 deletions.
24 changes: 21 additions & 3 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
params.getString("model"),
// boolean embedding,
params.hasKey("embedding") ? params.getBoolean("embedding") : false,
// int embd_normalize,
params.hasKey("embd_normalize") ? params.getInt("embd_normalize") : -1,
// int n_ctx,
params.hasKey("n_ctx") ? params.getInt("n_ctx") : 512,
// int n_batch,
Expand All @@ -66,9 +68,14 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f,
// float rope_freq_scale
params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f,
// int pooling_type,
params.hasKey("pooling_type") ? params.getInt("pooling_type") : -1,
// LoadProgressCallback load_progress_callback
params.hasKey("use_progress_callback") ? new LoadProgressCallback(this) : null
);
if (this.context == -1) {
throw new IllegalStateException("Failed to initialize context");
}
this.modelDetails = loadModelDetails(this.context);
this.reactContext = reactContext;
}
Expand Down Expand Up @@ -258,11 +265,16 @@ public String detokenize(ReadableArray tokens) {
return detokenize(this.context, toks);
}

public WritableMap getEmbedding(String text) {
public WritableMap getEmbedding(String text, ReadableMap params) {
if (isEmbeddingEnabled(this.context) == false) {
throw new IllegalStateException("Embedding is not enabled");
}
WritableMap result = embedding(this.context, text);
WritableMap result = embedding(
this.context,
text,
// int embd_normalize,
params.hasKey("embd_normalize") ? params.getInt("embd_normalize") : -1
);
if (result.hasKey("error")) {
throw new IllegalStateException(result.getString("error"));
}
Expand Down Expand Up @@ -365,6 +377,7 @@ protected static native WritableMap modelInfo(
protected static native long initContext(
String model,
boolean embedding,
int embd_normalize,
int n_ctx,
int n_batch,
int n_threads,
Expand All @@ -376,6 +389,7 @@ protected static native long initContext(
float lora_scaled,
float rope_freq_base,
float rope_freq_scale,
int pooling_type,
LoadProgressCallback load_progress_callback
);
protected static native void interruptLoad(long contextPtr);
Expand Down Expand Up @@ -429,7 +443,11 @@ protected static native WritableMap doCompletion(
protected static native WritableArray tokenize(long contextPtr, String text);
protected static native String detokenize(long contextPtr, int[] tokens);
protected static native boolean isEmbeddingEnabled(long contextPtr);
protected static native WritableMap embedding(long contextPtr, String text);
protected static native WritableMap embedding(
long contextPtr,
String text,
int embd_normalize
);
protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
protected static native void freeContext(long contextPtr);
protected static native void logToAndroid();
Expand Down
4 changes: 2 additions & 2 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ protected void onPostExecute(String result) {
tasks.put(task, "detokenize-" + contextId);
}

public void embedding(double id, final String text, final Promise promise) {
public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
private Exception exception;
Expand All @@ -361,7 +361,7 @@ protected WritableMap doInBackground(Void... voids) {
if (context == null) {
throw new Exception("Context not found");
}
return context.getEmbedding(text);
return context.getEmbedding(text, params);
} catch (Exception e) {
exception = e;
}
Expand Down
58 changes: 51 additions & 7 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,15 @@ static inline void pushDouble(JNIEnv *env, jobject arr, double value) {
env->CallVoidMethod(arr, pushDoubleMethod, value);
}

// Method to push string into WritableArray
static inline void pushString(JNIEnv *env, jobject arr, const char *value) {
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
jmethodID pushStringMethod = env->GetMethodID(mapClass, "pushString", "(Ljava/lang/String;)V");

jstring jValue = env->NewStringUTF(value);
env->CallVoidMethod(arr, pushStringMethod, jValue);
}

// Method to push WritableMap into WritableArray
static inline void pushMap(JNIEnv *env, jobject arr, jobject value) {
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
Expand Down Expand Up @@ -213,6 +222,7 @@ Java_com_rnllama_LlamaContext_initContext(
jobject thiz,
jstring model_path_str,
jboolean embedding,
jint embd_normalize,
jint n_ctx,
jint n_batch,
jint n_threads,
Expand All @@ -224,6 +234,7 @@ Java_com_rnllama_LlamaContext_initContext(
jfloat lora_scaled,
jfloat rope_freq_base,
jfloat rope_freq_scale,
jint pooling_type,
jobject load_progress_callback
) {
UNUSED(thiz);
Expand All @@ -238,11 +249,22 @@ Java_com_rnllama_LlamaContext_initContext(
const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
defaultParams.model = model_path_chars;

defaultParams.embedding = embedding;

defaultParams.n_ctx = n_ctx;
defaultParams.n_batch = n_batch;

if (pooling_type != -1) {
defaultParams.pooling_type = static_cast<enum llama_pooling_type>(pooling_type);
}

defaultParams.embedding = embedding;
if (embd_normalize != -1) {
defaultParams.embd_normalize = embd_normalize;
}
if (embedding) {
// For non-causal models, batch size must be equal to ubatch size
defaultParams.n_ubatch = defaultParams.n_batch;
}

int max_threads = std::thread::hardware_concurrency();
// Use 2 threads by default on 4-core devices, 4 threads on more cores
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
Expand Down Expand Up @@ -291,16 +313,21 @@ Java_com_rnllama_LlamaContext_initContext(

bool is_model_loaded = llama->loadModel(defaultParams);

env->ReleaseStringUTFChars(model_path_str, model_path_chars);
env->ReleaseStringUTFChars(lora_str, lora_chars);

LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
if (is_model_loaded) {
if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) {
LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported");
llama_free(llama->ctx);
return -1;
}
context_map[(long) llama->ctx] = llama;
} else {
llama_free(llama->ctx);
}

env->ReleaseStringUTFChars(model_path_str, model_path_chars);
env->ReleaseStringUTFChars(lora_str, lora_chars);

return reinterpret_cast<jlong>(llama->ctx);
}

Expand Down Expand Up @@ -745,10 +772,21 @@ Java_com_rnllama_LlamaContext_isEmbeddingEnabled(

JNIEXPORT jobject JNICALL
Java_com_rnllama_LlamaContext_embedding(
JNIEnv *env, jobject thiz, jlong context_ptr, jstring text) {
JNIEnv *env, jobject thiz,
jlong context_ptr,
jstring text,
jint embd_normalize
) {
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];

common_params embdParams;
embdParams.embedding = true;
embdParams.embd_normalize = llama->params.embd_normalize;
if (embd_normalize != -1) {
embdParams.embd_normalize = embd_normalize;
}

const char *text_chars = env->GetStringUTFChars(text, nullptr);

llama->rewind();
Expand All @@ -769,14 +807,20 @@ Java_com_rnllama_LlamaContext_embedding(
llama->loadPrompt();
llama->doCompletion();

std::vector<float> embedding = llama->getEmbedding();
std::vector<float> embedding = llama->getEmbedding(embdParams);

auto embeddings = createWritableArray(env);
for (const auto &val : embedding) {
pushDouble(env, embeddings, (double) val);
}
putArray(env, result, "embedding", embeddings);

auto promptTokens = createWritableArray(env);
for (const auto &tok : llama->embd) {
pushString(env, promptTokens, common_token_to_piece(llama->ctx, tok).c_str());
}
putArray(env, result, "prompt_tokens", promptTokens);

env->ReleaseStringUTFChars(text, text_chars);
return result;
}
Expand Down
4 changes: 2 additions & 2 deletions android/src/newarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ public void detokenize(double id, final ReadableArray tokens, final Promise prom
}

@ReactMethod
public void embedding(double id, final String text, final Promise promise) {
rnllama.embedding(id, text, promise);
public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
rnllama.embedding(id, text, params, promise);
}

@ReactMethod
Expand Down
4 changes: 2 additions & 2 deletions android/src/oldarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ public void detokenize(double id, final ReadableArray tokens, final Promise prom
}

@ReactMethod
public void embedding(double id, final String text, final Promise promise) {
rnllama.embedding(id, text, promise);
public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
rnllama.embedding(id, text, params, promise);
}

@ReactMethod
Expand Down
17 changes: 9 additions & 8 deletions cpp/rn-llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,28 +595,29 @@ struct llama_rn_context
return token_with_probs;
}

std::vector<float> getEmbedding()
std::vector<float> getEmbedding(common_params &embd_params)
{
static const int n_embd = llama_n_embd(llama_get_model(ctx));
if (!params.embedding)
if (!embd_params.embedding)
{
LOG_WARNING("embedding disabled, embedding: %s", params.embedding);
LOG_WARNING("embedding disabled, embedding: %s", embd_params.embedding);
return std::vector<float>(n_embd, 0.0f);
}
float *data;

if(params.pooling_type == 0){
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
printf("pooling_type: %d\n", pooling_type);
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
data = llama_get_embeddings(ctx);
}
else {
} else {
data = llama_get_embeddings_seq(ctx, 0);
}

if(!data) {
if (!data) {
return std::vector<float>(n_embd, 0.0f);
}
std::vector<float> embedding(data, data + n_embd), out(data, data + n_embd);
common_embd_normalize(embedding.data(), out.data(), n_embd, params.embd_normalize);
common_embd_normalize(embedding.data(), out.data(), n_embd, embd_params.embd_normalize);
return out;
}

Expand Down
33 changes: 22 additions & 11 deletions docs/API/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ llama.rn
- [BenchResult](README.md#benchresult)
- [CompletionParams](README.md#completionparams)
- [ContextParams](README.md#contextparams)
- [EmbeddingParams](README.md#embeddingparams)
- [TokenData](README.md#tokendata)

### Functions
Expand Down Expand Up @@ -44,7 +45,7 @@ llama.rn

#### Defined in

[index.ts:52](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L52)
[index.ts:57](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L57)

___

Expand All @@ -54,17 +55,27 @@ ___

#### Defined in

[index.ts:44](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L44)
[index.ts:49](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L49)

___

### ContextParams

Ƭ **ContextParams**: `NativeContextParams`
Ƭ **ContextParams**: `Omit`<`NativeContextParams`, ``"pooling_type"``\> & { `pooling_type?`: ``"none"`` \| ``"mean"`` \| ``"cls"`` \| ``"last"`` \| ``"rank"`` }

#### Defined in

[index.ts:42](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L42)
[index.ts:43](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L43)

___

### EmbeddingParams

Ƭ **EmbeddingParams**: `NativeEmbeddingParams`

#### Defined in

[index.ts:47](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L47)

___

Expand All @@ -81,7 +92,7 @@ ___

#### Defined in

[index.ts:32](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L32)
[index.ts:33](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L33)

## Functions

Expand All @@ -105,7 +116,7 @@ ___

#### Defined in

[grammar.ts:824](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L824)
[grammar.ts:824](https://github.com/mybigday/llama.rn/blob/20a1819/src/grammar.ts#L824)

___

Expand All @@ -117,7 +128,7 @@ ___

| Name | Type |
| :------ | :------ |
| `«destructured»` | `NativeContextParams` |
| `«destructured»` | [`ContextParams`](README.md#contextparams) |
| `onProgress?` | (`progress`: `number`) => `void` |

#### Returns
Expand All @@ -126,7 +137,7 @@ ___

#### Defined in

[index.ts:208](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L208)
[index.ts:225](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L225)

___

Expand All @@ -146,7 +157,7 @@ ___

#### Defined in

[index.ts:202](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L202)
[index.ts:210](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L210)

___

Expand All @@ -160,7 +171,7 @@ ___

#### Defined in

[index.ts:245](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L245)
[index.ts:269](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L269)

___

Expand All @@ -180,4 +191,4 @@ ___

#### Defined in

[index.ts:188](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L188)
[index.ts:196](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L196)
Loading

0 comments on commit 6190f57

Please sign in to comment.