Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llava Initial approach to clear images #664

Merged
merged 5 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions LLama.Examples/Examples/LlavaInteractiveModeExecute.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
using System.Text.RegularExpressions;
using LLama.Batched;
using LLama.Common;
using Spectre.Console;
using LLama.Abstractions;
using LLama.Native;

namespace LLama.Examples.Examples
{
Expand All @@ -19,12 +18,8 @@ public static async Task Run()

var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n";

var parameters = new ModelParams(modelPath)
{
ContextSize = 4096,
Seed = 1337,
GpuLayerCount = 10
};
var parameters = new ModelParams(modelPath);

using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);

Expand All @@ -47,16 +42,16 @@ public static async Task Run()
var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
var imageCount = imageMatches.Count();
var hasImages = imageCount > 0;
byte[][] imageBytes = null;

if (hasImages)
{
var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value);
var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value).ToList();

List<byte[]> imageBytes;
try
{
imageBytes = imagePaths.Select(File.ReadAllBytes).ToArray();
imageBytes = imagePaths.Select(File.ReadAllBytes).ToList();
}
catch (IOException exception)
{
Expand All @@ -69,15 +64,17 @@ public static async Task Run()
break;
}

// Each prompt with images we clear cache
// When the prompt contains images we clear KV_CACHE to restart conversation
// See:
// https://github.com/ggerganov/llama.cpp/discussions/3620
ex.Context.NativeHandle.KvCacheRemove( LLamaSeqId.Zero, -1, -1 );

int index = 0;
foreach (var path in imagePathsWithCurlyBraces)
{
// First image replace to tag <image, the rest of the images delete the tag
if (index++ == 0)
prompt = prompt.Replace(path, "<image>");
else
prompt = prompt.Replace(path, "");
prompt = prompt.Replace(path, index++ == 0 ? "<image>" : "");
}


Expand All @@ -102,7 +99,7 @@ public static async Task Run()
//
foreach (var image in imagePaths)
{
ex.Images.Add(File.ReadAllBytes(image));
ex.Images.Add(await File.ReadAllBytesAsync(image));
}
}

Expand All @@ -118,7 +115,7 @@ public static async Task Run()

// let the user finish with exit
//
if (prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
if (prompt != null && prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
break;

}
Expand Down
2 changes: 1 addition & 1 deletion LLama/Abstractions/ILLamaExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public interface ILLamaExecutor
public LLavaWeights? ClipModel { get; }

/// <summary>
/// List of images: Image filen path, uri or image byte array. See ImageData.
/// List of images: List of images in byte array format.
/// </summary>
public List<byte[]> Images { get; }

Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
public LLavaWeights? ClipModel { get; }

/// <inheritdoc />
public List<byte[]> Images { get; set; }
public List<byte[]> Images { get; }

/// <summary>
/// Current "mu" value for mirostat sampling
Expand Down Expand Up @@ -419,10 +419,10 @@
public string? SessionFilePath { get; set; }

[JsonPropertyName("embd")]
public LLamaToken[] Embeds { get; set; }

Check warning on line 422 in LLama/LLamaExecutorBase.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Non-nullable property 'Embeds' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

[JsonPropertyName("embd_inps")]
public LLamaToken[] EmbedInps { get; set; }

Check warning on line 425 in LLama/LLamaExecutorBase.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Non-nullable property 'EmbedInps' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

[JsonPropertyName("session_tokens")]
public LLamaToken[] SessionTokens { get; set; }
Expand Down
31 changes: 25 additions & 6 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
using LLama.Exceptions;
using LLama.Extensions;
using Microsoft.Extensions.Logging;
using System.Net.Http;


namespace LLama
{
Expand Down Expand Up @@ -101,7 +101,7 @@
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync<InteractiveExecutorState>(fs);
await LoadState(state);

Check warning on line 104 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'data' in 'Task InteractiveExecutor.LoadState(ExecutorBaseState data)'.

Check warning on line 104 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Possible null reference argument for parameter 'data' in 'Task InteractiveExecutor.LoadState(ExecutorBaseState data)'.
}
}

Expand Down Expand Up @@ -136,24 +136,33 @@
text += "\n";
}

var line_inp = Context.Tokenize(text, false);
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Length;
if (!this.IsMultiModal)
{
var line_inp = Context.Tokenize(text, false);
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Length;
}
else
{
PreprocessLlava(text, args, false);
}
}

return Task.CompletedTask;
}

/// <inheritdoc />
private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true )
{
int usedTokens = 0;

// If the prompt contains the tag <image> extract this.
_imageInPrompt = text.Contains("<image>");
if (_imageInPrompt && ClipModel != null)
if (_imageInPrompt && IsMultiModal )
{
foreach (var image in Images)
{
_imageEmbedHandles.Add(SafeLlavaImageEmbedHandle.CreateFromMemory(ClipModel.NativeHandle, Context, image));

Check warning on line 165 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Dereference of a possibly null reference.

Check warning on line 165 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Dereference of a possibly null reference.
}

int imageIndex = text.IndexOf("<image>");
Expand All @@ -170,7 +179,16 @@
}
else
{
_embed_inps = Context.Tokenize(text, true).ToList();
if (addBos)
{
_embed_inps = Context.Tokenize(text, true).ToList();
}
else
{
var line_inp = Context.Tokenize(text, false);
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Length;
}
}
return Task.CompletedTask;
}
Expand All @@ -181,11 +199,11 @@
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
/// <returns></returns>
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)

Check warning on line 202 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))

Check warning on line 206 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'
args.WaitForInput = true;

if (_pastTokensCount > 0 && args.WaitForInput)
Expand Down Expand Up @@ -239,6 +257,7 @@

_EmbedImagePosition = -1;
_imageEmbedHandles.Clear();
Images.Clear();
}
else
{
Expand Down
32 changes: 17 additions & 15 deletions docs/Examples/LLavaInteractiveModeExecute.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

```cs
using System.Text.RegularExpressions;
using LLama.Batched;
using LLama.Common;
using Spectre.Console;
using LLama.Native;

namespace LLama.Examples.Examples
{
Expand All @@ -21,11 +21,8 @@ namespace LLama.Examples.Examples

var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n";

var parameters = new ModelParams(modelPath)
{
ContextSize = 4096,
Seed = 1337,
};
var parameters = new ModelParams(modelPath);

using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);

Expand All @@ -48,16 +45,16 @@ namespace LLama.Examples.Examples
var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
var imageCount = imageMatches.Count();
var hasImages = imageCount > 0;
byte[][] imageBytes = null;

if (hasImages)
{
var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value);
var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value).ToList();

List<byte[]> imageBytes;
try
{
imageBytes = imagePaths.Select(File.ReadAllBytes).ToArray();
imageBytes = imagePaths.Select(File.ReadAllBytes).ToList();
}
catch (IOException exception)
{
Expand All @@ -70,15 +67,17 @@ namespace LLama.Examples.Examples
break;
}

// Each prompt with images we clear cache
// When the prompt contains images we clear KV_CACHE to restart conversation
// See:
// https://github.com/ggerganov/llama.cpp/discussions/3620
ex.Context.NativeHandle.KvCacheRemove( LLamaSeqId.Zero, -1, -1 );

int index = 0;
foreach (var path in imagePathsWithCurlyBraces)
{
// First image replace to tag <image, the rest of the images delete the tag
if (index++ == 0)
prompt = prompt.Replace(path, "<image>");
else
prompt = prompt.Replace(path, "");
prompt = prompt.Replace(path, index++ == 0 ? "<image>" : "");
}


Expand All @@ -101,7 +100,10 @@ namespace LLama.Examples.Examples

// Initilize Images in executor
//
ex.ImagePaths = imagePaths.ToList();
foreach (var image in imagePaths)
{
ex.Images.Add(await File.ReadAllBytesAsync(image));
}
}

Console.ForegroundColor = Color.White;
Expand All @@ -116,7 +118,7 @@ namespace LLama.Examples.Examples

// let the user finish with exit
//
if (prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
if (prompt != null && prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
break;

}
Expand Down
4 changes: 2 additions & 2 deletions docs/Tutorials/Executors.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ public interface ILLamaExecutor
public LLavaWeights? ClipModel { get; }

/// <summary>
/// List of images: Image filename and path (jpeg images).
/// List of images: List of images in byte array format.
/// </summary>
public List<string> ImagePaths { get; set; }
public List<byte[]> Images { get; }


/// <summary>
Expand Down
Loading