diff --git a/LLama/LLavaWeights.cs b/LLama/LLavaWeights.cs
index 9594dcdbb..cb9692ead 100644
--- a/LLama/LLavaWeights.cs
+++ b/LLama/LLavaWeights.cs
@@ -21,7 +21,8 @@ private LLavaWeights(SafeLlavaModelHandle weights)
{
NativeHandle = weights;
}
-
+
+ #region load
///
/// Load weights into memory
///
@@ -43,7 +44,9 @@ public static Task LoadFromFileAsync(string mmProject, Cancellatio
{
return Task.Run(() => LoadFromFile(mmProject), token);
}
+ #endregion
+ #region embed
///
/// Create the Image Embeddings from the bytes of an image.
///
@@ -57,9 +60,20 @@ public static Task LoadFromFileAsync(string mmProject, Cancellatio
///
///
///
- public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image )
+ public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image)
+ {
+ return NativeHandle.CreateImageEmbeddings(ctxLlama, image);
+ }
+
+ ///
+ /// Create the Image Embeddings.
+ ///
+ /// Image in binary format (it supports jpeg format only)
+ /// Number of threads to use
+ /// return the SafeHandle of these embeddings
+ public SafeLlavaImageEmbedHandle CreateImageEmbeddings(byte[] image, int threads = -1)
{
- return NativeHandle.CreateImageEmbeddings(ctxLlama, image );
+ return NativeHandle.CreateImageEmbeddings(image, threads);
}
///
@@ -76,10 +90,30 @@ public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, by
///
///
///
- public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, string image )
+ public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, string image)
+ {
+ return NativeHandle.CreateImageEmbeddings(ctxLlama, image);
+ }
+
+ ///
+ /// Create the Image Embeddings from the bytes of an image.
+ ///
+ /// Path to the image file. Supported formats:
+ ///
+ /// - JPG
+ /// - PNG
+ /// - BMP
+ /// - TGA
+ ///
+ ///
+ ///
+ ///
+ ///
+ public SafeLlavaImageEmbedHandle CreateImageEmbeddings(string image, int threads = -1)
{
- return NativeHandle.CreateImageEmbeddings(ctxLlama, image );
+ return NativeHandle.CreateImageEmbeddings(image, threads);
}
+ #endregion
///
/// Eval the image embeddings
diff --git a/LLama/Native/LLavaImageEmbed.cs b/LLama/Native/LLavaImageEmbed.cs
index 2030515ec..7704b73de 100644
--- a/LLama/Native/LLavaImageEmbed.cs
+++ b/LLama/Native/LLavaImageEmbed.cs
@@ -5,8 +5,9 @@ namespace LLama.Native;
///
/// LLaVa Image embeddings
///
+/// llava_image_embed
[StructLayout(LayoutKind.Sequential)]
-unsafe public struct LLavaImageEmbed
+public unsafe struct LLavaImageEmbed
{
public float* embed;
public int n_image_pos;
diff --git a/LLama/Native/SafeLlavaImageEmbedHandle.cs b/LLama/Native/SafeLlavaImageEmbedHandle.cs
index aa6da9e0e..77b4eaf66 100644
--- a/LLama/Native/SafeLlavaImageEmbedHandle.cs
+++ b/LLama/Native/SafeLlavaImageEmbedHandle.cs
@@ -1,4 +1,4 @@
-using System;
+using System;
using System.IO;
@@ -10,11 +10,39 @@ namespace LLama.Native
public sealed class SafeLlavaImageEmbedHandle
: SafeLLamaHandleBase
{
+ ///
+ /// Get the model used to create this image embedding
+ ///
+ public SafeLlavaModelHandle Model { get; private set; } = null!;
+
+ #region embed
+ ///
+ /// Create an image embed from an image file
+ ///
+ ///
+ ///
+ /// Path to the image file. Supported formats:
+ ///
+ /// - JPG
+ /// - PNG
+ /// - BMP
+ /// - TGA
+ ///
+ ///
+ ///
+ ///
+ public static SafeLlavaImageEmbedHandle CreateFromFileName(SafeLlavaModelHandle clip, LLamaContext ctx, string image)
+ {
+ if (!NativeApi.llava_validate_embed_size(ctx.NativeHandle, clip))
+ throw new InvalidOperationException($"Cannot create image embed. Embedding dim of the multimodal projector ({clip.EmbeddingDimensions}) is not equal to embedding dim of model ({ctx.EmbeddingSize})");
+
+ return CreateFromFileName(clip, image, (int)ctx.BatchThreads);
+ }
+
///
/// Create an image embed from an image file
///
- ///
- ///
+ ///
/// Path to the image file. Supported formats:
///
/// - JPG
@@ -23,10 +51,14 @@ public sealed class SafeLlavaImageEmbedHandle
/// - TGA
///
///
+ ///
///
///
- public static SafeLlavaImageEmbedHandle CreateFromFileName( SafeLlavaModelHandle ctxLlava, LLamaContext ctxLlama, string image )
+ public static SafeLlavaImageEmbedHandle CreateFromFileName(SafeLlavaModelHandle clip, string image, int threads = -1)
{
+ if (threads <= 0)
+ threads = Environment.ProcessorCount / 2;
+
// Try to open the image file, this will check:
// - File exists (automatically throws FileNotFoundException)
// - File is readable (explicit check)
@@ -34,14 +66,17 @@ public static SafeLlavaImageEmbedHandle CreateFromFileName( SafeLlavaModelHandle
using (var fs = new FileStream(image, FileMode.Open))
if (!fs.CanRead)
throw new InvalidOperationException($"Llava image file '{image}' is not readable");
- return NativeApi.llava_image_embed_make_with_filename(ctxLlava, (int) ctxLlama.BatchThreads, image);
+
+ var embed = NativeApi.llava_image_embed_make_with_filename(clip, threads, image);
+ embed.Model = clip;
+ return embed;
}
-
+
///
/// Create an image embed from the bytes of an image.
///
- ///
- ///
+ ///
+ ///
/// Image bytes. Supported formats:
///
/// - JPG
@@ -51,11 +86,39 @@ public static SafeLlavaImageEmbedHandle CreateFromFileName( SafeLlavaModelHandle
///
///
///
- public static SafeLlavaImageEmbedHandle CreateFromMemory( SafeLlavaModelHandle ctxLlava, LLamaContext ctxLlama, byte[] image )
+ public static SafeLlavaImageEmbedHandle CreateFromMemory(SafeLlavaModelHandle clip, LLamaContext ctx, byte[] image)
{
- return NativeApi.llava_image_embed_make_with_bytes(ctxLlava, (int) ctxLlama.BatchThreads, image, image.Length);
+ if (!NativeApi.llava_validate_embed_size(ctx.NativeHandle, clip))
+ throw new InvalidOperationException($"Cannot create image embed. Embedding dim of the multimodal projector ({clip.EmbeddingDimensions}) is not equal to embedding dim of model ({ctx.EmbeddingSize})");
+
+ return CreateFromMemory(clip, image, (int)ctx.BatchThreads);
}
+ ///
+ /// Create an image embed from the bytes of an image.
+ ///
+ ///
+ /// Image bytes. Supported formats:
+ ///
+ /// - JPG
+ /// - PNG
+ /// - BMP
+ /// - TGA
+ ///
+ ///
+ ///
+ ///
+ public static SafeLlavaImageEmbedHandle CreateFromMemory(SafeLlavaModelHandle clip, byte[] image, int threads = -1)
+ {
+ if (threads <= 0)
+ threads = Environment.ProcessorCount / 2;
+
+ var embed = NativeApi.llava_image_embed_make_with_bytes(clip, threads, image, image.Length);
+ embed.Model = clip;
+ return embed;
+ }
+ #endregion
+
///
protected override bool ReleaseHandle()
{
@@ -63,5 +126,27 @@ protected override bool ReleaseHandle()
SetHandle(IntPtr.Zero);
return true;
}
+
+ ///
+ /// Copy the embeddings data to the destination span
+ ///
+ ///
+ ///
+ public void GetEmbedding(Span dest, int index)
+ {
+ if (index < 0)
+ throw new ArgumentOutOfRangeException(nameof(index), "index must be >= 0");
+ if (index >= Model.PatchCount)
+ throw new ArgumentOutOfRangeException(nameof(index), "index must be < Model.PatchCount");
+
+ unsafe
+ {
+ var embed = (LLavaImageEmbed*)DangerousGetHandle();
+ new Span(
+ embed->embed + Model.EmbeddingDimensions * index,
+ Model.EmbeddingDimensions
+ ).CopyTo(dest);
+ }
+ }
}
}
diff --git a/LLama/Native/SafeLlavaModelHandle.cs b/LLama/Native/SafeLlavaModelHandle.cs
index fd898b536..9bc1ec8d2 100644
--- a/LLama/Native/SafeLlavaModelHandle.cs
+++ b/LLama/Native/SafeLlavaModelHandle.cs
@@ -1,4 +1,4 @@
-using System;
+using System;
using System.IO;
using System.Runtime.InteropServices;
using LLama.Exceptions;
@@ -12,6 +12,16 @@ namespace LLama.Native
public sealed class SafeLlavaModelHandle
: SafeLLamaHandleBase
{
+ ///
+ /// Get the number of dimensions in an embedding
+ ///
+ public int EmbeddingDimensions => clip_n_mmproj_embd(this);
+
+ ///
+ /// Get the number of "patches" in an image embedding
+ ///
+ public int PatchCount => clip_n_patches(this);
+
///
protected override bool ReleaseHandle()
{
@@ -30,7 +40,6 @@ protected override bool ReleaseHandle()
///
public static SafeLlavaModelHandle LoadFromFile(string modelPath, int verbosity )
{
-
// Try to open the model file, this will check:
// - File exists (automatically throws FileNotFoundException)
// - File is readable (explicit check)
@@ -57,16 +66,38 @@ public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, st
return SafeLlavaImageEmbedHandle.CreateFromFileName(this, ctxLlama, image);
}
+ ///
+ /// Create the Image Embeddings.
+ ///
+ /// Image in binary format (it supports jpeg format only)
+ /// Number of threads to use
+ /// return the SafeHandle of these embeddings
+ public SafeLlavaImageEmbedHandle CreateImageEmbeddings(string image, int threads = -1)
+ {
+ return SafeLlavaImageEmbedHandle.CreateFromFileName(this, image, threads);
+ }
+
///
/// Create the Image Embeddings.
///
/// LLama Context
/// Image in binary format (it supports jpeg format only)
/// return the SafeHandle of these embeddings
- public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image )
+ public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image)
{
return SafeLlavaImageEmbedHandle.CreateFromMemory(this, ctxLlama, image );
}
+
+ ///
+ /// Create the Image Embeddings.
+ ///
+ /// Image in binary format (it supports jpeg format only)
+ /// Number of threads to use
+ /// return the SafeHandle of these embeddings
+ public SafeLlavaImageEmbedHandle CreateImageEmbeddings(byte[] image, int threads = -1)
+ {
+ return SafeLlavaImageEmbedHandle.CreateFromMemory(this, image, threads);
+ }
///
/// Evaluates the image embeddings.
@@ -79,7 +110,8 @@ public bool EvalImageEmbed(LLamaContext ctxLlama, SafeLlavaImageEmbedHandle imag
{
return NativeApi.llava_eval_image_embed(ctxLlama.NativeHandle, imageEmbed, (int)ctxLlama.Params.BatchSize, ref n_past );
}
-
+
+ #region native API
///
/// Load MULTI MODAL PROJECTIONS model / Clip Model
///
@@ -96,6 +128,11 @@ public bool EvalImageEmbed(LLamaContext ctxLlama, SafeLlavaImageEmbedHandle imag
[DllImport(NativeApi.llavaLibraryName, EntryPoint = "clip_free", CallingConvention = CallingConvention.Cdecl)]
private static extern void clip_free(IntPtr ctx);
+ [DllImport(NativeApi.llavaLibraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern int clip_n_mmproj_embd(SafeLlavaModelHandle ctx);
+ [DllImport(NativeApi.llavaLibraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern int clip_n_patches(SafeLlavaModelHandle ctx);
+ #endregion
}
}