Skip to content

Commit

Permalink
feat: add progress callback in initLlama (#82)
Browse files Browse the repository at this point in the history
* feat(ios): add progress callback in initLlama

* feat(android): add progress callback in initLlama

* fix(ts): skip random context id on testing
  • Loading branch information
jhen0409 authored Nov 4, 2024
1 parent 192c9ae commit 41b779f
Show file tree
Hide file tree
Showing 19 changed files with 236 additions and 58 deletions.
33 changes: 30 additions & 3 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
}
Log.d(NAME, "Setting log callback");
logToAndroid();
eventEmitter = reactContext.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class);
this.id = id;
this.context = initContext(
// String model,
Expand Down Expand Up @@ -64,11 +65,16 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
// float rope_freq_base,
params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f,
// float rope_freq_scale
params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f
params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f,
// LoadProgressCallback load_progress_callback
params.hasKey("use_progress_callback") ? new LoadProgressCallback(this) : null
);
this.modelDetails = loadModelDetails(this.context);
this.reactContext = reactContext;
eventEmitter = reactContext.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class);
}

public void interruptLoad() {
interruptLoad(this.context);
}

public long getContext() {
Expand All @@ -87,6 +93,25 @@ public String getFormattedChat(ReadableArray messages, String chatTemplate) {
return getFormattedChat(this.context, msgs, chatTemplate == null ? "" : chatTemplate);
}

private void emitLoadProgress(int progress) {
WritableMap event = Arguments.createMap();
event.putInt("contextId", LlamaContext.this.id);
event.putInt("progress", progress);
eventEmitter.emit("@RNLlama_onInitContextProgress", event);
}

private static class LoadProgressCallback {
LlamaContext context;

public LoadProgressCallback(LlamaContext context) {
this.context = context;
}

void onLoadProgress(int progress) {
context.emitLoadProgress(progress);
}
}

private void emitPartialCompletion(WritableMap tokenResult) {
WritableMap event = Arguments.createMap();
event.putInt("contextId", LlamaContext.this.id);
Expand Down Expand Up @@ -346,8 +371,10 @@ protected static native long initContext(
String lora,
float lora_scaled,
float rope_freq_base,
float rope_freq_scale
float rope_freq_scale,
LoadProgressCallback load_progress_callback
);
protected static native void interruptLoad(long contextPtr);
protected static native WritableMap loadModelDetails(
long contextPtr
);
Expand Down
14 changes: 9 additions & 5 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,24 @@ public void setContextLimit(double limit, Promise promise) {
promise.resolve(null);
}

public void initContext(final ReadableMap params, final Promise promise) {
public void initContext(double id, final ReadableMap params, final Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
private Exception exception;

@Override
protected WritableMap doInBackground(Void... voids) {
try {
int id = Math.abs(new Random().nextInt());
LlamaContext llamaContext = new LlamaContext(id, reactContext, params);
LlamaContext context = contexts.get(contextId);
if (context != null) {
throw new Exception("Context already exists");
}
LlamaContext llamaContext = new LlamaContext(contextId, reactContext, params);
if (llamaContext.getContext() == 0) {
throw new Exception("Failed to initialize context");
}
contexts.put(id, llamaContext);
contexts.put(contextId, llamaContext);
WritableMap result = Arguments.createMap();
result.putInt("contextId", id);
result.putBoolean("gpu", false);
result.putString("reasonNoGPU", "Currently not supported");
result.putMap("model", llamaContext.getModelDetails());
Expand Down Expand Up @@ -393,6 +396,7 @@ protected Void doInBackground(Void... voids) {
if (context == null) {
throw new Exception("Context " + id + " not found");
}
context.interruptLoad();
context.stopCompletion();
AsyncTask completionTask = null;
for (AsyncTask task : tasks.keySet()) {
Expand Down
48 changes: 47 additions & 1 deletion android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject v
env->CallVoidMethod(map, putArrayMethod, jKey, value);
}

struct callback_context {
JNIEnv *env;
rnllama::llama_rn_context *llama;
jobject callback;
};

std::unordered_map<long, rnllama::llama_rn_context *> context_map;

Expand All @@ -151,7 +156,8 @@ Java_com_rnllama_LlamaContext_initContext(
jstring lora_str,
jfloat lora_scaled,
jfloat rope_freq_base,
jfloat rope_freq_scale
jfloat rope_freq_scale,
jobject load_progress_callback
) {
UNUSED(thiz);

Expand Down Expand Up @@ -190,6 +196,32 @@ Java_com_rnllama_LlamaContext_initContext(
defaultParams.rope_freq_scale = rope_freq_scale;

auto llama = new rnllama::llama_rn_context();
llama->is_load_interrupted = false;
llama->loading_progress = 0;

if (load_progress_callback != nullptr) {
defaultParams.progress_callback = [](float progress, void * user_data) {
callback_context *cb_ctx = (callback_context *)user_data;
JNIEnv *env = cb_ctx->env;
auto llama = cb_ctx->llama;
jobject callback = cb_ctx->callback;
int percentage = (int) (100 * progress);
if (percentage > llama->loading_progress) {
llama->loading_progress = percentage;
jclass callback_class = env->GetObjectClass(callback);
jmethodID onLoadProgress = env->GetMethodID(callback_class, "onLoadProgress", "(I)V");
env->CallVoidMethod(callback, onLoadProgress, percentage);
}
return !llama->is_load_interrupted;
};

callback_context *cb_ctx = new callback_context;
cb_ctx->env = env;
cb_ctx->llama = llama;
cb_ctx->callback = env->NewGlobalRef(load_progress_callback);
defaultParams.progress_callback_user_data = cb_ctx;
}

bool is_model_loaded = llama->loadModel(defaultParams);

LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
Expand All @@ -205,6 +237,20 @@ Java_com_rnllama_LlamaContext_initContext(
return reinterpret_cast<jlong>(llama->ctx);
}


JNIEXPORT void JNICALL
Java_com_rnllama_LlamaContext_interruptLoad(
JNIEnv *env,
jobject thiz,
jlong context_ptr
) {
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];
if (llama) {
llama->is_load_interrupted = true;
}
}

JNIEXPORT jobject JNICALL
Java_com_rnllama_LlamaContext_loadModelDetails(
JNIEnv *env,
Expand Down
4 changes: 2 additions & 2 deletions android/src/newarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ public void setContextLimit(double limit, Promise promise) {
}

@ReactMethod
public void initContext(final ReadableMap params, final Promise promise) {
rnllama.initContext(params, promise);
public void initContext(double id, final ReadableMap params, final Promise promise) {
rnllama.initContext(id, params, promise);
}

@ReactMethod
Expand Down
4 changes: 2 additions & 2 deletions android/src/oldarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ public void setContextLimit(double limit, Promise promise) {
}

@ReactMethod
public void initContext(final ReadableMap params, final Promise promise) {
rnllama.initContext(params, promise);
public void initContext(double id, final ReadableMap params, final Promise promise) {
rnllama.initContext(id, params, promise);
}

@ReactMethod
Expand Down
3 changes: 3 additions & 0 deletions cpp/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,9 @@ struct llama_model_params common_model_params_to_llama(const common_params & par
mparams.kv_overrides = params.kv_overrides.data();
}

mparams.progress_callback = params.progress_callback;
mparams.progress_callback_user_data = params.progress_callback_user_data;

return mparams;
}

Expand Down
3 changes: 3 additions & 0 deletions cpp/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ struct common_params {
bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data

llama_progress_callback progress_callback;
void * progress_callback_user_data;

std::string cache_type_k = "f16"; // KV cache data type for the K
std::string cache_type_v = "f16"; // KV cache data type for the V

Expand Down
19 changes: 11 additions & 8 deletions cpp/rn-llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,12 @@ struct llama_rn_context
common_params params;

llama_model *model = nullptr;
float loading_progress = 0;
bool is_load_interrupted = false;

llama_context *ctx = nullptr;
common_sampler *ctx_sampling = nullptr;

int n_ctx;

bool truncated = false;
Expand Down Expand Up @@ -367,7 +370,7 @@ struct llama_rn_context
n_eval = params.n_batch;
}
if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval)))
{
{
LOG_ERROR("failed to eval, n_eval: %d, n_past: %d, n_threads: %d, embd: %s",
n_eval,
n_past,
Expand All @@ -378,7 +381,7 @@ struct llama_rn_context
return result;
}
n_past += n_eval;

if(is_interrupted) {
LOG_INFO("Decoding Interrupted");
embd.resize(n_past);
Expand All @@ -400,19 +403,19 @@ struct llama_rn_context
candidates.reserve(llama_n_vocab(model));

result.tok = common_sampler_sample(ctx_sampling, ctx, -1);

llama_token_data_array cur_p = *common_sampler_get_candidates(ctx_sampling);

const int32_t n_probs = params.sparams.n_probs;

// deprecated
/*if (params.sparams.temp <= 0 && n_probs > 0)
{
// For llama_sample_token_greedy we need to sort candidates
llama_sampler_init_softmax();
}*/


for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i)
{
Expand Down Expand Up @@ -542,14 +545,14 @@ struct llama_rn_context
return std::vector<float>(n_embd, 0.0f);
}
float *data;

if(params.pooling_type == 0){
data = llama_get_embeddings(ctx);
}
else {
data = llama_get_embeddings_seq(ctx, 0);
}

if(!data) {
return std::vector<float>(n_embd, 0.0f);
}
Expand Down
2 changes: 1 addition & 1 deletion example/ios/.xcode.env.local
Original file line number Diff line number Diff line change
@@ -1 +1 @@
export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1730514789911-0.16979892623603998/node
export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1730697817603-0.6786179339916347/node
19 changes: 18 additions & 1 deletion example/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ export default function App() {
metadata: { system: true, ...metadata },
}
addMessage(textMessage)
return textMessage.id
}

const handleReleaseContext = async () => {
Expand All @@ -82,12 +83,28 @@ export default function App() {

const handleInitContext = async (file: DocumentPickerResponse) => {
await handleReleaseContext()
addSystemMessage('Initializing context...')
const msgId = addSystemMessage('Initializing context...')
initLlama({
model: file.uri,
use_mlock: true,
n_gpu_layers: Platform.OS === 'ios' ? 0 : 0, // > 0: enable GPU
// embedding: true,
}, (progress) => {
setMessages((msgs) => {
const index = msgs.findIndex((msg) => msg.id === msgId)
if (index >= 0) {
return msgs.map((msg, i) => {
if (msg.type == 'text' && i === index) {
return {
...msg,
text: `Initializing context... ${progress}%`,
}
}
return msg
})
}
return msgs
})
})
.then((ctx) => {
setContext(ctx)
Expand Down
23 changes: 17 additions & 6 deletions ios/RNLlama.mm
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,17 @@ @implementation RNLlama
resolve(nil);
}

RCT_EXPORT_METHOD(initContext:(NSDictionary *)contextParams
RCT_EXPORT_METHOD(initContext:(double)contextId
withContextParams:(NSDictionary *)contextParams
withResolver:(RCTPromiseResolveBlock)resolve
withRejecter:(RCTPromiseRejectBlock)reject)
{
NSNumber *contextIdNumber = [NSNumber numberWithDouble:contextId];
if (llamaContexts[contextIdNumber] != nil) {
reject(@"llama_error", @"Context already exists", nil);
return;
}

if (llamaDQueue == nil) {
llamaDQueue = dispatch_queue_create("com.rnllama", DISPATCH_QUEUE_SERIAL);
}
Expand All @@ -38,19 +45,19 @@ @implementation RNLlama
return;
}

RNLlamaContext *context = [RNLlamaContext initWithParams:contextParams];
RNLlamaContext *context = [RNLlamaContext initWithParams:contextParams onProgress:^(unsigned int progress) {
dispatch_async(dispatch_get_main_queue(), ^{
[self sendEventWithName:@"@RNLlama_onInitContextProgress" body:@{ @"contextId": @(contextId), @"progress": @(progress) }];
});
}];
if (![context isModelLoaded]) {
reject(@"llama_cpp_error", @"Failed to load the model", nil);
return;
}

double contextId = (double) arc4random_uniform(1000000);

NSNumber *contextIdNumber = [NSNumber numberWithDouble:contextId];
[llamaContexts setObject:context forKey:contextIdNumber];

resolve(@{
@"contextId": contextIdNumber,
@"gpu": @([context isMetalEnabled]),
@"reasonNoGPU": [context reasonNoMetal],
@"model": [context modelInfo],
Expand Down Expand Up @@ -125,6 +132,7 @@ @implementation RNLlama

- (NSArray *)supportedEvents {
return@[
@"@RNLlama_onInitContextProgress",
@"@RNLlama_onToken",
];
}
Expand Down Expand Up @@ -260,6 +268,9 @@ - (NSArray *)supportedEvents {
reject(@"llama_error", @"Context not found", nil);
return;
}
if (![context isModelLoaded]) {
[context interruptLoad];
}
[context stopCompletion];
dispatch_barrier_sync(llamaDQueue, ^{});
[context invalidate];
Expand Down
Loading

0 comments on commit 41b779f

Please sign in to comment.