From 35153a77dd51ad72903667590196b2117d0b9e05 Mon Sep 17 00:00:00 2001 From: eublefar Date: Sat, 2 Mar 2024 14:51:03 +0100 Subject: [PATCH 01/13] Chat session Get/Load in-memory state operations, reset state ops for stateful executors and context --- LLama/ChatSession.cs | 92 ++++++++++++++++++++++++++++++++++ LLama/LLamaContext.cs | 11 ++++ LLama/LLamaExecutorBase.cs | 8 +++ LLama/LLamaInstructExecutor.cs | 16 ++++++ LLama/LLamaInteractExecutor.cs | 16 ++++++ 5 files changed, 143 insertions(+) diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 45985b21c..eec9a9d38 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -8,6 +8,8 @@ using LLama.Abstractions; using LLama.Common; using static LLama.InteractiveExecutor; +using static LLama.LLamaContext; +using static LLama.StatefulExecutorBase; namespace LLama; @@ -134,6 +136,60 @@ public void SaveSession(string path) File.WriteAllText(historyFilepath, History.ToJson()); } + /// + /// Get the session state. + /// + /// SessionState object representing session state in-memory + public SessionState GetSessionState() + { + return new() + { + ExecutorState = ((StatefulExecutorBase)Executor).GetStateData(), + ContextState = Executor.Context.GetState(), + InputTransformPipeline = InputTransformPipeline, + OutputTransform = OutputTransform, + HistoryTransform = HistoryTransform, + History = History.ToJson() + }; + } + + /// + /// Load a session from a session state. + /// + /// + /// + /// + public void LoadSession(SessionState state) + { + if (Executor is StatefulExecutorBase statefulExecutor) + { + if (state.ExecutorState is null) + { + statefulExecutor.ResetState(); + } + else + { + statefulExecutor.LoadState(state.ExecutorState); + } + } + else + { + if (state.ExecutorState is not null) + { + throw new ArgumentException("Executor does not support state", nameof(state)); + } + } + if (state.ContextState is null) + { + Executor.Context.ResetState(); + } + else + { + Executor.Context.LoadState(state.ContextState); + } + History = ChatHistory.FromJson(state.History) ?? new(); + } + /// /// Load a session from a directory. /// @@ -494,3 +550,39 @@ in OutputTransform } } } + +/// +/// The state of a chat session in-memory. +/// +public record SessionState +{ + /// + /// Saved executor state for the session. + /// + public ExecutorBaseState? ExecutorState { get; init; } + + /// + /// Saved context state (KV cache) for the session. + /// + public State? ContextState { get; init; } + + /// + /// The input transform pipeline used in this session. + /// + public List InputTransformPipeline { get; init; } = new(); + + /// + /// The output transform used in this session. + /// + public ITextStreamTransform OutputTransform { get; init; } = new LLamaTransforms.EmptyTextOutputStreamTransform(); + + /// + /// The history transform used in this session. + /// + public IHistoryTransform HistoryTransform { get; init; } = new LLamaTransforms.DefaultHistoryTransform(); + + /// + /// The JSON representation of the chat history for this session. + /// + public string History { get; init; } = new ChatHistory().ToJson(); +} \ No newline at end of file diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index d8b418c31..b52b6e54f 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -24,6 +24,7 @@ public sealed class LLamaContext : IDisposable { private readonly ILogger? _logger; + private readonly State _emptyState; /// /// Total number of tokens in vocabulary of this model @@ -75,6 +76,7 @@ public LLamaContext(LLamaWeights model, IContextParams @params, ILogger? logger @params.ToLlamaContextParams(out var lparams); NativeHandle = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); + _emptyState = GetState(); } /// @@ -214,6 +216,15 @@ public void LoadState(State state) } } + /// + /// Reset the context to the empty state. + /// + /// + public void ResetState() + { + LoadState(_emptyState); + } + /// /// Sample a single token from this context, using the given sampling pipeline /// diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 3a697507b..f96b11bd0 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -250,6 +250,14 @@ protected virtual void TryReuseMathingPrefix() /// public abstract ExecutorBaseState GetStateData(); + + /// + /// Resets the executor to its initial state. + /// Note: Does not affect the context and KV cache. + /// + /// + public abstract void ResetState(); + /// /// Load the state from data. /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 9476976e9..13e8f0675 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -110,6 +110,22 @@ public override async Task LoadState(string filename) } } + /// + public override void ResetState() + { + _n_session_consumed = 0; + _embed_inps = new List(); + _is_prompt_run = true; + _consumedTokensCount = 0; + _embeds = new List(); + _last_n_tokens = new FixedSizeQueue((int) Context.ContextSize); + _n_matching_session_tokens = 0; + _pastTokensCount = 0; + _pathSession = null; + _session_tokens = new List(); + MirostatMu = 0; + } + /// protected override Task GetLoopCondition(InferStateArgs args) { diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 79f1b8cc4..9c6ae9546 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -93,6 +93,22 @@ public override async Task LoadState(string filename) } } + /// + public override void ResetState() + { + _n_session_consumed = 0; + _embed_inps = new List(); + _is_prompt_run = true; + _consumedTokensCount = 0; + _embeds = new List(); + _last_n_tokens = new FixedSizeQueue((int)Context.ContextSize); + _n_matching_session_tokens = 0; + _pastTokensCount = 0; + _pathSession = null; + _session_tokens = new List(); + MirostatMu = 0; + } + /// /// Define whether to continue the loop to generate responses. /// From b2f7dbb39b35fefee320aadd00da98446658f71f Mon Sep 17 00:00:00 2001 From: eublefar Date: Sat, 2 Mar 2024 17:26:06 +0100 Subject: [PATCH 02/13] AddPromptAsync method for stateful executors, Chat session initialize from history and process system message methods for pre-processing prompts. Serializing executor state to JSON, to avoid saved states from being updated by reference. --- LLama/ChatSession.cs | 61 +++++++++++++++++++++++++++++++++++--- LLama/LLamaContext.cs | 2 +- LLama/LLamaExecutorBase.cs | 28 +++++++++++++++++ 3 files changed, 86 insertions(+), 5 deletions(-) diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index eec9a9d38..2a74bb291 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -3,6 +3,7 @@ using System.IO; using System.Linq; using System.Runtime.CompilerServices; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using LLama.Abstractions; @@ -47,6 +48,27 @@ public class ChatSession /// public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform(); + /// + /// Create a new chat session and preprocess history. + /// + /// The executor for this session + /// History for this session + /// Cancellation token to stop session pre-processing + /// + public static async Task InitializeSessionFromHistoryAsync( + ILLamaExecutor executor, + ChatHistory history, + CancellationToken cancellationToken = default) + { + if (executor is not StatefulExecutorBase statefulExecutor) + { + throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor)); + } + var session = new ChatSession(executor, history); + await statefulExecutor.AddPromptAsync(session.HistoryTransform.HistoryToText(history), cancellationToken); + return session; + } + /// /// Create a new chat session. /// @@ -144,7 +166,7 @@ public SessionState GetSessionState() { return new() { - ExecutorState = ((StatefulExecutorBase)Executor).GetStateData(), + ExecutorState = JsonSerializer.Serialize(((StatefulExecutorBase)Executor).GetStateData()), ContextState = Executor.Context.GetState(), InputTransformPipeline = InputTransformPipeline, OutputTransform = OutputTransform, @@ -169,7 +191,11 @@ public void LoadSession(SessionState state) } else { - statefulExecutor.LoadState(state.ExecutorState); + statefulExecutor.LoadState( + JsonSerializer.Deserialize( + state.ExecutorState, statefulExecutor.GetStateData().GetType() + ) as ExecutorBaseState ?? throw new ArgumentException("Executor state is invalid", nameof(state)) + ); } } else @@ -260,6 +286,33 @@ public ChatSession AddMessage(ChatHistory.Message message) return this; } + + /// + /// Compute KV cache for the system message and add it to the chat history. + /// + /// + /// + public async Task ProcessSystemMessage(string content) + { + if (Executor is not StatefulExecutorBase statefulExecutor) + { + throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages."); + } + if (History.Messages.Count > 0) + { + throw new ArgumentException("Cannot add a system message after another message", nameof(content)); + } + foreach (var inputTransform in InputTransformPipeline) + { + content = inputTransform.Transform(content); + } + + await statefulExecutor.AddPromptAsync(content); + + History.AddMessage(AuthorRole.System, content); + return this; + } + /// /// Add a system message to the chat history. /// @@ -557,9 +610,9 @@ in OutputTransform public record SessionState { /// - /// Saved executor state for the session. + /// Saved executor state for the session in JSON format. /// - public ExecutorBaseState? ExecutorState { get; init; } + public string? ExecutorState { get; init; } /// /// Saved context state (KV cache) for the session. diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index b52b6e54f..dc0508e50 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -222,7 +222,7 @@ public void LoadState(State state) /// public void ResetState() { - LoadState(_emptyState); + NativeApi.llama_kv_cache_clear(NativeHandle); } /// diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index f96b11bd0..b9a0b4124 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -323,6 +323,34 @@ public virtual async IAsyncEnumerable InferAsync(string text, IInference } } + /// + /// Asynchronously runs a prompt through the model to compute KV cache without generating any new tokens. + /// + /// Prompt to process + /// A cancellation token + /// + public virtual async Task AddPromptAsync(string prompt, CancellationToken cancellationToken = default) + { + var inferenceParams = new InferenceParams + { + MaxTokens = 0 + }; + var args = new InferStateArgs + { + Antiprompts = new List(), + RemainedTokens = 0, + ReturnValue = false, + WaitForInput = true, + NeedToSaveSession = false + }; + + await PreprocessInputs(prompt, args); + // First run adds the prompt to the _embeds + await InferInternal(inferenceParams, args); + // Second run puts it through decode + await InferInternal(inferenceParams, args); + } + /// /// State arguments that are used in single inference /// From 0763f307ec06d85725d81c5f654c86218eb9068f Mon Sep 17 00:00:00 2001 From: eublefar Date: Sat, 2 Mar 2024 17:27:18 +0100 Subject: [PATCH 03/13] Example chat session with preprocessing of chat history and reset operation that resets chat to original point of history without extra processing --- LLama.Examples/ExampleRunner.cs | 1 + .../Examples/ChatSessionWithRestart.cs | 94 +++++++++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 LLama.Examples/Examples/ChatSessionWithRestart.cs diff --git a/LLama.Examples/ExampleRunner.cs b/LLama.Examples/ExampleRunner.cs index 790a1f9c6..b74170e3a 100644 --- a/LLama.Examples/ExampleRunner.cs +++ b/LLama.Examples/ExampleRunner.cs @@ -8,6 +8,7 @@ public class ExampleRunner { "Chat Session: History", ChatSessionWithHistory.Run }, { "Chat Session: Role names", ChatSessionWithRoleName.Run }, { "Chat Session: Role names stripped", ChatSessionStripRoleName.Run }, + { "Chat Session: Pre-processing and reset", ChatSessionWithRestart.Run }, { "Chat Session: Coding Assistant", CodingAssistant.Run }, { "Chat Session: Automatic conversation", TalkToYourself.Run }, { "Chat Session: Chinese characters", ChatChineseGB2312.Run }, diff --git a/LLama.Examples/Examples/ChatSessionWithRestart.cs b/LLama.Examples/Examples/ChatSessionWithRestart.cs new file mode 100644 index 000000000..3462c5062 --- /dev/null +++ b/LLama.Examples/Examples/ChatSessionWithRestart.cs @@ -0,0 +1,94 @@ +using LLama.Common; + +namespace LLama.Examples.Examples; + +public class ChatSessionWithRestart +{ + public static async Task Run() + { + string modelPath = UserSettings.GetModelPath(); + + var parameters = new ModelParams(modelPath) + { + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var executor = new InteractiveExecutor(context); + + var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); + ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + ChatSession prototypeSession = + await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory); + prototypeSession.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( + new string[] { "User:", "Assistant:" }, + redundancyLength: 8)); + var resetState = prototypeSession.GetSessionState(); + + ChatSession session = new ChatSession(executor); + session.LoadSession(resetState); + + InferenceParams inferenceParams = new InferenceParams() + { + Temperature = 0.9f, + AntiPrompts = new List { "User:" } + }; + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("The chat session has started."); + + // show the prompt + Console.ForegroundColor = ConsoleColor.Green; + string userInput = Console.ReadLine() ?? ""; + + while (userInput != "exit") + { + if(userInput == "reset") + { + session.LoadSession(resetState); + Console.WriteLine($"History: {session.HistoryTransform.HistoryToText(session.History)}"); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Session reset."); + } + else if (userInput == "save") + { + session.SaveSession("Assets/chat-with-bob"); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Session saved."); + } + else if (userInput == "regenerate") + { + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Regenerating last response ..."); + + await foreach ( + var text + in session.RegenerateAssistantMessageAsync( + inferenceParams)) + { + Console.ForegroundColor = ConsoleColor.White; + Console.Write(text); + } + } + else + { + await foreach ( + var text + in session.ChatAsync( + new ChatHistory.Message(AuthorRole.User, userInput), + inferenceParams)) + { + Console.ForegroundColor = ConsoleColor.White; + Console.Write(text); + } + } + + Console.ForegroundColor = ConsoleColor.Green; + userInput = Console.ReadLine() ?? ""; + + Console.ForegroundColor = ConsoleColor.White; + } + } +} From e05d5d4e140da30c1994dfbd076205eb695b1dd7 Mon Sep 17 00:00:00 2001 From: eublefar Date: Sat, 2 Mar 2024 20:07:17 +0100 Subject: [PATCH 04/13] Remove resetting state ops and make SessionState.ExecutorState and SessionState.ContextState no nullable --- LLama/ChatSession.cs | 50 ++++++++++++++-------------------- LLama/LLamaContext.cs | 11 -------- LLama/LLamaExecutorBase.cs | 8 ------ LLama/LLamaInstructExecutor.cs | 16 ----------- LLama/LLamaInteractExecutor.cs | 16 ----------- 5 files changed, 21 insertions(+), 80 deletions(-) diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 2a74bb291..251573fc4 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -164,10 +164,8 @@ public void SaveSession(string path) /// SessionState object representing session state in-memory public SessionState GetSessionState() { - return new() + return new SessionState(Executor.Context.GetState(), ((StatefulExecutorBase)Executor).GetStateData()) { - ExecutorState = JsonSerializer.Serialize(((StatefulExecutorBase)Executor).GetStateData()), - ContextState = Executor.Context.GetState(), InputTransformPipeline = InputTransformPipeline, OutputTransform = OutputTransform, HistoryTransform = HistoryTransform, @@ -185,34 +183,17 @@ public void LoadSession(SessionState state) { if (Executor is StatefulExecutorBase statefulExecutor) { - if (state.ExecutorState is null) - { - statefulExecutor.ResetState(); - } - else - { - statefulExecutor.LoadState( - JsonSerializer.Deserialize( - state.ExecutorState, statefulExecutor.GetStateData().GetType() - ) as ExecutorBaseState ?? throw new ArgumentException("Executor state is invalid", nameof(state)) - ); - } - } - else - { - if (state.ExecutorState is not null) - { - throw new ArgumentException("Executor does not support state", nameof(state)); - } - } - if (state.ContextState is null) - { - Executor.Context.ResetState(); + statefulExecutor.LoadState( + JsonSerializer.Deserialize( + state.ExecutorState, statefulExecutor.GetStateData().GetType() + ) as ExecutorBaseState ?? throw new ArgumentException("Executor state is invalid", nameof(state)) + ); } else { - Executor.Context.LoadState(state.ContextState); + throw new ArgumentException("Executor must be a StatefulExecutorBase to support loading of session state", nameof(state)); } + Executor.Context.LoadState(state.ContextState); History = ChatHistory.FromJson(state.History) ?? new(); } @@ -612,12 +593,12 @@ public record SessionState /// /// Saved executor state for the session in JSON format. /// - public string? ExecutorState { get; init; } + public string ExecutorState { get; init; } /// /// Saved context state (KV cache) for the session. /// - public State? ContextState { get; init; } + public State ContextState { get; init; } /// /// The input transform pipeline used in this session. @@ -638,4 +619,15 @@ public record SessionState /// The JSON representation of the chat history for this session. /// public string History { get; init; } = new ChatHistory().ToJson(); + + /// + /// Create a new session state. + /// + /// + /// + public SessionState(State contextState, ExecutorBaseState executorState) + { + ContextState = contextState; + ExecutorState = JsonSerializer.Serialize(executorState); + } } \ No newline at end of file diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index dc0508e50..d8b418c31 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -24,7 +24,6 @@ public sealed class LLamaContext : IDisposable { private readonly ILogger? _logger; - private readonly State _emptyState; /// /// Total number of tokens in vocabulary of this model @@ -76,7 +75,6 @@ public LLamaContext(LLamaWeights model, IContextParams @params, ILogger? logger @params.ToLlamaContextParams(out var lparams); NativeHandle = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); - _emptyState = GetState(); } /// @@ -216,15 +214,6 @@ public void LoadState(State state) } } - /// - /// Reset the context to the empty state. - /// - /// - public void ResetState() - { - NativeApi.llama_kv_cache_clear(NativeHandle); - } - /// /// Sample a single token from this context, using the given sampling pipeline /// diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index b9a0b4124..7bdb8d2b0 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -250,14 +250,6 @@ protected virtual void TryReuseMathingPrefix() /// public abstract ExecutorBaseState GetStateData(); - - /// - /// Resets the executor to its initial state. - /// Note: Does not affect the context and KV cache. - /// - /// - public abstract void ResetState(); - /// /// Load the state from data. /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 13e8f0675..9476976e9 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -110,22 +110,6 @@ public override async Task LoadState(string filename) } } - /// - public override void ResetState() - { - _n_session_consumed = 0; - _embed_inps = new List(); - _is_prompt_run = true; - _consumedTokensCount = 0; - _embeds = new List(); - _last_n_tokens = new FixedSizeQueue((int) Context.ContextSize); - _n_matching_session_tokens = 0; - _pastTokensCount = 0; - _pathSession = null; - _session_tokens = new List(); - MirostatMu = 0; - } - /// protected override Task GetLoopCondition(InferStateArgs args) { diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 9c6ae9546..79f1b8cc4 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -93,22 +93,6 @@ public override async Task LoadState(string filename) } } - /// - public override void ResetState() - { - _n_session_consumed = 0; - _embed_inps = new List(); - _is_prompt_run = true; - _consumedTokensCount = 0; - _embeds = new List(); - _last_n_tokens = new FixedSizeQueue((int)Context.ContextSize); - _n_matching_session_tokens = 0; - _pastTokensCount = 0; - _pathSession = null; - _session_tokens = new List(); - MirostatMu = 0; - } - /// /// Define whether to continue the loop to generate responses. /// From af796fc3e907fdc4bb4bd2f432c62da8eaa993df Mon Sep 17 00:00:00 2001 From: eublefar Date: Sun, 17 Mar 2024 11:58:26 +0100 Subject: [PATCH 05/13] Change List types in executor state to arrays to enforce copy on get/set operations --- LLama/LLamaExecutorBase.cs | 6 +++--- LLama/LLamaInstructExecutor.cs | 12 ++++++------ LLama/LLamaInteractExecutor.cs | 12 ++++++------ 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 7bdb8d2b0..c2b083059 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -388,13 +388,13 @@ public class ExecutorBaseState public string? SessionFilePath { get; set; } [JsonPropertyName("embd")] - public List Embeds { get; set; } + public LLamaToken[] Embeds { get; set; } [JsonPropertyName("embd_inps")] - public List EmbedInps { get; set; } + public LLamaToken[] EmbedInps { get; set; } [JsonPropertyName("session_tokens")] - public List SessionTokens { get; set; } + public LLamaToken[] SessionTokens { get; set; } [JsonPropertyName("last_n_tokens")] public LLamaToken[] LastTokens { get; set; } diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 9476976e9..99d45e5a5 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -49,17 +49,17 @@ public override ExecutorBaseState GetStateData() InstructExecutorState state = new() { ConsumedSessionCount = _n_session_consumed, - EmbedInps = _embed_inps, + EmbedInps = _embed_inps.ToArray(), IsPromptRun = _is_prompt_run, ConsumedTokensCount = _consumedTokensCount, - Embeds = _embeds, + Embeds = _embeds.ToArray(), LastTokens = _last_n_tokens.ToArray(), InputPrefixTokens = _inp_pfx, InputSuffixTokens = _inp_sfx, MatchingSessionTokensCount = _n_matching_session_tokens, PastTokensCount = _pastTokensCount, SessionFilePath = _pathSession, - SessionTokens = _session_tokens, + SessionTokens = _session_tokens.ToArray(), LastTokensCapacity = _last_n_tokens.Capacity, MirostatMu = MirostatMu }; @@ -71,17 +71,17 @@ public override Task LoadState(ExecutorBaseState data) if(data is InstructExecutorState state) { _n_session_consumed = state.ConsumedSessionCount; - _embed_inps = state.EmbedInps; + _embed_inps = state.EmbedInps.ToList(); _is_prompt_run = state.IsPromptRun; _consumedTokensCount = state.ConsumedTokensCount; - _embeds = state.Embeds; + _embeds = state.Embeds.ToList(); _last_n_tokens = new FixedSizeQueue(state.LastTokensCapacity, state.LastTokens); _inp_pfx = state.InputPrefixTokens; _inp_sfx = state.InputSuffixTokens; _n_matching_session_tokens = state.MatchingSessionTokensCount; _pastTokensCount = state.PastTokensCount; _pathSession = state.SessionFilePath; - _session_tokens = state.SessionTokens; + _session_tokens = state.SessionTokens.ToList(); } else { diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 79f1b8cc4..2a14eeaf5 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -39,15 +39,15 @@ public override ExecutorBaseState GetStateData() InteractiveExecutorState state = new() { ConsumedSessionCount = _n_session_consumed, - EmbedInps = _embed_inps, + EmbedInps = _embed_inps.ToArray(), IsPromptRun = _is_prompt_run, ConsumedTokensCount = _consumedTokensCount, - Embeds = _embeds, + Embeds = _embeds.ToArray(), LastTokens = _last_n_tokens.ToArray(), MatchingSessionTokensCount = _n_matching_session_tokens, PastTokensCount = _pastTokensCount, SessionFilePath = _pathSession, - SessionTokens = _session_tokens, + SessionTokens = _session_tokens.ToArray(), LastTokensCapacity = _last_n_tokens.Capacity, MirostatMu = MirostatMu }; @@ -59,15 +59,15 @@ public override Task LoadState(ExecutorBaseState data) if (data is InteractiveExecutorState state) { _n_session_consumed = state.ConsumedSessionCount; - _embed_inps = state.EmbedInps; + _embed_inps = state.EmbedInps.ToList(); _is_prompt_run = state.IsPromptRun; _consumedTokensCount = state.ConsumedTokensCount; - _embeds = state.Embeds; + _embeds = state.Embeds.ToList(); _last_n_tokens = new FixedSizeQueue(state.LastTokensCapacity, state.LastTokens); _n_matching_session_tokens = state.MatchingSessionTokensCount; _pastTokensCount = state.PastTokensCount; _pathSession = state.SessionFilePath; - _session_tokens = state.SessionTokens; + _session_tokens = state.SessionTokens.ToList(); } else throw new ArgumentException("Invalid state data type."); From 87fe982f102e35404dff26e9e8ce536ccb541b78 Mon Sep 17 00:00:00 2001 From: eublefar Date: Sun, 17 Mar 2024 12:11:19 +0100 Subject: [PATCH 06/13] Change method signature as suggested --- LLama/LLamaExecutorBase.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index c2b083059..ea5616b56 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -317,11 +317,11 @@ public virtual async IAsyncEnumerable InferAsync(string text, IInference /// /// Asynchronously runs a prompt through the model to compute KV cache without generating any new tokens. + /// It could reduce the latency of the first time response if the first input from the user is not immediate. /// /// Prompt to process - /// A cancellation token /// - public virtual async Task AddPromptAsync(string prompt, CancellationToken cancellationToken = default) + public virtual async Task PrefillPromptAsync(string prompt) { var inferenceParams = new InferenceParams { From 5f3803d23c8b4365ea8a456cfb1c4ef49eb4429b Mon Sep 17 00:00:00 2001 From: eublefar Date: Sun, 17 Mar 2024 12:21:52 +0100 Subject: [PATCH 07/13] Make state editable by the user, add deepcopy to fields that require it --- LLama/ChatSession.cs | 64 +++++++++++++++++++++---------------- LLama/Common/ChatHistory.cs | 10 ++++++ 2 files changed, 46 insertions(+), 28 deletions(-) diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 251573fc4..1cc7d29ee 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using LLama.Abstractions; using LLama.Common; +using static LLama.Common.ChatHistory; using static LLama.InteractiveExecutor; using static LLama.LLamaContext; using static LLama.StatefulExecutorBase; @@ -53,19 +54,16 @@ public class ChatSession /// /// The executor for this session /// History for this session - /// Cancellation token to stop session pre-processing /// public static async Task InitializeSessionFromHistoryAsync( - ILLamaExecutor executor, - ChatHistory history, - CancellationToken cancellationToken = default) + ILLamaExecutor executor, ChatHistory history) { if (executor is not StatefulExecutorBase statefulExecutor) { throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor)); } var session = new ChatSession(executor, history); - await statefulExecutor.AddPromptAsync(session.HistoryTransform.HistoryToText(history), cancellationToken); + await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history)); return session; } @@ -164,13 +162,13 @@ public void SaveSession(string path) /// SessionState object representing session state in-memory public SessionState GetSessionState() { - return new SessionState(Executor.Context.GetState(), ((StatefulExecutorBase)Executor).GetStateData()) - { - InputTransformPipeline = InputTransformPipeline, - OutputTransform = OutputTransform, - HistoryTransform = HistoryTransform, - History = History.ToJson() - }; + return new SessionState( + Executor.Context.GetState(), + ((StatefulExecutorBase)Executor).GetStateData(), + History, + InputTransformPipeline, + OutputTransform, + HistoryTransform); } /// @@ -183,18 +181,17 @@ public void LoadSession(SessionState state) { if (Executor is StatefulExecutorBase statefulExecutor) { - statefulExecutor.LoadState( - JsonSerializer.Deserialize( - state.ExecutorState, statefulExecutor.GetStateData().GetType() - ) as ExecutorBaseState ?? throw new ArgumentException("Executor state is invalid", nameof(state)) - ); + statefulExecutor.LoadState(state.ExecutorState); } else { throw new ArgumentException("Executor must be a StatefulExecutorBase to support loading of session state", nameof(state)); } Executor.Context.LoadState(state.ContextState); - History = ChatHistory.FromJson(state.History) ?? new(); + History = new ChatHistory(state.History); + InputTransformPipeline = state.InputTransformPipeline.ToList(); + OutputTransform = state.OutputTransform; + HistoryTransform = state.HistoryTransform; } /// @@ -288,7 +285,7 @@ public async Task ProcessSystemMessage(string content) content = inputTransform.Transform(content); } - await statefulExecutor.AddPromptAsync(content); + await statefulExecutor.PrefillPromptAsync(content); History.AddMessage(AuthorRole.System, content); return this; @@ -593,41 +590,52 @@ public record SessionState /// /// Saved executor state for the session in JSON format. /// - public string ExecutorState { get; init; } + public ExecutorBaseState ExecutorState { get; set; } /// /// Saved context state (KV cache) for the session. /// - public State ContextState { get; init; } + public State ContextState { get; set; } /// /// The input transform pipeline used in this session. /// - public List InputTransformPipeline { get; init; } = new(); + public ITextTransform[] InputTransformPipeline { get; set; } = Array.Empty(); /// /// The output transform used in this session. /// - public ITextStreamTransform OutputTransform { get; init; } = new LLamaTransforms.EmptyTextOutputStreamTransform(); + public ITextStreamTransform OutputTransform { get; set; } = new LLamaTransforms.EmptyTextOutputStreamTransform(); /// /// The history transform used in this session. /// - public IHistoryTransform HistoryTransform { get; init; } = new LLamaTransforms.DefaultHistoryTransform(); + public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); /// - /// The JSON representation of the chat history for this session. + /// The the chat history messages for this session. /// - public string History { get; init; } = new ChatHistory().ToJson(); + public Message[] History { get; set; } = Array.Empty(); /// /// Create a new session state. /// /// /// - public SessionState(State contextState, ExecutorBaseState executorState) + /// + /// + /// + /// + public SessionState( + State contextState, ExecutorBaseState executorState, + ChatHistory history, List inputTransformPipeline, + ITextStreamTransform outputTransform, IHistoryTransform historyTransform) { ContextState = contextState; - ExecutorState = JsonSerializer.Serialize(executorState); + ExecutorState = executorState; + History = history.Messages.ToArray(); + InputTransformPipeline = inputTransformPipeline.ToArray(); + OutputTransform = outputTransform; + HistoryTransform = historyTransform; } } \ No newline at end of file diff --git a/LLama/Common/ChatHistory.cs b/LLama/Common/ChatHistory.cs index dc7414490..c22cc7c06 100644 --- a/LLama/Common/ChatHistory.cs +++ b/LLama/Common/ChatHistory.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; @@ -80,6 +81,15 @@ public Message(AuthorRole authorRole, string content) [JsonConstructor] public ChatHistory() { } + /// + /// Create a new instance of the chat history from array of messages + /// + /// + public ChatHistory(Message[] messageHistory) + { + this.Messages = messageHistory.ToList(); + } + /// /// Add a message to the chat history /// From 6f76d773503936ff52c3c692322865582d43a24e Mon Sep 17 00:00:00 2001 From: eublefar Date: Sun, 17 Mar 2024 12:37:02 +0100 Subject: [PATCH 08/13] Make text transform interfaces have explicit copy operation --- LLama/Abstractions/IHistoryTransform.cs | 6 ++++++ LLama/Abstractions/ITextStreamTransform.cs | 6 ++++++ LLama/Abstractions/ITextTransform.cs | 6 ++++++ LLama/ChatSession.cs | 12 +++++------ LLama/LLamaTransforms.cs | 24 ++++++++++++++++++++++ 5 files changed, 48 insertions(+), 6 deletions(-) diff --git a/LLama/Abstractions/IHistoryTransform.cs b/LLama/Abstractions/IHistoryTransform.cs index c9217ae0f..5651343f2 100644 --- a/LLama/Abstractions/IHistoryTransform.cs +++ b/LLama/Abstractions/IHistoryTransform.cs @@ -21,5 +21,11 @@ public interface IHistoryTransform /// The chat history as plain text. /// The updated history. ChatHistory TextToHistory(AuthorRole role, string text); + + /// + /// Copy the transform. + /// + /// + IHistoryTransform Clone(); } } diff --git a/LLama/Abstractions/ITextStreamTransform.cs b/LLama/Abstractions/ITextStreamTransform.cs index 2725214f5..2b63299da 100644 --- a/LLama/Abstractions/ITextStreamTransform.cs +++ b/LLama/Abstractions/ITextStreamTransform.cs @@ -13,5 +13,11 @@ public interface ITextStreamTransform /// /// IAsyncEnumerable TransformAsync(IAsyncEnumerable tokens); + + /// + /// Copy the transform. + /// + /// + ITextStreamTransform Clone(); } } diff --git a/LLama/Abstractions/ITextTransform.cs b/LLama/Abstractions/ITextTransform.cs index ac196644e..0bfeeb7f6 100644 --- a/LLama/Abstractions/ITextTransform.cs +++ b/LLama/Abstractions/ITextTransform.cs @@ -17,5 +17,11 @@ public interface ITextTransform /// /// string Transform(string text); + + /// + /// Copy the transform. + /// + /// + ITextTransform Clone(); } } diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 1cc7d29ee..b41178427 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -189,9 +189,9 @@ public void LoadSession(SessionState state) } Executor.Context.LoadState(state.ContextState); History = new ChatHistory(state.History); - InputTransformPipeline = state.InputTransformPipeline.ToList(); - OutputTransform = state.OutputTransform; - HistoryTransform = state.HistoryTransform; + InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList(); + OutputTransform = state.OutputTransform.Clone(); + HistoryTransform = state.HistoryTransform.Clone(); } /// @@ -634,8 +634,8 @@ public SessionState( ContextState = contextState; ExecutorState = executorState; History = history.Messages.ToArray(); - InputTransformPipeline = inputTransformPipeline.ToArray(); - OutputTransform = outputTransform; - HistoryTransform = historyTransform; + InputTransformPipeline = inputTransformPipeline.Select(t => t.Clone()).ToArray(); + OutputTransform = outputTransform.Clone(); + HistoryTransform = historyTransform.Clone(); } } \ No newline at end of file diff --git a/LLama/LLamaTransforms.cs b/LLama/LLamaTransforms.cs index 29c16c187..1ac0a79be 100644 --- a/LLama/LLamaTransforms.cs +++ b/LLama/LLamaTransforms.cs @@ -47,6 +47,12 @@ public DefaultHistoryTransform(string? userName = null, string? assistantName = _isInstructMode = isInstructMode; } + /// + public IHistoryTransform Clone() + { + return new DefaultHistoryTransform(_userName, _assistantName, _systemName, _unknownName, _isInstructMode); + } + /// public virtual string HistoryToText(ChatHistory history) { @@ -116,6 +122,12 @@ public string Transform(string text) { return text.Trim(); } + + /// + public ITextTransform Clone() + { + return new NaiveTextInputTransform(); + } } /// @@ -129,6 +141,12 @@ public IAsyncEnumerable TransformAsync(IAsyncEnumerable tokens) { return tokens; } + + /// + public ITextStreamTransform Clone() + { + return new EmptyTextOutputStreamTransform(); + } } /// @@ -157,6 +175,12 @@ public KeywordTextOutputStreamTransform(IEnumerable keywords, int redund _removeAllMatchedTokens = removeAllMatchedTokens; } + /// + public ITextStreamTransform Clone() + { + return new KeywordTextOutputStreamTransform(_keywords, _maxKeywordLength, _removeAllMatchedTokens); + } + /// public async IAsyncEnumerable TransformAsync(IAsyncEnumerable tokens) { From a31391edd7ddd9d7e3e85d8848b037f2dcc9d674 Mon Sep 17 00:00:00 2001 From: eublefar Date: Sun, 17 Mar 2024 15:34:36 +0100 Subject: [PATCH 09/13] Polymorphic serialization for executor state and transforms --- .../Examples/ChatSessionWithHistory.cs | 6 + .../Examples/ChatSessionWithRestart.cs | 7 +- LLama/Abstractions/IHistoryTransform.cs | 2 + LLama/Abstractions/ITextStreamTransform.cs | 5 +- LLama/Abstractions/ITextTransform.cs | 6 +- LLama/ChatSession.cs | 195 ++++++++++++++---- LLama/Common/PolymorphicJSONConverter.cs | 57 +++++ LLama/LLamaContext.cs | 30 ++- LLama/LLamaExecutorBase.cs | 1 + LLama/LLamaTransforms.cs | 43 ++++ 10 files changed, 302 insertions(+), 50 deletions(-) create mode 100644 LLama/Common/PolymorphicJSONConverter.cs diff --git a/LLama.Examples/Examples/ChatSessionWithHistory.cs b/LLama.Examples/Examples/ChatSessionWithHistory.cs index 17908908d..6a84d2fd7 100644 --- a/LLama.Examples/Examples/ChatSessionWithHistory.cs +++ b/LLama.Examples/Examples/ChatSessionWithHistory.cs @@ -61,6 +61,12 @@ public static async Task Run() Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Session saved."); } + else if (userInput == "load") + { + session.LoadSession("Assets/chat-with-bob"); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Session loaded."); + } else if (userInput == "regenerate") { Console.ForegroundColor = ConsoleColor.Yellow; diff --git a/LLama.Examples/Examples/ChatSessionWithRestart.cs b/LLama.Examples/Examples/ChatSessionWithRestart.cs index 3462c5062..234bac3c6 100644 --- a/LLama.Examples/Examples/ChatSessionWithRestart.cs +++ b/LLama.Examples/Examples/ChatSessionWithRestart.cs @@ -37,7 +37,8 @@ public static async Task Run() }; Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("The chat session has started."); + Console.WriteLine("The chat session has started. Write `save` to save session in memory." + + " Write `reset` to start from the last saved checkpoint"); // show the prompt Console.ForegroundColor = ConsoleColor.Green; @@ -48,13 +49,13 @@ public static async Task Run() if(userInput == "reset") { session.LoadSession(resetState); - Console.WriteLine($"History: {session.HistoryTransform.HistoryToText(session.History)}"); + Console.WriteLine($"Reset to history:\n{session.HistoryTransform.HistoryToText(session.History)}"); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Session reset."); } else if (userInput == "save") { - session.SaveSession("Assets/chat-with-bob"); + resetState = session.GetSessionState(); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Session saved."); } diff --git a/LLama/Abstractions/IHistoryTransform.cs b/LLama/Abstractions/IHistoryTransform.cs index 5651343f2..9644b3e1d 100644 --- a/LLama/Abstractions/IHistoryTransform.cs +++ b/LLama/Abstractions/IHistoryTransform.cs @@ -1,10 +1,12 @@ using LLama.Common; +using System.Text.Json.Serialization; namespace LLama.Abstractions { /// /// Transform history to plain text and vice versa. /// + [JsonConverter(typeof(PolymorphicJSONConverter))] public interface IHistoryTransform { /// diff --git a/LLama/Abstractions/ITextStreamTransform.cs b/LLama/Abstractions/ITextStreamTransform.cs index 2b63299da..3ebdba675 100644 --- a/LLama/Abstractions/ITextStreamTransform.cs +++ b/LLama/Abstractions/ITextStreamTransform.cs @@ -1,10 +1,13 @@ -using System.Collections.Generic; +using LLama.Common; +using System.Collections.Generic; +using System.Text.Json.Serialization; namespace LLama.Abstractions { /// /// Takes a stream of tokens and transforms them. /// + [JsonConverter(typeof(PolymorphicJSONConverter))] public interface ITextStreamTransform { /// diff --git a/LLama/Abstractions/ITextTransform.cs b/LLama/Abstractions/ITextTransform.cs index 0bfeeb7f6..f6f743f9f 100644 --- a/LLama/Abstractions/ITextTransform.cs +++ b/LLama/Abstractions/ITextTransform.cs @@ -1,4 +1,7 @@ -namespace LLama.Abstractions +using System.Text.Json.Serialization; +using LLama.Common; + +namespace LLama.Abstractions { /// /// An interface for text transformations. @@ -9,6 +12,7 @@ /// - Trimming /// - etc. /// + [JsonConverter(typeof(PolymorphicJSONConverter))] public interface ITextTransform { /// diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index b41178427..80298725b 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -8,7 +8,6 @@ using System.Threading.Tasks; using LLama.Abstractions; using LLama.Common; -using static LLama.Common.ChatHistory; using static LLama.InteractiveExecutor; using static LLama.LLamaContext; using static LLama.StatefulExecutorBase; @@ -20,9 +19,30 @@ namespace LLama; /// public class ChatSession { - private const string _modelStateFilename = "ModelState.st"; - private const string _executorStateFilename = "ExecutorState.json"; - private const string _hsitoryFilename = "ChatHistory.json"; + /// + /// The filename for the serialized model state (KV cache, etc). + /// + public const string MODEL_STATE_FILENAME = "ModelState.st"; + /// + /// The filename for the serialized executor state. + /// + public const string EXECUTOR_STATE_FILENAME = "ExecutorState.json"; + /// + /// The filename for the serialized chat history. + /// + public const string HISTORY_STATE_FILENAME = "ChatHistory.json"; + /// + /// The filename for the serialized input transform pipeline. + /// + public const string INPUT_TRANSFORM_FILENAME = "InputTransform.json"; + /// + /// The filename for the serialized output transform. + /// + public const string OUTPUT_TRANSFORM_FILENAME = "OutputTransform.json"; + /// + /// The filename for the serialized history transform. + /// + public const string HISTORY_TRANSFORM_FILENAME = "HistoryTransform.json"; /// /// The executor for this session. @@ -134,26 +154,7 @@ public ChatSession WithOutputTransform(ITextStreamTransform transform) /// public void SaveSession(string path) { - if (string.IsNullOrWhiteSpace(path)) - { - throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); - } - - if (Directory.Exists(path)) - { - Directory.Delete(path, recursive: true); - } - - Directory.CreateDirectory(path); - - string modelStateFilePath = Path.Combine(path, _modelStateFilename); - Executor.Context.SaveState(modelStateFilePath); - - string executorStateFilepath = Path.Combine(path, _executorStateFilename); - ((StatefulExecutorBase)Executor).SaveState(executorStateFilepath); - - string historyFilepath = Path.Combine(path, _hsitoryFilename); - File.WriteAllText(historyFilepath, History.ToJson()); + GetSessionState().Save(path); } /// @@ -202,26 +203,14 @@ public void LoadSession(SessionState state) /// public void LoadSession(string path) { - if (string.IsNullOrWhiteSpace(path)) + var state = SessionState.Load(path); + // Handle non-polymorphic serialization of executor state + if (state.ExecutorState is ExecutorBaseState) { - throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); + var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME); + ((StatefulExecutorBase) Executor).LoadState(filename: executorPath); } - - if (!Directory.Exists(path)) - { - throw new ArgumentException("Directory does not exist", nameof(path)); - } - - string modelStateFilePath = Path.Combine(path, _modelStateFilename); - Executor.Context.LoadState(modelStateFilePath); - - string executorStateFilepath = Path.Combine(path, _executorStateFilename); - ((StatefulExecutorBase)Executor).LoadState(executorStateFilepath); - - string historyFilepath = Path.Combine(path, _hsitoryFilename); - string historyJson = File.ReadAllText(historyFilepath); - History = ChatHistory.FromJson(historyJson) - ?? throw new ArgumentException("History file is invalid", nameof(path)); + LoadSession(state); } /// @@ -615,7 +604,7 @@ public record SessionState /// /// The the chat history messages for this session. /// - public Message[] History { get; set; } = Array.Empty(); + public ChatHistory.Message[] History { get; set; } = Array.Empty(); /// /// Create a new session state. @@ -638,4 +627,124 @@ public SessionState( OutputTransform = outputTransform.Clone(); HistoryTransform = historyTransform.Clone(); } + + /// + /// Save the session state to folder. + /// + /// + public void Save(string path) + { + if (string.IsNullOrWhiteSpace(path)) + { + throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); + } + + if (Directory.Exists(path)) + { + Directory.Delete(path, recursive: true); + } + + Directory.CreateDirectory(path); + + string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME); + var bytes = ContextState.ToByteArray(); + File.WriteAllBytes(modelStateFilePath, bytes); + + string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME); + File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState)); + + string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); + File.WriteAllText(historyFilepath, new ChatHistory(History).ToJson()); + + string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME); + File.WriteAllText(inputTransformFilepath, JsonSerializer.Serialize(InputTransformPipeline)); + + string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME); + File.WriteAllText(outputTransformFilepath, JsonSerializer.Serialize(OutputTransform)); + + string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME); + File.WriteAllText(historyTransformFilepath, JsonSerializer.Serialize(HistoryTransform)); + } + + /// + /// Load the session state from folder. + /// + /// + /// + /// Throws when session state is incorrect + public static SessionState Load(string path) + { + if (string.IsNullOrWhiteSpace(path)) + { + throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); + } + + if (!Directory.Exists(path)) + { + throw new ArgumentException("Directory does not exist", nameof(path)); + } + + string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME); + var contextState = State.FromByteArray(File.ReadAllBytes(modelStateFilePath)); + + string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME); + var executorState = JsonSerializer.Deserialize(File.ReadAllText(executorStateFilepath)) + ?? throw new ArgumentException("Executor state file is invalid", nameof(path)); + + string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); + string historyJson = File.ReadAllText(historyFilepath); + var history = ChatHistory.FromJson(historyJson) + ?? throw new ArgumentException("History file is invalid", nameof(path)); + + string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME); + ITextTransform[] inputTransforms; + try + { + inputTransforms = File.Exists(inputTransformFilepath) ? + (JsonSerializer.Deserialize(File.ReadAllText(inputTransformFilepath)) + ?? throw new ArgumentException("Input transform file is invalid", nameof(path))) + : Array.Empty(); + } + catch (JsonException) + { + throw new ArgumentException("Input transform file is invalid", nameof(path)); + } + + string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME); + + ITextStreamTransform outputTransform; + try + { + outputTransform = File.Exists(outputTransformFilepath) ? + (JsonSerializer.Deserialize(File.ReadAllText(outputTransformFilepath)) + ?? throw new ArgumentException("Output transform file is invalid", nameof(path))) + : new LLamaTransforms.EmptyTextOutputStreamTransform(); + } + catch (JsonException) + { + throw new ArgumentException("Output transform file is invalid", nameof(path)); + } + + string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME); + IHistoryTransform historyTransform; + try + { + historyTransform = File.Exists(historyTransformFilepath) ? + (JsonSerializer.Deserialize(File.ReadAllText(historyTransformFilepath)) + ?? throw new ArgumentException("History transform file is invalid", nameof(path))) + : new LLamaTransforms.DefaultHistoryTransform(); + } + catch (JsonException) + { + throw new ArgumentException("History transform file is invalid", nameof(path)); + } + + return new SessionState( + contextState, + executorState, + history, + inputTransforms.ToList(), + outputTransform, + historyTransform); + } } \ No newline at end of file diff --git a/LLama/Common/PolymorphicJSONConverter.cs b/LLama/Common/PolymorphicJSONConverter.cs new file mode 100644 index 000000000..6cec2f27c --- /dev/null +++ b/LLama/Common/PolymorphicJSONConverter.cs @@ -0,0 +1,57 @@ +using LLama.Abstractions; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace LLama.Common +{ + internal class PolymorphicJSONConverter : JsonConverter + { + public override T? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType != JsonTokenType.StartObject) + throw new JsonException(); + reader.Read(); + if (reader.TokenType != JsonTokenType.PropertyName) + throw new JsonException(); + string? propertyName = reader.GetString(); + if (propertyName != "Name") + throw new JsonException(); + reader.Read(); + if (reader.TokenType != JsonTokenType.String) + throw new JsonException(); + string? name = reader.GetString() ?? throw new JsonException(); + var inheritedTypes = Assembly.GetExecutingAssembly().GetTypes().Where( + t => typeof(T).IsAssignableFrom(t) && !t.IsAbstract && !t.IsInterface + ); + var type = inheritedTypes.FirstOrDefault(t => t.Name == name); + if (type == null) + throw new JsonException(); + reader.Read(); + if (reader.TokenType != JsonTokenType.PropertyName) + throw new JsonException(); + propertyName = reader.GetString(); + if (propertyName != "Data") + throw new JsonException(); + var data = JsonSerializer.Deserialize(ref reader, type, options); + if (data == null) + throw new JsonException(); + reader.Read(); + reader.Read(); + return (T)data; + } + + public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + writer.WriteString("Name", value.GetType().Name); + writer.WritePropertyName("Data"); + JsonSerializer.Serialize(writer, value, value.GetType(), options); + writer.WriteEndObject(); + } + } +} diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index d8b418c31..4a63be362 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -166,7 +166,7 @@ public State GetState() memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize); // Wrap memory in a "state" - var state = new State(memory); + var state = new State(memory, actualSize); // Set memory to zero, to prevent it being freed in finally block memory = IntPtr.Zero; @@ -384,9 +384,12 @@ public void Dispose() public class State : SafeLLamaHandleBase { - internal State(IntPtr memory) + private ulong _size; + + internal State(IntPtr memory, ulong size) : base(memory, true) { + _size = size; } /// @@ -395,6 +398,29 @@ protected override bool ReleaseHandle() Marshal.FreeHGlobal(handle); return true; } + + /// + /// Convert this state to a byte array + /// + /// + public byte[] ToByteArray() + { + var bytes = new byte[_size]; + Marshal.Copy(handle, bytes, 0, (int)_size); + return bytes; + } + + /// + /// Load state from a byte array + /// + /// + /// + public static State FromByteArray(byte[] bytes) + { + var memory = Marshal.AllocHGlobal(bytes.Length); + Marshal.Copy(bytes, 0, memory, bytes.Length); + return new State(memory, (ulong)bytes.Length); + } } } } diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index ea5616b56..ec72a25ad 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -370,6 +370,7 @@ protected class InferStateArgs public bool NeedToSaveSession { get; set; } } + [JsonConverter(typeof(PolymorphicJSONConverter))] public class ExecutorBaseState { [JsonPropertyName("n_past")] diff --git a/LLama/LLamaTransforms.cs b/LLama/LLamaTransforms.cs index 1ac0a79be..d74d9ddaf 100644 --- a/LLama/LLamaTransforms.cs +++ b/LLama/LLamaTransforms.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using System.Text.Json.Serialization; namespace LLama { @@ -29,6 +30,12 @@ public class DefaultHistoryTransform : IHistoryTransform private readonly string _unknownName; private readonly bool _isInstructMode; + public string UserName => _userName; + public string AssistantName => _assistantName; + public string SystemName => _systemName; + public string UnknownName => _unknownName; + public bool IsInstructMode => _isInstructMode; + /// /// /// @@ -158,6 +165,42 @@ public class KeywordTextOutputStreamTransform : ITextStreamTransform private readonly int _maxKeywordLength; private readonly bool _removeAllMatchedTokens; + /// + /// Keywords that you want to remove from the response. + /// This property is used for JSON serialization. + /// + [JsonPropertyName("keywords")] + public HashSet Keywords => _keywords; + + /// + /// Maximum length of the keywords. + /// This property is used for JSON serialization. + /// + [JsonPropertyName("maxKeywordLength")] + public int MaxKeywordLength => _maxKeywordLength; + + /// + /// If set to true, when getting a matched keyword, all the related tokens will be removed. + /// Otherwise only the part of keyword will be removed. + /// This property is used for JSON serialization. + /// + [JsonPropertyName("removeAllMatchedTokens")] + public bool RemoveAllMatchedTokens => _removeAllMatchedTokens; + + /// + /// JSON constructor. + /// + [JsonConstructor] + public KeywordTextOutputStreamTransform( + HashSet keywords, + int maxKeywordLength, + bool removeAllMatchedTokens) + { + _keywords = new(keywords); + _maxKeywordLength = maxKeywordLength; + _removeAllMatchedTokens = removeAllMatchedTokens; + } + /// /// /// From 00c873a1979189382b1589c7d1e9d1e91f0f0e79 Mon Sep 17 00:00:00 2001 From: eublefar Date: Sun, 17 Mar 2024 15:55:35 +0100 Subject: [PATCH 10/13] Avoid saving empty context state in binary format, it smh messes with the llama.cpp --- LLama/ChatSession.cs | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 80298725b..59140d54d 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -163,9 +163,11 @@ public void SaveSession(string path) /// SessionState object representing session state in-memory public SessionState GetSessionState() { + var executorState = ((StatefulExecutorBase)Executor).GetStateData(); return new SessionState( - Executor.Context.GetState(), - ((StatefulExecutorBase)Executor).GetStateData(), + executorState.PastTokensCount > 0 + ? Executor.Context.GetState() : null, + executorState, History, InputTransformPipeline, OutputTransform, @@ -188,7 +190,14 @@ public void LoadSession(SessionState state) { throw new ArgumentException("Executor must be a StatefulExecutorBase to support loading of session state", nameof(state)); } - Executor.Context.LoadState(state.ContextState); + if (state.ContextState is null) + { + Executor.Context.NativeHandle.KvCacheClear(); + } + else + { + Executor.Context.LoadState(state.ContextState); + } History = new ChatHistory(state.History); InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList(); OutputTransform = state.OutputTransform.Clone(); @@ -584,7 +593,7 @@ public record SessionState /// /// Saved context state (KV cache) for the session. /// - public State ContextState { get; set; } + public State? ContextState { get; set; } /// /// The input transform pipeline used in this session. @@ -616,7 +625,7 @@ public record SessionState /// /// public SessionState( - State contextState, ExecutorBaseState executorState, + State? contextState, ExecutorBaseState executorState, ChatHistory history, List inputTransformPipeline, ITextStreamTransform outputTransform, IHistoryTransform historyTransform) { @@ -647,8 +656,11 @@ public void Save(string path) Directory.CreateDirectory(path); string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME); - var bytes = ContextState.ToByteArray(); - File.WriteAllBytes(modelStateFilePath, bytes); + var bytes = ContextState?.ToByteArray(); + if (bytes is not null) + { + File.WriteAllBytes(modelStateFilePath, bytes); + } string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME); File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState)); @@ -685,7 +697,9 @@ public static SessionState Load(string path) } string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME); - var contextState = State.FromByteArray(File.ReadAllBytes(modelStateFilePath)); + var contextState = File.Exists(modelStateFilePath) ? + State.FromByteArray(File.ReadAllBytes(modelStateFilePath)) + : null; string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME); var executorState = JsonSerializer.Deserialize(File.ReadAllText(executorStateFilepath)) From d88f9e119980c83adce904cf1c11f286942dd7e4 Mon Sep 17 00:00:00 2001 From: eublefar Date: Sun, 17 Mar 2024 16:22:25 +0100 Subject: [PATCH 11/13] Return null executor state if it's serialized in an old way --- LLama/ChatSession.cs | 7 +++---- LLama/Common/PolymorphicJSONConverter.cs | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 59140d54d..6c9accdfb 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -214,7 +214,7 @@ public void LoadSession(string path) { var state = SessionState.Load(path); // Handle non-polymorphic serialization of executor state - if (state.ExecutorState is ExecutorBaseState) + if (state.ExecutorState is null) { var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME); ((StatefulExecutorBase) Executor).LoadState(filename: executorPath); @@ -588,7 +588,7 @@ public record SessionState /// /// Saved executor state for the session in JSON format. /// - public ExecutorBaseState ExecutorState { get; set; } + public ExecutorBaseState? ExecutorState { get; set; } /// /// Saved context state (KV cache) for the session. @@ -702,8 +702,7 @@ public static SessionState Load(string path) : null; string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME); - var executorState = JsonSerializer.Deserialize(File.ReadAllText(executorStateFilepath)) - ?? throw new ArgumentException("Executor state file is invalid", nameof(path)); + var executorState = JsonSerializer.Deserialize(File.ReadAllText(executorStateFilepath)); string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); string historyJson = File.ReadAllText(historyFilepath); diff --git a/LLama/Common/PolymorphicJSONConverter.cs b/LLama/Common/PolymorphicJSONConverter.cs index 6cec2f27c..1af4011cc 100644 --- a/LLama/Common/PolymorphicJSONConverter.cs +++ b/LLama/Common/PolymorphicJSONConverter.cs @@ -20,7 +20,7 @@ internal class PolymorphicJSONConverter : JsonConverter throw new JsonException(); string? propertyName = reader.GetString(); if (propertyName != "Name") - throw new JsonException(); + return default; reader.Read(); if (reader.TokenType != JsonTokenType.String) throw new JsonException(); From 9440f153da658dea9bd27937b8befe70fc5bea03 Mon Sep 17 00:00:00 2001 From: eublefar Date: Thu, 21 Mar 2024 12:14:15 +0100 Subject: [PATCH 12/13] Make process message method more flexible --- .../Examples/ChatSessionWithHistory.cs | 6 ++ .../Examples/ChatSessionWithRestart.cs | 36 ++++++---- LLama/ChatSession.cs | 70 ++++++++++++------- 3 files changed, 73 insertions(+), 39 deletions(-) diff --git a/LLama.Examples/Examples/ChatSessionWithHistory.cs b/LLama.Examples/Examples/ChatSessionWithHistory.cs index 6a84d2fd7..31b6a7718 100644 --- a/LLama.Examples/Examples/ChatSessionWithHistory.cs +++ b/LLama.Examples/Examples/ChatSessionWithHistory.cs @@ -48,6 +48,10 @@ public static async Task Run() Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The chat session has started."); + Console.WriteLine("Type 'exit' to end the chat session."); + Console.WriteLine("Type 'save' to save the chat session to disk."); + Console.WriteLine("Type 'load' to load the chat session from disk."); + Console.WriteLine("Type 'regenerate' to regenerate the last response."); // show the prompt Console.ForegroundColor = ConsoleColor.Green; @@ -55,12 +59,14 @@ public static async Task Run() while (userInput != "exit") { + // Save the chat state to disk if (userInput == "save") { session.SaveSession("Assets/chat-with-bob"); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Session saved."); } + // Load the chat state from disk else if (userInput == "load") { session.LoadSession("Assets/chat-with-bob"); diff --git a/LLama.Examples/Examples/ChatSessionWithRestart.cs b/LLama.Examples/Examples/ChatSessionWithRestart.cs index 234bac3c6..923f78f67 100644 --- a/LLama.Examples/Examples/ChatSessionWithRestart.cs +++ b/LLama.Examples/Examples/ChatSessionWithRestart.cs @@ -37,8 +37,11 @@ public static async Task Run() }; Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("The chat session has started. Write `save` to save session in memory." - + " Write `reset` to start from the last saved checkpoint"); + Console.WriteLine("The chat session has started. Starting point saved."); + Console.WriteLine("Type 'exit' to end the chat session."); + Console.WriteLine("Type 'save' to save chat session state in memory."); + Console.WriteLine("Type 'reset' to reset the chat session to its saved state."); + Console.WriteLine("Type 'answer for assistant' to add and process provided user and assistant messages."); // show the prompt Console.ForegroundColor = ConsoleColor.Green; @@ -46,6 +49,7 @@ public static async Task Run() while (userInput != "exit") { + // Load the session state from the reset state if(userInput == "reset") { session.LoadSession(resetState); @@ -53,25 +57,33 @@ public static async Task Run() Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Session reset."); } + // Assign new reset state. else if (userInput == "save") { resetState = session.GetSessionState(); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Session saved."); } - else if (userInput == "regenerate") + // Provide user and override assistant answer with your own. + else if (userInput == "answer for assistant") { Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("Regenerating last response ..."); + Console.WriteLine("Provide user input: "); - await foreach ( - var text - in session.RegenerateAssistantMessageAsync( - inferenceParams)) - { - Console.ForegroundColor = ConsoleColor.White; - Console.Write(text); - } + Console.ForegroundColor = ConsoleColor.Green; + string userInputOverride = Console.ReadLine() ?? ""; + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Provide assistant input: "); + + Console.ForegroundColor = ConsoleColor.Green; + string assistantInputOverride = Console.ReadLine() ?? ""; + + await session.AddAndProcessUserMessage(userInputOverride); + await session.AddAndProcessAssistantMessage(assistantInputOverride); + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("User and assistant messages processed. Provide next user message:"); } else { diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 6c9accdfb..9620dc4f2 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -262,33 +262,6 @@ public ChatSession AddMessage(ChatHistory.Message message) return this; } - - /// - /// Compute KV cache for the system message and add it to the chat history. - /// - /// - /// - public async Task ProcessSystemMessage(string content) - { - if (Executor is not StatefulExecutorBase statefulExecutor) - { - throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages."); - } - if (History.Messages.Count > 0) - { - throw new ArgumentException("Cannot add a system message after another message", nameof(content)); - } - foreach (var inputTransform in InputTransformPipeline) - { - content = inputTransform.Transform(content); - } - - await statefulExecutor.PrefillPromptAsync(content); - - History.AddMessage(AuthorRole.System, content); - return this; - } - /// /// Add a system message to the chat history. /// @@ -323,6 +296,49 @@ public ChatSession RemoveLastMessage() return this; } + /// + /// Compute KV cache for the message and add it to the chat history. + /// + /// + /// + public async Task AddAndProcessMessage(ChatHistory.Message message) + { + if (Executor is not StatefulExecutorBase statefulExecutor) + { + throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages."); + } + AddMessage(message); + var content = message.Content; + if (message.AuthorRole != AuthorRole.Assistant) + { + foreach (var inputTransform in InputTransformPipeline) + { + content = inputTransform.Transform(content); + } + } + + await statefulExecutor.PrefillPromptAsync(content); + return this; + } + + /// + /// Compute KV cache for the system message and add it to the chat history. + /// + public Task AddAndProcessSystemMessage(string content) + => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content)); + + /// + /// Compute KV cache for the user message and add it to the chat history. + /// + public Task AddAndProcessUserMessage(string content) + => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content)); + + /// + /// Compute KV cache for the assistant message and add it to the chat history. + /// + public Task AddAndProcessAssistantMessage(string content) + => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content)); + /// /// Replace a user message with a new message and remove all messages after the new message. /// This is useful when the user wants to edit a message. And regenerate the response. From b8cd5b7ee565b17feb81bbb3330866ad9df1616c Mon Sep 17 00:00:00 2001 From: eublefar Date: Thu, 21 Mar 2024 12:18:38 +0100 Subject: [PATCH 13/13] loadTransforms flag for LoadSession methods --- LLama/ChatSession.cs | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 9620dc4f2..0a5accc5e 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -178,17 +178,17 @@ public SessionState GetSessionState() /// Load a session from a session state. /// /// + /// If true loads transforms saved in the session state. /// /// - public void LoadSession(SessionState state) + public void LoadSession(SessionState state, bool loadTransforms = true) { if (Executor is StatefulExecutorBase statefulExecutor) { - statefulExecutor.LoadState(state.ExecutorState); - } - else - { - throw new ArgumentException("Executor must be a StatefulExecutorBase to support loading of session state", nameof(state)); + if (state.ExecutorState is not null) + { + statefulExecutor.LoadState(state.ExecutorState); + } } if (state.ContextState is null) { @@ -199,18 +199,22 @@ public void LoadSession(SessionState state) Executor.Context.LoadState(state.ContextState); } History = new ChatHistory(state.History); - InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList(); - OutputTransform = state.OutputTransform.Clone(); - HistoryTransform = state.HistoryTransform.Clone(); + if (loadTransforms) + { + InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList(); + OutputTransform = state.OutputTransform.Clone(); + HistoryTransform = state.HistoryTransform.Clone(); + } } /// /// Load a session from a directory. /// /// + /// If true loads transforms saved in the session state. /// /// - public void LoadSession(string path) + public void LoadSession(string path, bool loadTransforms = true) { var state = SessionState.Load(path); // Handle non-polymorphic serialization of executor state @@ -219,7 +223,7 @@ public void LoadSession(string path) var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME); ((StatefulExecutorBase) Executor).LoadState(filename: executorPath); } - LoadSession(state); + LoadSession(state, loadTransforms); } ///