Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

feature: implemented parallel inference for llama-rs, implemented naive sequential async inference for llama-cpp and rwkv-cpp #52

Merged
merged 5 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/cli/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { existsSync } from "fs";

const convertType = ["q4_0", "q4_1", "f16", "f32"] as const;

type ConvertType = typeof convertType[number];
type ConvertType = (typeof convertType)[number];

interface CLIInferenceArguments extends LLamaInferenceArguments, LLamaConfig {
logger?: boolean;
Expand Down Expand Up @@ -75,7 +75,7 @@ class InferenceCommand implements yargs.CommandModule {
if (logger) {
LLama.enableLogger();
}
const llama = LLama.create({ path: absolutePath, numCtxTokens });
const llama = await LLama.create({ path: absolutePath, numCtxTokens });
llama.inference(rest, (result) => {
switch (result.type) {
case InferenceResultType.Data:
Expand Down
2 changes: 1 addition & 1 deletion packages/core/__test__/index.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ test(
async () => {
LLama.enableLogger();

const llama = LLama.create({
const llama = await LLama.create({
path: process.env.model?.toString()!,
numCtxTokens: 128,
});
Expand Down
68 changes: 36 additions & 32 deletions packages/core/example/cachesession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,46 +6,50 @@ const saveSession = path.resolve(process.cwd(), "./tmp/session.bin");

LLama.enableLogger();

const llama = LLama.create({
path: model,
numCtxTokens: 128,
});
const run = async () => {
const llama = await LLama.create({
path: model,
numCtxTokens: 128,
});

const template = `how are you`;
const template = `how are you`;

const prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request.
const prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
${template}
### Response:`;

llama.inference(
{
prompt,
numPredict: 128,
temp: 0.2,
topP: 1,
topK: 40,
repeatPenalty: 1,
repeatLastN: 64,
seed: 0,
feedPrompt: true,
feedPromptOnly: true,
saveSession,
},
(response) => {
switch (response.type) {
case InferenceResultType.Data: {
process.stdout.write(response.data?.token ?? "");
break;
}
case InferenceResultType.End:
case InferenceResultType.Error: {
console.log(response);
break;
llama.inference(
{
prompt,
numPredict: 128,
temp: 0.2,
topP: 1,
topK: 40,
repeatPenalty: 1,
repeatLastN: 64,
seed: 0,
feedPrompt: true,
feedPromptOnly: true,
saveSession,
},
(response) => {
switch (response.type) {
case InferenceResultType.Data: {
process.stdout.write(response.data?.token ?? "");
break;
}
case InferenceResultType.End:
case InferenceResultType.Error: {
console.log(response);
break;
}
}
}
}
);
);
};

run();
74 changes: 35 additions & 39 deletions packages/core/example/embedding.ts
Original file line number Diff line number Diff line change
@@ -1,51 +1,47 @@
import { EmbeddingResultType, LLama } from "../index";
import { LLama } from "../index";
import path from "path";
import fs from "fs";

const model = path.resolve(process.cwd(), "../../ggml-alpaca-7b-q4.bin");

LLama.enableLogger();

const llama = LLama.create({
path: model,
numCtxTokens: 128,
});

const getWordEmbeddings = (prompt: string, file: string) => {
llama.getWordEmbeddings(
{
prompt,
numPredict: 128,
temp: 0.2,
topP: 1,
topK: 40,
repeatPenalty: 1,
repeatLastN: 64,
seed: 0,
},
(response) => {
switch (response.type) {
case EmbeddingResultType.Data: {
fs.writeFileSync(
path.resolve(process.cwd(), file),
JSON.stringify(response.data)
);
break;
}
case EmbeddingResultType.Error: {
console.log(response);
break;
}
}
}
const getWordEmbeddings = async (
llama: LLama,
prompt: string,
file: string
) => {
const response = await llama.getWordEmbeddings({
prompt,
numPredict: 128,
temp: 0.2,
topP: 1,
topK: 40,
repeatPenalty: 1,
repeatLastN: 64,
seed: 0,
});

fs.writeFileSync(
path.resolve(process.cwd(), file),
JSON.stringify(response)
);
};

const dog1 = `My favourite animal is the dog`;
getWordEmbeddings(dog1, "./example/semantic-compare/dog1.json");
const run = async () => {
const llama = await LLama.create({
path: model,
numCtxTokens: 128,
});

const dog1 = `My favourite animal is the dog`;
getWordEmbeddings(llama, dog1, "./example/semantic-compare/dog1.json");

const dog2 = `I have just adopted a cute dog`;
getWordEmbeddings(dog2, "./example/semantic-compare/dog2.json");
const dog2 = `I have just adopted a cute dog`;
getWordEmbeddings(llama, dog2, "./example/semantic-compare/dog2.json");

const cat1 = `My favourite animal is the cat`;
getWordEmbeddings(llama, cat1, "./example/semantic-compare/cat1.json");
};

const cat1 = `My favourite animal is the cat`;
getWordEmbeddings(cat1, "./example/semantic-compare/cat1.json");
run();
64 changes: 33 additions & 31 deletions packages/core/example/inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,47 @@ const model = path.resolve(process.cwd(), "../../ggml-alpaca-7b-q4.bin");

LLama.enableLogger();

const llama = LLama.create({
path: model,
numCtxTokens: 128,
});
const run = async () => {
const llama = await LLama.create({
path: model,
numCtxTokens: 128,
});

const template = `how are you`;
const template = `how are you`;

const prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request.
const prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
${template}
### Response:`;

llama.inference(
{
prompt,
numPredict: 128,
temp: 0.2,
topP: 1,
topK: 40,
repeatPenalty: 1,
repeatLastN: 64,
seed: 0,
feedPrompt: true,
// persistSession,
},
(response) => {
switch (response.type) {
case InferenceResultType.Data: {
process.stdout.write(response.data?.token ?? "");
break;
}
case InferenceResultType.End:
case InferenceResultType.Error: {
console.log(response);
break;
llama.inference(
{
prompt,
numPredict: 128,
temp: 0.2,
topP: 1,
topK: 40,
repeatPenalty: 1,
repeatLastN: 64,
seed: 0,
feedPrompt: true,
},
(response) => {
switch (response.type) {
case InferenceResultType.Data: {
process.stdout.write(response.data?.token ?? "");
break;
}
case InferenceResultType.End:
case InferenceResultType.Error: {
console.log(response);
break;
}
}
}
}
);
);
};
run();
60 changes: 32 additions & 28 deletions packages/core/example/loadsession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,38 @@ const loadSession = path.resolve(process.cwd(), "./tmp/session.bin");

LLama.enableLogger();

const llama = LLama.create({
path: model,
numCtxTokens: 128,
});
const run = async () => {
const llama = await LLama.create({
path: model,
numCtxTokens: 128,
});

llama.inference(
{
prompt: "",
numPredict: 128,
temp: 0.2,
topP: 1,
topK: 40,
repeatPenalty: 1,
repeatLastN: 64,
seed: 0,
loadSession,
},
(response) => {
switch (response.type) {
case InferenceResultType.Data: {
process.stdout.write(response.data?.token ?? "");
break;
}
case InferenceResultType.End:
case InferenceResultType.Error: {
console.log(response);
break;
llama.inference(
{
prompt: "",
numPredict: 128,
temp: 0.2,
topP: 1,
topK: 40,
repeatPenalty: 1,
repeatLastN: 64,
seed: 0,
loadSession,
},
(response) => {
switch (response.type) {
case InferenceResultType.Data: {
process.stdout.write(response.data?.token ?? "");
break;
}
case InferenceResultType.End:
case InferenceResultType.Error: {
console.log(response);
break;
}
}
}
}
);
);
};

run();
21 changes: 12 additions & 9 deletions packages/core/example/tokenize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@ const model = path.resolve(process.cwd(), "../../ggml-alpaca-7b-q4.bin");

LLama.enableLogger();

const llama = LLama.create({
path: model,
numCtxTokens: 128,
});
const run = async () => {
const llama = await LLama.create({
path: model,
numCtxTokens: 128,
});

const prompt = "My favourite animal is the cat";
const prompt = "My favourite animal is the cat";

llama.tokenize(prompt, (response) => {
console.log(response);
console.log(response.data.length); // 7
});
const tokens = await llama.tokenize(prompt);

console.log(tokens);
};

run();
Loading