Skip to content

Commit

Permalink
Removed llama_eval. It is going to be completely removed in the nex…
Browse files Browse the repository at this point in the history
…t version of llama.cpp (#553)
  • Loading branch information
martindevans authored Feb 28, 2024
1 parent f0e7e7c commit 8ac1634
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 121 deletions.
3 changes: 2 additions & 1 deletion LLama.Unittest/BeamTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public void BasicBeam()

var initial_tokens = context.Tokenize(prompt);
result.Append(prompt);
context.Eval(initial_tokens.AsSpan(), 0);
//context.Eval(initial_tokens.AsSpan(), 0);
throw new NotImplementedException("Replace Eval");

NativeApi.llama_beam_search(context.NativeHandle, (data, state) =>
{
Expand Down
61 changes: 0 additions & 61 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -370,67 +370,6 @@ public Task<DecodeResult> DecodeAsync(LLamaBatch batch, CancellationToken cancel
{
return Task.Run(() => Decode(batch), cancellationToken);
}

/// <summary>
///
/// </summary>
/// <param name="tokens"></param>
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use Decode() instead")]
public int Eval(List<LLamaToken> tokens, int pastTokensCount)
{
#if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(tokens);
return Eval(span, pastTokensCount);
#else
// on netstandard2.0 we can't use CollectionsMarshal to get directly at the internal memory of
// the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't
// avoid the copying.

var rented = System.Buffers.ArrayPool<LLamaToken>.Shared.Rent(tokens.Count);
try
{
tokens.CopyTo(rented, 0);
return Eval(rented.AsSpan(0, tokens.Count), pastTokensCount);
}
finally
{
System.Buffers.ArrayPool<LLamaToken>.Shared.Return(rented);
}
#endif
}

/// <summary>
///
/// </summary>
/// <param name="tokens"></param>
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use Decode() instead")]
public int Eval(ReadOnlySpan<LLamaToken> tokens, int pastTokensCount)
{
var total = tokens.Length;
for(var i = 0; i < total; i += (int)Params.BatchSize)
{
var n_eval = total - i;
if (n_eval > Params.BatchSize)
{
n_eval = (int)Params.BatchSize;
}

if (!NativeHandle.Eval(tokens.Slice(i, n_eval), pastTokensCount))
{
_logger?.LogError("[LLamaContext] Failed to eval.");
throw new RuntimeError("Failed to eval.");
}

pastTokensCount += n_eval;
}
return pastTokensCount;
}
#endregion

/// <inheritdoc />
Expand Down
12 changes: 9 additions & 3 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
using LLama.Exceptions;
using LLama.Extensions;
using Microsoft.Extensions.Logging;

Expand Down Expand Up @@ -178,6 +179,8 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
/// <inheritdoc />
protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
{
var batch = new LLamaBatch();

if (_embeds.Count > 0)
{
_is_prompt_run = false;
Expand All @@ -187,7 +190,10 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
}

TryReuseMathingPrefix();
_pastTokensCount = Context.Eval(_embeds, _pastTokensCount);

var (result, _) = Context.NativeHandle.Decode(_embeds, LLamaSeqId.Zero, batch, ref _pastTokensCount);
if (result != DecodeResult.Ok)
throw new LLamaDecodeError(result);

if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
{
Expand All @@ -212,12 +218,12 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
LLamaToken id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(batch.TokenCount - 1), _last_n_tokens.ToArray());
inferenceParams.SamplingPipeline.Accept(Context.NativeHandle, id);
}
else
{
var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
var tokenDataArray = Context.ApplyPenalty(batch.TokenCount - 1, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
Expand Down
12 changes: 9 additions & 3 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
using LLama.Exceptions;
using LLama.Extensions;
using Microsoft.Extensions.Logging;

Expand Down Expand Up @@ -157,6 +158,8 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
/// <inheritdoc />
protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
{
var batch = new LLamaBatch();

if (_embeds.Count > 0)
{
_is_prompt_run = false;
Expand All @@ -166,7 +169,10 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
}

TryReuseMathingPrefix();
_pastTokensCount = Context.Eval(_embeds, _pastTokensCount);

var (result, _) = Context.NativeHandle.Decode(_embeds, LLamaSeqId.Zero, batch, ref _pastTokensCount);
if (result != DecodeResult.Ok)
throw new LLamaDecodeError(result);

if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
{
Expand All @@ -191,12 +197,12 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
LLamaToken id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(batch.TokenCount - 1), _last_n_tokens.ToArray());
inferenceParams.SamplingPipeline.Accept(Context.NativeHandle, id);
}
else
{
var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
var tokenDataArray = Context.ApplyPenalty(batch.TokenCount - 1, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
Expand Down
18 changes: 3 additions & 15 deletions LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,9 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams

// Evaluate the prompt, in chunks smaller than the max batch size
var n_past = 0;
var batchSize = (int)Context.Params.BatchSize;
for (var i = 0; i < tokens.Count; i += batchSize)
{
var n_eval = tokens.Count - i;
if (n_eval > batchSize)
n_eval = batchSize;

_batch.Clear();
for (var j = 0; j < n_eval; j++)
_batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, (i + j) == tokens.Count - 1);

var returnCode = await Context.DecodeAsync(_batch, cancellationToken);
if (returnCode != 0)
throw new LLamaDecodeError(returnCode);
}
var (r, _) = Context.NativeHandle.Decode(tokens, LLamaSeqId.Zero, _batch, ref n_past);
if (r != DecodeResult.Ok)
throw new LLamaDecodeError(r);

// Begin loop, evaluating one token at a time
var mu = (float?)null;
Expand Down
16 changes: 1 addition & 15 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -141,20 +141,6 @@ public static void llama_empty_call()
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_save_session_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens, ulong n_token_count);

/// <summary>
/// Run the llama inference to obtain the logits and probabilities for the next token.
/// tokens + n_tokens is the provided batch of new tokens to process
/// n_past is the number of tokens to use from previous eval calls
/// </summary>
/// <param name="ctx"></param>
/// <param name="tokens"></param>
/// <param name="n_tokens"></param>
/// <param name="n_past"></param>
/// <returns>Returns 0 on success</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
[Obsolete("use llama_decode() instead")]
public static extern unsafe int llama_eval(SafeLLamaContextHandle ctx, LLamaToken* tokens, int n_tokens, int n_past);

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, LLamaToken token);

Expand All @@ -181,7 +167,7 @@ public static void llama_empty_call()
public static extern uint llama_n_batch(SafeLLamaContextHandle ctx);

/// <summary>
/// Token logits obtained from the last call to llama_eval()
/// Token logits obtained from the last call to llama_decode
/// The logits for the last token are stored in the last row
/// Can be mutated in order to change the probabilities of the next token.<br />
/// Rows: n_tokens<br />
Expand Down
67 changes: 44 additions & 23 deletions LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using LLama.Exceptions;

namespace LLama.Native
Expand Down Expand Up @@ -28,6 +30,11 @@ public sealed class SafeLLamaContextHandle
/// </summary>
public int EmbeddingSize => ThrowIfDisposed().EmbeddingSize;

/// <summary>
/// Get the maximum batch size for this context
/// </summary>
public uint BatchSize => NativeApi.llama_n_batch(this);

/// <summary>
/// Get the model which this context is using
/// </summary>
Expand Down Expand Up @@ -108,7 +115,7 @@ static SafeLLamaContextHandle()
#endregion

/// <summary>
/// Token logits obtained from the last call to llama_eval()
/// Token logits obtained from the last call to llama_decode
/// The logits for the last token are stored in the last row
/// Can be mutated in order to change the probabilities of the next token.<br />
/// Rows: n_tokens<br />
Expand Down Expand Up @@ -170,26 +177,6 @@ public uint TokenToSpan(LLamaToken token, Span<byte> dest)
#endregion

#region infer
/// <summary>
/// Run the llama inference to obtain the logits and probabilities for the next token.
/// </summary>
/// <param name="tokens">The provided batch of new tokens to process</param>
/// <param name="n_past">the number of tokens to use from previous eval calls</param>
/// <returns>Returns true on success</returns>
[Obsolete("use llama_decode() instead")]
public bool Eval(ReadOnlySpan<LLamaToken> tokens, int n_past)
{
unsafe
{
fixed (LLamaToken* pinned = tokens)
{
// the entire `eval` system needs replacing with the new batch system!
var ret = NativeApi.llama_eval(this, pinned, tokens.Length, n_past);
return ret == 0;
}
}
}

/// <summary>
/// </summary>
/// <param name="batch"></param>
Expand All @@ -198,10 +185,44 @@ public bool Eval(ReadOnlySpan<LLamaToken> tokens, int n_past)
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public int Decode(LLamaBatch batch)
public DecodeResult Decode(LLamaBatch batch)
{
using (batch.ToNativeBatch(out var nb))
return NativeApi.llama_decode(this, nb);
return (DecodeResult)NativeApi.llama_decode(this, nb);
}

/// <summary>
/// Decode a set of tokens in batch-size chunks.
/// </summary>
/// <param name="tokens"></param>
/// <param name="id"></param>
/// <param name="batch"></param>
/// <param name="n_past"></param>
/// <returns>A tuple, containing the decode result and the number of tokens that have <b>not</b> been decoded yet.</returns>
internal (DecodeResult, int) Decode(List<LLamaToken> tokens, LLamaSeqId id, LLamaBatch batch, ref int n_past)
{
var batchSize = checked((int)BatchSize);

// Evaluate the prompt, in chunks smaller than the max batch size
var n_left = tokens.Count;
for (var i = 0; i < tokens.Count; i += batchSize)
{
var n_eval = tokens.Count - i;
if (n_eval > batchSize)
n_eval = batchSize;

batch.Clear();
for (var j = 0; j < n_eval; j++)
batch.Add(tokens[i + j], n_past++, id, (i + j) == tokens.Count - 1);

var returnCode = Decode(batch);
if (returnCode != DecodeResult.Ok)
return (returnCode, n_left);

n_left -= n_eval;
}

return (DecodeResult.Ok, 0);
}
#endregion

Expand Down

0 comments on commit 8ac1634

Please sign in to comment.