Skip to content

Commit

Permalink
feat: support chat format (#72)
Browse files Browse the repository at this point in the history
* feat: add isChatTemplateSupported in model info

* feat(ts): add formatChat util

* feat(ts): add getFormattedChat native method

* feat(ts): completion: add messages

* feat(example): use messages

* feat(docs): update
  • Loading branch information
jhen0409 authored Jul 28, 2024
1 parent 2f70192 commit 030ebaf
Show file tree
Hide file tree
Showing 16 changed files with 422 additions and 172 deletions.
114 changes: 75 additions & 39 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ You can search HuggingFace for available models (Keyword: [`GGUF`](https://huggi
For create a GGUF model manually, for example in Llama 2:

Download the Llama 2 model

1. Request access from [here](https://ai.meta.com/llama)
2. Download the model from HuggingFace [here](https://huggingface.co/meta-llama/Llama-2-7b-chat) (`Llama-2-7b-chat`)

Convert the model to ggml format

```bash
# Start with submodule in this repo (or you can clone the repo https://github.com/ggerganov/llama.cpp.git)
yarn && yarn bootstrap
Expand Down Expand Up @@ -80,26 +82,53 @@ const context = await initLlama({
// embedding: true, // use embedding
})

// Do completion
const { text, timings } = await context.completion(
const stopWords = ['</s>', '<|end|>', '<|eot_id|>', '<|end_of_text|>', '<|im_end|>', '<|EOT|>', '<|END_OF_TURN_TOKEN|>', '<|end_of_turn|>', '<|endoftext|>']

// Do chat completion
const msgResult = await context.completion(
{
messages: [
{
role: 'system',
content: 'This is a conversation between user and assistant, a friendly chatbot.',
},
{
role: 'user',
content: 'Hello!',
},
],
n_predict: 100,
stop: stopWords,
// ...other params
},
(data) => {
// This is a partial completion callback
const { token } = data
},
)
console.log('Result:', msgResult.text)
console.log('Timings:', msgResult.timings)

// Or do text completion
const textResult = await context.completion(
{
prompt: 'This is a conversation between user and llama, a friendly chatbot. respond in simple markdown.\n\nUser: Hello!\nLlama:',
n_predict: 100,
stop: ['</s>', 'Llama:', 'User:'],
// n_threads: 4,
stop: [...stopWords, 'Llama:', 'User:'],
// ...other params
},
(data) => {
// This is a partial completion callback
const { token } = data
},
)
console.log('Result:', text)
console.log('Timings:', timings)
console.log('Result:', textResult.text)
console.log('Timings:', textResult.timings)
```

The binding’s deisgn inspired by [server.cpp](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) example in llama.cpp, so you can map its API to LlamaContext:

- `/completion`: `context.completion(params, partialCompletionCallback)`
- `/completion` and `/chat/completions`: `context.completion(params, partialCompletionCallback)`
- `/tokenize`: `context.tokenize(content)`
- `/detokenize`: `context.detokenize(tokens)`
- `/embedding`: `context.embedding(content)`
Expand All @@ -114,6 +143,7 @@ Please visit the [Documentation](docs/API) for more details.
You can also visit the [example](example) to see how to use it.

Run the example:

```bash
yarn && yarn bootstrap

Expand Down Expand Up @@ -146,7 +176,9 @@ You can see [GBNF Guide](https://github.com/ggerganov/llama.cpp/tree/master/gram
```js
import { initLlama, convertJsonSchemaToGrammar } from 'llama.rn'

const schema = { /* JSON Schema, see below */ }
const schema = {
/* JSON Schema, see below */
}

const context = await initLlama({
model: 'file://<path to gguf model>',
Expand All @@ -157,7 +189,7 @@ const context = await initLlama({
grammar: convertJsonSchemaToGrammar({
schema,
propOrder: { function: 0, arguments: 1 },
})
}),
})

const { text } = await context.completion({
Expand All @@ -175,80 +207,81 @@ console.log('Result:', text)
{
oneOf: [
{
type: "object",
name: "get_current_weather",
description: "Get the current weather in a given location",
type: 'object',
name: 'get_current_weather',
description: 'Get the current weather in a given location',
properties: {
function: {
const: "get_current_weather",
const: 'get_current_weather',
},
arguments: {
type: "object",
type: 'object',
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
type: 'string',
description: 'The city and state, e.g. San Francisco, CA',
},
unit: {
type: "string",
enum: ["celsius", "fahrenheit"],
type: 'string',
enum: ['celsius', 'fahrenheit'],
},
},
required: ["location"],
required: ['location'],
},
},
},
{
type: "object",
name: "create_event",
description: "Create a calendar event",
type: 'object',
name: 'create_event',
description: 'Create a calendar event',
properties: {
function: {
const: "create_event",
const: 'create_event',
},
arguments: {
type: "object",
type: 'object',
properties: {
title: {
type: "string",
description: "The title of the event",
type: 'string',
description: 'The title of the event',
},
date: {
type: "string",
description: "The date of the event",
type: 'string',
description: 'The date of the event',
},
time: {
type: "string",
description: "The time of the event",
type: 'string',
description: 'The time of the event',
},
},
required: ["title", "date", "time"],
required: ['title', 'date', 'time'],
},
},
},
{
type: "object",
name: "image_search",
description: "Search for an image",
type: 'object',
name: 'image_search',
description: 'Search for an image',
properties: {
function: {
const: "image_search",
const: 'image_search',
},
arguments: {
type: "object",
type: 'object',
properties: {
query: {
type: "string",
description: "The search query",
type: 'string',
description: 'The search query',
},
},
required: ["query"],
required: ['query'],
},
},
},
],
}
```

</details>

<details>
Expand All @@ -272,6 +305,7 @@ string ::= "\"" (
2 ::= "{" space "\"function\"" space ":" space 2-function "," space "\"arguments\"" space ":" space 2-arguments "}" space
root ::= 0 | 1 | 2
```

</details>

## Mock `llama.rn`
Expand All @@ -285,12 +319,14 @@ jest.mock('llama.rn', () => require('llama.rn/jest/mock'))
## NOTE

iOS:

- The [Extended Virtual Addressing](https://developer.apple.com/documentation/bundleresources/entitlements/com_apple_developer_kernel_extended-virtual-addressing) capability is recommended to enable on iOS project.
- Metal:
- We have tested to know some devices is not able to use Metal ('params.n_gpu_layers > 0') due to llama.cpp used SIMD-scoped operation, you can check if your device is supported in [Metal feature set tables](https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf), Apple7 GPU will be the minimum requirement.
- It's also not supported in iOS simulator due to [this limitation](https://developer.apple.com/documentation/metal/developing_metal_apps_that_run_in_simulator#3241609), we used constant buffers more than 14.

Android:

- Currently only supported arm64-v8a / x86_64 platform, this means you can't initialize a context on another platforms. The 64-bit platform are recommended because it can allocate more memory for the model.
- No integrated any GPU backend yet.

Expand Down
13 changes: 13 additions & 0 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ public WritableMap getModelDetails() {
return modelDetails;
}

public String getFormattedChat(ReadableArray messages, String chatTemplate) {
ReadableMap[] msgs = new ReadableMap[messages.size()];
for (int i = 0; i < messages.size(); i++) {
msgs[i] = messages.getMap(i);
}
return getFormattedChat(this.context, msgs, chatTemplate == null ? "" : chatTemplate);
}

private void emitPartialCompletion(WritableMap tokenResult) {
WritableMap event = Arguments.createMap();
event.putInt("contextId", LlamaContext.this.id);
Expand Down Expand Up @@ -316,6 +324,11 @@ protected static native long initContext(
protected static native WritableMap loadModelDetails(
long contextPtr
);
protected static native String getFormattedChat(
long contextPtr,
ReadableMap[] messages,
String chatTemplate
);
protected static native WritableMap loadSession(
long contextPtr,
String path
Expand Down
32 changes: 32 additions & 0 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,38 @@ protected void onPostExecute(WritableMap result) {
tasks.put(task, "initContext");
}

public void getFormattedChat(double id, final ReadableArray messages, final String chatTemplate, Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, String>() {
private Exception exception;

@Override
protected String doInBackground(Void... voids) {
try {
LlamaContext context = contexts.get(contextId);
if (context == null) {
throw new Exception("Context not found");
}
return context.getFormattedChat(messages, chatTemplate);
} catch (Exception e) {
exception = e;
return null;
}
}

@Override
protected void onPostExecute(String result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(result);
tasks.remove(this);
}
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
tasks.put(task, "getFormattedChat-" + contextId);
}

public void loadSession(double id, final String path, Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
Expand Down
40 changes: 40 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,46 @@ Java_com_rnllama_LlamaContext_loadModelDetails(
return reinterpret_cast<jobject>(result);
}

JNIEXPORT jobject JNICALL
Java_com_rnllama_LlamaContext_getFormattedChat(
JNIEnv *env,
jobject thiz,
jlong context_ptr,
jobjectArray messages,
jstring chat_template
) {
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];

std::vector<llama_chat_msg> chat;

int messages_len = env->GetArrayLength(messages);
for (int i = 0; i < messages_len; i++) {
jobject msg = env->GetObjectArrayElement(messages, i);
jclass msgClass = env->GetObjectClass(msg);

jmethodID getRoleMethod = env->GetMethodID(msgClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;");
jstring roleKey = env->NewStringUTF("role");
jstring contentKey = env->NewStringUTF("content");

jstring role_str = (jstring) env->CallObjectMethod(msg, getRoleMethod, roleKey);
jstring content_str = (jstring) env->CallObjectMethod(msg, getRoleMethod, contentKey);

const char *role = env->GetStringUTFChars(role_str, nullptr);
const char *content = env->GetStringUTFChars(content_str, nullptr);

chat.push_back({ role, content });

env->ReleaseStringUTFChars(role_str, role);
env->ReleaseStringUTFChars(content_str, content);
}

const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
std::string formatted_chat = llama_chat_apply_template(llama->model, tmpl_chars, chat, true);

return env->NewStringUTF(formatted_chat.c_str());
}

JNIEXPORT jobject JNICALL
Java_com_rnllama_LlamaContext_loadSession(
JNIEnv *env,
Expand Down
5 changes: 5 additions & 0 deletions android/src/newarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ public void initContext(final ReadableMap params, final Promise promise) {
rnllama.initContext(params, promise);
}

@ReactMethod
public void getFormattedChat(double id, ReadableArray messages, String chatTemplate, Promise promise) {
rnllama.getFormattedChat(id, messages, chatTemplate, promise);
}

@ReactMethod
public void loadSession(double id, String path, Promise promise) {
rnllama.loadSession(id, path, promise);
Expand Down
5 changes: 5 additions & 0 deletions android/src/oldarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ public void initContext(final ReadableMap params, final Promise promise) {
rnllama.initContext(params, promise);
}

@ReactMethod
public void getFormattedChat(double id, ReadableArray messages, String chatTemplate, Promise promise) {
rnllama.getFormattedChat(id, messages, chatTemplate, promise);
}

@ReactMethod
public void loadSession(double id, String path, Promise promise) {
rnllama.loadSession(id, path, promise);
Expand Down
Loading

0 comments on commit 030ebaf

Please sign in to comment.