diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs index 7f1c94960..7d774f46f 100644 --- a/LLama.Unittest/LLamaContextTests.cs +++ b/LLama.Unittest/LLamaContextTests.cs @@ -41,6 +41,14 @@ public void Tokenize() Assert.Equal(new LLamaToken[] { 1, 450, 4996, 17354, 1701, 29916 }, tokens); } + [Fact] + public void TokenizeNewline() + { + var tokens = _context.Tokenize("\n"); + + Assert.Equal(new LLamaToken[] { 1, 29871, 13 }, tokens); + } + [Fact] public void TokenizeWithoutBOS() { diff --git a/LLama/Native/LLamaToken.cs b/LLama/Native/LLamaToken.cs index 0bc485856..128d9f58f 100644 --- a/LLama/Native/LLamaToken.cs +++ b/LLama/Native/LLamaToken.cs @@ -1,4 +1,5 @@ -using System.Runtime.InteropServices; +using System.Diagnostics; +using System.Runtime.InteropServices; namespace LLama.Native; @@ -6,6 +7,7 @@ namespace LLama.Native; /// A single token /// [StructLayout(LayoutKind.Sequential)] +[DebuggerDisplay("Value")] public readonly record struct LLamaToken { /// @@ -35,4 +37,10 @@ private LLamaToken(int value) /// /// public static implicit operator LLamaToken(int value) => new(value); + + /// + public override string ToString() + { + return Value.ToString(); + } } \ No newline at end of file diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 17fa13cf5..3f303123b 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -155,37 +155,7 @@ public Span GetLogitsIth(int i) /// public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) { - ThrowIfDisposed(); - - if (string.IsNullOrEmpty(text) && !add_bos) - return Array.Empty(); - - // Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't - // possibly be more than this. - var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0); - - // "Rent" an array to write results into (avoiding an allocation of a large array) - var temporaryArray = ArrayPool.Shared.Rent(count); - try - { - // Do the actual conversion - var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos, special); - if (n < 0) - { - throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " + - "specify the encoding."); - } - - // Copy the results from the rented into an array which is exactly the right size - var result = new LLamaToken[n]; - Array.ConstrainedCopy(temporaryArray, 0, result, 0, n); - - return result; - } - finally - { - ArrayPool.Shared.Return(temporaryArray); - } + return ThrowIfDisposed().Tokenize(text, add_bos, special, encoding); } /// diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 47ffb4dd6..8ffa2be31 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -1,4 +1,5 @@ using System; +using System.Buffers; using System.Collections.Generic; using System.Diagnostics; using System.Runtime.InteropServices; @@ -172,34 +173,40 @@ internal Span TokensToSpan(IReadOnlyList tokens, Span de /// public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) { + // Early exit if there's no work to do + if (text == "" && !add_bos) + return Array.Empty(); + // Convert string to bytes, adding one extra byte to the end (null terminator) var bytesCount = encoding.GetByteCount(text); - var bytes = new byte[bytesCount + 1]; - unsafe + var bytes = ArrayPool.Shared.Rent(bytesCount + 1); + try { - fixed (char* charPtr = text) - fixed (byte* bytePtr = &bytes[0]) + unsafe { - encoding.GetBytes(charPtr, text.Length, bytePtr, bytes.Length); - } - } - - unsafe - { - fixed (byte* bytesPtr = &bytes[0]) - { - // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) - var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (LLamaToken*)IntPtr.Zero, 0, add_bos, special); - - // Tokenize again, this time outputting into an array of exactly the right size - var tokens = new LLamaToken[count]; - fixed (LLamaToken* tokensPtr = &tokens[0]) + fixed (char* textPtr = text) + fixed (byte* bytesPtr = bytes) { - NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special); - return tokens; + // Convert text into bytes + encoding.GetBytes(textPtr, text.Length, bytesPtr, bytes.Length); + + // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) + var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (LLamaToken*)IntPtr.Zero, 0, add_bos, special); + + // Tokenize again, this time outputting into an array of exactly the right size + var tokens = new LLamaToken[count]; + fixed (LLamaToken* tokensPtr = tokens) + { + NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special); + return tokens; + } } } } + finally + { + ArrayPool.Shared.Return(bytes, true); + } } #endregion