diff --git a/LLama.Examples/Examples/BatchedExecutorGuidance.cs b/LLama.Examples/Examples/BatchedExecutorGuidance.cs
index fedfe4e71..b82379c5b 100644
--- a/LLama.Examples/Examples/BatchedExecutorGuidance.cs
+++ b/LLama.Examples/Examples/BatchedExecutorGuidance.cs
@@ -79,7 +79,7 @@ await AnsiConsole
guidance.Prompt(g);
// Early exit if we reach the natural end of the guided sentence
- if (g == model.Tokens.EOS)
+ if (model.Tokens.IsEndOfGeneration(g))
break;
// Update progress bar
diff --git a/LLama.KernelMemory/LLamaSharp.KernelMemory.csproj b/LLama.KernelMemory/LLamaSharp.KernelMemory.csproj
index 7f272e0f7..b2eb0945f 100644
--- a/LLama.KernelMemory/LLamaSharp.KernelMemory.csproj
+++ b/LLama.KernelMemory/LLamaSharp.KernelMemory.csproj
@@ -4,7 +4,7 @@
net6.0;net7.0;net8.0
enable
enable
- 0.11.2
+ 0.12.0
Xbotter
SciSharp STACK
true
@@ -17,7 +17,7 @@
The integration of LLamaSharp and Microsoft kernel-memory. It could make it easy to support document search for LLamaSharp model inference.
- v0.11.2 followed the updating of LLamaSharp.
+ v0.12.0 released with v0.12.0 of LLamaSharp.
MIT
packages
diff --git a/LLama.SemanticKernel/LLamaSharp.SemanticKernel.csproj b/LLama.SemanticKernel/LLamaSharp.SemanticKernel.csproj
index c11a27c05..e46cb249f 100644
--- a/LLama.SemanticKernel/LLamaSharp.SemanticKernel.csproj
+++ b/LLama.SemanticKernel/LLamaSharp.SemanticKernel.csproj
@@ -10,7 +10,7 @@
enable
enable
- 0.11.2
+ 0.12.0
Tim Miller, Xbotter
SciSharp STACK
true
@@ -23,7 +23,7 @@
The integration of LLamaSharp and Microsoft semantic-kernel.
- v0.11.2 followed the updating of LLamaSharp.
+ v0.12.0 released with v0.12.0 of LLamaSharp.
MIT
packages
diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs
index e87649f9a..f2b92b245 100644
--- a/LLama.Web/Common/ModelOptions.cs
+++ b/LLama.Web/Common/ModelOptions.cs
@@ -29,12 +29,14 @@ public class ModelOptions
///
public int GpuLayerCount { get; set; } = 20;
- public uint SeqMax { get; }
+ ///
+ public uint SeqMax { get; set; }
///
public uint? Seed { get; set; } = 1686349486;
- public bool Embeddings { get; }
+ ///
+ public bool Embeddings { get; set; }
///
public bool UseMemorymap { get; set; } = true;
@@ -102,6 +104,9 @@ public class ModelOptions
///
public bool NoKqvOffload { get; set; }
+ ///
+ public bool FlashAttention { get; set; }
+
///
public Encoding Encoding { get; set; } = Encoding.UTF8;
diff --git a/LLama/Abstractions/IContextParams.cs b/LLama/Abstractions/IContextParams.cs
index 8f0c00b4d..f93b2145b 100644
--- a/LLama/Abstractions/IContextParams.cs
+++ b/LLama/Abstractions/IContextParams.cs
@@ -109,7 +109,14 @@ public interface IContextParams
bool NoKqvOffload { get; }
///
+ /// Whether to use flash attention
+ ///
+ bool FlashAttention { get; }
+
+ ///
+ /// defragment the KV cache if holes/size > defrag_threshold, Set to < 0 to disable (default)
/// defragment the KV cache if holes/size > defrag_threshold, Set to or < 0 to disable (default)
+
///
float? DefragThreshold { get; }
diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs
index ac81f1fdb..2b1e1679d 100644
--- a/LLama/Abstractions/IModelParams.cs
+++ b/LLama/Abstractions/IModelParams.cs
@@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
+using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using LLama.Native;
@@ -241,6 +242,7 @@ public sealed record MetadataOverride
private readonly int _valueInt;
private readonly float _valueFloat;
private readonly bool _valueBool;
+ private readonly byte[]? _valueString;
///
/// Create a new override for an int key
@@ -278,6 +280,21 @@ public MetadataOverride(string key, bool value)
Type = LLamaModelKvOverrideType.Bool;
}
+ ///
+ /// Create a new override for a string key
+ ///
+ ///
+ ///
+ public MetadataOverride(string key, string value)
+ {
+ Key = key;
+ _valueString = Encoding.UTF8.GetBytes(value);
+ Type = LLamaModelKvOverrideType.String;
+
+ if (_valueString.Length > 128)
+ throw new ArgumentException("Value string is too long, must be < 128 UTF8 bytes", nameof(value));
+ }
+
internal void WriteValue(ref LLamaModelMetadataOverride dest)
{
switch (Type)
@@ -291,6 +308,13 @@ internal void WriteValue(ref LLamaModelMetadataOverride dest)
case LLamaModelKvOverrideType.Bool:
dest.BoolValue = _valueBool ? -1L : 0;
break;
+ case LLamaModelKvOverrideType.String:
+ unsafe
+ {
+ fixed (byte* strValPtr = dest.StringValue)
+ new Span(_valueString!).CopyTo(new Span(strValPtr, 128));
+ }
+ break;
default:
throw new InvalidEnumArgumentException($"Unknown {nameof(LLamaModelKvOverrideType)} value: {Type}");
}
diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs
index 2d6e7b4d0..28b1ef4e0 100644
--- a/LLama/Common/ModelParams.cs
+++ b/LLama/Common/ModelParams.cs
@@ -99,6 +99,10 @@ public record ModelParams
///
public bool NoKqvOffload { get; set; }
+ ///
+
+ public bool FlashAttention { get; set; }
+
///
public float? DefragThreshold { get; set; }
diff --git a/LLama/Extensions/IContextParamsExtensions.cs b/LLama/Extensions/IContextParamsExtensions.cs
index 40eca2c2b..6c033f8aa 100644
--- a/LLama/Extensions/IContextParamsExtensions.cs
+++ b/LLama/Extensions/IContextParamsExtensions.cs
@@ -50,6 +50,7 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
result.type_k = @params.TypeK ?? GGMLType.GGML_TYPE_F16;
result.type_k = @params.TypeV ?? GGMLType.GGML_TYPE_F16;
result.offload_kqv = !@params.NoKqvOffload;
+ result.flash_attention = @params.FlashAttention;
result.llama_pooling_type = @params.PoolingType;
result.n_threads = Threads(@params.Threads);
diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj
index 3947b7c31..5c7a59dec 100644
--- a/LLama/LLamaSharp.csproj
+++ b/LLama/LLamaSharp.csproj
@@ -7,7 +7,7 @@
AnyCPU;x64;Arm64
True
- 0.11.2
+ 0.12.0
Rinne, Martin Evans, jlsantiago and all the other contributors in https://github.com/SciSharp/LLamaSharp/graphs/contributors.
SciSharp STACK
true
@@ -22,7 +22,7 @@
With the higher-level APIs and RAG support, it's convenient to deploy LLM (Large Language Model) in your application with LLamaSharp.
- LLamaSharp 0.11.2 fixed the performance issue of LLaVA on GPU and improved the log suppression.
+ Updated llama.cpp version to include better support for LLama3 tokenization.
MIT
packages
diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs
index ab5f41469..433d9cd16 100644
--- a/LLama/LLamaStatelessExecutor.cs
+++ b/LLama/LLamaStatelessExecutor.cs
@@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
+using System.Text;
using System.Threading;
using LLama.Exceptions;
using LLama.Native;
@@ -123,8 +124,8 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams
);
}
- // Check if this is the EOS token
- if (id == _weights.Tokens.EOS)
+ // Check if this token should end generation
+ if (_weights.Tokens.IsEndOfGeneration(id))
break;
// Decode this token into text
diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs
index 1ea52e6b8..aab903785 100644
--- a/LLama/Native/LLamaContextParams.cs
+++ b/LLama/Native/LLamaContextParams.cs
@@ -151,6 +151,16 @@ public bool offload_kqv
}
private sbyte _offload_kqv;
+ ///
+ /// whether to use flash attention
+ ///
+ public bool flash_attention
+ {
+ readonly get => Convert.ToBoolean(_flash_attention);
+ set => _flash_attention = Convert.ToSByte(value);
+ }
+ private sbyte _flash_attention;
+
//todo: implement abort callback support
///
/// ggml_abort_callback
diff --git a/LLama/Native/LLamaFtype.cs b/LLama/Native/LLamaFtype.cs
index ae2702db2..bc4b5c4cb 100644
--- a/LLama/Native/LLamaFtype.cs
+++ b/LLama/Native/LLamaFtype.cs
@@ -171,6 +171,11 @@ public enum LLamaFtype
///
LLAMA_FTYPE_MOSTLY_IQ1_M = 31,
+ ///
+ /// except 1d tensors
+ ///
+ LLAMA_FTYPE_MOSTLY_BF16 = 32,
+
///
/// File type was not specified
///
diff --git a/LLama/Native/LLamaModelMetadataOverride.cs b/LLama/Native/LLamaModelMetadataOverride.cs
index ff4e8dd9d..c0cadd1fe 100644
--- a/LLama/Native/LLamaModelMetadataOverride.cs
+++ b/LLama/Native/LLamaModelMetadataOverride.cs
@@ -43,6 +43,12 @@ public unsafe struct LLamaModelMetadataOverride
///
[FieldOffset(136)]
public long BoolValue;
+
+ ///
+ /// Value, **must** only be used if Tag == String
+ ///
+ [FieldOffset(136)]
+ public fixed byte StringValue[128];
}
///
@@ -65,4 +71,9 @@ public enum LLamaModelKvOverrideType
/// Overriding a bool value
///
Bool = 2,
+
+ ///
+ /// Overriding a string value
+ ///
+ String = 3,
}
\ No newline at end of file
diff --git a/LLama/Native/LLamaModelParams.cs b/LLama/Native/LLamaModelParams.cs
index 6fca41fc8..bbece4648 100644
--- a/LLama/Native/LLamaModelParams.cs
+++ b/LLama/Native/LLamaModelParams.cs
@@ -81,6 +81,16 @@ public bool use_mlock
}
private sbyte _use_mlock;
+ ///
+ /// validate model tensor data
+ ///
+ public bool check_tensors
+ {
+ readonly get => Convert.ToBoolean(_check_tensors);
+ set => _check_tensors = Convert.ToSByte(value);
+ }
+ private sbyte _check_tensors;
+
///
/// Create a LLamaModelParams with default values
///
diff --git a/LLama/Native/LLamaModelQuantizeParams.cs b/LLama/Native/LLamaModelQuantizeParams.cs
index b2d37eb05..4a6a4e218 100644
--- a/LLama/Native/LLamaModelQuantizeParams.cs
+++ b/LLama/Native/LLamaModelQuantizeParams.cs
@@ -70,6 +70,16 @@ public bool pure
}
private sbyte _pure;
+ ///
+ /// quantize to the same number of shards
+ ///
+ public bool keep_split
+ {
+ get => Convert.ToBoolean(_keep_split);
+ set => _keep_split = Convert.ToSByte(value);
+ }
+ private sbyte _keep_split;
+
///
/// pointer to importance matrix data
///
diff --git a/LLama/Native/LLamaVocabPreType.cs b/LLama/Native/LLamaVocabPreType.cs
new file mode 100644
index 000000000..0d31d4347
--- /dev/null
+++ b/LLama/Native/LLamaVocabPreType.cs
@@ -0,0 +1,17 @@
+namespace LLama.Native;
+
+///
+///
+///
+/// llama_vocab_pre_type
+internal enum LLamaVocabPreType
+{
+ LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
+ LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
+ LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
+ LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
+ LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
+ LLAMA_VOCAB_PRE_TYPE_MPT = 5,
+ LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
+ LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
+}
\ No newline at end of file
diff --git a/LLama/Native/NativeApi.LLava.cs b/LLama/Native/NativeApi.LLava.cs
index e3aeef4b3..183f183a7 100644
--- a/LLama/Native/NativeApi.LLava.cs
+++ b/LLama/Native/NativeApi.LLava.cs
@@ -13,6 +13,7 @@ public static unsafe partial class NativeApi
/// Llava Model
/// True if validate successfully
[DllImport(llavaLibraryName, EntryPoint = "llava_validate_embed_size", CallingConvention = CallingConvention.Cdecl)]
+ [return: MarshalAs(UnmanagedType.U1)]
public static extern bool llava_validate_embed_size( SafeLLamaContextHandle ctxLlama, SafeLlavaModelHandle ctxClip);
///
@@ -56,7 +57,7 @@ SafeLlavaImageEmbedHandle llava_image_embed_make_with_filename(SafeLlavaModelHan
/// Embedding handle
/// True on success
[DllImport(llavaLibraryName, EntryPoint = "llava_eval_image_embed", CallingConvention = CallingConvention.Cdecl)]
- public static extern bool llava_eval_image_embed(SafeLLamaContextHandle ctx_llama, SafeLlavaImageEmbedHandle embed,
- int n_batch, ref int n_past);
+ [return: MarshalAs(UnmanagedType.U1)]
+ public static extern bool llava_eval_image_embed(SafeLLamaContextHandle ctx_llama, SafeLlavaImageEmbedHandle embed, int n_batch, ref int n_past);
}
\ No newline at end of file
diff --git a/LLama/Native/NativeApi.Sampling.cs b/LLama/Native/NativeApi.Sampling.cs
index 441e70ecd..1b30a1cf7 100644
--- a/LLama/Native/NativeApi.Sampling.cs
+++ b/LLama/Native/NativeApi.Sampling.cs
@@ -176,7 +176,7 @@ public static void llama_sample_apply_guidance(SafeLLamaContextHandle ctx, Span<
public static extern LLamaToken llama_sample_token_greedy(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates);
///
- /// Randomly selects a token from the candidates based on their probabilities.
+ /// Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
///
///
/// Pointer to LLamaTokenDataArray
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index 9301198ef..715225ed2 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -34,6 +34,7 @@ public static void llama_empty_call()
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ [return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_supports_mmap();
///
@@ -41,6 +42,7 @@ public static void llama_empty_call()
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ [return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_supports_mlock();
///
@@ -48,6 +50,7 @@ public static void llama_empty_call()
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ [return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_supports_gpu_offload();
///
@@ -77,6 +80,7 @@ public static void llama_empty_call()
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ [return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_state_load_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out);
///
@@ -88,6 +92,7 @@ public static void llama_empty_call()
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ [return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_state_save_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens, ulong n_token_count);
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
@@ -133,6 +138,14 @@ public static void llama_empty_call()
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern uint llama_n_seq_max(SafeLLamaContextHandle ctx);
+ ///
+ /// Get the pooling type for this context
+ ///
+ ///
+ ///
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern LLamaPoolingType llama_pooling_type(SafeLLamaContextHandle ctx);
+
///
/// Get the embeddings for the a specific sequence.
/// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
@@ -223,19 +236,20 @@ public static unsafe int llama_chat_apply_template(SafeLlamaModelHandle? model,
///
///
/// buffer to write string into
+ /// If true, special tokens are rendered in the output
/// The length written, or if the buffer is too small a negative that indicates the length required
- public static int llama_token_to_piece(SafeLlamaModelHandle model, LLamaToken llamaToken, Span buffer)
+ public static int llama_token_to_piece(SafeLlamaModelHandle model, LLamaToken llamaToken, Span buffer, bool special)
{
unsafe
{
fixed (byte* bufferPtr = buffer)
{
- return llama_token_to_piece_native(model, llamaToken, bufferPtr, buffer.Length);
+ return llama_token_to_piece_native(model, llamaToken, bufferPtr, buffer.Length, special);
}
}
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_token_to_piece")]
- static extern unsafe int llama_token_to_piece_native(SafeLlamaModelHandle model, LLamaToken llamaToken, byte* buffer, int length);
+ static extern unsafe int llama_token_to_piece_native(SafeLlamaModelHandle model, LLamaToken llamaToken, byte* buffer, int length, bool special);
}
///
@@ -282,7 +296,7 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
public static extern int llama_get_kv_cache_used_cells(SafeLLamaContextHandle ctx);
///
- /// Clear the KV cache
+ /// Clear the KV cache. Both cell info is erased and KV data is zeroed
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
@@ -297,6 +311,7 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
///
/// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ [return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_kv_cache_seq_rm(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1);
///
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
index 2758c0509..f24cfe5fd 100644
--- a/LLama/Native/SafeLlamaModelHandle.cs
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -361,6 +361,16 @@ private static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model,
///
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_token_eot(SafeLlamaModelHandle model);
+
+ ///
+ /// Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
+ ///
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ [return: MarshalAs(UnmanagedType.U1)]
+ private static extern bool llama_token_is_eog(SafeLlamaModelHandle model, LLamaToken token);
#endregion
#region LoRA
@@ -402,10 +412,11 @@ public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null
///
/// Token to decode
/// A span to attempt to write into. If this is too small nothing will be written
+ /// If true, special characters will be converted to text. If false they will be invisible.
/// The size of this token. **nothing will be written** if this is larger than `dest`
- public uint TokenToSpan(LLamaToken token, Span dest)
+ public uint TokenToSpan(LLamaToken token, Span dest, bool special = false)
{
- var length = NativeApi.llama_token_to_piece(this, token, dest);
+ var length = NativeApi.llama_token_to_piece(this, token, dest, special);
return (uint)Math.Abs(length);
}
@@ -623,6 +634,16 @@ internal ModelTokens(SafeLlamaModelHandle model)
/// Codellama end of infill middle
///
public LLamaToken? EOT => Normalize(llama_token_eot(_model));
+
+ ///
+ /// Check if the given token should end generation
+ ///
+ ///
+ ///
+ public bool IsEndOfGeneration(LLamaToken token)
+ {
+ return llama_token_is_eog(_model, token);
+ }
}
}
}
diff --git a/LLama/Native/llama_vocab_pre_type.cs b/LLama/Native/llama_vocab_pre_type.cs
new file mode 100644
index 000000000..7f08fb35c
--- /dev/null
+++ b/LLama/Native/llama_vocab_pre_type.cs
@@ -0,0 +1,27 @@
+namespace LLama.Native;
+
+/////
+///// pre-tokenization type
+/////
+///// llama_vocab_pre_type
+//public enum llama_vocab_pre_type
+//{
+// ///
+// /// Default pre tokenization type
+// ///
+// /// LLAMA_VOCAB_PRE_TYPE_DEFAULT
+// Default = 0,
+//
+// LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
+// LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
+// LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
+// LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
+// LLAMA_VOCAB_PRE_TYPE_MPT = 5,
+// LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
+// LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
+// LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
+// LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
+// LLAMA_VOCAB_PRE_TYPE_QWEN2 = 10,
+// LLAMA_VOCAB_PRE_TYPE_OLMO = 11,
+// LLAMA_VOCAB_PRE_TYPE_DBRX = 12,
+//}
\ No newline at end of file
diff --git a/LLama/runtimes/deps/avx/libllama.dll b/LLama/runtimes/deps/avx/libllama.dll
index 3b3b3fae3..6eac478ab 100644
Binary files a/LLama/runtimes/deps/avx/libllama.dll and b/LLama/runtimes/deps/avx/libllama.dll differ
diff --git a/LLama/runtimes/deps/avx/libllama.so b/LLama/runtimes/deps/avx/libllama.so
index 08dcee0f7..b4d2fb9c5 100644
Binary files a/LLama/runtimes/deps/avx/libllama.so and b/LLama/runtimes/deps/avx/libllama.so differ
diff --git a/LLama/runtimes/deps/avx/libllava_shared.so b/LLama/runtimes/deps/avx/libllava_shared.so
index 1c8adfcb7..d1ef24e17 100644
Binary files a/LLama/runtimes/deps/avx/libllava_shared.so and b/LLama/runtimes/deps/avx/libllava_shared.so differ
diff --git a/LLama/runtimes/deps/avx/llama.dll b/LLama/runtimes/deps/avx/llama.dll
index 3b3b3fae3..6eac478ab 100644
Binary files a/LLama/runtimes/deps/avx/llama.dll and b/LLama/runtimes/deps/avx/llama.dll differ
diff --git a/LLama/runtimes/deps/avx/llava_shared.dll b/LLama/runtimes/deps/avx/llava_shared.dll
index e08a474b6..5d1b67a93 100644
Binary files a/LLama/runtimes/deps/avx/llava_shared.dll and b/LLama/runtimes/deps/avx/llava_shared.dll differ
diff --git a/LLama/runtimes/deps/avx2/libllama.dll b/LLama/runtimes/deps/avx2/libllama.dll
index bb8e5c48b..23de7074c 100644
Binary files a/LLama/runtimes/deps/avx2/libllama.dll and b/LLama/runtimes/deps/avx2/libllama.dll differ
diff --git a/LLama/runtimes/deps/avx2/libllama.so b/LLama/runtimes/deps/avx2/libllama.so
index e7c27f79f..f3eea88b4 100644
Binary files a/LLama/runtimes/deps/avx2/libllama.so and b/LLama/runtimes/deps/avx2/libllama.so differ
diff --git a/LLama/runtimes/deps/avx2/libllava_shared.so b/LLama/runtimes/deps/avx2/libllava_shared.so
index f9bbdf272..5d55bfa5d 100644
Binary files a/LLama/runtimes/deps/avx2/libllava_shared.so and b/LLama/runtimes/deps/avx2/libllava_shared.so differ
diff --git a/LLama/runtimes/deps/avx2/llama.dll b/LLama/runtimes/deps/avx2/llama.dll
index bb8e5c48b..23de7074c 100644
Binary files a/LLama/runtimes/deps/avx2/llama.dll and b/LLama/runtimes/deps/avx2/llama.dll differ
diff --git a/LLama/runtimes/deps/avx2/llava_shared.dll b/LLama/runtimes/deps/avx2/llava_shared.dll
index 6b4ad9c13..b286c4e54 100644
Binary files a/LLama/runtimes/deps/avx2/llava_shared.dll and b/LLama/runtimes/deps/avx2/llava_shared.dll differ
diff --git a/LLama/runtimes/deps/avx512/libllama.dll b/LLama/runtimes/deps/avx512/libllama.dll
index fcbc052eb..d29a14f20 100644
Binary files a/LLama/runtimes/deps/avx512/libllama.dll and b/LLama/runtimes/deps/avx512/libllama.dll differ
diff --git a/LLama/runtimes/deps/avx512/libllama.so b/LLama/runtimes/deps/avx512/libllama.so
index a0044ad66..abfe110d3 100644
Binary files a/LLama/runtimes/deps/avx512/libllama.so and b/LLama/runtimes/deps/avx512/libllama.so differ
diff --git a/LLama/runtimes/deps/avx512/libllava_shared.so b/LLama/runtimes/deps/avx512/libllava_shared.so
index d0c76ef13..4ff11d280 100644
Binary files a/LLama/runtimes/deps/avx512/libllava_shared.so and b/LLama/runtimes/deps/avx512/libllava_shared.so differ
diff --git a/LLama/runtimes/deps/avx512/llama.dll b/LLama/runtimes/deps/avx512/llama.dll
index fcbc052eb..d29a14f20 100644
Binary files a/LLama/runtimes/deps/avx512/llama.dll and b/LLama/runtimes/deps/avx512/llama.dll differ
diff --git a/LLama/runtimes/deps/avx512/llava_shared.dll b/LLama/runtimes/deps/avx512/llava_shared.dll
index 8d643cb1d..088b1b8d2 100644
Binary files a/LLama/runtimes/deps/avx512/llava_shared.dll and b/LLama/runtimes/deps/avx512/llava_shared.dll differ
diff --git a/LLama/runtimes/deps/clblast/libllama.so b/LLama/runtimes/deps/clblast/libllama.so
index b6bff999a..c3e6eb39d 100644
Binary files a/LLama/runtimes/deps/clblast/libllama.so and b/LLama/runtimes/deps/clblast/libllama.so differ
diff --git a/LLama/runtimes/deps/clblast/libllava_shared.so b/LLama/runtimes/deps/clblast/libllava_shared.so
index 6f63d183a..52b2483b2 100644
Binary files a/LLama/runtimes/deps/clblast/libllava_shared.so and b/LLama/runtimes/deps/clblast/libllava_shared.so differ
diff --git a/LLama/runtimes/deps/clblast/llama.dll b/LLama/runtimes/deps/clblast/llama.dll
index 055b24a84..d7158fcd8 100644
Binary files a/LLama/runtimes/deps/clblast/llama.dll and b/LLama/runtimes/deps/clblast/llama.dll differ
diff --git a/LLama/runtimes/deps/clblast/llava_shared.dll b/LLama/runtimes/deps/clblast/llava_shared.dll
index 349ec89e5..2eb43fd15 100644
Binary files a/LLama/runtimes/deps/clblast/llava_shared.dll and b/LLama/runtimes/deps/clblast/llava_shared.dll differ
diff --git a/LLama/runtimes/deps/cu11.7.1/libllama.so b/LLama/runtimes/deps/cu11.7.1/libllama.so
index 1f1a79a0f..955355d2e 100644
Binary files a/LLama/runtimes/deps/cu11.7.1/libllama.so and b/LLama/runtimes/deps/cu11.7.1/libllama.so differ
diff --git a/LLama/runtimes/deps/cu11.7.1/libllava_shared.so b/LLama/runtimes/deps/cu11.7.1/libllava_shared.so
index 47cba9b13..a9fe23026 100644
Binary files a/LLama/runtimes/deps/cu11.7.1/libllava_shared.so and b/LLama/runtimes/deps/cu11.7.1/libllava_shared.so differ
diff --git a/LLama/runtimes/deps/cu11.7.1/llama.dll b/LLama/runtimes/deps/cu11.7.1/llama.dll
index a1cd82b25..0d18a43d0 100644
Binary files a/LLama/runtimes/deps/cu11.7.1/llama.dll and b/LLama/runtimes/deps/cu11.7.1/llama.dll differ
diff --git a/LLama/runtimes/deps/cu11.7.1/llava_shared.dll b/LLama/runtimes/deps/cu11.7.1/llava_shared.dll
index 00e9794e6..c93b5461f 100644
Binary files a/LLama/runtimes/deps/cu11.7.1/llava_shared.dll and b/LLama/runtimes/deps/cu11.7.1/llava_shared.dll differ
diff --git a/LLama/runtimes/deps/cu12.1.0/libllama.so b/LLama/runtimes/deps/cu12.1.0/libllama.so
index 39b09e6b3..2f0311492 100644
Binary files a/LLama/runtimes/deps/cu12.1.0/libllama.so and b/LLama/runtimes/deps/cu12.1.0/libllama.so differ
diff --git a/LLama/runtimes/deps/cu12.1.0/libllava_shared.so b/LLama/runtimes/deps/cu12.1.0/libllava_shared.so
index ce830a7da..e7948cf79 100644
Binary files a/LLama/runtimes/deps/cu12.1.0/libllava_shared.so and b/LLama/runtimes/deps/cu12.1.0/libllava_shared.so differ
diff --git a/LLama/runtimes/deps/cu12.1.0/llama.dll b/LLama/runtimes/deps/cu12.1.0/llama.dll
index 09a87cb20..ba15b7677 100644
Binary files a/LLama/runtimes/deps/cu12.1.0/llama.dll and b/LLama/runtimes/deps/cu12.1.0/llama.dll differ
diff --git a/LLama/runtimes/deps/cu12.1.0/llava_shared.dll b/LLama/runtimes/deps/cu12.1.0/llava_shared.dll
index 597733ed3..5c5962104 100644
Binary files a/LLama/runtimes/deps/cu12.1.0/llava_shared.dll and b/LLama/runtimes/deps/cu12.1.0/llava_shared.dll differ
diff --git a/LLama/runtimes/deps/libllama.dll b/LLama/runtimes/deps/libllama.dll
index deb86e0df..b2d85078e 100644
Binary files a/LLama/runtimes/deps/libllama.dll and b/LLama/runtimes/deps/libllama.dll differ
diff --git a/LLama/runtimes/deps/libllama.so b/LLama/runtimes/deps/libllama.so
index 85dce3430..09a78f8f7 100644
Binary files a/LLama/runtimes/deps/libllama.so and b/LLama/runtimes/deps/libllama.so differ
diff --git a/LLama/runtimes/deps/libllava_shared.so b/LLama/runtimes/deps/libllava_shared.so
index f41a9c670..7ff06062d 100644
Binary files a/LLama/runtimes/deps/libllava_shared.so and b/LLama/runtimes/deps/libllava_shared.so differ
diff --git a/LLama/runtimes/deps/llama.dll b/LLama/runtimes/deps/llama.dll
index deb86e0df..b2d85078e 100644
Binary files a/LLama/runtimes/deps/llama.dll and b/LLama/runtimes/deps/llama.dll differ
diff --git a/LLama/runtimes/deps/llava_shared.dll b/LLama/runtimes/deps/llava_shared.dll
index 10b057945..43d55ab8d 100644
Binary files a/LLama/runtimes/deps/llava_shared.dll and b/LLama/runtimes/deps/llava_shared.dll differ
diff --git a/LLama/runtimes/deps/osx-arm64/ggml-metal.metal b/LLama/runtimes/deps/osx-arm64/ggml-metal.metal
index 9a29f57a3..46c7d5039 100644
--- a/LLama/runtimes/deps/osx-arm64/ggml-metal.metal
+++ b/LLama/runtimes/deps/osx-arm64/ggml-metal.metal
@@ -213,6 +213,15 @@ kernel void kernel_scale_4(
dst[tpig] = src0[tpig] * scale;
}
+kernel void kernel_clamp(
+ device const float * src0,
+ device float * dst,
+ constant float & min,
+ constant float & max,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
+}
+
kernel void kernel_relu(
device const float * src0,
device float * dst,
@@ -233,6 +242,15 @@ constant float GELU_QUICK_COEF = -1.702f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
kernel void kernel_gelu(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
@@ -246,6 +264,15 @@ kernel void kernel_gelu(
}
kernel void kernel_gelu_quick(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_gelu_quick_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
@@ -255,6 +282,14 @@ kernel void kernel_gelu_quick(
}
kernel void kernel_silu(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+ dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_silu_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
@@ -317,11 +352,12 @@ kernel void kernel_sum_rows(
dst_row[0] = row_sum;
}
+template
kernel void kernel_soft_max(
- device const float * src0,
- device const float * src1,
- device const float * src2,
- device float * dst,
+ device const char * src0,
+ device const char * src1,
+ device const char * src2,
+ device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -340,10 +376,10 @@ kernel void kernel_soft_max(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
- device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
- device const float * ppos = src2 != src0 ? src2 : nullptr;
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
+ device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
+ device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
float slope = 0.0f;
@@ -421,11 +457,12 @@ kernel void kernel_soft_max(
}
}
+template
kernel void kernel_soft_max_4(
- device const float * src0,
- device const float * src1,
- device const float * src2,
- device float * dst,
+ device const char * src0,
+ device const char * src1,
+ device const char * src2,
+ device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -444,10 +481,10 @@ kernel void kernel_soft_max_4(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
- device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
- device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
+ device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
+ device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
float slope = 0.0f;
@@ -464,7 +501,7 @@ kernel void kernel_soft_max_4(
float4 lmax4 = -INFINITY;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
}
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -490,7 +527,7 @@ kernel void kernel_soft_max_4(
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
@@ -527,6 +564,14 @@ kernel void kernel_soft_max_4(
}
}
+typedef decltype(kernel_soft_max) kernel_soft_max_t;
+typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t;
+
+template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max;
+template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max;
+template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4;
+template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4;
+
kernel void kernel_diag_mask_inf(
device const float * src0,
device float * dst,
@@ -866,6 +911,7 @@ void mul_vec_q_n_f32_impl(
int64_t ne1,
uint r2,
uint r3,
+ threadgroup int8_t * shared_values,
uint3 tgpig, uint tiisg, uint sgitg) {
const int nb = ne00/QK4_0;
@@ -942,7 +988,7 @@ kernel void kernel_mul_mv_q4_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q4_1_f32(
@@ -968,7 +1014,7 @@ kernel void kernel_mul_mv_q4_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q5_0_f32(
@@ -994,7 +1040,7 @@ kernel void kernel_mul_mv_q5_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q5_1_f32(
@@ -1020,7 +1066,7 @@ kernel void kernel_mul_mv_q5_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
@@ -1030,18 +1076,19 @@ void kernel_mul_mv_q8_0_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nr = N_DST;
const int nsg = N_SIMDGROUP;
const int nw = N_SIMDWIDTH;
@@ -1119,7 +1166,7 @@ kernel void kernel_mul_mv_q8_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
#define N_F32_F32 4
@@ -1128,24 +1175,24 @@ void kernel_mul_mv_f32_f32_impl(
device const char * src0,
device const char * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ uint64_t nb00,
+ uint64_t nb01,
+ uint64_t nb02,
+ int64_t ne10,
+ int64_t ne11,
+ int64_t ne12,
+ uint64_t nb10,
+ uint64_t nb11,
+ uint64_t nb12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ uint3 tgpig,
+ uint tiisg) {
const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_F32_F32;
@@ -1398,24 +1445,24 @@ void kernel_mul_mv_f16_f32_impl(
device const char * src0,
device const char * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ uint64_t nb00,
+ uint64_t nb01,
+ uint64_t nb02,
+ int64_t ne10,
+ int64_t ne11,
+ int64_t ne12,
+ uint64_t nb10,
+ uint64_t nb11,
+ uint64_t nb12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ uint3 tgpig,
+ uint tiisg) {
const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_F16_F32;
@@ -2047,9 +2094,12 @@ kernel void kernel_leaky_relu_f32(
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
}
-kernel void kernel_cpy_f16_f16(
- device const half * src0,
- device half * dst,
+typedef void (flash_attn_ext_f16_t)(
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * mask,
+ device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -2058,38 +2108,35 @@ kernel void kernel_cpy_f16_f16(
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne31,
+ constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
-
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
-
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
- device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
- dst_data[i00] = src[0];
- }
-}
-
-kernel void kernel_cpy_f16_f32(
- device const half * src0,
+ constant float & scale,
+ threadgroup half * shared,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]);
+
+// ref: https://arxiv.org/pdf/2307.08691.pdf
+template // head size, queries per threadgroup, cache items per threadgroup
+kernel void kernel_flash_attn_ext_f16(
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * mask,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -2099,123 +2146,316 @@ kernel void kernel_cpy_f16_f32(
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne31,
+ constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
+ constant float & scale,
+ threadgroup half * shared [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ const short nsg = ntg.y; // number of simdgroups
+
+ const short iq3 = tgpig[2];
+ const short iq2 = tgpig[1];
+ const short iq1 = tgpig[0]*Q;
+
+ const short D4 = D/4;
+ const short D8 = D/8;
+ //const short Q8 = Q/8;
+ const short NW = N_SIMDWIDTH;
+ const short SH = (C + Q); // shared memory per simdgroup in (half)
+
+ const short T = D + 2*nsg*SH; // shared memory size per query in (half)
+ const short TF = T/2; // shared memory size per query in (float)
+ const short T4 = T/4; // shared memory size per query in (half4)
+
+ threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
+ threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
+ threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
+
+ // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
+ simdgroup_half8x8 lo[D8];
+
+ // load heads from Q to shared memory
+ for (short j = sgitg; j < Q; j += nsg) {
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
+
+ for (short i = tiisg; i < D4; i += NW) {
+ if (iq1 + j < ne01) {
+ sq4[j*T4 + i] = (half4) q4[i];
+ } else {
+ sq4[j*T4 + i] = 0.0h;
+ }
+ }
+ }
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ // zero out lo
+ for (short i = 0; i < D8; ++i) {
+ lo[i] = make_filled_simdgroup_matrix(0.0h);
+ }
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+ // zero out shared memory SH
+ for (short j = 0; j < Q; ++j) {
+ for (short i = tiisg; i < SH; i += NW) {
+ ss[j*TF + i] = 0.0f;
+ }
+ }
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
- device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
- dst_data[i00] = src[0];
- }
-}
+ {
+ float S[Q] = { [0 ... Q-1] = 0.0h };
+ float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
-kernel void kernel_cpy_f32_f16(
- device const float * src0,
- device half * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
+ // assume K and V are same shape
+ const short ne22 = ne12;
+ const short ne23 = ne13;
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ const uint nb21 = nb11;
+ const uint nb22 = nb12;
+ const uint nb23 = nb13;
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+ // broadcast
+ const short rk2 = ne02/ne12;
+ const short rk3 = ne03/ne13;
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ const short rv2 = ne02/ne22;
+ const short rv3 = ne03/ne23;
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ // k indices
+ const short ik2 = iq2/rk2;
+ const short ik3 = iq3/rk3;
- dst_data[i00] = src[0];
- }
-}
+ // v indices
+ const short iv2 = iq2/rv2;
+ const short iv3 = iq3/rv3;
-kernel void kernel_cpy_f32_f32(
- device const float * src0,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
+ // load the queries from shared memory into local memory
+ simdgroup_half8x8 mq[D8];
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_load(mq[i], sq + i*8, T);
+ }
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+ // pointer to the mask
+ device const half * mp = (device const half *) (mask + iq1*nb31);
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ // prepare diagonal scale matrix
+ simdgroup_float8x8 mscale(scale);
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ // loop over the KV cache
+ // each simdgroup handles blocks of Q rows and C columns
+ for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
+ const int ic = ic0 + C*sgitg;
+ if (ic >= ne11) {
+ break;
+ }
- dst_data[i00] = src[0];
+ // Q*K^T
+ {
+ for (short cc = 0; cc < C/8; ++cc) {
+ simdgroup_float8x8 mqk = make_filled_simdgroup_matrix(0.h);
+
+ device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
+
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_half8x8 mk;
+ simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
+
+ simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
+ }
+
+ // mqk = mqk*scale + mask
+ simdgroup_half8x8 mm;
+ simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
+ simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
+
+ simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
+ }
+ }
+
+ // used to detect blocks full of -INF
+ float smax = -INFINITY;
+
+ // online softmax
+ {
+ float ms[Q];
+
+ for (short j = 0; j < Q; ++j) {
+ const short p = tiisg;
+
+ const float m = M[j];
+ const float s = ss[j*TF + p];
+
+ smax = simd_max(max(smax, s));
+ M[j] = simd_max(max(M[j], s));
+
+ ms[j] = exp(m - M[j]);
+ const float vs = exp(s - M[j]);
+
+ S[j] = S[j]*ms[j] + simd_sum(vs);
+
+ // the P matrix from the paper (Q rows, C columns)
+ ss[j*TF + p] = vs;
+ }
+
+ // create a QxQ diagonal matrix for rescaling the output
+ if (tiisg < Q) {
+ ss[tiisg*TF + C + tiisg] = ms[tiisg];
+ }
+ }
+
+ // skip -INF blocks
+ if (smax == -INFINITY) {
+ continue;
+ }
+
+ // O = diag(ms)*O
+ {
+ simdgroup_float8x8 mm;
+ simdgroup_load(mm, ss + C, TF, 0, false);
+
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_multiply(lo[i], mm, lo[i]);
+ }
+ }
+
+ // O = O + (Q*K^T)*V
+ {
+ for (short cc = 0; cc < C/8; ++cc) {
+ device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
+
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_half8x8 mk;
+ simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
+
+ simdgroup_float8x8 mv;
+ simdgroup_load(mv, ss + 8*cc, TF, 0, false);
+
+ simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
+ }
+ }
+ }
+ }
+
+ // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
+ for (short j = 0; j < Q; ++j) {
+ if (tiisg == 0) {
+ ss[j*TF + 0] = S[j];
+ ss[j*TF + 1] = M[j];
+ }
+ }
+ }
+
+ // reduce the warps sequentially
+ for (short sg = 1; sg < nsg; ++sg) {
+ float S = { 0.0h };
+ float M = { -FLT_MAX/2 };
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // each simdgroup stores its output to shared memory, reusing sq
+ if (sgitg == sg) {
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_store(lo[i], sq + i*8, T, 0, false);
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // the first simdgroup accumulates the results from the other simdgroups
+ if (sgitg == 0) {
+ for (short j = 0; j < Q; ++j) {
+ const float S0 = ss[j*TF + 0];
+ const float S1 = ss[j*TF + sg*SH + 0];
+
+ const float M0 = ss[j*TF + 1];
+ const float M1 = ss[j*TF + sg*SH + 1];
+
+ M = max(M0, M1);
+
+ const float ms0 = exp(M0 - M);
+ const float ms1 = exp(M1 - M);
+
+ S = S0*ms0 + S1*ms1;
+
+ if (tiisg == 0) {
+ ss[j*TF + 0] = S;
+ ss[j*TF + 1] = M;
+
+ ss[j*TF + C + j ] = ms0;
+ ss[j*TF + C + j + sg*SH] = ms1;
+ }
+ }
+
+ // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
+ {
+ simdgroup_half8x8 t;
+ simdgroup_float8x8 ms0;
+ simdgroup_float8x8 ms1;
+
+ simdgroup_load(ms0, ss + C, TF, 0, false);
+ simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
+
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_load (t, sq + i*8, T, 0, false);
+ simdgroup_multiply(t, ms1, t);
+
+ simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
+ }
+ }
+ }
+ }
+
+ // store result to shared memory (reuse sq)
+ if (sgitg == 0) {
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_store(lo[i], sq + i*8, T, 0, false);
+ }
+ }
+
+ device float4 * dst4 = (device float4 *) dst;
+
+ // final rescale with 1/S and store to global memory
+ if (sgitg == 0) {
+ for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
+ const float S = ss[j*TF + 0];
+
+ for (short i = tiisg; i < D4; i += NW) {
+ dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
+ }
+ }
}
}
-kernel void kernel_cpy_f32_q8_0(
- device const float * src0,
- device void * dst,
+template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
+template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
+template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
+template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
+template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
+template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
+
+template // head size, queries per threadgroup, cache items per threadgroup
+kernel void kernel_flash_attn_ext_vec_f16(
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * mask,
+ device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -2224,56 +2464,265 @@ kernel void kernel_cpy_f32_q8_0(
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne31,
+ constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
+ constant float & scale,
+ threadgroup half * shared [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ const short nsg = ntg.y; // number of simdgroups
+
+ const short iq3 = tgpig[2];
+ const short iq2 = tgpig[1];
+ const short iq1 = tgpig[0];
+
+ const short D4 = D/4;
+ const short NW = N_SIMDWIDTH;
+ const short SH = (C + Q); // shared memory per simdgroup in (half)
+
+ const short T = D + 2*nsg*SH; // shared memory size per query in (half)
+
+ //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
+ threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
+ threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
+ threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
+ threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
+
+ // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
+ half4 lo[D4/NW];
+
+ // load heads from Q to shared memory
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
+
+ for (short i = tiisg; i < D4; i += NW) {
+ if (iq1 < ne01) {
+ sq4[i] = (half4) q4[i];
+ } else {
+ sq4[i] = 0.0h;
+ }
+ }
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ // zero out lo
+ for (short i = tiisg; i < D4; i += NW) {
+ lo[i/NW] = 0.0h;
+ }
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
+ // zero out shared memory SH
+ for (short i = tiisg; i < SH/4; i += NW) {
+ ss4[i] = 0.0h;
+ }
- device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ {
+ float S = { 0.0h };
+ float M = { -FLT_MAX/2 };
- float amax = 0.0f; // absolute max
+ // assume K and V are same shape
+ const short ne22 = ne12;
+ const short ne23 = ne13;
- for (int j = 0; j < QK8_0; j++) {
- const float v = src[j];
- amax = MAX(amax, fabs(v));
+ const uint nb21 = nb11;
+ const uint nb22 = nb12;
+ const uint nb23 = nb13;
+
+ // broadcast
+ const short rk2 = ne02/ne12;
+ const short rk3 = ne03/ne13;
+
+ const short rv2 = ne02/ne22;
+ const short rv3 = ne03/ne23;
+
+ // k indices
+ const short ik2 = iq2 / rk2;
+ const short ik3 = iq3 / rk3;
+
+ // v indices
+ const short iv2 = iq2 / rv2;
+ const short iv3 = iq3 / rv3;
+
+ // load the queries from shared memory into local memory
+ half4 mq[D4];
+
+ for (short ii = 0; ii < D4; ii += NW) {
+ short i = ii + tiisg;
+ mq[i] = sq4[i];
}
- const float d = amax / ((1 << 7) - 1);
- const float id = d ? 1.0f/d : 0.0f;
+ // pointer to the mask
+ device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
- dst_data[i00/QK8_0].d = d;
+ // loop over the KV cache
+ // each simdgroup handles blocks of Q rows and C columns
+ for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
+ const int ic = ic0 + C*sgitg;
+ if (ic >= ne11) {
+ break;
+ }
- for (int j = 0; j < QK8_0; ++j) {
- const float x0 = src[j]*id;
+ // Q*K^T
+ {
+#pragma unroll
+ for (short cc = 0; cc < C/4; ++cc) {
+ float4 mqk = { 0.0h };
+
+ device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
+
+#pragma unroll
+ for (short ii = 0; ii < D4; ii += NW) {
+ const short i = ii + tiisg;
+
+ half4x4 mk;
+ mk[0] = pk4[i + 0*(nb11/8)];
+ mk[1] = pk4[i + 1*(nb11/8)];
+ mk[2] = pk4[i + 2*(nb11/8)];
+ mk[3] = pk4[i + 3*(nb11/8)];
+
+ mqk += (float4) (mq[i] * mk);
+ }
+
+ // reduce the results from the threads in the simdgroup
+ mqk += simd_shuffle_down(mqk, 16);
+ mqk += simd_shuffle_down(mqk, 8);
+ mqk += simd_shuffle_down(mqk, 4);
+ mqk += simd_shuffle_down(mqk, 2);
+ mqk += simd_shuffle_down(mqk, 1);
+
+ // mqk = mqk*scale + mask
+ if (tiisg == 0) {
+ float4 mm = (float4) mp4[ic/4 + cc];
+ mqk = mqk*scale + mm;
+
+ ss4[cc] = mqk;
+ }
+ }
+ }
+
+ // online softmax
+ {
+ const short p = tiisg;
+
+ const float m = M;
+ const float s = ss[p];
+
+ M = simd_max(max(M, s));
+
+ const float ms = exp(m - M);
+ const float vs = exp(s - M);
+
+ S = S*ms + simd_sum(vs);
+
+ // the P matrix from the paper (Q rows, C columns)
+ ss[p] = vs;
+
+ // O = diag(ms)*O
+#pragma unroll
+ for (short ii = 0; ii < D4; ii += NW) {
+ const short i = ii + tiisg;
+ lo[i/NW] *= ms;
+ }
+ }
+
+ // O = O + (Q*K^T)*V
+ {
+#pragma unroll
+ for (short cc = 0; cc < C/4; ++cc) {
+ device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
+
+#pragma unroll
+ for (short ii = 0; ii < D4; ii += NW) {
+ const short i = ii + tiisg;
+
+ lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
+ lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
+ lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
+ lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
+ }
+ }
+ }
- dst_data[i00/QK8_0].qs[j] = round(x0);
+ }
+
+ // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
+ if (tiisg == 0) {
+ ss[0] = S;
+ ss[1] = M;
+ }
+ }
+
+ // store results to shared memory
+ for (short ii = 0; ii < D4; ii += NW) {
+ short i = ii + tiisg;
+ sr4[i] = lo[ii/NW];
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // parallel reduce
+ for (short r = nsg/2; r > 0; r >>= 1) {
+ if (sgitg < r) {
+ const float S0 = ss[ 0];
+ const float S1 = ss[r*SH + 0];
+
+ const float M0 = ss[ 1];
+ const float M1 = ss[r*SH + 1];
+
+ const float M = max(M0, M1);
+
+ const float ms0 = exp(M0 - M);
+ const float ms1 = exp(M1 - M);
+
+ const float S = S0*ms0 + S1*ms1;
+
+ if (tiisg == 0) {
+ ss[0] = S;
+ ss[1] = M;
+ }
+
+ // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
+ for (short ii = 0; ii < D4; ii += NW) {
+ short i = ii + tiisg;
+ sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ device float4 * dst4 = (device float4 *) dst;
+
+ // final rescale with 1/S and store to global memory
+ if (sgitg == 0) {
+ const float S = ss[0];
+
+ for (short ii = 0; ii < D4; ii += NW) {
+ short i = ii + tiisg;
+ dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
}
}
}
-kernel void kernel_cpy_f32_q4_0(
- device const float * src0,
- device void * dst,
+template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
+
+kernel void kernel_cpy_f16_f16(
+ device const half * src0,
+ device half * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -2302,45 +2751,19 @@ kernel void kernel_cpy_f32_q4_0(
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
-
- device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-
- float amax = 0.0f; // absolute max
- float max = 0.0f;
-
- for (int j = 0; j < QK4_0; j++) {
- const float v = src[j];
- if (amax < fabs(v)) {
- amax = fabs(v);
- max = v;
- }
- }
-
- const float d = max / -8;
- const float id = d ? 1.0f/d : 0.0f;
-
- dst_data[i00/QK4_0].d = d;
-
- for (int j = 0; j < QK4_0/2; ++j) {
- const float x0 = src[0 + j]*id;
- const float x1 = src[QK4_0/2 + j]*id;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
- dst_data[i00/QK4_0].qs[j] = xi0;
- dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
- }
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ dst_data[i00] = src[0];
}
}
-kernel void kernel_cpy_f32_q4_1(
- device const float * src0,
- device void * dst,
+kernel void kernel_cpy_f16_f32(
+ device const half * src0,
+ device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -2369,42 +2792,159 @@ kernel void kernel_cpy_f32_q4_1(
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
-
- device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
- float min = FLT_MAX;
- float max = -FLT_MAX;
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
- for (int j = 0; j < QK4_1; j++) {
- const float v = src[j];
- if (min > v) min = v;
- if (max < v) max = v;
- }
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ dst_data[i00] = src[0];
+ }
+}
- const float d = (max - min) / ((1 << 4) - 1);
- const float id = d ? 1.0f/d : 0.0f;
+kernel void kernel_cpy_f32_f16(
+ device const float * src0,
+ device half * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
- dst_data[i00/QK4_1].d = d;
- dst_data[i00/QK4_1].m = min;
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
- for (int j = 0; j < QK4_1/2; ++j) {
- const float x0 = (src[0 + j] - min)*id;
- const float x1 = (src[QK4_1/2 + j] - min)*id;
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
- dst_data[i00/QK4_1].qs[j] = xi0;
- dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ dst_data[i00] = src[0];
+ }
+}
+
+kernel void kernel_cpy_f32_f32(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ dst_data[i00] = src[0];
+ }
+}
+
+kernel void kernel_cpy_f32_q8_0(
+ device const float * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
+
+ device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < QK8_0; j++) {
+ const float v = src[j];
+ amax = MAX(amax, fabs(v));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst_data[i00/QK8_0].d = d;
+
+ for (int j = 0; j < QK8_0; ++j) {
+ const float x0 = src[j]*id;
+
+ dst_data[i00/QK8_0].qs[j] = round(x0);
}
}
}
-kernel void kernel_cpy_f32_q5_0(
+kernel void kernel_cpy_f32_q4_0(
device const float * src0,
device void * dst,
constant int64_t & ne00,
@@ -2435,17 +2975,17 @@ kernel void kernel_cpy_f32_q5_0(
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
- device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
- for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
+ for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
float amax = 0.0f; // absolute max
float max = 0.0f;
- for (int j = 0; j < QK5_0; j++) {
+ for (int j = 0; j < QK4_0; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
@@ -2453,31 +2993,25 @@ kernel void kernel_cpy_f32_q5_0(
}
}
- const float d = max / -16;
+ const float d = max / -8;
const float id = d ? 1.0f/d : 0.0f;
- dst_data[i00/QK5_0].d = d;
+ dst_data[i00/QK4_0].d = d;
- uint32_t qh = 0;
- for (int j = 0; j < QK5_0/2; ++j) {
+ for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = src[0 + j]*id;
- const float x1 = src[QK5_0/2 + j]*id;
+ const float x1 = src[QK4_0/2 + j]*id;
- const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
- const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
- dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
- }
- thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
- for (int j = 0; j < 4; ++j) {
- dst_data[i00/QK5_0].qh[j] = qh8[j];
+ dst_data[i00/QK4_0].qs[j] = xi0;
+ dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
}
}
}
-kernel void kernel_cpy_f32_q5_1(
+kernel void kernel_cpy_f32_q4_1(
device const float * src0,
device void * dst,
constant int64_t & ne00,
@@ -2508,63 +3042,42 @@ kernel void kernel_cpy_f32_q5_1(
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
- device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
- for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
+ for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
- float max = src[0];
- float min = src[0];
+ float min = FLT_MAX;
+ float max = -FLT_MAX;
- for (int j = 1; j < QK5_1; j++) {
+ for (int j = 0; j < QK4_1; j++) {
const float v = src[j];
- min = v < min ? v : min;
- max = v > max ? v : max;
+ if (min > v) min = v;
+ if (max < v) max = v;
}
- const float d = (max - min) / 31;
+ const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
- dst_data[i00/QK5_1].d = d;
- dst_data[i00/QK5_1].m = min;
+ dst_data[i00/QK4_1].d = d;
+ dst_data[i00/QK4_1].m = min;
- uint32_t qh = 0;
- for (int j = 0; j < QK5_1/2; ++j) {
+ for (int j = 0; j < QK4_1/2; ++j) {
const float x0 = (src[0 + j] - min)*id;
- const float x1 = (src[QK5_1/2 + j] - min)*id;
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
- const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
- const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
- dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
- }
- thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
- for (int j = 0; j < 4; ++j) {
- dst_data[i00/QK5_1].qh[j] = qh8[j];
+ dst_data[i00/QK4_1].qs[j] = xi0;
+ dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
}
}
}
-static inline int best_index_int8(int n, constant float * val, float x) {
- if (x <= val[0]) return 0;
- if (x >= val[n-1]) return n-1;
- int ml = 0, mu = n-1;
- while (mu-ml > 1) {
- int mav = (ml+mu)/2;
- if (x < val[mav]) mu = mav; else ml = mav;
- }
- return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
-}
-
-constexpr constant static float kvalues_iq4nl_f[16] = {
- -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
-};
-
-kernel void kernel_cpy_f32_iq4_nl(
+kernel void kernel_cpy_f32_q5_0(
device const float * src0,
device void * dst,
constant int64_t & ne00,
@@ -2595,17 +3108,17 @@ kernel void kernel_cpy_f32_iq4_nl(
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
- device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
- for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
+ for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
float amax = 0.0f; // absolute max
float max = 0.0f;
- for (int j = 0; j < QK4_0; j++) {
+ for (int j = 0; j < QK5_0; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
@@ -2613,16 +3126,176 @@ kernel void kernel_cpy_f32_iq4_nl(
}
}
- const float d = max / kvalues_iq4nl_f[0];
+ const float d = max / -16;
const float id = d ? 1.0f/d : 0.0f;
- float sumqx = 0, sumq2 = 0;
- for (int j = 0; j < QK4_NL/2; ++j) {
- const float x0 = src[0 + j]*id;
- const float x1 = src[QK4_NL/2 + j]*id;
+ dst_data[i00/QK5_0].d = d;
- const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
- const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
+ uint32_t qh = 0;
+ for (int j = 0; j < QK5_0/2; ++j) {
+ const float x0 = src[0 + j]*id;
+ const float x1 = src[QK5_0/2 + j]*id;
+
+ const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
+ const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
+
+ dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+ }
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+ for (int j = 0; j < 4; ++j) {
+ dst_data[i00/QK5_0].qh[j] = qh8[j];
+ }
+ }
+}
+
+kernel void kernel_cpy_f32_q5_1(
+ device const float * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
+
+ device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ float max = src[0];
+ float min = src[0];
+
+ for (int j = 1; j < QK5_1; j++) {
+ const float v = src[j];
+ min = v < min ? v : min;
+ max = v > max ? v : max;
+ }
+
+ const float d = (max - min) / 31;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst_data[i00/QK5_1].d = d;
+ dst_data[i00/QK5_1].m = min;
+
+ uint32_t qh = 0;
+ for (int j = 0; j < QK5_1/2; ++j) {
+ const float x0 = (src[0 + j] - min)*id;
+ const float x1 = (src[QK5_1/2 + j] - min)*id;
+
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+ dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
+ }
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+ for (int j = 0; j < 4; ++j) {
+ dst_data[i00/QK5_1].qh[j] = qh8[j];
+ }
+ }
+}
+
+static inline int best_index_int8(int n, constant float * val, float x) {
+ if (x <= val[0]) return 0;
+ if (x >= val[n-1]) return n-1;
+ int ml = 0, mu = n-1;
+ while (mu-ml > 1) {
+ int mav = (ml+mu)/2;
+ if (x < val[mav]) mu = mav; else ml = mav;
+ }
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
+constexpr constant static float kvalues_iq4nl_f[16] = {
+ -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
+};
+
+kernel void kernel_cpy_f32_iq4_nl(
+ device const float * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
+
+ device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < QK4_0; j++) {
+ const float v = src[j];
+ if (amax < fabs(v)) {
+ amax = fabs(v);
+ max = v;
+ }
+ }
+
+ const float d = max / kvalues_iq4nl_f[0];
+ const float id = d ? 1.0f/d : 0.0f;
+
+ float sumqx = 0, sumq2 = 0;
+ for (int j = 0; j < QK4_NL/2; ++j) {
+ const float x0 = src[0 + j]*id;
+ const float x1 = src[QK4_NL/2 + j]*id;
+
+ const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
+ const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
@@ -2700,18 +3373,19 @@ void kernel_mul_mv_q2_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
@@ -2871,7 +3545,7 @@ kernel void kernel_mul_mv_q2_K_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
#if QK_K == 256
@@ -2879,18 +3553,19 @@ void kernel_mul_mv_q3_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
@@ -3046,6 +3721,7 @@ void kernel_mul_mv_q3_K_f32_impl(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3135,7 +3811,7 @@ kernel void kernel_mul_mv_q3_K_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
#if QK_K == 256
@@ -3143,18 +3819,19 @@ void kernel_mul_mv_q4_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
@@ -3265,6 +3942,7 @@ void kernel_mul_mv_q4_K_f32_impl(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3373,25 +4051,26 @@ kernel void kernel_mul_mv_q4_K_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_q5_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
@@ -3579,25 +4258,26 @@ kernel void kernel_mul_mv_q5_K_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_q6_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const uint8_t kmask1 = 0x03;
const uint8_t kmask2 = 0x0C;
@@ -3713,7 +4393,7 @@ kernel void kernel_mul_mv_q6_K_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
// ======================= "True" 2-bit
@@ -3722,19 +4402,19 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
@@ -3851,19 +4531,19 @@ void kernel_mul_mv_iq2_xs_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
@@ -3990,19 +4670,19 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
@@ -4122,19 +4802,19 @@ void kernel_mul_mv_iq3_s_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
@@ -4254,19 +4934,19 @@ void kernel_mul_mv_iq2_s_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
@@ -4387,18 +5067,19 @@ void kernel_mul_mv_iq1_s_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_value,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
@@ -4476,18 +5157,19 @@ void kernel_mul_mv_iq1_m_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_value,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
@@ -4584,20 +5266,21 @@ void kernel_mul_mv_iq4_nl_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup float * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values_i8,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+ threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
const int nb = ne00/QK4_NL;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
@@ -4678,20 +5361,21 @@ void kernel_mul_mv_iq4_xs_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup float * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values_i8,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+ threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
@@ -4794,7 +5478,7 @@ kernel void kernel_mul_mv_iq1_s_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
[[host_name("kernel_mul_mv_iq1_m_f32")]]
@@ -4822,7 +5506,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
@@ -4846,7 +5530,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
- threadgroup float * shared_values [[threadgroup(0)]],
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -4875,7 +5559,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
- threadgroup float * shared_values [[threadgroup(0)]],
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -5632,25 +6316,25 @@ void kernel_mul_mm_impl(device const uchar * src0,
}
}
-// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
+// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
template
void kernel_mul_mm_id_impl(
device const uchar * src0,
device const uchar * src1,
- threadgroup short * src1ids,
+ threadgroup ushort2 * rowids,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
int64_t ne1,
- constant uint & r2,
- constant uint & r3,
+ int64_t ne0ne1,
threadgroup uchar * shared_memory,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
@@ -5661,7 +6345,6 @@ void kernel_mul_mm_id_impl(
const uint r0 = tgpig.y;
const uint r1 = tgpig.x;
- const uint im = tgpig.z;
if (r1 * BLOCK_SIZE_N >= ne1) return;
@@ -5679,19 +6362,16 @@ void kernel_mul_mm_id_impl(
for (int i = 0; i < 8; i++){
c_res[i] = make_filled_simdgroup_matrix(0.f);
}
-
short il = (tiitg % THREAD_PER_ROW);
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
ushort offset1 = il/nl;
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
+ threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
+
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
device const float * y = (device const float *)(src1
- + nb12 * im
- + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
+ + nb12 * id[1]
+ + nb11 * (id[0] % ne11)
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
@@ -5720,11 +6400,11 @@ void kernel_mul_mm_id_impl(
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
for (int i = 0; i < 4; i++) {
- simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
+ simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
}
simdgroup_barrier(mem_flags::mem_none);
for (int i = 0; i < 2; i++) {
- simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
+ simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
}
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
@@ -5746,11 +6426,13 @@ void kernel_mul_mm_id_impl(
threadgroup_barrier(mem_flags::mem_threadgroup);
- device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
+ device float * C = dst + (BLOCK_SIZE_M * r0);
if (sgitg == 0) {
- for (int i = 0; i < n_rows; i++) {
- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
- *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
+ threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
+ int joff = jid[0] * ne0 + jid[1] * ne0ne1;
+ for (int i = 0; i < n_rows; i++) {
+ *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
}
}
}
@@ -5805,11 +6487,14 @@ kernel void kernel_mul_mm_id(
device const uchar * src1,
device float * dst,
device const uchar * ids,
+ constant int64_t & nei0,
+ constant int64_t & nei1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne02,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
@@ -5818,47 +6503,52 @@ kernel void kernel_mul_mm_id(
constant int64_t & ne0,
constant int64_t & ne1,
constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- // expert id
- const int32_t id = tgpig.z/(ne12*ne13);
- device const uchar * src0 = src0s + id*nb02;
+ const int32_t i02 = tgpig.z;
+ tgpig.z = 0;
- tgpig.z = tgpig.z%(ne12*ne13);
+ device const uchar * src0 = src0s + i02*nb02;
- // row indices of src1 for expert id
- threadgroup short * src1ids = (threadgroup short *)(shared_memory + 8192);
+ // row indices
+ threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
+ // TODO: parallelize this loop
int64_t _ne1 = 0;
- for (int64_t i1 = 0; i1 < ne1; i1++) {
- if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
- src1ids[_ne1++] = i1;
+ for (ushort ii1 = 0; ii1 < nei1; ii1++) {
+ for (ushort ii0 = 0; ii0 < nei0; ii0++) {
+ int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
+ if (id == i02) {
+ //if (tiitg == 0) {
+ rowids[_ne1] = ushort2(ii0, ii1);
+ //}
+ _ne1++;
+ }
}
}
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
kernel_mul_mm_id_impl(
src0,
src1,
- src1ids,
+ rowids,
dst,
ne00,
ne02,
nb01,
nb02,
+ ne11,
ne12,
nb10,
nb11,
nb12,
ne0,
_ne1,
- r2,
- r3,
+ ne0*ne1,
shared_memory,
tgpig,
tiitg,
@@ -5919,24 +6609,7 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_r
// matrix-matrix multiplication
//
-typedef void (mat_mm_t)(
- device const uchar * src0,
- device const uchar * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne02,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup uchar *,
- uint3, uint, uint);
+typedef decltype(kernel_mul_mm) mat_mm_t;
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm;
@@ -5968,29 +6641,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m
// indirect matrix-matrix multiplication
//
-typedef void (mat_mm_id_t)(
- device const uchar * src0s,
- device const uchar * src1,
- device float * dst,
- device const uchar * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne02,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- threadgroup uchar *,
- uint3, uint, uint);
+typedef decltype(kernel_mul_mm_id) mat_mm_id_t;
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
@@ -6022,244 +6673,119 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
// matrix-vector multiplication
//
-[[host_name("kernel_mul_mv_id_f32_f32")]]
-kernel void kernel_mul_mv_id_f32_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_f32_f32_impl(
- src0,
- src1 + bid*nb11,
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- nb00,
- nb01,
- nb02,
- ne10,
- ne11,
- ne12,
- nb10,
- nb11,
- nb12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg);
-}
-
-[[host_name("kernel_mul_mv_id_f16_f32")]]
-kernel void kernel_mul_mv_id_f16_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_f16_f32_impl(
- src0,
- src1 + bid*nb11,
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- nb00,
- nb01,
- nb02,
- ne10,
- ne11,
- ne12,
- nb10,
- nb11,
- nb12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg);
-}
+typedef void (kernel_mul_mv_impl_t)(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ uint64_t nb00,
+ uint64_t nb01,
+ uint64_t nb02,
+ int64_t ne10,
+ int64_t ne11,
+ int64_t ne12,
+ uint64_t nb10,
+ uint64_t nb11,
+ uint64_t nb12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ uint3 tgpig,
+ uint tiisg);
-[[host_name("kernel_mul_mv_id_q8_0_f32")]]
-kernel void kernel_mul_mv_id_q8_0_f32(
- device const char * src0s,
+typedef void (kernel_mul_mv2_impl_t)(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg);
+
+template
+void mmv_fn(
+ device const char * src0,
device const char * src1,
device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_q8_0_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q4_0_f32")]]
-kernel void kernel_mul_mv_id_q4_0_f32(
- device const char * src0s,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ uint64_t nb00,
+ uint64_t nb01,
+ uint64_t nb02,
+ int64_t ne10,
+ int64_t ne11,
+ int64_t ne12,
+ int64_t ne13,
+ uint64_t nb10,
+ uint64_t nb11,
+ uint64_t nb12,
+ int64_t ne0,
+ int64_t ne1,
+ uint64_t nb1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiitg,
+ uint tiisg,
+ uint sgitg) {
+ impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
+}
+
+template
+void mmv_fn(
+ device const char * src0,
device const char * src1,
device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- mul_vec_q_n_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q4_1_f32")]]
-kernel void kernel_mul_mv_id_q4_1_f32(
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ uint64_t nb00,
+ uint64_t nb01,
+ uint64_t nb02,
+ int64_t ne10,
+ int64_t ne11,
+ int64_t ne12,
+ int64_t ne13,
+ uint64_t nb10,
+ uint64_t nb11,
+ uint64_t nb12,
+ int64_t ne0,
+ int64_t ne1,
+ uint64_t nb1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiitg,
+ uint tiisg,
+ uint sgitg) {
+ impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
+}
+
+typedef decltype(mmv_fn) mul_mv_impl_fn_t;
+
+template
+kernel void kernel_mul_mv_id(
device const char * src0s,
device const char * src1,
device float * dst,
device const char * ids,
+ constant int64_t & nei0,
+ constant int64_t & nei1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -6277,932 +6803,80 @@ kernel void kernel_mul_mv_id_q4_1_f32(
constant int64_t & ne0,
constant int64_t & ne1,
constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- mul_vec_q_n_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
+ const int iid1 = tgpig.z/nei0;
+ const int idx = tgpig.z%nei0;
+
+ tgpig.z = 0;
+
+ const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx];
+
+ const int64_t i11 = idx % ne11;
+ const int64_t i12 = iid1;
+
+ const int64_t i1 = idx;
+ const int64_t i2 = i12;
+
+ device const char * src0_cur = src0s + i02*nb02;
+ device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
+ device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0;
+
+ impl_fn(
+ /* src0 */ src0_cur,
+ /* src1 */ src1_cur,
+ /* dst */ dst_cur,
+ /* ne00 */ ne00,
+ /* ne01 */ ne01,
+ /* ne02 */ 1,//ne02,
+ /* nb00 */ nb00,
+ /* nb01 */ nb01,
+ /* nb02 */ nb02,
+ /* ne10 */ ne10,
+ /* ne11 */ 1,//ne11,
+ /* ne12 */ 1,//ne12,
+ /* ne13 */ 1,//ne13,
+ /* nb10 */ nb10,
+ /* nb11 */ nb11,
+ /* nb12 */ nb12,
+ /* ne0 */ ne0,
+ /* ne1 */ 1,//ne1,
+ /* nb1 */ nb1,
+ /* r2 */ 1,
+ /* r3 */ 1,
+ shared_values,
tgpig,
+ tiitg,
tiisg,
sgitg);
}
-[[host_name("kernel_mul_mv_id_q5_0_f32")]]
-kernel void kernel_mul_mv_id_q5_0_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- mul_vec_q_n_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
+typedef decltype(kernel_mul_mv_id>) kernel_mul_mv_id_t;
+
+template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+#if QK_K != 64
+template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+#endif
-[[host_name("kernel_mul_mv_id_q5_1_f32")]]
-kernel void kernel_mul_mv_id_q5_1_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- mul_vec_q_n_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q2_K_f32")]]
-kernel void kernel_mul_mv_id_q2_K_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_q2_K_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q3_K_f32")]]
-kernel void kernel_mul_mv_id_q3_K_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_q3_K_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q4_K_f32")]]
-kernel void kernel_mul_mv_id_q4_K_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_q4_K_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q5_K_f32")]]
-kernel void kernel_mul_mv_id_q5_K_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_q5_K_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q6_K_f32")]]
-kernel void kernel_mul_mv_id_q6_K_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_q6_K_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
-kernel void kernel_mul_mv_id_iq2_xxs_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq2_xxs_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
-kernel void kernel_mul_mv_id_iq2_xs_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq2_xs_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
-kernel void kernel_mul_mv_id_iq3_xxs_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq3_xxs_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
-kernel void kernel_mul_mv_id_iq3_s_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq3_s_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
-kernel void kernel_mul_mv_id_iq2_s_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq2_s_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
-kernel void kernel_mul_mv_id_iq1_s_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq1_s_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
-kernel void kernel_mul_mv_id_iq1_m_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq1_m_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
-kernel void kernel_mul_mv_id_iq4_nl_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- threadgroup float * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq4_nl_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
-kernel void kernel_mul_mv_id_iq4_xs_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- threadgroup float * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
-#if QK_K == 64
- kernel_mul_mv_iq4_nl_f32_impl(
-#else
- kernel_mul_mv_iq4_xs_f32_impl(
-#endif
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
-}
diff --git a/LLama/runtimes/deps/osx-arm64/libllama.dylib b/LLama/runtimes/deps/osx-arm64/libllama.dylib
index b474697df..61bc4fae2 100644
Binary files a/LLama/runtimes/deps/osx-arm64/libllama.dylib and b/LLama/runtimes/deps/osx-arm64/libllama.dylib differ
diff --git a/LLama/runtimes/deps/osx-arm64/libllava_shared.dylib b/LLama/runtimes/deps/osx-arm64/libllava_shared.dylib
index c47c71ed0..404833398 100644
Binary files a/LLama/runtimes/deps/osx-arm64/libllava_shared.dylib and b/LLama/runtimes/deps/osx-arm64/libllava_shared.dylib differ
diff --git a/LLama/runtimes/deps/osx-x64/libllama.dylib b/LLama/runtimes/deps/osx-x64/libllama.dylib
index bbd84c9c2..c803fd588 100644
Binary files a/LLama/runtimes/deps/osx-x64/libllama.dylib and b/LLama/runtimes/deps/osx-x64/libllama.dylib differ
diff --git a/LLama/runtimes/deps/osx-x64/libllava_shared.dylib b/LLama/runtimes/deps/osx-x64/libllava_shared.dylib
index c6b265ff4..922d4cd73 100644
Binary files a/LLama/runtimes/deps/osx-x64/libllava_shared.dylib and b/LLama/runtimes/deps/osx-x64/libllava_shared.dylib differ
diff --git a/README.md b/README.md
index 8be7314b9..f1b081456 100644
--- a/README.md
+++ b/README.md
@@ -243,6 +243,7 @@ If you want to compile llama.cpp yourself you **must** use the exact commit ID l
| v0.9.0, v0.9.1 | [Mixtral-8x7B](https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF) | [`9fb13f9`](https://github.com/ggerganov/llama.cpp/blob/9fb13f95840c722ad419f390dc8a9c86080a3700) |
| v0.10.0 | [Phi2](https://huggingface.co/TheBloke/phi-2-GGUF) | [`d71ac90`](https://github.com/ggerganov/llama.cpp/tree/d71ac90985854b0905e1abba778e407e17f9f887) |
| v0.11.1, v0.11.2 | [LLaVA-v1.5](https://hf-mirror.com/jartine/llava-v1.5-7B-GGUF/blob/main/llava-v1.5-7b-mmproj-Q4_0.gguf), [Phi2](https://huggingface.co/TheBloke/phi-2-GGUF)| [`3ab8b3a`](https://github.com/ggerganov/llama.cpp/tree/3ab8b3a92ede46df88bc5a2dfca3777de4a2b2b6) |
+| v0.12.0 | LLama3 | [`a743d76`](https://github.com/ggerganov/llama.cpp/tree/a743d76a01f23038b2c85af1e9048ee836767b44)
## License
diff --git a/llama.cpp b/llama.cpp
index f7001ccc5..a743d76a0 160000
--- a/llama.cpp
+++ b/llama.cpp
@@ -1 +1 @@
-Subproject commit f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7
+Subproject commit a743d76a01f23038b2c85af1e9048ee836767b44