Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent duplication of user prompts / chat history in ChatSession. #266

Merged
merged 3 commits into from
Nov 9, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions LLama/ChatSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ public virtual void SaveSession(string path)
Directory.CreateDirectory(path);
}
_executor.Context.SaveState(Path.Combine(path, _modelStateFilename));
if(Executor is StatelessExecutor)
if (Executor is StatelessExecutor)
{

}
else if(Executor is StatefulExecutorBase statefulExecutor)
else if (Executor is StatefulExecutorBase statefulExecutor)
{
statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename));
}
Expand Down Expand Up @@ -135,30 +135,54 @@ public virtual void LoadSession(string path)
}

/// <summary>
/// Get the response from the LLama model. Note that prompt could not only be the preset words,
/// but also the question you want to ask.
/// Generates a response for a given user prompt and manages history state for the user.
/// This will always pass the whole history to the model. Don't pass a whole history
/// to this method as the user prompt will be appended to the history of the current session.
/// If more control is needed, use the other overload of this method that accepts a ChatHistory object.
/// </summary>
/// <param name="prompt"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <returns>Returns generated tokens of the assistant message.</returns>
martindevans marked this conversation as resolved.
Show resolved Hide resolved
public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach(var inputTransform in InputTransformPipeline)
foreach (var inputTransform in InputTransformPipeline)
prompt = inputTransform.Transform(prompt);

History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);

History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt));

string internalPrompt = HistoryTransform.HistoryToText(History);
martindevans marked this conversation as resolved.
Show resolved Hide resolved

StringBuilder sb = new();
await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))

await foreach (var result in ChatAsyncInternal(internalPrompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);

string assistantMessage = sb.ToString();

// Remove end tokens from the assistant message
// if defined in inferenceParams.AntiPrompts.
// We only want the response that was generated and not tokens
// that are delimiting the beginning or end of the response.
if (inferenceParams?.AntiPrompts != null)
{
foreach (var stopToken in inferenceParams.AntiPrompts)
{
assistantMessage = assistantMessage.Replace(stopToken, "");
}
}

History.Messages.Add(new ChatHistory.Message(AuthorRole.Assistant, assistantMessage));
}

/// <summary>
/// Get the response from the LLama model with chat histories.
/// Generates a response for a given chat history. This method does not manage history state for the user.
/// If you want to e.g. truncate the history of a session to fit into the model's context window,
/// use this method and pass the truncated history to it. If you don't need this control, use the other
/// overload of this method that accepts a user prompt instead.
/// </summary>
/// <param name="history"></param>
/// <param name="inferenceParams"></param>
Expand All @@ -167,14 +191,14 @@ public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams?
public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var prompt = HistoryTransform.HistoryToText(history);
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);

StringBuilder sb = new();

await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
}

private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
Expand Down
Loading