Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Examples #502

Merged
merged 1 commit into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions docs/Examples/BatchDecoding.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Batch decoding

```cs
using System.Diagnostics;
using System.Text;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;

public class BatchedDecoding
{
private const int n_parallel = 8;
private const int n_len = 32;

public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();

Console.WriteLine("Prompt (leave blank to select automatically):");
var prompt = Console.ReadLine();
if (string.IsNullOrWhiteSpace(prompt))
prompt = "Not many people know that";

// Load model
var parameters = new ModelParams(modelPath);

using var model = LLamaWeights.LoadFromFile(parameters);

// Tokenize prompt
var prompt_tokens = model.Tokenize(prompt, true, false, Encoding.UTF8);
var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel;

// Create a context
parameters.ContextSize = (uint)model.ContextSize;
parameters.BatchSize = (uint)Math.Max(n_len, n_parallel);
using var context = model.CreateContext(parameters);

var n_ctx = context.ContextSize;

// make sure the KV cache is big enough to hold all the prompt and generated tokens
if (n_kv_req > n_ctx)
{
await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n");
await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n");
return;
}

var batch = new LLamaBatch();

// evaluate the initial prompt
batch.AddRange(prompt_tokens, 0, LLamaSeqId.Zero, true);

if (await context.DecodeAsync(batch) != DecodeResult.Ok)
{
await Console.Error.WriteLineAsync("llama_decode failed");
return;
}

// assign the system KV cache to all parallel sequences
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
for (var i = 1; i < n_parallel; ++i)
{
context.NativeHandle.KvCacheSequenceCopy((LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount);
}

if (n_parallel > 1)
{
Console.WriteLine();
Console.WriteLine($"generating {n_parallel} sequences...");
}

// remember the batch index of the last token for each parallel sequence
// we need this to determine which logits to sample from
List<int> i_batch = new();
for (var i = 0; i < n_parallel; i++)
i_batch.Add(batch.TokenCount - 1);

// Create per-stream decoder and sampler
var decoders = new StreamingTokenDecoder[n_parallel];
var samplers = new ISamplingPipeline[n_parallel];
for (var i = 0; i < n_parallel; i++)
{
decoders[i] = new StreamingTokenDecoder(context);
samplers[i] = new DefaultSamplingPipeline
{
Temperature = 0.1f + (float)i / n_parallel,
MinP = 0.25f,
};
}

var n_cur = batch.TokenCount;
var n_decode = 0;

var timer = new Stopwatch();
timer.Start();
while (n_cur <= n_len)
{
batch.Clear();

for (var i = 0; i < n_parallel; i++)
{
// Skip completed streams
if (i_batch[i] < 0)
continue;

// Use the sampling pipeline to select a token
var new_token_id = samplers[i].Sample(
context.NativeHandle,
context.NativeHandle.GetLogitsIth(i_batch[i]),
Array.Empty<LLamaToken>()
);

// Finish this stream early if necessary
if (new_token_id == model.EndOfSentenceToken || new_token_id == model.NewlineToken)
{
i_batch[i] = -1;
Console.WriteLine($"Completed Stream {i} early");
continue;
}

// Add this token to the decoder, so it will be turned into text
decoders[i].Add(new_token_id);

i_batch[i] = batch.TokenCount;

// push this new token for next evaluation
batch.Add(new_token_id, n_cur, (LLamaSeqId)i, true);

n_decode++;
}

// Check if all streams are finished
if (batch.TokenCount == 0)
{
break;
}

n_cur++;

// evaluate the current batch with the transformer model
if (await context.DecodeAsync(batch) != 0)
{
await Console.Error.WriteLineAsync("failed to eval");
return;
}
}

timer.Stop();
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine();
Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms");
Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second");

var index = 0;
foreach (var stream in decoders)
{
var text = stream.Read();

Console.ForegroundColor = ConsoleColor.Green;
Console.Write($"{index++}. {prompt}");
Console.ForegroundColor = ConsoleColor.Red;
Console.WriteLine(text);
}

Console.WriteLine("Press any key to exit demo");
Console.ReadKey(true);
}
}
```
125 changes: 125 additions & 0 deletions docs/Examples/ChatChineseGB2312.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Chat Chinese

```cs
using System.Text;
using LLama.Common;

public class ChatChineseGB2312
{
private static string ConvertEncoding(string input, Encoding original, Encoding target)
{
byte[] bytes = original.GetBytes(input);
var convertedBytes = Encoding.Convert(original, target, bytes);
return target.GetString(convertedBytes);
}

public static async Task Run()
{
// Register provider for GB2312 encoding
Encoding.RegisterProvider(CodePagesEncodingProvider.Instance);

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("This example shows how to use Chinese with gb2312 encoding, which is common in windows. It's recommended" +
" to use https://huggingface.co/hfl/chinese-alpaca-2-7b-gguf/blob/main/ggml-model-q5_0.gguf, which has been verified by LLamaSharp developers.");
Console.ForegroundColor = ConsoleColor.White;

Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();

var parameters = new ModelParams(modelPath)
{
ContextSize = 1024,
Seed = 1337,
GpuLayerCount = 5,
Encoding = Encoding.UTF8
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

ChatSession session;
if (Directory.Exists("Assets/chat-with-kunkun-chinese"))
{
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Loading session from disk.");
Console.ForegroundColor = ConsoleColor.White;

session = new ChatSession(executor);
session.LoadSession("Assets/chat-with-kunkun-chinese");
}
else
{
var chatHistoryJson = File.ReadAllText("Assets/chat-with-kunkun-chinese.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();

session = new ChatSession(executor, chatHistory);
}

session
.WithHistoryTransform(new LLamaTransforms.DefaultHistoryTransform("用户", "坤坤"));

InferenceParams inferenceParams = new InferenceParams()
{
Temperature = 0.9f,
AntiPrompts = new List<string> { "用户:" }
};

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started.");

// show the prompt
Console.ForegroundColor = ConsoleColor.White;
Console.Write("用户:");
Console.ForegroundColor = ConsoleColor.Green;
string userInput = Console.ReadLine() ?? "";

while (userInput != "exit")
{
// Convert the encoding from gb2312 to utf8 for the language model
// and later saving to the history json file.
userInput = ConvertEncoding(userInput, Encoding.GetEncoding("gb2312"), Encoding.UTF8);

if (userInput == "save")
{
session.SaveSession("Assets/chat-with-kunkun-chinese");
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Session saved.");
}
else if (userInput == "regenerate")
{
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Regenerating last response ...");

await foreach (
var text
in session.RegenerateAssistantMessageAsync(
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;

// Convert the encoding from utf8 to gb2312 for the console output.
Console.Write(ConvertEncoding(text, Encoding.UTF8, Encoding.GetEncoding("gb2312")));
}
}
else
{
await foreach (
var text
in session.ChatAsync(
new ChatHistory.Message(AuthorRole.User, userInput),
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
Console.Write(text);
}
}

Console.ForegroundColor = ConsoleColor.Green;
userInput = Console.ReadLine() ?? "";

Console.ForegroundColor = ConsoleColor.White;
}
}
}

```
48 changes: 39 additions & 9 deletions docs/Examples/ChatSessionStripRoleName.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,54 @@ public class ChatSessionStripRoleName
public static void Run()
{
Console.Write("Please input your model path: ");
string modelPath = Console.ReadLine();
var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim();
InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5)));
ChatSession session = new ChatSession(ex).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8));
var modelPath = Console.ReadLine();

var parameters = new ModelParams(modelPath)
{
ContextSize = 1024,
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();

ChatSession session = new(executor, chatHistory);
session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
new string[] { "User:", "Assistant:" },
redundancyLength: 8));

InferenceParams inferenceParams = new InferenceParams()
{
Temperature = 0.9f,
AntiPrompts = new List<string> { "User:" }
};

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started. The role names won't be printed.");
Console.ForegroundColor = ConsoleColor.White;
Console.WriteLine("The chat session has started.");

while (true)
// show the prompt
Console.ForegroundColor = ConsoleColor.Green;
string userInput = Console.ReadLine() ?? "";

while (userInput != "exit")
{
foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
await foreach (
var text
in session.ChatAsync(
new ChatHistory.Message(AuthorRole.User, userInput),
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
Console.Write(text);
}

Console.ForegroundColor = ConsoleColor.Green;
prompt = Console.ReadLine();
userInput = Console.ReadLine() ?? "";

Console.ForegroundColor = ConsoleColor.White;
}
}
Expand Down
Loading
Loading