Skip to content

Commit

Permalink
Change interface to support multiple images and add the capabitlity t…
Browse files Browse the repository at this point in the history
…o render the image in the console
  • Loading branch information
SignalRT committed Mar 26, 2024
1 parent 2d9a114 commit 43677c5
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 49 deletions.
96 changes: 83 additions & 13 deletions LLama.Examples/Examples/LlavaInteractiveModeExecute.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
using LLama.Common;
using System.Text.RegularExpressions;
using LLama.Batched;
using LLama.Common;
using Spectre.Console;

namespace LLama.Examples.Examples
{
Expand All @@ -8,15 +11,15 @@ public static async Task Run()
{
string multiModalProj = UserSettings.GetMMProjPath();
string modelPath = UserSettings.GetModelPath();
string imagePath = UserSettings.GetImagePath();
string modelImage = UserSettings.GetImagePath();
const int maxTokens = 1024;

var prompt = (await File.ReadAllTextAsync("Assets/vicuna-llava-v16.txt")).Trim();
var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n";

var parameters = new ModelParams(modelPath)
{
ContextSize = 4096,
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
Expand All @@ -26,26 +29,93 @@ public static async Task Run()

var ex = new InteractiveExecutor(context, clipModel );

ex.ImagePath = imagePath;

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 1024 and the context size is 4096. ");
Console.ForegroundColor = ConsoleColor.White;

Console.Write(prompt);
Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to {0} and the context size is {1}.", maxTokens, parameters.ContextSize );
Console.WriteLine("To send an image, enter its filename in curly braces, like this {c:/image.jpg}.");

var inferenceParams = new InferenceParams() { Temperature = 0.1f, AntiPrompts = new List<string> { "USER:" }, MaxTokens = 1024 };
var inferenceParams = new InferenceParams() { Temperature = 0.1f, AntiPrompts = new List<string> { "\nUSER:" }, MaxTokens = maxTokens };

while (true)
do
{

// Evaluate if we have images
//
var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
var imageCount = imageMatches.Count();
var hasImages = imageCount > 0;
byte[][] imageBytes = null;

if (hasImages)
{
var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value);

try
{
imageBytes = imagePaths.Select(File.ReadAllBytes).ToArray();
}
catch (IOException exception)
{
Console.ForegroundColor = ConsoleColor.Red;
Console.Write(
$"Could not load your {(imageCount == 1 ? "image" : "images")}:");
Console.Write($"{exception.Message}");
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Please try again.");
break;
}


int index = 0;
foreach (var path in imagePathsWithCurlyBraces)
{
// First image replace to tag <image, the rest of the images delete the tag
if (index++ == 0)
prompt = prompt.Replace(path, "<image>");
else
prompt = prompt.Replace(path, "");
}


Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine($"Here are the images, that are sent to the chat model in addition to your message.");
Console.WriteLine();

foreach (var consoleImage in imageBytes?.Select(bytes => new CanvasImage(bytes)))
{
consoleImage.MaxWidth = 50;
AnsiConsole.Write(consoleImage);
}

Console.WriteLine();
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine($"The images were scaled down for the console only, the model gets full versions.");
Console.WriteLine($"Write /exit or press Ctrl+c to return to main menu.");
Console.WriteLine();


// Initilize Images in executor
//
ex.ImagePaths = imagePaths.ToList();
}

Console.ForegroundColor = Color.White;
await foreach (var text in ex.InferAsync(prompt, inferenceParams))
{
Console.Write(text);
}
Console.Write(" ");
Console.ForegroundColor = ConsoleColor.Green;
prompt = Console.ReadLine();
Console.ForegroundColor = ConsoleColor.White;
Console.WriteLine();

// let the user finish with exit
//
if (prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
break;

}
while(true);
}
}
}
1 change: 1 addition & 0 deletions LLama.Examples/LLama.Examples.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
<PackageReference Include="Microsoft.SemanticKernel" Version="1.6.2" />
<PackageReference Include="Microsoft.SemanticKernel.Plugins.Memory" Version="1.6.2-alpha" />
<PackageReference Include="Spectre.Console" Version="0.48.0" />
<PackageReference Include="Spectre.Console.ImageSharp" Version="0.48.0" />
</ItemGroup>

<ItemGroup>
Expand Down
4 changes: 2 additions & 2 deletions LLama/Abstractions/ILLamaExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ public interface ILLamaExecutor
public LLavaWeights? ClipModel { get; }

/// <summary>
/// Image filename and path (jpeg images).
/// List of images: Image filename and path (jpeg images).
/// </summary>
public string? ImagePath { get; set; }
public List<string> ImagePaths { get; set; }


/// <summary>
Expand Down
5 changes: 3 additions & 2 deletions LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ public bool IsMultiModal
{
get
{
return ClipModel != null && ImagePath != null;
return ClipModel != null;
}
}

/// <inheritdoc />
public LLavaWeights? ClipModel { get; }

/// <inheritdoc />
public string? ImagePath { get; set; }
public List<string> ImagePaths { get; set; }

/// <summary>
/// Current "mu" value for mirostat sampling
Expand All @@ -95,6 +95,7 @@ public bool IsMultiModal
/// <param name="logger"></param>
protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null)
{
ImagePaths = new List<string>();
_logger = logger;
Context = context;
_pastTokensCount = 0;
Expand Down
70 changes: 40 additions & 30 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public class InteractiveExecutor : StatefulExecutorBase

// LLava
private int _EmbedImagePosition = -1;
private SafeLlavaImageEmbedHandle _imageEmbedHandle = null;
private List<SafeLlavaImageEmbedHandle> _imageEmbedHandles = new List<SafeLlavaImageEmbedHandle>();
private bool _imageInPrompt = false;

/// <summary>
Expand Down Expand Up @@ -125,30 +125,7 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
}
else
{
// If the prompt contains the tag <image> extract this.
_imageInPrompt = text.Contains("<image>");
if (_imageInPrompt)
{
if (!string.IsNullOrEmpty(ImagePath))
{
_imageEmbedHandle = SafeLlavaImageEmbedHandle.CreateFromFileName( ClipModel.NativeHandle, Context, ImagePath);
}

int imageIndex = text.IndexOf("<image>");
// Tokenize segment 1 (before <image> tag)
string preImagePrompt = text.Substring(0, imageIndex);
var segment1 = Context.Tokenize(preImagePrompt, true);
// Remember the position to add the image embeddings
_EmbedImagePosition = segment1.Length;
string postImagePrompt = text.Substring(imageIndex + 7);
var segment2 = Context.Tokenize(postImagePrompt, false);
_embed_inps.AddRange(segment1);
_embed_inps.AddRange(segment2);
}
else
{
_embed_inps = Context.Tokenize(text, true).ToList();
}
PreprocessLlava(text, args, true );
}
}
else
Expand All @@ -157,6 +134,7 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
{
text += "\n";
}

var line_inp = Context.Tokenize(text, false);
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Length;
Expand All @@ -165,6 +143,37 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
return Task.CompletedTask;
}

private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true )
{
int usedTokens = 0;
// If the prompt contains the tag <image> extract this.
_imageInPrompt = text.Contains("<image>");
if (_imageInPrompt)
{
foreach (var image in ImagePaths)
{
_imageEmbedHandles.Add(SafeLlavaImageEmbedHandle.CreateFromFileName( ClipModel.NativeHandle, Context, image ) );
}

int imageIndex = text.IndexOf("<image>");
// Tokenize segment 1 (before <image> tag)
string preImagePrompt = text.Substring(0, imageIndex);
var segment1 = Context.Tokenize(preImagePrompt, addBos );
// Remember the position to add the image embeddings
_EmbedImagePosition = segment1.Length;
string postImagePrompt = text.Substring(imageIndex + 7);
var segment2 = Context.Tokenize(postImagePrompt, false);
_embed_inps.AddRange(segment1);
_embed_inps.AddRange(segment2);
usedTokens += (segment1.Length + segment2.Length);
}
else
{
_embed_inps = Context.Tokenize(text, true).ToList();
}
return Task.CompletedTask;
}

/// <summary>
/// Return whether to break the generation.
/// </summary>
Expand Down Expand Up @@ -216,18 +225,19 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
(DecodeResult, int) header, end, result;
if (IsMultiModal && _EmbedImagePosition > 0)
{
// Previous to Image
// Tokens previous to the images
header = Context.NativeHandle.Decode(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount);
if (header.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(header.Item1);

// Image
ClipModel.EvalImageEmbed(Context, _imageEmbedHandle, ref _pastTokensCount);
// Images
foreach( var image in _imageEmbedHandles )
ClipModel.EvalImageEmbed(Context, image, ref _pastTokensCount);

// Post-image
// Post-image Tokens
end = Context.NativeHandle.Decode(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount);

_EmbedImagePosition = -1;

_imageEmbedHandles.Clear();
}
else
{
Expand Down
5 changes: 3 additions & 2 deletions LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ public class StatelessExecutor
// LLava Section
public bool IsMultiModal => false;
public bool MultiModalProject { get; }
public LLavaWeights ClipModel { get; }
public string ImagePath { get; set; }
public LLavaWeights? ClipModel { get; }
public List<string> ImagePaths { get; set; }

/// <summary>
/// The context used by the executor when running the inference.
Expand All @@ -43,6 +43,7 @@ public class StatelessExecutor
/// <param name="logger"></param>
public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
ImagePaths = new List<string>();
_weights = weights;
_params = @params;
_logger = logger;
Expand Down

0 comments on commit 43677c5

Please sign in to comment.