Make most Tokenizer abstract methods virtual (#7198)
* Make most Tokenizer abstract methods virtual All of the functionality in all of the methods can be implemented in terms of just a single Decode and EncodeToTokens set of methods. Only those two need to be abstract; everything else that was abstract can instead be virtual and implemented in terms of those. * Address feedback and clean up a few things
This commit is contained in:
Родитель
0c2e82e7d4
Коммит
5b920f9601
|
@ -1,7 +1,7 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Runtime.CompilerServices;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
|
@ -9,7 +9,11 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// </summary>
|
||||
public struct EncodeSettings
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes the <see cref="EncodeSettings"/> instance.
|
||||
/// </summary>
|
||||
public EncodeSettings() { MaxTokenCount = int.MaxValue; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets a value indicating whether to consider the input normalization during encoding.
|
||||
/// </summary>
|
||||
|
|
|
@ -2,10 +2,6 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
|
||||
namespace Microsoft.ML.Tokenizers
|
||||
{
|
||||
/// <summary>
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
|
@ -95,7 +95,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// https://huggingface.co/microsoft/phi-2/resolve/main/vocab.json?download=true
|
||||
/// https://huggingface.co/microsoft/phi-2/resolve/main/merges.txt?download=true
|
||||
/// </remarks>
|
||||
public new static Phi2Tokenizer Create(
|
||||
public static new Phi2Tokenizer Create(
|
||||
Stream vocabStream,
|
||||
Stream mergesStream,
|
||||
bool addPrefixSpace = false,
|
||||
|
|
|
@ -9,7 +9,7 @@ using System.Collections.Generic;
|
|||
namespace Microsoft.ML.Tokenizers
|
||||
{
|
||||
/// <summary>
|
||||
/// serves as an abstraction for concrete tokenizers, enabling the encoding of text into tokens and IDs, as well as the decoding of IDs back into text.
|
||||
/// Provides an abstraction for tokenizers, enabling the encoding of text into tokens and the decoding of token IDs back into text.
|
||||
/// </summary>
|
||||
public abstract class Tokenizer
|
||||
{
|
||||
|
@ -35,7 +35,27 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
|
||||
/// <param name="settings">The settings used to encode the text.</param>
|
||||
/// <returns>The encoded results containing the list of encoded Ids.</returns>
|
||||
protected abstract EncodeResults<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);
|
||||
/// <remarks>
|
||||
/// Types derived from <see cref="Tokenizer"/> may override this implementation to provide a more efficient implementation.
|
||||
/// By default, it uses <see cref="EncodeToTokens(string?, ReadOnlySpan{char}, EncodeSettings)"/>.
|
||||
/// </remarks>
|
||||
protected virtual EncodeResults<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
|
||||
{
|
||||
EncodeResults<EncodedToken> results = EncodeToTokens(text, textSpan, settings);
|
||||
|
||||
var ids = new int[results.Tokens.Count];
|
||||
for (int i = 0; i < ids.Length; i++)
|
||||
{
|
||||
ids[i] = results.Tokens[i].Id;
|
||||
}
|
||||
|
||||
return new EncodeResults<int>
|
||||
{
|
||||
Tokens = ids,
|
||||
CharsConsumed = results.CharsConsumed,
|
||||
NormalizedText = results.NormalizedText,
|
||||
};
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Encodes input text to token Ids.
|
||||
|
@ -45,7 +65,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||
/// <returns>The list of encoded Ids.</returns>
|
||||
public IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||
=> EncodeToIds(text, ReadOnlySpan<char>.Empty, new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization }).Tokens;
|
||||
=> EncodeToIds(text, text.AsSpan(), new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization }).Tokens;
|
||||
|
||||
/// <summary>
|
||||
/// Encodes input text to token Ids.
|
||||
|
@ -69,7 +89,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <returns>The list of encoded Ids.</returns>
|
||||
public IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedText, out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||
{
|
||||
EncodeResults<int> result = EncodeToIds(text, ReadOnlySpan<char>.Empty,
|
||||
EncodeResults<int> result = EncodeToIds(text, text.AsSpan(),
|
||||
new EncodeSettings
|
||||
{
|
||||
ConsiderPreTokenization = considerPreTokenization,
|
||||
|
@ -127,7 +147,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <returns>The list of encoded <see cref="EncodedToken" />s.</returns>
|
||||
public IReadOnlyList<EncodedToken> EncodeToTokens(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||
{
|
||||
EncodeResults<EncodedToken> result = EncodeToTokens(text, ReadOnlySpan<char>.Empty, new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization });
|
||||
EncodeResults<EncodedToken> result = EncodeToTokens(text, text.AsSpan(), new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization });
|
||||
|
||||
normalizedString = result.NormalizedText;
|
||||
return result.Tokens;
|
||||
|
@ -156,7 +176,12 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
|
||||
/// <param name="settings">The settings used to encode the text.</param>
|
||||
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
|
||||
protected abstract int CountTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);
|
||||
/// <remarks>
|
||||
/// Types derived from <see cref="Tokenizer"/> may override this implementation to provide a more efficient implementation.
|
||||
/// By default, it uses <see cref="EncodeToTokens(string?, ReadOnlySpan{char}, EncodeSettings)"/>.
|
||||
/// </remarks>
|
||||
protected virtual int CountTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
|
||||
=> EncodeToTokens(text, textSpan, settings).Tokens.Count;
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
|
@ -166,7 +191,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
|
||||
public int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||
=> CountTokens(text, ReadOnlySpan<char>.Empty, new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization });
|
||||
=> CountTokens(text, text.AsSpan(), new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization });
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
|
@ -194,7 +219,44 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// If <paramRef name="fromEnd" /> is <see langword="true"/>, it represents the index of the first character to be included. In cases where no tokens fit, the result will be the text length; conversely,
|
||||
/// if all tokens fit, the result will be zero.
|
||||
/// </returns>
|
||||
protected abstract int GetIndexByTokenCount(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings, bool fromEnd, out string? normalizedString, out int tokenCount);
|
||||
/// <remarks>
|
||||
/// Types derived from <see cref="Tokenizer"/> may override this implementation to provide a more efficient implementation.
|
||||
/// By default, it uses <see cref="EncodeToTokens(string?, ReadOnlySpan{char}, EncodeSettings)"/>.
|
||||
/// </remarks>
|
||||
protected virtual int GetIndexByTokenCount(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings, bool fromEnd, out string? normalizedString, out int tokenCount)
|
||||
{
|
||||
int maxTokenCount = settings.MaxTokenCount;
|
||||
if (fromEnd)
|
||||
{
|
||||
// If we're looking from the end, we need to process the whole input.
|
||||
settings.MaxTokenCount = int.MaxValue;
|
||||
}
|
||||
|
||||
EncodeResults<EncodedToken> tokens = EncodeToTokens(text, textSpan, settings);
|
||||
normalizedString = tokens.NormalizedText;
|
||||
tokenCount = Math.Min(maxTokenCount, tokens.Tokens.Count);
|
||||
|
||||
if (!fromEnd)
|
||||
{
|
||||
if (tokenCount > 0)
|
||||
{
|
||||
var token = tokens.Tokens[tokenCount - 1];
|
||||
return token.Offset.Index + token.Offset.Length;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (tokenCount > 0)
|
||||
{
|
||||
var token = tokens.Tokens[tokens.Tokens.Count - tokenCount];
|
||||
return token.Offset.Index;
|
||||
}
|
||||
|
||||
return tokens.NormalizedText?.Length ?? textSpan.Length;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Find the index of the maximum encoding capacity without surpassing the token limit.
|
||||
|
@ -213,7 +275,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
public int GetIndexByTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||
=> GetIndexByTokenCount(
|
||||
text,
|
||||
ReadOnlySpan<char>.Empty,
|
||||
text.AsSpan(),
|
||||
new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization, MaxTokenCount = maxTokenCount },
|
||||
fromEnd: false,
|
||||
out normalizedString,
|
||||
|
@ -259,7 +321,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
public int GetIndexByTokenCountFromEnd(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||
=> GetIndexByTokenCount(
|
||||
text,
|
||||
ReadOnlySpan<char>.Empty,
|
||||
text.AsSpan(),
|
||||
new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization, MaxTokenCount = maxTokenCount },
|
||||
fromEnd: true,
|
||||
out normalizedString,
|
||||
|
@ -293,7 +355,64 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// </summary>
|
||||
/// <param name="ids">The list of ids that we want to decode.</param>
|
||||
/// <returns>The decoded string.</returns>
|
||||
public abstract string? Decode(IEnumerable<int> ids);
|
||||
/// <exception cref="ArgumentNullException"><paramref name="ids"/> is null.</exception>
|
||||
/// <exception cref="InvalidOperationException"><paramref name="ids"/> contains invalid data.</exception>
|
||||
/// <remarks>
|
||||
/// Types derived from <see cref="Tokenizer"/> may override this implementation to provide a more efficient implementation.
|
||||
/// By default, it uses <see cref="Decode(IEnumerable{int}, Span{char}, out int, out int)"/>.
|
||||
/// </remarks>
|
||||
public virtual string? Decode(IEnumerable<int> ids)
|
||||
{
|
||||
if (ids is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(ids));
|
||||
}
|
||||
|
||||
int idCount = 0;
|
||||
if (ids is ICollection<int> c)
|
||||
{
|
||||
idCount = c.Count;
|
||||
if (idCount == 0)
|
||||
{
|
||||
return string.Empty;
|
||||
}
|
||||
}
|
||||
|
||||
char[] destination = ArrayPool<char>.Shared.Rent(
|
||||
#if DEBUG
|
||||
1); // to help validate growth logic
|
||||
#else
|
||||
idCount == 0 ? 1024 : idCount * 8); // arbitrary starting point / heuristic
|
||||
#endif
|
||||
while (true)
|
||||
{
|
||||
switch (Decode(ids, destination, out int idsConsumed, out int charsWritten))
|
||||
{
|
||||
case OperationStatus.Done:
|
||||
string result = destination.AsSpan(0, charsWritten).ToString();
|
||||
ArrayPool<char>.Shared.Return(destination);
|
||||
return result;
|
||||
|
||||
case OperationStatus.DestinationTooSmall:
|
||||
long newSize = (long)destination.Length * 2;
|
||||
if (newSize > int.MaxValue)
|
||||
{
|
||||
newSize = (long)destination.Length + 1;
|
||||
if (newSize > int.MaxValue)
|
||||
{
|
||||
throw new OutOfMemoryException();
|
||||
}
|
||||
}
|
||||
|
||||
ArrayPool<char>.Shared.Return(destination);
|
||||
destination = ArrayPool<char>.Shared.Rent((int)newSize);
|
||||
break;
|
||||
|
||||
default:
|
||||
throw new InvalidOperationException("The provided token IDs could not be decoded.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Decode the given ids back to text and store the result in the <paramref name="destination"/> span.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
|
@ -15,7 +15,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// This should only be used with a Ptr/Length for querying. For storing in a dictionary, this should
|
||||
/// always be used with a string.
|
||||
/// </remarks>
|
||||
internal unsafe readonly struct StringSpanOrdinalKey : IEquatable<StringSpanOrdinalKey>
|
||||
internal readonly unsafe struct StringSpanOrdinalKey : IEquatable<StringSpanOrdinalKey>
|
||||
{
|
||||
public readonly char* Ptr;
|
||||
public readonly int Length;
|
||||
|
@ -45,7 +45,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
public override int GetHashCode() => Helpers.GetHashCode(Span);
|
||||
}
|
||||
|
||||
internal unsafe readonly struct StringSpanOrdinalKeyPair : IEquatable<StringSpanOrdinalKeyPair>
|
||||
internal readonly unsafe struct StringSpanOrdinalKeyPair : IEquatable<StringSpanOrdinalKeyPair>
|
||||
{
|
||||
private readonly StringSpanOrdinalKey _left;
|
||||
private readonly StringSpanOrdinalKey _right;
|
||||
|
@ -173,7 +173,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// </summary>
|
||||
internal static class StringSpanOrdinalKeyExtensions
|
||||
{
|
||||
public unsafe static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKey, TValue> map, ReadOnlySpan<char> key, out TValue value)
|
||||
public static unsafe bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKey, TValue> map, ReadOnlySpan<char> key, out TValue value)
|
||||
{
|
||||
fixed (char* ptr = key)
|
||||
{
|
||||
|
@ -184,7 +184,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
public static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKey, TValue> map, string key, out TValue value) =>
|
||||
map.TryGetValue(new StringSpanOrdinalKey(key), out value!);
|
||||
|
||||
public unsafe static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKeyPair, TValue> map, ReadOnlySpan<char> key1, ReadOnlySpan<char> key2, out TValue value)
|
||||
public static unsafe bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKeyPair, TValue> map, ReadOnlySpan<char> key1, ReadOnlySpan<char> key2, out TValue value)
|
||||
{
|
||||
fixed (char* ptr1 = key1)
|
||||
fixed (char* ptr2 = key2)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System.Buffers;
|
||||
using System.Diagnostics;
|
||||
|
|
|
@ -20,7 +20,7 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
{
|
||||
private const string UnknownToken = "[unk]";
|
||||
|
||||
private readonly static Dictionary<string, int> _vocabDataWithWordPrefixAndEndOfWordSuffix =
|
||||
private static readonly Dictionary<string, int> _vocabDataWithWordPrefixAndEndOfWordSuffix =
|
||||
new Dictionary<string, int>() { { UnknownToken, 0 }, { "!", 5 }, { ",", 6 }, { ".", 7 }, { "B", 8 }, { "H", 9 }, { "T", 10 }, { "W", 11 }, { "a", 12 }, { "b", 13 }, { "c", 14 }, { "d", 15 }, { "e", 16 },
|
||||
{ "f", 17 }, { "g", 18 }, { "h", 19 }, { "i", 20 }, { "k", 21 }, { "l", 22 }, { "m", 23 }, { "n", 24 }, { "o", 25 }, { "p", 26 }, { "r", 27 }, { "s", 28 }, { "t", 29 }, { "u", 30 }, { "v", 31 },
|
||||
{ "z", 32 }, { ".</w>", 33 }, { "##o", 34 }, { "##r", 35 }, { "##l", 36 }, { "##d</w>", 37 }, { "##h", 38 }, { "##i", 39 }, { "##s</w>", 40 }, { "##s", 41 }, { "##e</w>", 42 }, { "a</w>", 43 },
|
||||
|
@ -31,7 +31,7 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
{ "##ken", 92 }, { "##um", 93 }, { "##ent</w>", 94 }, { "Bpe</w>", 95 }, { "Hell", 96 }, { "This</w>", 97 }, { "Worl", 98 }, { "and</w>", 99 }, { "docum", 100 }, { "file", 101 }, { "genera", 102 },
|
||||
{ "merg", 103 }, { "token", 104 }, { "the</w>", 105 }, { "train</w>", 106 }, { "use</w>", 107 }, { "vocab</w>", 108 }, { "##izer</w>", 109 }, { "Hello</w>", 110 }, { "World</w>", 111 },
|
||||
{ "document</w>", 112 }, { "files</w>", 113 }, { "generate</w>", 114 }, { "merge</w>", 115 }, { "tokenizer</w>", 116 } };
|
||||
private readonly static (string, string)[] _mergeDataWithWordPrefixAndEndOfWordSuffix =
|
||||
private static readonly (string, string)[] _mergeDataWithWordPrefixAndEndOfWordSuffix =
|
||||
new (string, string)[] { ("t", "##o</w>"), ("##e", "##n"), ("##o", "##c"), ("##r", "##a"), ("B", "##p"), ("H", "##e"), ("T", "##h"), ("W", "##o"), ("a", "##n"),
|
||||
("d", "##oc"), ("f", "##i"), ("g", "##en"), ("i", "##s</w>"), ("m", "##e"), ("t", "##o"), ("t", "##h"), ("t", "##ra"), ("u", "##s"), ("v", "##oc"), ("##r", "##l"), ("##r", "##g"), ("##l", "##l"),
|
||||
("##l", "##e"), ("##i", "##s</w>"), ("##i", "##n</w>"), ("##i", "##z"), ("##a", "##b</w>"), ("##e", "##r</w>"), ("##e", "##ra"), ("##t", "##e</w>"), ("##k", "##en"), ("##u", "##m"), ("##en", "##t</w>"),
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using Microsoft.ML.Tokenizers;
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using Xunit;
|
||||
|
@ -12,6 +12,114 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
{
|
||||
public class TokenizerTests
|
||||
{
|
||||
[Fact]
|
||||
public void Decode_DefaultImplementation()
|
||||
{
|
||||
var tokenizer = new EnglishAlphabetTokenizer();
|
||||
|
||||
Assert.Equal("", tokenizer.Decode([]));
|
||||
|
||||
Assert.Equal("hello", tokenizer.Decode([7, 4, 11, 11, 14]));
|
||||
|
||||
Assert.Equal(
|
||||
string.Concat(Enumerable.Repeat("abcdefghijklmnopqrstuvwxyz", 100)),
|
||||
tokenizer.Decode(Enumerable.Repeat("abcdefghijklmnopqrstuvwxyz", 100).SelectMany(s => s.Select(c => c - 'a'))));
|
||||
|
||||
Assert.Throws<InvalidOperationException>(() => tokenizer.Decode([26, 27, 28, 29]));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void EncodeToIds_DefaultImplementation()
|
||||
{
|
||||
var tokenizer = new EnglishAlphabetTokenizer();
|
||||
|
||||
IReadOnlyList<int> ids = tokenizer.EncodeToIds("hello, world", 5, out string? normalizedText, out int charsConsumed);
|
||||
|
||||
Assert.Equal([7, 4, 11, 11, 14], ids);
|
||||
Assert.Null(normalizedText);
|
||||
Assert.Equal(5, charsConsumed);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CountTokens_DefaultImplementation()
|
||||
{
|
||||
var tokenizer = new EnglishAlphabetTokenizer();
|
||||
|
||||
Assert.Equal(5, tokenizer.CountTokens("hello"));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetIndexByTokenCount_DefaultImplementation()
|
||||
{
|
||||
var tokenizer = new EnglishAlphabetTokenizer();
|
||||
|
||||
Assert.Equal(2, tokenizer.GetIndexByTokenCount("hello", 2, out string? normalizedString, out int tokenCount));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(2, tokenCount);
|
||||
|
||||
Assert.Equal(5, tokenizer.GetIndexByTokenCount("hello", 8, out normalizedString, out tokenCount));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(5, tokenCount);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetIndexByTokenCountFromEnd_DefaultImplementation()
|
||||
{
|
||||
var tokenizer = new EnglishAlphabetTokenizer();
|
||||
|
||||
Assert.Equal(3, tokenizer.GetIndexByTokenCountFromEnd("hello", 2, out string? normalizedString, out int tokenCount));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(2, tokenCount);
|
||||
|
||||
Assert.Equal(0, tokenizer.GetIndexByTokenCountFromEnd("hello", 8, out normalizedString, out tokenCount));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(5, tokenCount);
|
||||
}
|
||||
|
||||
private sealed class EnglishAlphabetTokenizer : Tokenizer
|
||||
{
|
||||
public override OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, out int idsConsumed, out int charsWritten)
|
||||
{
|
||||
int pos = 0;
|
||||
foreach (int i in ids)
|
||||
{
|
||||
if (pos >= destination.Length)
|
||||
{
|
||||
charsWritten = idsConsumed = pos;
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
|
||||
if (i is < 0 or >= 26)
|
||||
{
|
||||
charsWritten = idsConsumed = pos;
|
||||
return OperationStatus.InvalidData;
|
||||
}
|
||||
|
||||
destination[pos++] = (char)('a' + i);
|
||||
}
|
||||
|
||||
charsWritten = idsConsumed = pos;
|
||||
return OperationStatus.Done;
|
||||
}
|
||||
|
||||
protected override EncodeResults<EncodedToken> EncodeToTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
|
||||
{
|
||||
var tokens = new List<EncodedToken>();
|
||||
|
||||
int count = 0;
|
||||
foreach (char c in textSpan)
|
||||
{
|
||||
if (count >= settings.MaxTokenCount)
|
||||
break;
|
||||
|
||||
tokens.Add(new EncodedToken(c - 'a', c.ToString(), (count, 1)));
|
||||
count++;
|
||||
}
|
||||
|
||||
return new EncodeResults<EncodedToken> { Tokens = tokens, CharsConsumed = count };
|
||||
}
|
||||
}
|
||||
|
||||
internal static void TestTokenLimits(Tokenizer tokenizer)
|
||||
{
|
||||
string input = @"
|
||||
|
@ -116,4 +224,4 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.GetIndexByTokenCountFromEnd(input, maxTokenCount: -1, out _, out _));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче