diff --git a/LLama.Unittest/BeamTests.cs b/LLama.Unittest/BeamTests.cs index 1bbeae9c5..be21a5f2f 100644 --- a/LLama.Unittest/BeamTests.cs +++ b/LLama.Unittest/BeamTests.cs @@ -40,7 +40,8 @@ public void BasicBeam() var initial_tokens = context.Tokenize(prompt); result.Append(prompt); - context.Eval(initial_tokens.AsSpan(), 0); + //context.Eval(initial_tokens.AsSpan(), 0); + throw new NotImplementedException("Replace Eval"); NativeApi.llama_beam_search(context.NativeHandle, (data, state) => { diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 44531a9f7..d8b418c31 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -370,67 +370,6 @@ public Task DecodeAsync(LLamaBatch batch, CancellationToken cancel { return Task.Run(() => Decode(batch), cancellationToken); } - - /// - /// - /// - /// - /// - /// The updated `pastTokensCount`. - /// - [Obsolete("use Decode() instead")] - public int Eval(List tokens, int pastTokensCount) - { -#if NET5_0_OR_GREATER - var span = CollectionsMarshal.AsSpan(tokens); - return Eval(span, pastTokensCount); -#else - // on netstandard2.0 we can't use CollectionsMarshal to get directly at the internal memory of - // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't - // avoid the copying. - - var rented = System.Buffers.ArrayPool.Shared.Rent(tokens.Count); - try - { - tokens.CopyTo(rented, 0); - return Eval(rented.AsSpan(0, tokens.Count), pastTokensCount); - } - finally - { - System.Buffers.ArrayPool.Shared.Return(rented); - } -#endif - } - - /// - /// - /// - /// - /// - /// The updated `pastTokensCount`. - /// - [Obsolete("use Decode() instead")] - public int Eval(ReadOnlySpan tokens, int pastTokensCount) - { - var total = tokens.Length; - for(var i = 0; i < total; i += (int)Params.BatchSize) - { - var n_eval = total - i; - if (n_eval > Params.BatchSize) - { - n_eval = (int)Params.BatchSize; - } - - if (!NativeHandle.Eval(tokens.Slice(i, n_eval), pastTokensCount)) - { - _logger?.LogError("[LLamaContext] Failed to eval."); - throw new RuntimeError("Failed to eval."); - } - - pastTokensCount += n_eval; - } - return pastTokensCount; - } #endregion /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 31e975caf..9476976e9 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -8,6 +8,7 @@ using System.Text.Json; using System.Text.Json.Serialization; using System.Threading.Tasks; +using LLama.Exceptions; using LLama.Extensions; using Microsoft.Extensions.Logging; @@ -178,6 +179,8 @@ protected override Task PreprocessInputs(string text, InferStateArgs args) /// protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) { + var batch = new LLamaBatch(); + if (_embeds.Count > 0) { _is_prompt_run = false; @@ -187,7 +190,10 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta } TryReuseMathingPrefix(); - _pastTokensCount = Context.Eval(_embeds, _pastTokensCount); + + var (result, _) = Context.NativeHandle.Decode(_embeds, LLamaSeqId.Zero, batch, ref _pastTokensCount); + if (result != DecodeResult.Ok) + throw new LLamaDecodeError(result); if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) { @@ -212,12 +218,12 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta LLamaToken id; if (inferenceParams.SamplingPipeline is not null) { - id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); + id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(batch.TokenCount - 1), _last_n_tokens.ToArray()); inferenceParams.SamplingPipeline.Accept(Context.NativeHandle, id); } else { - var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + var tokenDataArray = Context.ApplyPenalty(batch.TokenCount - 1, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); var mu = MirostatMu; diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index bd36f6128..79f1b8cc4 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -8,6 +8,7 @@ using System.Text.Json; using System.Text.Json.Serialization; using System.Threading.Tasks; +using LLama.Exceptions; using LLama.Extensions; using Microsoft.Extensions.Logging; @@ -157,6 +158,8 @@ protected override Task PreprocessInputs(string text, InferStateArgs args) /// protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) { + var batch = new LLamaBatch(); + if (_embeds.Count > 0) { _is_prompt_run = false; @@ -166,7 +169,10 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta } TryReuseMathingPrefix(); - _pastTokensCount = Context.Eval(_embeds, _pastTokensCount); + + var (result, _) = Context.NativeHandle.Decode(_embeds, LLamaSeqId.Zero, batch, ref _pastTokensCount); + if (result != DecodeResult.Ok) + throw new LLamaDecodeError(result); if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) { @@ -191,12 +197,12 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta LLamaToken id; if (inferenceParams.SamplingPipeline is not null) { - id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); + id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(batch.TokenCount - 1), _last_n_tokens.ToArray()); inferenceParams.SamplingPipeline.Accept(Context.NativeHandle, id); } else { - var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + var tokenDataArray = Context.ApplyPenalty(batch.TokenCount - 1, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); var mu = MirostatMu; diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index b18f36c14..14884f858 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -81,21 +81,9 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams // Evaluate the prompt, in chunks smaller than the max batch size var n_past = 0; - var batchSize = (int)Context.Params.BatchSize; - for (var i = 0; i < tokens.Count; i += batchSize) - { - var n_eval = tokens.Count - i; - if (n_eval > batchSize) - n_eval = batchSize; - - _batch.Clear(); - for (var j = 0; j < n_eval; j++) - _batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, (i + j) == tokens.Count - 1); - - var returnCode = await Context.DecodeAsync(_batch, cancellationToken); - if (returnCode != 0) - throw new LLamaDecodeError(returnCode); - } + var (r, _) = Context.NativeHandle.Decode(tokens, LLamaSeqId.Zero, _batch, ref n_past); + if (r != DecodeResult.Ok) + throw new LLamaDecodeError(r); // Begin loop, evaluating one token at a time var mu = (float?)null; diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 902808f64..61172c188 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -141,20 +141,6 @@ public static void llama_empty_call() [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern bool llama_save_session_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens, ulong n_token_count); - /// - /// Run the llama inference to obtain the logits and probabilities for the next token. - /// tokens + n_tokens is the provided batch of new tokens to process - /// n_past is the number of tokens to use from previous eval calls - /// - /// - /// - /// - /// - /// Returns 0 on success - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - [Obsolete("use llama_decode() instead")] - public static extern unsafe int llama_eval(SafeLLamaContextHandle ctx, LLamaToken* tokens, int n_tokens, int n_past); - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, LLamaToken token); @@ -181,7 +167,7 @@ public static void llama_empty_call() public static extern uint llama_n_batch(SafeLLamaContextHandle ctx); /// - /// Token logits obtained from the last call to llama_eval() + /// Token logits obtained from the last call to llama_decode /// The logits for the last token are stored in the last row /// Can be mutated in order to change the probabilities of the next token.
/// Rows: n_tokens
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 2d9387ae4..b1f1e4b70 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -1,6 +1,8 @@ using System; +using System.Collections.Generic; using System.Runtime.InteropServices; using System.Text; +using System.Threading; using LLama.Exceptions; namespace LLama.Native @@ -28,6 +30,11 @@ public sealed class SafeLLamaContextHandle ///
public int EmbeddingSize => ThrowIfDisposed().EmbeddingSize; + /// + /// Get the maximum batch size for this context + /// + public uint BatchSize => NativeApi.llama_n_batch(this); + /// /// Get the model which this context is using /// @@ -108,7 +115,7 @@ static SafeLLamaContextHandle() #endregion /// - /// Token logits obtained from the last call to llama_eval() + /// Token logits obtained from the last call to llama_decode /// The logits for the last token are stored in the last row /// Can be mutated in order to change the probabilities of the next token.
/// Rows: n_tokens
@@ -170,26 +177,6 @@ public uint TokenToSpan(LLamaToken token, Span dest) #endregion #region infer - /// - /// Run the llama inference to obtain the logits and probabilities for the next token. - /// - /// The provided batch of new tokens to process - /// the number of tokens to use from previous eval calls - /// Returns true on success - [Obsolete("use llama_decode() instead")] - public bool Eval(ReadOnlySpan tokens, int n_past) - { - unsafe - { - fixed (LLamaToken* pinned = tokens) - { - // the entire `eval` system needs replacing with the new batch system! - var ret = NativeApi.llama_eval(this, pinned, tokens.Length, n_past); - return ret == 0; - } - } - } - /// /// /// @@ -198,10 +185,44 @@ public bool Eval(ReadOnlySpan tokens, int n_past) /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
/// - < 0: error
/// - public int Decode(LLamaBatch batch) + public DecodeResult Decode(LLamaBatch batch) { using (batch.ToNativeBatch(out var nb)) - return NativeApi.llama_decode(this, nb); + return (DecodeResult)NativeApi.llama_decode(this, nb); + } + + /// + /// Decode a set of tokens in batch-size chunks. + /// + /// + /// + /// + /// + /// A tuple, containing the decode result and the number of tokens that have not been decoded yet. + internal (DecodeResult, int) Decode(List tokens, LLamaSeqId id, LLamaBatch batch, ref int n_past) + { + var batchSize = checked((int)BatchSize); + + // Evaluate the prompt, in chunks smaller than the max batch size + var n_left = tokens.Count; + for (var i = 0; i < tokens.Count; i += batchSize) + { + var n_eval = tokens.Count - i; + if (n_eval > batchSize) + n_eval = batchSize; + + batch.Clear(); + for (var j = 0; j < n_eval; j++) + batch.Add(tokens[i + j], n_past++, id, (i + j) == tokens.Count - 1); + + var returnCode = Decode(batch); + if (returnCode != DecodeResult.Ok) + return (returnCode, n_left); + + n_left -= n_eval; + } + + return (DecodeResult.Ok, 0); } #endregion