diff --git a/LLama.Examples/ExampleRunner.cs b/LLama.Examples/ExampleRunner.cs index b25563aa5..7945a7e12 100644 --- a/LLama.Examples/ExampleRunner.cs +++ b/LLama.Examples/ExampleRunner.cs @@ -5,6 +5,7 @@ public class ExampleRunner { private static readonly Dictionary> Examples = new() { + { "Chat Session: LLama3", LLama3ChatSession.Run }, { "Chat Session: History", ChatSessionWithHistory.Run }, { "Chat Session: Role names", ChatSessionWithRoleName.Run }, { "Chat Session: Role names stripped", ChatSessionStripRoleName.Run }, diff --git a/LLama.Examples/Examples/LLama3ChatSession.cs b/LLama.Examples/Examples/LLama3ChatSession.cs new file mode 100644 index 000000000..c9a32e0ce --- /dev/null +++ b/LLama.Examples/Examples/LLama3ChatSession.cs @@ -0,0 +1,126 @@ +using LLama.Abstractions; +using LLama.Common; + +namespace LLama.Examples.Examples; + +// When using chatsession, it's a common case that you want to strip the role names +// rather than display them. This example shows how to use transforms to strip them. +public class LLama3ChatSession +{ + public static async Task Run() + { + string modelPath = UserSettings.GetModelPath(); + + var parameters = new ModelParams(modelPath) + { + Seed = 1337, + GpuLayerCount = 10 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var executor = new InteractiveExecutor(context); + + var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); + ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + + ChatSession session = new(executor, chatHistory); + session.WithHistoryTransform(new LLama3HistoryTransform()); + session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( + new string[] { "User:", "Assistant:", "�" }, + redundancyLength: 5)); + + InferenceParams inferenceParams = new InferenceParams() + { + Temperature = 0.6f, + AntiPrompts = new List { "User:" } + }; + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("The chat session has started."); + + // show the prompt + Console.ForegroundColor = ConsoleColor.Green; + string userInput = Console.ReadLine() ?? ""; + + while (userInput != "exit") + { + await foreach ( + var text + in session.ChatAsync( + new ChatHistory.Message(AuthorRole.User, userInput), + inferenceParams)) + { + Console.ForegroundColor = ConsoleColor.White; + Console.Write(text); + } + Console.WriteLine(); + + Console.ForegroundColor = ConsoleColor.Green; + userInput = Console.ReadLine() ?? ""; + + Console.ForegroundColor = ConsoleColor.White; + } + } + + class LLama3HistoryTransform : IHistoryTransform + { + /// + /// Convert a ChatHistory instance to plain text. + /// + /// The ChatHistory instance + /// + public string HistoryToText(ChatHistory history) + { + string res = Bos; + foreach (var message in history.Messages) + { + res += EncodeMessage(message); + } + res += EncodeHeader(new ChatHistory.Message(AuthorRole.Assistant, "")); + return res; + } + + private string EncodeHeader(ChatHistory.Message message) + { + string res = StartHeaderId; + res += message.AuthorRole.ToString(); + res += EndHeaderId; + res += "\n\n"; + return res; + } + + private string EncodeMessage(ChatHistory.Message message) + { + string res = EncodeHeader(message); + res += message.Content; + res += EndofTurn; + return res; + } + + /// + /// Converts plain text to a ChatHistory instance. + /// + /// The role for the author. + /// The chat history as plain text. + /// The updated history. + public ChatHistory TextToHistory(AuthorRole role, string text) + { + return new ChatHistory(new ChatHistory.Message[] { new ChatHistory.Message(role, text) }); + } + + /// + /// Copy the transform. + /// + /// + public IHistoryTransform Clone() + { + return new LLama3HistoryTransform(); + } + + private const string StartHeaderId = "<|start_header_id|>"; + private const string EndHeaderId = "<|end_header_id|>"; + private const string Bos = "<|begin_of_text|>"; + private const string Eos = "<|end_of_text|>"; + private const string EndofTurn = "<|eot_id|>"; + } +} diff --git a/LLama/LLamaTransforms.cs b/LLama/LLamaTransforms.cs index d74d9ddaf..f50d32c7f 100644 --- a/LLama/LLamaTransforms.cs +++ b/LLama/LLamaTransforms.cs @@ -235,7 +235,7 @@ public async IAsyncEnumerable TransformAsync(IAsyncEnumerable to var current = string.Join("", window); if (_keywords.Any(x => current.Contains(x))) { - var matchedKeyword = _keywords.First(x => current.Contains(x)); + var matchedKeywords = _keywords.Where(x => current.Contains(x)); int total = window.Count; for (int i = 0; i < total; i++) { @@ -243,7 +243,11 @@ public async IAsyncEnumerable TransformAsync(IAsyncEnumerable to } if (!_removeAllMatchedTokens) { - yield return current.Replace(matchedKeyword, ""); + foreach(var keyword in matchedKeywords) + { + current = current.Replace(keyword, ""); + } + yield return current; } } if (current.Length >= _maxKeywordLength)