Skip to content

Commit

Permalink
fix: wait task finished on stop/release method
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Sep 7, 2023
1 parent 120e67e commit f61e18c
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 15 deletions.
62 changes: 54 additions & 8 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@
import java.io.PushbackInputStream;

public class RNLlama implements LifecycleEventListener {
public static final String NAME = "RNLlama";

private ReactApplicationContext reactContext;

public RNLlama(ReactApplicationContext reactContext) {
reactContext.addLifecycleEventListener(this);
this.reactContext = reactContext;
}

private HashMap<AsyncTask, String> tasks = new HashMap<>();

private HashMap<Integer, LlamaContext> contexts = new HashMap<>();

private int llamaContextLimit = 1;
Expand All @@ -39,7 +43,7 @@ public void setContextLimit(double limit, Promise promise) {
}

public void initContext(final ReadableMap params, final Promise promise) {
new AsyncTask<Void, Void, WritableMap>() {
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
private Exception exception;

@Override
Expand Down Expand Up @@ -69,13 +73,15 @@ protected void onPostExecute(WritableMap result) {
return;
}
promise.resolve(result);
tasks.remove(this);
}
}.execute();
tasks.put(task, "initContext");
}

public void completion(double id, final ReadableMap params, final Promise promise) {
final int contextId = (int) id;
new AsyncTask<Void, Void, WritableMap>() {
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
private Exception exception;

@Override
Expand Down Expand Up @@ -103,13 +109,15 @@ protected void onPostExecute(WritableMap result) {
return;
}
promise.resolve(result);
tasks.remove(this);
}
}.execute();
tasks.put(task, "completion-" + contextId);
}

public void stopCompletion(double id, final Promise promise) {
final int contextId = (int) id;
new AsyncTask<Void, Void, Void>() {
AsyncTask task = new AsyncTask<Void, Void, Void>() {
private Exception exception;

@Override
Expand All @@ -120,6 +128,13 @@ protected Void doInBackground(Void... voids) {
throw new Exception("Context not found");
}
context.stopCompletion();
AsyncTask completionTask = null;
for (AsyncTask task : tasks.keySet()) {
if (tasks.get(task).equals("completion-" + contextId)) {
task.get();
break;
}
}
} catch (Exception e) {
exception = e;
}
Expand All @@ -133,13 +148,15 @@ protected void onPostExecute(Void result) {
return;
}
promise.resolve(result);
tasks.remove(this);
}
}.execute();
tasks.put(task, "stopCompletion-" + contextId);
}

public void tokenize(double id, final String text, final Promise promise) {
final int contextId = (int) id;
new AsyncTask<Void, Void, WritableMap>() {
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
private Exception exception;

@Override
Expand All @@ -163,13 +180,15 @@ protected void onPostExecute(WritableMap result) {
return;
}
promise.resolve(result);
tasks.remove(this);
}
}.execute();
tasks.put(task, "tokenize-" + contextId);
}

public void detokenize(double id, final ReadableArray tokens, final Promise promise) {
final int contextId = (int) id;
new AsyncTask<Void, Void, String>() {
AsyncTask task = new AsyncTask<Void, Void, String>() {
private Exception exception;

@Override
Expand All @@ -193,13 +212,15 @@ protected void onPostExecute(String result) {
return;
}
promise.resolve(result);
tasks.remove(this);
}
}.execute();
tasks.put(task, "detokenize-" + contextId);
}

public void embedding(double id, final String text, final Promise promise) {
final int contextId = (int) id;
new AsyncTask<Void, Void, WritableMap>() {
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
private Exception exception;

@Override
Expand All @@ -223,13 +244,15 @@ protected void onPostExecute(WritableMap result) {
return;
}
promise.resolve(result);
tasks.remove(this);
}
}.execute();
tasks.put(task, "embedding-" + contextId);
}

public void releaseContext(double id, Promise promise) {
final int contextId = (int) id;
new AsyncTask<Void, Void, Void>() {
AsyncTask task = new AsyncTask<Void, Void, Void>() {
private Exception exception;

@Override
Expand All @@ -239,6 +262,14 @@ protected Void doInBackground(Void... voids) {
if (context == null) {
throw new Exception("Context " + id + " not found");
}
context.stopCompletion();
AsyncTask completionTask = null;
for (AsyncTask task : tasks.keySet()) {
if (tasks.get(task).equals("completion-" + contextId)) {
task.get();
break;
}
}
context.release();
contexts.remove(contextId);
} catch (Exception e) {
Expand All @@ -254,12 +285,14 @@ protected void onPostExecute(Void result) {
return;
}
promise.resolve(null);
tasks.remove(this);
}
}.execute();
tasks.put(task, "releaseContext-" + contextId);
}

public void releaseAllContexts(Promise promise) {
new AsyncTask<Void, Void, Void>() {
AsyncTask task = new AsyncTask<Void, Void, Void>() {
private Exception exception;

@Override
Expand All @@ -279,8 +312,10 @@ protected void onPostExecute(Void result) {
return;
}
promise.resolve(null);
tasks.remove(this);
}
}.execute();
tasks.put(task, "releaseAllContexts");
}

@Override
Expand All @@ -293,6 +328,17 @@ public void onHostPause() {

@Override
public void onHostDestroy() {
for (LlamaContext context : contexts.values()) {
context.stopCompletion();
}
for (AsyncTask task : tasks.keySet()) {
try {
task.get();
} catch (Exception e) {
Log.e(NAME, "Failed to wait for task", e);
}
}
tasks.clear();
for (LlamaContext context : contexts.values()) {
context.release();
}
Expand Down
4 changes: 2 additions & 2 deletions android/src/newarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
import java.io.FileInputStream;
import java.io.PushbackInputStream;

@ReactModule(name = RNLlamaModule.NAME)
@ReactModule(name = RNLlama.NAME)
public class RNLlamaModule extends NativeRNLlamaSpec {
public static final String NAME = "RNLlama";
public static final String NAME = RNLlama.NAME;

private RNLlama rnllama = null;

Expand Down
4 changes: 2 additions & 2 deletions android/src/oldarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import java.io.FileInputStream;
import java.io.PushbackInputStream;

@ReactModule(name = RNLlamaModule.NAME)
@ReactModule(name = RNLlama.NAME)
public class RNLlamaModule extends ReactContextBaseJavaModule {
public static final String NAME = "RNLlama";
public static final String NAME = RNLlama.NAME;

private RNLlama rnllama = null;

Expand Down
4 changes: 2 additions & 2 deletions example/ios/Podfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ PODS:
- hermes-engine/Pre-built (= 0.72.3)
- hermes-engine/Pre-built (0.72.3)
- libevent (2.1.12)
- llama-rn (0.2.0-rc.2):
- llama-rn (0.2.0-rc.3):
- RCT-Folly
- RCTRequired
- RCTTypeSafety
Expand Down Expand Up @@ -1242,7 +1242,7 @@ SPEC CHECKSUMS:
glog: 04b94705f318337d7ead9e6d17c019bd9b1f6b1b
hermes-engine: 10fbd3f62405c41ea07e71973ea61e1878d07322
libevent: 4049cae6c81cdb3654a443be001fb9bdceff7913
llama-rn: eda3c9288703cf662d48ade3efee3b14a80b8c21
llama-rn: d4c8780eeb17350f4cd56e6fa23c3fb4f8044e37
RCT-Folly: 424b8c9a7a0b9ab2886ffe9c3b041ef628fd4fb1
RCTRequired: a2faf4bad4e438ca37b2040cb8f7799baa065c18
RCTTypeSafety: cb09f3e4747b6d18331a15eb05271de7441ca0b3
Expand Down
6 changes: 5 additions & 1 deletion ios/RNLlama.mm
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ @implementation RNLlama

NSMutableDictionary *llamaContexts;
double llamaContextLimit = 1;
dispatch_queue_t llamaDQueue = dispatch_queue_create("com.rnllama", DISPATCH_QUEUE_SERIAL);

RCT_EXPORT_MODULE()

Expand Down Expand Up @@ -71,7 +72,7 @@ - (NSArray *)supportedEvents {
reject(@"llama_error", @"Context is busy", nil);
return;
}
dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{
dispatch_async(llamaDQueue, ^{
@try {
@autoreleasepool {
NSDictionary* completionResult = [context completion:completionParams
Expand Down Expand Up @@ -168,6 +169,7 @@ - (NSArray *)supportedEvents {
return;
}
[context stopCompletion];
dispatch_barrier_sync(llamaDQueue, ^{});
[context invalidate];
[llamaContexts removeObjectForKey:[NSNumber numberWithDouble:contextId]];
resolve(nil);
Expand All @@ -188,6 +190,8 @@ - (void)invalidate {

for (NSNumber *contextId in llamaContexts) {
RNLlamaContext *context = llamaContexts[contextId];
[context stopCompletion];
dispatch_barrier_sync(llamaDQueue, ^{});
[context invalidate];
}

Expand Down

0 comments on commit f61e18c

Please sign in to comment.