Skip to content

Commit

Permalink
feat: expose flash_attn / cache_type_k / cache_type_v
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Nov 17, 2024
1 parent 3c80478 commit 4ce8ff8
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 41 deletions.
9 changes: 9 additions & 0 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
params.hasKey("n_threads") ? params.getInt("n_threads") : 0,
// int n_gpu_layers, // TODO: Support this
params.hasKey("n_gpu_layers") ? params.getInt("n_gpu_layers") : 0,
// boolean flash_attn,
params.hasKey("flash_attn") ? params.getBoolean("flash_attn") : false,

This comment has been minimized.

Copy link
@Vali-98

Vali-98 Nov 17, 2024

Contributor

Does it make sense to expose this on the Android side? I do not believe flash_attn is implemented on android aside the cpu implementation for testing.

This comment has been minimized.

Copy link
@jhen0409

jhen0409 Nov 17, 2024

Author Member

I have commented in ts side that it is only recommended for GPU device, because it slows in CPU, so yes it just a testing purpose.

// String cache_type_k,
params.hasKey("cache_type_k") ? params.getString("cache_type_k") : "f16",
// String cache_type_v,
params.hasKey("cache_type_v") ? params.getString("cache_type_v") : "f16",
// boolean use_mlock,
params.hasKey("use_mlock") ? params.getBoolean("use_mlock") : true,
// boolean use_mmap,
Expand Down Expand Up @@ -382,6 +388,9 @@ protected static native long initContext(
int n_batch,
int n_threads,
int n_gpu_layers, // TODO: Support this
boolean flash_attn,
String cache_type_k,
String cache_type_v,
boolean use_mlock,
boolean use_mmap,
boolean vocab_only,
Expand Down
11 changes: 11 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ Java_com_rnllama_LlamaContext_initContext(
jint n_batch,
jint n_threads,
jint n_gpu_layers, // TODO: Support this
jboolean flash_attn,
jstring cache_type_k,
jstring cache_type_v,
jboolean use_mlock,
jboolean use_mmap,
jboolean vocab_only,
Expand Down Expand Up @@ -271,6 +274,12 @@ Java_com_rnllama_LlamaContext_initContext(
defaultParams.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;

defaultParams.n_gpu_layers = n_gpu_layers;
defaultParams.flash_attn = flash_attn;

const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
defaultParams.cache_type_k = cache_type_k_chars;
defaultParams.cache_type_v = cache_type_v_chars;

defaultParams.use_mlock = use_mlock;
defaultParams.use_mmap = use_mmap;
Expand Down Expand Up @@ -314,6 +323,8 @@ Java_com_rnllama_LlamaContext_initContext(

env->ReleaseStringUTFChars(model_path_str, model_path_chars);
env->ReleaseStringUTFChars(lora_str, lora_chars);
env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);

LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
if (is_model_loaded) {
Expand Down
20 changes: 10 additions & 10 deletions docs/API/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ llama.rn

#### Defined in

[index.ts:58](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L58)
[index.ts:58](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L58)

___

Expand All @@ -55,7 +55,7 @@ ___

#### Defined in

[index.ts:49](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L49)
[index.ts:49](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L49)

___

Expand All @@ -65,7 +65,7 @@ ___

#### Defined in

[index.ts:43](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L43)
[index.ts:43](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L43)

___

Expand All @@ -75,7 +75,7 @@ ___

#### Defined in

[index.ts:47](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L47)
[index.ts:47](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L47)

___

Expand All @@ -92,7 +92,7 @@ ___

#### Defined in

[index.ts:33](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L33)
[index.ts:33](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L33)

## Functions

Expand All @@ -116,7 +116,7 @@ ___

#### Defined in

[grammar.ts:824](https://github.com/mybigday/llama.rn/blob/38fa660/src/grammar.ts#L824)
[grammar.ts:824](https://github.com/mybigday/llama.rn/blob/68acf1a/src/grammar.ts#L824)

___

Expand All @@ -137,7 +137,7 @@ ___

#### Defined in

[index.ts:225](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L225)
[index.ts:225](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L225)

___

Expand All @@ -157,7 +157,7 @@ ___

#### Defined in

[index.ts:210](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L210)
[index.ts:210](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L210)

___

Expand All @@ -171,7 +171,7 @@ ___

#### Defined in

[index.ts:275](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L275)
[index.ts:275](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L275)

___

Expand All @@ -191,4 +191,4 @@ ___

#### Defined in

[index.ts:196](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L196)
[index.ts:196](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L196)
30 changes: 15 additions & 15 deletions docs/API/classes/LlamaContext.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

#### Defined in

[index.ts:79](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L79)
[index.ts:79](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L79)

## Properties

Expand All @@ -52,7 +52,7 @@

#### Defined in

[index.ts:71](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L71)
[index.ts:71](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L71)

___

Expand All @@ -62,7 +62,7 @@ ___

#### Defined in

[index.ts:69](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L69)
[index.ts:69](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L69)

___

Expand All @@ -78,7 +78,7 @@ ___

#### Defined in

[index.ts:75](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L75)
[index.ts:75](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L75)

___

Expand All @@ -88,7 +88,7 @@ ___

#### Defined in

[index.ts:73](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L73)
[index.ts:73](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L73)

## Methods

Expand All @@ -111,7 +111,7 @@ ___

#### Defined in

[index.ts:171](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L171)
[index.ts:171](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L171)

___

Expand All @@ -132,7 +132,7 @@ ___

#### Defined in

[index.ts:115](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L115)
[index.ts:115](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L115)

___

Expand All @@ -152,7 +152,7 @@ ___

#### Defined in

[index.ts:160](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L160)
[index.ts:160](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L160)

___

Expand All @@ -173,7 +173,7 @@ ___

#### Defined in

[index.ts:164](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L164)
[index.ts:164](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L164)

___

Expand All @@ -194,7 +194,7 @@ ___

#### Defined in

[index.ts:105](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L105)
[index.ts:105](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L105)

___

Expand All @@ -216,7 +216,7 @@ Load cached prompt & completion state from a file.

#### Defined in

[index.ts:89](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L89)
[index.ts:89](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L89)

___

Expand All @@ -230,7 +230,7 @@ ___

#### Defined in

[index.ts:191](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L191)
[index.ts:191](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L191)

___

Expand All @@ -254,7 +254,7 @@ Save current cached prompt & completion state to a file.

#### Defined in

[index.ts:98](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L98)
[index.ts:98](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L98)

___

Expand All @@ -268,7 +268,7 @@ ___

#### Defined in

[index.ts:152](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L152)
[index.ts:152](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L152)

___

Expand All @@ -288,4 +288,4 @@ ___

#### Defined in

[index.ts:156](https://github.com/mybigday/llama.rn/blob/38fa660/src/index.ts#L156)
[index.ts:156](https://github.com/mybigday/llama.rn/blob/68acf1a/src/index.ts#L156)
Loading

0 comments on commit 4ce8ff8

Please sign in to comment.