Make most Tokenizer abstract methods virtual (#7198)

* Make most Tokenizer abstract methods virtual

All of the functionality in all of the methods can be implemented in terms of just a single Decode and EncodeToTokens set of methods. Only those two need to be abstract; everything else that was abstract can instead be virtual and implemented in terms of those.

* Address feedback and clean up a few things
This commit is contained in:
Stephen Toub 2024-07-25 22:53:25 -04:00 коммит произвёл GitHub
Родитель 0c2e82e7d4
Коммит 5b920f9601
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
9 изменённых файлов: 259 добавлений и 31 удалений

Просмотреть файл

@ -1,7 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.CompilerServices;

Просмотреть файл

@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
@ -9,7 +9,11 @@ namespace Microsoft.ML.Tokenizers
/// </summary>
public struct EncodeSettings
{
/// <summary>
/// Initializes the <see cref="EncodeSettings"/> instance.
/// </summary>
public EncodeSettings() { MaxTokenCount = int.MaxValue; }
/// <summary>
/// Gets or sets a value indicating whether to consider the input normalization during encoding.
/// </summary>

Просмотреть файл

@ -2,10 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Text;
namespace Microsoft.ML.Tokenizers
{
/// <summary>

Просмотреть файл

@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
@ -95,7 +95,7 @@ namespace Microsoft.ML.Tokenizers
/// https://huggingface.co/microsoft/phi-2/resolve/main/vocab.json?download=true
/// https://huggingface.co/microsoft/phi-2/resolve/main/merges.txt?download=true
/// </remarks>
public new static Phi2Tokenizer Create(
public static new Phi2Tokenizer Create(
Stream vocabStream,
Stream mergesStream,
bool addPrefixSpace = false,

Просмотреть файл

@ -9,7 +9,7 @@ using System.Collections.Generic;
namespace Microsoft.ML.Tokenizers
{
/// <summary>
/// serves as an abstraction for concrete tokenizers, enabling the encoding of text into tokens and IDs, as well as the decoding of IDs back into text.
/// Provides an abstraction for tokenizers, enabling the encoding of text into tokens and the decoding of token IDs back into text.
/// </summary>
public abstract class Tokenizer
{
@ -35,7 +35,27 @@ namespace Microsoft.ML.Tokenizers
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
/// <returns>The encoded results containing the list of encoded Ids.</returns>
protected abstract EncodeResults<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);
/// <remarks>
/// Types derived from <see cref="Tokenizer"/> may override this implementation to provide a more efficient implementation.
/// By default, it uses <see cref="EncodeToTokens(string?, ReadOnlySpan{char}, EncodeSettings)"/>.
/// </remarks>
protected virtual EncodeResults<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
{
EncodeResults<EncodedToken> results = EncodeToTokens(text, textSpan, settings);
var ids = new int[results.Tokens.Count];
for (int i = 0; i < ids.Length; i++)
{
ids[i] = results.Tokens[i].Id;
}
return new EncodeResults<int>
{
Tokens = ids,
CharsConsumed = results.CharsConsumed,
NormalizedText = results.NormalizedText,
};
}
/// <summary>
/// Encodes input text to token Ids.
@ -45,7 +65,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(text, ReadOnlySpan<char>.Empty, new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization }).Tokens;
=> EncodeToIds(text, text.AsSpan(), new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization }).Tokens;
/// <summary>
/// Encodes input text to token Ids.
@ -69,7 +89,7 @@ namespace Microsoft.ML.Tokenizers
/// <returns>The list of encoded Ids.</returns>
public IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedText, out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true)
{
EncodeResults<int> result = EncodeToIds(text, ReadOnlySpan<char>.Empty,
EncodeResults<int> result = EncodeToIds(text, text.AsSpan(),
new EncodeSettings
{
ConsiderPreTokenization = considerPreTokenization,
@ -127,7 +147,7 @@ namespace Microsoft.ML.Tokenizers
/// <returns>The list of encoded <see cref="EncodedToken" />s.</returns>
public IReadOnlyList<EncodedToken> EncodeToTokens(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true)
{
EncodeResults<EncodedToken> result = EncodeToTokens(text, ReadOnlySpan<char>.Empty, new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization });
EncodeResults<EncodedToken> result = EncodeToTokens(text, text.AsSpan(), new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization });
normalizedString = result.NormalizedText;
return result.Tokens;
@ -156,7 +176,12 @@ namespace Microsoft.ML.Tokenizers
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
protected abstract int CountTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);
/// <remarks>
/// Types derived from <see cref="Tokenizer"/> may override this implementation to provide a more efficient implementation.
/// By default, it uses <see cref="EncodeToTokens(string?, ReadOnlySpan{char}, EncodeSettings)"/>.
/// </remarks>
protected virtual int CountTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
=> EncodeToTokens(text, textSpan, settings).Tokens.Count;
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
@ -166,7 +191,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
public int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
=> CountTokens(text, ReadOnlySpan<char>.Empty, new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization });
=> CountTokens(text, text.AsSpan(), new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization });
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
@ -194,7 +219,44 @@ namespace Microsoft.ML.Tokenizers
/// If <paramRef name="fromEnd" /> is <see langword="true"/>, it represents the index of the first character to be included. In cases where no tokens fit, the result will be the text length; conversely,
/// if all tokens fit, the result will be zero.
/// </returns>
protected abstract int GetIndexByTokenCount(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings, bool fromEnd, out string? normalizedString, out int tokenCount);
/// <remarks>
/// Types derived from <see cref="Tokenizer"/> may override this implementation to provide a more efficient implementation.
/// By default, it uses <see cref="EncodeToTokens(string?, ReadOnlySpan{char}, EncodeSettings)"/>.
/// </remarks>
protected virtual int GetIndexByTokenCount(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings, bool fromEnd, out string? normalizedString, out int tokenCount)
{
int maxTokenCount = settings.MaxTokenCount;
if (fromEnd)
{
// If we're looking from the end, we need to process the whole input.
settings.MaxTokenCount = int.MaxValue;
}
EncodeResults<EncodedToken> tokens = EncodeToTokens(text, textSpan, settings);
normalizedString = tokens.NormalizedText;
tokenCount = Math.Min(maxTokenCount, tokens.Tokens.Count);
if (!fromEnd)
{
if (tokenCount > 0)
{
var token = tokens.Tokens[tokenCount - 1];
return token.Offset.Index + token.Offset.Length;
}
return 0;
}
else
{
if (tokenCount > 0)
{
var token = tokens.Tokens[tokens.Tokens.Count - tokenCount];
return token.Offset.Index;
}
return tokens.NormalizedText?.Length ?? textSpan.Length;
}
}
/// <summary>
/// Find the index of the maximum encoding capacity without surpassing the token limit.
@ -213,7 +275,7 @@ namespace Microsoft.ML.Tokenizers
public int GetIndexByTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> GetIndexByTokenCount(
text,
ReadOnlySpan<char>.Empty,
text.AsSpan(),
new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization, MaxTokenCount = maxTokenCount },
fromEnd: false,
out normalizedString,
@ -259,7 +321,7 @@ namespace Microsoft.ML.Tokenizers
public int GetIndexByTokenCountFromEnd(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> GetIndexByTokenCount(
text,
ReadOnlySpan<char>.Empty,
text.AsSpan(),
new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization, MaxTokenCount = maxTokenCount },
fromEnd: true,
out normalizedString,
@ -293,7 +355,64 @@ namespace Microsoft.ML.Tokenizers
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <returns>The decoded string.</returns>
public abstract string? Decode(IEnumerable<int> ids);
/// <exception cref="ArgumentNullException"><paramref name="ids"/> is null.</exception>
/// <exception cref="InvalidOperationException"><paramref name="ids"/> contains invalid data.</exception>
/// <remarks>
/// Types derived from <see cref="Tokenizer"/> may override this implementation to provide a more efficient implementation.
/// By default, it uses <see cref="Decode(IEnumerable{int}, Span{char}, out int, out int)"/>.
/// </remarks>
public virtual string? Decode(IEnumerable<int> ids)
{
if (ids is null)
{
throw new ArgumentNullException(nameof(ids));
}
int idCount = 0;
if (ids is ICollection<int> c)
{
idCount = c.Count;
if (idCount == 0)
{
return string.Empty;
}
}
char[] destination = ArrayPool<char>.Shared.Rent(
#if DEBUG
1); // to help validate growth logic
#else
idCount == 0 ? 1024 : idCount * 8); // arbitrary starting point / heuristic
#endif
while (true)
{
switch (Decode(ids, destination, out int idsConsumed, out int charsWritten))
{
case OperationStatus.Done:
string result = destination.AsSpan(0, charsWritten).ToString();
ArrayPool<char>.Shared.Return(destination);
return result;
case OperationStatus.DestinationTooSmall:
long newSize = (long)destination.Length * 2;
if (newSize > int.MaxValue)
{
newSize = (long)destination.Length + 1;
if (newSize > int.MaxValue)
{
throw new OutOfMemoryException();
}
}
ArrayPool<char>.Shared.Return(destination);
destination = ArrayPool<char>.Shared.Rent((int)newSize);
break;
default:
throw new InvalidOperationException("The provided token IDs could not be decoded.");
}
}
}
/// <summary>
/// Decode the given ids back to text and store the result in the <paramref name="destination"/> span.

Просмотреть файл

@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
@ -15,7 +15,7 @@ namespace Microsoft.ML.Tokenizers
/// This should only be used with a Ptr/Length for querying. For storing in a dictionary, this should
/// always be used with a string.
/// </remarks>
internal unsafe readonly struct StringSpanOrdinalKey : IEquatable<StringSpanOrdinalKey>
internal readonly unsafe struct StringSpanOrdinalKey : IEquatable<StringSpanOrdinalKey>
{
public readonly char* Ptr;
public readonly int Length;
@ -45,7 +45,7 @@ namespace Microsoft.ML.Tokenizers
public override int GetHashCode() => Helpers.GetHashCode(Span);
}
internal unsafe readonly struct StringSpanOrdinalKeyPair : IEquatable<StringSpanOrdinalKeyPair>
internal readonly unsafe struct StringSpanOrdinalKeyPair : IEquatable<StringSpanOrdinalKeyPair>
{
private readonly StringSpanOrdinalKey _left;
private readonly StringSpanOrdinalKey _right;
@ -173,7 +173,7 @@ namespace Microsoft.ML.Tokenizers
/// </summary>
internal static class StringSpanOrdinalKeyExtensions
{
public unsafe static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKey, TValue> map, ReadOnlySpan<char> key, out TValue value)
public static unsafe bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKey, TValue> map, ReadOnlySpan<char> key, out TValue value)
{
fixed (char* ptr = key)
{
@ -184,7 +184,7 @@ namespace Microsoft.ML.Tokenizers
public static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKey, TValue> map, string key, out TValue value) =>
map.TryGetValue(new StringSpanOrdinalKey(key), out value!);
public unsafe static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKeyPair, TValue> map, ReadOnlySpan<char> key1, ReadOnlySpan<char> key2, out TValue value)
public static unsafe bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKeyPair, TValue> map, ReadOnlySpan<char> key1, ReadOnlySpan<char> key2, out TValue value)
{
fixed (char* ptr1 = key1)
fixed (char* ptr2 = key2)

Просмотреть файл

@ -1,5 +1,6 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Buffers;
using System.Diagnostics;

Просмотреть файл

@ -20,7 +20,7 @@ namespace Microsoft.ML.Tokenizers.Tests
{
private const string UnknownToken = "[unk]";
private readonly static Dictionary<string, int> _vocabDataWithWordPrefixAndEndOfWordSuffix =
private static readonly Dictionary<string, int> _vocabDataWithWordPrefixAndEndOfWordSuffix =
new Dictionary<string, int>() { { UnknownToken, 0 }, { "!", 5 }, { ",", 6 }, { ".", 7 }, { "B", 8 }, { "H", 9 }, { "T", 10 }, { "W", 11 }, { "a", 12 }, { "b", 13 }, { "c", 14 }, { "d", 15 }, { "e", 16 },
{ "f", 17 }, { "g", 18 }, { "h", 19 }, { "i", 20 }, { "k", 21 }, { "l", 22 }, { "m", 23 }, { "n", 24 }, { "o", 25 }, { "p", 26 }, { "r", 27 }, { "s", 28 }, { "t", 29 }, { "u", 30 }, { "v", 31 },
{ "z", 32 }, { ".</w>", 33 }, { "##o", 34 }, { "##r", 35 }, { "##l", 36 }, { "##d</w>", 37 }, { "##h", 38 }, { "##i", 39 }, { "##s</w>", 40 }, { "##s", 41 }, { "##e</w>", 42 }, { "a</w>", 43 },
@ -31,7 +31,7 @@ namespace Microsoft.ML.Tokenizers.Tests
{ "##ken", 92 }, { "##um", 93 }, { "##ent</w>", 94 }, { "Bpe</w>", 95 }, { "Hell", 96 }, { "This</w>", 97 }, { "Worl", 98 }, { "and</w>", 99 }, { "docum", 100 }, { "file", 101 }, { "genera", 102 },
{ "merg", 103 }, { "token", 104 }, { "the</w>", 105 }, { "train</w>", 106 }, { "use</w>", 107 }, { "vocab</w>", 108 }, { "##izer</w>", 109 }, { "Hello</w>", 110 }, { "World</w>", 111 },
{ "document</w>", 112 }, { "files</w>", 113 }, { "generate</w>", 114 }, { "merge</w>", 115 }, { "tokenizer</w>", 116 } };
private readonly static (string, string)[] _mergeDataWithWordPrefixAndEndOfWordSuffix =
private static readonly (string, string)[] _mergeDataWithWordPrefixAndEndOfWordSuffix =
new (string, string)[] { ("t", "##o</w>"), ("##e", "##n"), ("##o", "##c"), ("##r", "##a"), ("B", "##p"), ("H", "##e"), ("T", "##h"), ("W", "##o"), ("a", "##n"),
("d", "##oc"), ("f", "##i"), ("g", "##en"), ("i", "##s</w>"), ("m", "##e"), ("t", "##o"), ("t", "##h"), ("t", "##ra"), ("u", "##s"), ("v", "##oc"), ("##r", "##l"), ("##r", "##g"), ("##l", "##l"),
("##l", "##e"), ("##i", "##s</w>"), ("##i", "##n</w>"), ("##i", "##z"), ("##a", "##b</w>"), ("##e", "##r</w>"), ("##e", "##ra"), ("##t", "##e</w>"), ("##k", "##en"), ("##u", "##m"), ("##en", "##t</w>"),

Просмотреть файл

@ -1,9 +1,9 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Tokenizers;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using Xunit;
@ -12,6 +12,114 @@ namespace Microsoft.ML.Tokenizers.Tests
{
public class TokenizerTests
{
[Fact]
public void Decode_DefaultImplementation()
{
var tokenizer = new EnglishAlphabetTokenizer();
Assert.Equal("", tokenizer.Decode([]));
Assert.Equal("hello", tokenizer.Decode([7, 4, 11, 11, 14]));
Assert.Equal(
string.Concat(Enumerable.Repeat("abcdefghijklmnopqrstuvwxyz", 100)),
tokenizer.Decode(Enumerable.Repeat("abcdefghijklmnopqrstuvwxyz", 100).SelectMany(s => s.Select(c => c - 'a'))));
Assert.Throws<InvalidOperationException>(() => tokenizer.Decode([26, 27, 28, 29]));
}
[Fact]
public void EncodeToIds_DefaultImplementation()
{
var tokenizer = new EnglishAlphabetTokenizer();
IReadOnlyList<int> ids = tokenizer.EncodeToIds("hello, world", 5, out string? normalizedText, out int charsConsumed);
Assert.Equal([7, 4, 11, 11, 14], ids);
Assert.Null(normalizedText);
Assert.Equal(5, charsConsumed);
}
[Fact]
public void CountTokens_DefaultImplementation()
{
var tokenizer = new EnglishAlphabetTokenizer();
Assert.Equal(5, tokenizer.CountTokens("hello"));
}
[Fact]
public void GetIndexByTokenCount_DefaultImplementation()
{
var tokenizer = new EnglishAlphabetTokenizer();
Assert.Equal(2, tokenizer.GetIndexByTokenCount("hello", 2, out string? normalizedString, out int tokenCount));
Assert.Null(normalizedString);
Assert.Equal(2, tokenCount);
Assert.Equal(5, tokenizer.GetIndexByTokenCount("hello", 8, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(5, tokenCount);
}
[Fact]
public void GetIndexByTokenCountFromEnd_DefaultImplementation()
{
var tokenizer = new EnglishAlphabetTokenizer();
Assert.Equal(3, tokenizer.GetIndexByTokenCountFromEnd("hello", 2, out string? normalizedString, out int tokenCount));
Assert.Null(normalizedString);
Assert.Equal(2, tokenCount);
Assert.Equal(0, tokenizer.GetIndexByTokenCountFromEnd("hello", 8, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(5, tokenCount);
}
private sealed class EnglishAlphabetTokenizer : Tokenizer
{
public override OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, out int idsConsumed, out int charsWritten)
{
int pos = 0;
foreach (int i in ids)
{
if (pos >= destination.Length)
{
charsWritten = idsConsumed = pos;
return OperationStatus.DestinationTooSmall;
}
if (i is < 0 or >= 26)
{
charsWritten = idsConsumed = pos;
return OperationStatus.InvalidData;
}
destination[pos++] = (char)('a' + i);
}
charsWritten = idsConsumed = pos;
return OperationStatus.Done;
}
protected override EncodeResults<EncodedToken> EncodeToTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
{
var tokens = new List<EncodedToken>();
int count = 0;
foreach (char c in textSpan)
{
if (count >= settings.MaxTokenCount)
break;
tokens.Add(new EncodedToken(c - 'a', c.ToString(), (count, 1)));
count++;
}
return new EncodeResults<EncodedToken> { Tokens = tokens, CharsConsumed = count };
}
}
internal static void TestTokenLimits(Tokenizer tokenizer)
{
string input = @"
@ -116,4 +224,4 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.GetIndexByTokenCountFromEnd(input, maxTokenCount: -1, out _, out _));
}
}
}
}