* Remove teh confusing specialTokens flag parameter from all APIs

* Normalize casing of the Tiktoken name

* Make Model.Encode work with span instead of string

* Support granular Last/IndexOf.

* Remove wrong assert.

* Address the feedback

* Update the package doc.
This commit is contained in:
Tarek Mahmoud Sayed 2024-03-27 11:24:30 -07:00 коммит произвёл GitHub
Родитель 214e12aefc
Коммит c980eaf964
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
18 изменённых файлов: 1125 добавлений и 410 удалений

Просмотреть файл

@ -20,6 +20,7 @@ namespace Microsoft.ML.Tokenizers
{
/// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model.
private const int MaxWordLengthToCache = 15;
private string? _unknownToken;
/// <summary>
@ -176,10 +177,9 @@ namespace Microsoft.ML.Tokenizers
/// <summary>
/// Encode a text string to a list of tokens.
/// </summary>
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token. This parameter is ignored in this model.</param>
/// <param name="text">The text to encode.</param>
/// <returns>The list of tokens generated from the text tokenization.</returns>
public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = false)
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text)
{
if (text.Length == 0)
{
@ -192,34 +192,44 @@ namespace Microsoft.ML.Tokenizers
/// <summary>
/// Encode a split text string to a list of Ids and add them to the accumulatedIds list.
/// </summary>
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token. This parameter is ignored in this model.</param>
/// <param name="text">The text to encode.</param>
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
public override void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds);
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsWithCache(text, accumulatedIds, maxTokens, out textLength);
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token. This parameter is ignored in this model.</param>
/// <returns>The number of tokens that the input text will be encoded to. This parameter is ignored in this model.</returns>
public override int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken) => EncodeToIdsWithCache(text, null);
/// <param name="text">The text to encode.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsWithCache(text, null, maxTokens, out textLength);
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue) => EncodeToIdsFromEndWithCache(text, null, maxTokens, out textIndex);
/// <summary>
/// Map the token to encoded Id.
/// </summary>
/// <param name="token">The token to map to the Id.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
/// <returns>The mapped Id of the token.</returns>
public override int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out int value) ? value : null;
public override int? MapTokenToId(ReadOnlySpan<char> token) => _vocab.TryGetValue(token, out int value) ? value : null;
/// <summary>
/// Map the encoded Id to the token.
/// </summary>
/// <param name="id">The Id to map to the token.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the decoding.</param>
/// <returns>The mapped token of the Id.</returns>
public override string? MapIdToToken(int id, bool considerSpecialTokens = true)
public override string? MapIdToToken(int id)
{
if (VocabReverse.TryGetValue(id, out string? value))
{
@ -434,7 +444,7 @@ namespace Microsoft.ML.Tokenizers
internal List<Token> WordToTokens(ref Word word) => word.ToTokens(VocabReverse);
internal List<Token> EncodeWithCache(string text)
internal List<Token> EncodeWithCache(ReadOnlySpan<char> text)
{
Word word;
if (Cache is not null)
@ -444,28 +454,64 @@ namespace Microsoft.ML.Tokenizers
return WordToTokens(ref word);
}
word = MergeWord(text.AsSpan());
Cache.Set(text, word);
word = MergeWord(text);
if (text.Length <= MaxWordLengthToCache)
{
Cache.Set(text.ToString(), word);
}
}
else
{
word = MergeWord(text.AsSpan());
word = MergeWord(text);
}
return WordToTokens(ref word);
}
internal int WordToIds(ref Word word, IList<int>? accumulatedIds)
internal int WordToIds(ref Word word, IList<int>? accumulatedIds, out int textLength, int fullTextLength, int maxTokens)
{
if (accumulatedIds is not null)
if (word.SymbolsCount < maxTokens)
{
word.PopulateIds(accumulatedIds);
textLength = fullTextLength;
if (accumulatedIds is not null)
{
word.PopulateIds(accumulatedIds);
}
return word.SymbolsCount;
}
return word.SymbolsCount;
if (accumulatedIds is not null)
{
return word.PopulateIdsUpToMax(accumulatedIds, maxTokens, out textLength);
}
return word.CountIdsUpToMax(maxTokens, out textLength);
}
internal int EncodeToIdsWithCache(ReadOnlySpan<char> text, IList<int>? accumulatedIds)
internal int WordToIdsFromEnd(ref Word word, IList<int>? accumulatedIds, out int textIndex, int fullTextLength, int maxTokens)
{
if (word.SymbolsCount < maxTokens)
{
textIndex = 0;
if (accumulatedIds is not null)
{
word.PopulateIds(accumulatedIds);
}
return word.SymbolsCount;
}
if (accumulatedIds is not null)
{
return word.PopulateIdsUpToMaxFromEnd(accumulatedIds, maxTokens, fullTextLength, out textIndex);
}
return word.CountIdsUpToMaxFromEnd(maxTokens, fullTextLength, out textIndex);
}
internal int EncodeToIdsWithCache(ReadOnlySpan<char> text, IList<int>? accumulatedIds, int maxTokens, out int textLength)
{
Word word;
@ -473,18 +519,48 @@ namespace Microsoft.ML.Tokenizers
{
if (Cache.TryGetValue(text, out Word hit))
{
return WordToIds(ref hit, accumulatedIds);
return WordToIds(ref hit, accumulatedIds, out textLength, text.Length, maxTokens);
}
word = MergeWord(text);
Cache.Set(text.ToString(), word);
if (text.Length <= MaxWordLengthToCache)
{
Cache.Set(text.ToString(), word);
}
}
else
{
word = MergeWord(text);
}
return WordToIds(ref word, accumulatedIds);
return WordToIds(ref word, accumulatedIds, out textLength, text.Length, maxTokens);
}
internal int EncodeToIdsFromEndWithCache(ReadOnlySpan<char> text, IList<int>? accumulatedIds, int maxTokens, out int textIndex)
{
Word word;
if (Cache is not null)
{
if (Cache.TryGetValue(text, out Word hit))
{
return WordToIdsFromEnd(ref hit, accumulatedIds, out textIndex, text.Length, maxTokens);
}
word = MergeWord(text);
if (text.Length <= MaxWordLengthToCache)
{
Cache.Set(text.ToString(), word);
}
}
else
{
word = MergeWord(text);
}
return WordToIdsFromEnd(ref word, accumulatedIds, out textIndex, text.Length, maxTokens);
}
internal static readonly List<Token> EmptyTokensList = new();

Просмотреть файл

@ -135,15 +135,9 @@ namespace Microsoft.ML.Tokenizers
/// Map the encoded Id to the token.
/// </summary>
/// <param name="id">The Id to map to the string.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the decoding.</param>
/// <returns>The mapped token of the Id.</returns>
public override string? MapIdToToken(int id, bool considerSpecialTokens = true)
public override string? MapIdToToken(int id)
{
if (!considerSpecialTokens && id < 0)
{
return null;
}
if (_vocabReverse.TryGetValue(id, out var value))
{
string v = value.Data!;
@ -176,12 +170,11 @@ namespace Microsoft.ML.Tokenizers
/// <summary>
/// Encode a text string to a list of tokens.
/// </summary>
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token. This parameter is ignored in this model.</param>
/// <param name="text">The text to encode.</param>
/// <returns>The list of tokens generated from the text tokenization.</returns>
public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = false)
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text)
{
if (string.IsNullOrEmpty(text))
if (text.IsEmpty)
{
return Bpe.EmptyTokensList;
}
@ -215,7 +208,7 @@ namespace Microsoft.ML.Tokenizers
}
List<Token> result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
_cache.Set(text, result);
_cache.Set(text.ToString(), result);
ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping);
return result;
@ -224,37 +217,116 @@ namespace Microsoft.ML.Tokenizers
/// <summary>
/// Encode a split text string to a list of Ids and add them to the accumulatedIds list.
/// </summary>
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token. This parameter is ignored in this model.</param>
/// <param name="text">The text to encode.</param>
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
public override void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIds(text, accumulatedIds);
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsInternal(text, accumulatedIds, out textLength, maxTokens);
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token. This parameter is ignored in this model.</param>
/// <param name="text">The text to encode.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken) => EncodeToIds(text, null);
public override int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsInternal(text, null, out textLength, maxTokens);
private int EncodeToIds(ReadOnlySpan<char> text, IList<int>? accumulatedIds)
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue) => EncodeToIdsFromEndInternal(text, null, out textIndex, maxTokens);
private int EncodeToIdsResult(List<Token> tokens, IList<int>? accumulatedIds, int maxTokens, int fullTextLength, out int textLength)
{
if (text.IsEmpty)
{
return 0;
}
textLength = 0;
if (_cache.TryGetValue(text, out List<Token>? hit))
if (tokens.Count <= maxTokens)
{
if (accumulatedIds is not null)
{
foreach (var t in hit)
foreach (var t in tokens)
{
accumulatedIds.Add(t.Id);
}
}
return hit.Count;
textLength = fullTextLength;
return tokens.Count;
}
if (accumulatedIds is not null)
{
for (int i = 0; i < maxTokens; i++)
{
accumulatedIds.Add(tokens[i].Id);
textLength += tokens[i].Offset.Length;
}
}
else
{
for (int i = 0; i < maxTokens; i++)
{
textLength += tokens[i].Offset.Length;
}
}
return maxTokens;
}
private int EncodeToIdsFromEndResult(List<Token> tokens, IList<int>? accumulatedIds, int maxTokens, int fullTextLength, out int textIndex)
{
textIndex = fullTextLength;
if (tokens.Count <= maxTokens)
{
if (accumulatedIds is not null)
{
foreach (var t in tokens)
{
accumulatedIds.Add(t.Id);
}
}
textIndex = 0;
return tokens.Count;
}
if (accumulatedIds is not null)
{
for (int i = tokens.Count - maxTokens; i < tokens.Count; i++)
{
accumulatedIds.Add(tokens[i].Id);
textIndex -= tokens[i].Offset.Length;
}
}
else
{
for (int i = tokens.Count - maxTokens; i < tokens.Count; i++)
{
textIndex -= tokens[i].Offset.Length;
}
}
return maxTokens;
}
private int EncodeToIdsInternal(ReadOnlySpan<char> text, IList<int>? accumulatedIds, out int textLength, int maxTokens)
{
if (text.IsEmpty)
{
textLength = 0;
return 0;
}
if (_cache.TryGetValue(text, out List<Token>? hit))
{
return EncodeToIdsResult(hit, accumulatedIds, maxTokens, text.Length, out textLength);
}
char[] token = ArrayPool<char>.Shared.Rent(text.Length);
@ -275,6 +347,7 @@ namespace Microsoft.ML.Tokenizers
{
ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping);
textLength = 0;
return 0;
}
@ -283,24 +356,58 @@ namespace Microsoft.ML.Tokenizers
ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping);
if (accumulatedIds is not null)
return EncodeToIdsResult(result, accumulatedIds, maxTokens, text.Length, out textLength);
}
private int EncodeToIdsFromEndInternal(ReadOnlySpan<char> text, IList<int>? accumulatedIds, out int textIndex, int maxTokens)
{
if (text.IsEmpty)
{
foreach (var t in result)
textIndex = text.Length;
return 0;
}
if (_cache.TryGetValue(text, out List<Token>? hit))
{
return EncodeToIdsFromEndResult(hit, accumulatedIds, maxTokens, text.Length, out textIndex);
}
char[] token = ArrayPool<char>.Shared.Rent(text.Length);
int[] indexMapping = ArrayPool<int>.Shared.Rent(text.Length);
int newTokenIndex = 0;
for (int i = 0; i < text.Length; i++)
{
if (_byteToUnicode.TryGetValue(text[i], out var value))
{
accumulatedIds.Add(t.Id);
token[newTokenIndex] = value;
indexMapping[newTokenIndex] = i;
newTokenIndex++;
}
}
return result.Count;
if (newTokenIndex == 0)
{
ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping);
textIndex = text.Length;
return 0;
}
List<Token> result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
_cache.Set(text.ToString(), result);
ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping);
return EncodeToIdsFromEndResult(result, accumulatedIds, maxTokens, text.Length, out textIndex);
}
/// <summary>
/// Map the token to encoded Id.
/// </summary>
/// <param name="token">The token to map to the Id.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
/// <returns>The mapped Id of the token.</returns>
public override int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out int value) ? value : null;
public override int? MapTokenToId(ReadOnlySpan<char> token) => _vocab.TryGetValue(token, out int value) ? value : null;
/// <summary>
/// Convert a list of tokens Ids to highest occurrence rankings.

Просмотреть файл

@ -16,22 +16,23 @@ namespace Microsoft.ML.Tokenizers
/// <summary>
/// Encode a text to a list of tokens.
/// </summary>
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token.</param>
/// <param name="text">The text to encode.</param>
/// <returns>The list of tokens generated from the text tokenization.</returns>
public abstract IReadOnlyList<Token> Encode(string text, bool isSpecialToken = false);
public abstract IReadOnlyList<Token> Encode(ReadOnlySpan<char> text);
/// <summary>
/// Encode a text to a list of Ids and add them to the accumulatedIds list.
/// </summary>
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token.</param>
/// <param name="text">The text to encode. </param>
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
/// <remarks>
/// This method does the default implementation that uses the Encode method to get the token's Ids.
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
/// </remarks>
public virtual void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds)
public virtual int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue)
{
if (accumulatedIds is null)
{
@ -39,60 +40,126 @@ namespace Microsoft.ML.Tokenizers
}
// Default implementation is not optimized for memory allocation. It is recommended to override this method for the sake of the performance.
var tokens = Encode(text.ToString());
foreach (var token in tokens)
textLength = 0;
var tokens = Encode(text);
int count = Math.Min(tokens.Count, maxTokens);
for (int i = 0; i < count; i++)
{
accumulatedIds.Add(token.Id);
textLength += tokens[i].Offset.Length;
accumulatedIds.Add(tokens[i].Id);
}
return count;
}
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token.</param>
/// <param name="text">The text to encode.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
/// <remarks>
/// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids.
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
/// </remarks>
public virtual int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken)
public virtual int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue)
{
if (maxTokens <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokens), "The maximum number of tokens must be greater than 0.");
}
var ids = new List<int>();
EncodeToIds(text, isSpecialToken, ids);
return ids.Count;
if (maxTokens == int.MaxValue)
{
EncodeToIds(text, ids, out _);
textLength = text.Length;
return ids.Count;
}
IReadOnlyList<Token> tokens = Encode(text);
textLength = 0;
int count = Math.Min(tokens.Count, maxTokens);
for (int i = 0; i < count; i++)
{
textLength += tokens[i].Offset.Length;
}
return count;
}
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
/// <remarks>
/// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids.
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
/// </remarks>
public virtual int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue)
{
if (maxTokens <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokens), "The maximum number of tokens must be greater than 0.");
}
var ids = new List<int>();
if (maxTokens == int.MaxValue)
{
EncodeToIds(text, ids, out _);
textIndex = 0;
return ids.Count;
}
IReadOnlyList<Token> tokens = Encode(text);
textIndex = text.Length;
int count = Math.Min(tokens.Count, maxTokens);
int tokensCount = tokens.Count;
int end = tokensCount - count;
for (int i = tokensCount - 1; i >= end; i--)
{
textIndex -= tokens[i].Offset.Length;
}
return count;
}
/// <summary>
/// Map the token to encoded id with the option to skip the special tokens.
/// </summary>
/// <param name="token">The token to map to Id</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
/// <returns>The mapped Id of the token.</returns>
public abstract int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true);
public abstract int? MapTokenToId(ReadOnlySpan<char> token);
/// <summary>
/// Map the encoded Id to the token.
/// </summary>
/// <param name="id">The Id to map to the token.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the decoding.</param>
/// <returns>The mapped token of the Id.</returns>
public abstract string? MapIdToToken(int id, bool considerSpecialTokens = true);
public abstract string? MapIdToToken(int id);
/// <summary>
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="considerSpecialTokens">Whether the special tokens should be kept in the decoded string.</param>
/// <param name="decoder">The optional Decoder to merge the given list of tokens in a string.</param>
/// <returns>The decoded string.</returns>
public virtual string? Decode(IEnumerable<int> ids, TokenizerDecoder? decoder = null, bool considerSpecialTokens = true)
public virtual string? Decode(IEnumerable<int> ids, TokenizerDecoder? decoder = null)
{
List<string> tokens = new List<string>();
foreach (int id in ids)
{
if (MapIdToToken(id, considerSpecialTokens) is string s)
if (MapIdToToken(id) is string s)
{
tokens.Add(s);
}

Просмотреть файл

@ -154,11 +154,10 @@ namespace Microsoft.ML.Tokenizers
/// <summary>
/// Encode a text to a list of tokens.
/// </summary>
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token.</param>
/// <param name="text">The text to encode.</param>
/// <returns>The list of tokens generated from the text tokenization.</returns>
/// <remarks>The input text has to be normalized before calling this method.</remarks>
public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = false) => Encode(text, AddBeginningOfSentence, AddEndOfSentence);
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text) => Encode(text, AddBeginningOfSentence, AddEndOfSentence);
/// <summary>
/// Encode a text to a list of tokens.
@ -168,13 +167,8 @@ namespace Microsoft.ML.Tokenizers
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <returns>The list of tokens generated from the text tokenization.</returns>
/// <remarks>The input text has to be normalized before calling this method.</remarks>
public IReadOnlyList<Token> Encode(string text, bool addBeginOfSentence, bool addEndOfSentence)
public IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence)
{
if (text is null)
{
throw new ArgumentNullException(nameof(text));
}
if (text.Length == 0)
{
return Array.Empty<Token>();
@ -182,7 +176,7 @@ namespace Microsoft.ML.Tokenizers
BpeSymbol[] symbols = ArrayPool<BpeSymbol>.Shared.Rent(text.Length);
Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text.AsSpan(), symbols);
Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols);
List<Token> tokens = new();
@ -198,7 +192,7 @@ namespace Microsoft.ML.Tokenizers
if (id == UninitializedId)
{
if (_vocab.TryGetValue(text.AsSpan().Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo))
if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo))
{
id = tokenInfo.Id;
type = tokenInfo.Type;
@ -214,19 +208,19 @@ namespace Microsoft.ML.Tokenizers
{
if (id == UnknownId && ByteFallback)
{
EncodeAsBytes(text.AsSpan().Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index);
EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index);
}
else
{
tokens.Add(new Token(
id,
GetTokenString(id, symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length, text.AsSpan()),
GetTokenString(id, symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length, text),
(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length)));
}
continue;
}
Segment(symbols[index].pieceSpan, text.AsSpan());
Segment(symbols[index].pieceSpan, text);
}
ArrayPool<BpeSymbol>.Shared.Return(symbols);
@ -314,11 +308,14 @@ namespace Microsoft.ML.Tokenizers
/// <summary>
/// Encode a text to a list of Ids and add them to the accumulatedIds list.
/// </summary>
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token.</param>
/// <param name="text">The text to encode.</param>
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
/// <remarks>The input text has to be normalized before calling this method.</remarks>
public override void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIds(text, AddBeginningOfSentence, AddEndOfSentence, accumulatedIds);
public override int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue)
=> EncodeToIds(text, AddBeginningOfSentence, AddEndOfSentence, accumulatedIds, out textLength, maxTokens);
/// <summary>
/// Encode a text to a list of Ids and add them to the accumulatedIds list.
@ -327,26 +324,29 @@ namespace Microsoft.ML.Tokenizers
/// <param name="addBeginOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
/// <remarks>The input text has to be normalized before calling this method.</remarks>
public void EncodeToIds(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, IList<int> accumulatedIds)
public int EncodeToIds(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue)
{
if (maxTokens <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokens), "The maximum number of tokens must be greater than 0.");
}
textLength = 0;
if (text.IsEmpty)
{
return;
return 0;
}
int idsCount = 0;
if (addBeginOfSentence)
{
accumulatedIds.Add(BeginningOfSentenceId);
}
if (text.IsEmpty)
{
if (addEndOfSentence)
{
accumulatedIds.Add(EndOfSentenceId);
}
return;
idsCount++;
}
BpeSymbol[] symbols = ArrayPool<BpeSymbol>.Shared.Rent(text.Length);
@ -376,34 +376,64 @@ namespace Microsoft.ML.Tokenizers
{
if (id == UnknownId && ByteFallback)
{
EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index);
if (!EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref textLength))
{
break;
}
}
else
{
accumulatedIds.Add(id);
if (idsCount < maxTokens)
{
accumulatedIds.Add(id);
textLength += symbols[index].pieceSpan.Length;
idsCount++;
}
else
{
return idsCount;
}
}
continue;
}
Segment(symbols[index].pieceSpan, text);
if (!Segment(symbols[index].pieceSpan, text, ref textLength))
{
break;
}
}
ArrayPool<BpeSymbol>.Shared.Return(symbols);
if (addEndOfSentence)
{
accumulatedIds.Add(EndOfSentenceId);
if (idsCount < maxTokens)
{
accumulatedIds.Add(EndOfSentenceId);
idsCount++;
}
}
return idsCount;
// Encode the Unknown token to bytes.
void EncodeAsBytes(ReadOnlySpan<char> text, int index)
bool EncodeAsBytes(ReadOnlySpan<char> text, int index, ref int textLength)
{
for (int i = 0; i < text.Length; i++)
{
char c = text[i];
if (c <= 0x7F)
{
accumulatedIds.Add((int)c + _byteCodeToIdOffset); // byte code is mapped to the to the Ids starting from 4.
if (idsCount < maxTokens)
{
textLength++;
accumulatedIds.Add((int)c + _byteCodeToIdOffset); // byte code is mapped to the to the Ids starting from 4.
idsCount++;
}
else
{
return false;
}
}
else
{
@ -419,9 +449,21 @@ namespace Microsoft.ML.Tokenizers
// Need to convert the text into UTF-8 bytes and then encode the bytes.
int bytesWritten = Helpers.GetUtf8Bytes(text.Slice(i), utf8Bytes);
for (int j = 0; j < bytesWritten; j++)
bool ret;
if (idsCount + bytesWritten <= maxTokens)
{
accumulatedIds.Add((int)utf8Bytes[j] + _byteCodeToIdOffset); // byte code is mapped to the to the Ids starting from 4.
for (int j = 0; j < bytesWritten; j++)
{
accumulatedIds.Add((int)utf8Bytes[j] + _byteCodeToIdOffset); // byte code is mapped to the to the Ids starting from 4.
}
textLength += text.Length - i;
ret = true;
}
else
{
ret = false;
}
if (arrayPoolArray is not null)
@ -429,40 +471,60 @@ namespace Microsoft.ML.Tokenizers
ArrayPool<byte>.Shared.Return(arrayPoolArray);
}
break;
return ret;
}
}
return true;
}
void Segment((int Index, int Length) pieceSpan, ReadOnlySpan<char> text)
bool Segment((int Index, int Length) pieceSpan, ReadOnlySpan<char> text, ref int textLength)
{
if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id))
{
EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index);
return;
return EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index, ref textLength);
}
if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused ||
revMerge is null ||
!revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge))
{
accumulatedIds.Add(id.Id);
return;
if (idsCount < maxTokens)
{
accumulatedIds.Add(id.Id);
textLength += pieceSpan.Length;
idsCount++;
return true;
}
else
{
return false;
}
}
Segment((merge.LeftIndex, merge.LeftLen), text);
Segment((merge.RightIndex, merge.RightLen), text);
return Segment((merge.LeftIndex, merge.LeftLen), text, ref textLength) && Segment((merge.RightIndex, merge.RightLen), text, ref textLength);
}
}
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token.</param>
/// <param name="text">The text to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
/// <remarks>The input text has to be normalized before calling this method.</remarks>
public override int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken) => CountTokens(text, AddBeginningOfSentence, AddEndOfSentence);
public override int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue) => CountTokens(text, AddBeginningOfSentence, AddEndOfSentence, out textLength, maxTokens);
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue) => CountTokensFromEnd(text, AddBeginningOfSentence, AddEndOfSentence, out textIndex, maxTokens);
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
@ -470,10 +532,13 @@ namespace Microsoft.ML.Tokenizers
/// <param name="text">The text to encode.</param>
/// <param name="addBeginOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
/// <remarks>The input text has to be normalized before calling this method.</remarks>
public int CountTokens(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence)
public int CountTokens(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, out int textLength, int maxTokens = int.MaxValue)
{
textLength = 0;
if (text.IsEmpty)
{
return 0;
@ -508,36 +573,61 @@ namespace Microsoft.ML.Tokenizers
{
if (id == UnknownId && ByteFallback)
{
EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index);
if (!EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref textLength))
{
break;
}
}
else
{
tokenCount++;
if (tokenCount < maxTokens)
{
tokenCount++;
textLength += symbols[index].pieceSpan.Length;
}
else
{
break;
}
}
continue;
}
Segment(symbols[index].pieceSpan, text);
if (!Segment(symbols[index].pieceSpan, text, ref textLength))
{
break;
}
}
ArrayPool<BpeSymbol>.Shared.Return(symbols);
if (addEndOfSentence)
{
tokenCount++;
if (tokenCount < maxTokens)
{
tokenCount++;
}
}
return tokenCount;
// Encode the Unknown token to bytes.
void EncodeAsBytes(ReadOnlySpan<char> text, int index)
bool EncodeAsBytes(ReadOnlySpan<char> text, int index, ref int textLength)
{
for (int i = 0; i < text.Length; i++)
{
char c = text[i];
if (c <= 0x7F)
{
tokenCount++;
if (tokenCount < maxTokens)
{
tokenCount++;
textLength++;
}
else
{
return false;
}
}
else
{
@ -552,36 +642,233 @@ namespace Microsoft.ML.Tokenizers
}
// Need to convert the text into UTF-8 bytes and then encode the bytes.
tokenCount += Helpers.GetUtf8Bytes(text.Slice(i), utf8Bytes);
int encodedCount = Helpers.GetUtf8Bytes(text.Slice(i), utf8Bytes);
bool ret;
if (tokenCount + encodedCount <= maxTokens)
{
tokenCount += encodedCount;
textLength += text.Length - i;
ret = true;
}
else
{
ret = false;
}
if (arrayPoolArray is not null)
{
ArrayPool<byte>.Shared.Return(arrayPoolArray);
}
break;
return ret;
}
}
return true;
}
void Segment((int Index, int Length) pieceSpan, ReadOnlySpan<char> text)
bool Segment((int Index, int Length) pieceSpan, ReadOnlySpan<char> text, ref int textLength)
{
if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id))
{
EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index);
return;
return EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index, ref textLength);
}
if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused ||
revMerge is null ||
!revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge))
{
tokenCount++;
return;
if (tokenCount < maxTokens)
{
tokenCount++;
textLength += pieceSpan.Length;
return true;
}
else
{
return false;
}
}
Segment((merge.LeftIndex, merge.LeftLen), text);
Segment((merge.RightIndex, merge.RightLen), text);
return Segment((merge.LeftIndex, merge.LeftLen), text, ref textLength) && Segment((merge.RightIndex, merge.RightLen), text, ref textLength);
}
}
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="addBeginOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
/// <remarks>The input text has to be normalized before calling this method.</remarks>
public int CountTokensFromEnd(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, out int textIndex, int maxTokens = int.MaxValue)
{
textIndex = text.Length;
if (text.IsEmpty)
{
return 0;
}
int tokenCount = addEndOfSentence ? 1 : 0;
BpeSymbol[] symbols = ArrayPool<BpeSymbol>.Shared.Rent(text.Length);
Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols);
// Move to the last symbol.
int lastSymbolIndex = 0;
while (symbols[lastSymbolIndex].next != -1 && lastSymbolIndex < symbols.Length)
{
lastSymbolIndex = symbols[lastSymbolIndex].next;
}
for (int index = lastSymbolIndex; index >= 0; index = symbols[index].prev)
{
int id = symbols[index].id;
byte type = symbols[index].type;
if (id == UninitializedId)
{
if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo))
{
id = tokenInfo.Id;
type = tokenInfo.Type;
}
else
{
id = UnknownId;
type = 0;
}
}
if (type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused)
{
if (id == UnknownId && ByteFallback)
{
if (!EncodeAsBytesFromEnd(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref textIndex))
{
break;
}
}
else
{
if (tokenCount < maxTokens)
{
tokenCount++;
textIndex -= symbols[index].pieceSpan.Length;
}
else
{
break;
}
}
continue;
}
if (!SegmentFromEnd(symbols[index].pieceSpan, text, ref textIndex))
{
break;
}
}
ArrayPool<BpeSymbol>.Shared.Return(symbols);
if (AddBeginningOfSentence)
{
if (tokenCount < maxTokens)
{
tokenCount++;
}
}
return tokenCount;
// Encode the Unknown token to bytes.
bool EncodeAsBytesFromEnd(ReadOnlySpan<char> text, int index, ref int textIndex)
{
for (int i = text.Length - 1; i >= 0; i--)
{
char c = text[i];
if (c <= 0x7F)
{
if (tokenCount < maxTokens)
{
tokenCount++;
textIndex--;
}
else
{
return false;
}
}
else
{
Span<byte> utf8Bytes = stackalloc byte[100];
byte[]? arrayPoolArray = null;
int len = Encoding.UTF8.GetMaxByteCount(text.Length - i);
if (len > utf8Bytes.Length)
{
arrayPoolArray = ArrayPool<byte>.Shared.Rent(len);
utf8Bytes = arrayPoolArray;
}
// Need to convert the text into UTF-8 bytes and then encode the bytes.
int encodedCount = Helpers.GetUtf8Bytes(text.Slice(0, i + 1), utf8Bytes);
bool ret;
if (tokenCount + encodedCount <= maxTokens)
{
tokenCount += encodedCount;
textIndex -= i + 1;
ret = true;
}
else
{
ret = false;
}
if (arrayPoolArray is not null)
{
ArrayPool<byte>.Shared.Return(arrayPoolArray);
}
return ret;
}
}
return true;
}
bool SegmentFromEnd((int Index, int Length) pieceSpan, ReadOnlySpan<char> text, ref int textIndex)
{
if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id))
{
return EncodeAsBytesFromEnd(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index, ref textIndex);
}
if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused ||
revMerge is null ||
!revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge))
{
if (tokenCount < maxTokens)
{
tokenCount++;
textIndex -= pieceSpan.Length;
return true;
}
else
{
return false;
}
}
// Segment the right part first.
return SegmentFromEnd((merge.RightIndex, merge.RightLen), text, ref textIndex) && SegmentFromEnd((merge.LeftIndex, merge.LeftLen), text, ref textIndex);
}
}
@ -589,32 +876,28 @@ namespace Microsoft.ML.Tokenizers
/// Map the token to encoded id with the option to skip the special tokens.
/// </summary>
/// <param name="token">The token to map to Id</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
/// <returns>The mapped Id of the token.</returns>
public override int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true)
public override int? MapTokenToId(ReadOnlySpan<char> token)
=> _vocab.TryGetValue(token, out (int Id, float Score, byte Type) value) ? value.Id : null;
/// <summary>
/// Map the encoded Id to the token.
/// </summary>
/// <param name="id">The Id to map to the token.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the decoding.</param>
/// <returns>The mapped token of the Id.</returns>
public override string? MapIdToToken(int id, bool considerSpecialTokens = true)
public override string? MapIdToToken(int id)
=> _vocabReverse.TryGetValue(id, out string? value) ? value : null;
/// <summary>
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="considerSpecialTokens">Whether the special tokens should be kept in the decoded string.</param>
/// <param name="decoder">The optional Decoder to merge the given list of tokens in a string.</param>
/// <returns>The decoded string.</returns>
/// <remarks>
/// The decoder is not used here because the SentencePiece Bpe model knows how to decode the ids in additions to avoid any performance overhead.
/// considerSpecialTokens is not used here because the SentencePiece Bpe model always remove unknown or control tokens during the decoding.
/// </remarks>
public override string? Decode(IEnumerable<int> ids, TokenizerDecoder? decoder = null, bool considerSpecialTokens = true)
public override string? Decode(IEnumerable<int> ids, TokenizerDecoder? decoder = null)
{
if (ids is null)
{

Просмотреть файл

@ -24,12 +24,10 @@ namespace Microsoft.ML.Tokenizers
{
private readonly Dictionary<ReadOnlyMemory<byte>, int> _encoder;
private readonly Dictionary<int, ReadOnlyMemory<byte>> _decoder;
private readonly LruCache<int[]> _cache;
private readonly Dictionary<StringSpanOrdinalKey, int>? _specialTokensEncoder;
private Dictionary<string, int>? _specialTokensEncoderOriginal;
private readonly Dictionary<int, string>? _specialTokensDecoder;
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;
private readonly LruCache<(int[] Bytes, string Token)> _cache;
private readonly Dictionary<StringSpanOrdinalKey, (int Id, string Token)> _vocab;
private IReadOnlyDictionary<string, int>? _vocabOriginal;
private const int MaxWordLengthToCache = 15;
/// <summary>
/// Create a new Tiktoken tokenizer's model object.
@ -68,7 +66,7 @@ namespace Microsoft.ML.Tokenizers
internal Tiktoken(
Dictionary<ReadOnlyMemory<byte>, int> encoder,
Dictionary<int, ReadOnlyMemory<byte>> decoder,
Dictionary<StringSpanOrdinalKey, int> vocab,
Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab,
IReadOnlyDictionary<string, int>? specialTokens,
int cacheSize = LruCache<int[]>.DefaultCacheSize)
{
@ -76,23 +74,24 @@ namespace Microsoft.ML.Tokenizers
_decoder = decoder ?? throw new ArgumentNullException(nameof(decoder));
_vocab = vocab ?? throw new ArgumentNullException(nameof(vocab));
Debug.Assert(encoder.Count == decoder.Count);
_encoder = encoder!;
_decoder = decoder!;
_vocab = vocab!;
_cache = new LruCache<int[]>(cacheSize);
_cache = new LruCache<(int[] Bytes, string Token)>(cacheSize);
(_specialTokensEncoder, _specialTokensDecoder) = CreateEncoderDecoder(specialTokens);
SpecialTokens = specialTokens;
CacheSpecialTokensEncoding(specialTokens);
}
private Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTokens, int cacheSize, bool disposeStream)
{
try
{
_cache = new LruCache<int[]>(cacheSize);
(_encoder, _vocab, _decoder) = LoadTikTokenBpeAsync(vocabStream, useAsync: false).GetAwaiter().GetResult();
(_specialTokensEncoder, _specialTokensDecoder) = CreateEncoderDecoder(specialTokens);
_cache = new LruCache<(int[] Bytes, string Token)>(cacheSize);
(_encoder, _vocab, _decoder) = LoadTiktokenBpeAsync(vocabStream, useAsync: false).GetAwaiter().GetResult();
SpecialTokens = specialTokens;
CacheSpecialTokensEncoding(specialTokens);
}
finally
{
@ -103,15 +102,19 @@ namespace Microsoft.ML.Tokenizers
}
}
private static (Dictionary<StringSpanOrdinalKey, int>?, Dictionary<int, string>?) CreateEncoderDecoder(IReadOnlyDictionary<string, int>? specialTokens)
private void CacheSpecialTokensEncoding(IReadOnlyDictionary<string, int>? specialTokens)
{
Debug.Assert(_cache is not null);
Debug.Assert(_decoder is not null);
if (specialTokens is not null)
{
var encoder = specialTokens.ToDictionary(e => new StringSpanOrdinalKey(e.Key), e => e.Value);
return (encoder, encoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key.Data!));
foreach (KeyValuePair<string, int> specialToken in specialTokens)
{
_decoder![specialToken.Value] = Encoding.UTF8.GetBytes(specialToken.Key);
_cache!.Add(specialToken.Key, (new[] { specialToken.Value }, specialToken.Key));
}
}
return (null, null);
}
/// <summary>
@ -133,8 +136,8 @@ namespace Microsoft.ML.Tokenizers
throw new ArgumentNullException(nameof(vocabStream));
}
(Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) =
await LoadTikTokenBpeAsync(vocabStream, useAsync: true, cancellationToken).ConfigureAwait(false);
(Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) =
await LoadTiktokenBpeAsync(vocabStream, useAsync: true, cancellationToken).ConfigureAwait(false);
return new Tiktoken(encoder, decoder, vocab, specialTokens, cacheSize);
}
@ -170,11 +173,11 @@ namespace Microsoft.ML.Tokenizers
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>Map of byte[] to integer token id</returns>
/// <exception cref="InvalidOperationException"></exception>
internal static async ValueTask<(Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<StringSpanOrdinalKey, int>, Dictionary<int, ReadOnlyMemory<byte>>)> LoadTikTokenBpeAsync(
internal static async ValueTask<(Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<StringSpanOrdinalKey, (int Id, string Token)>, Dictionary<int, ReadOnlyMemory<byte>>)> LoadTiktokenBpeAsync(
Stream vocabStream, bool useAsync, CancellationToken cancellationToken = default)
{
var encoder = new Dictionary<ReadOnlyMemory<byte>, int>(ReadOnlyMemoryByteComparer.Instance);
var vocab = new Dictionary<StringSpanOrdinalKey, int>();
var vocab = new Dictionary<StringSpanOrdinalKey, (int Id, string Token)>();
var decoder = new Dictionary<int, ReadOnlyMemory<byte>>();
try
@ -212,7 +215,7 @@ namespace Microsoft.ML.Tokenizers
if (decodedToken.IndexOf('\uFFFD') < 0)
{
vocab[new StringSpanOrdinalKey(decodedToken)] = rank;
vocab[new StringSpanOrdinalKey(decodedToken)] = (rank, decodedToken);
}
}
else
@ -231,140 +234,98 @@ namespace Microsoft.ML.Tokenizers
}
/// <summary>
/// Encode a split text string to a list of tokens.
/// Encode text to a list of tokens.
/// </summary>
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token.</param>
/// <param name="text">The text to encode.</param>
/// <returns>The list of tokens generated from the text tokenization.</returns>
public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = false)
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text)
{
Token[] tokens;
if (string.IsNullOrEmpty(text))
if (text.IsEmpty)
{
return Array.Empty<Token>();
}
if (isSpecialToken)
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
{
if (_specialTokensEncoder?.TryGetValue(text, out int id) is true)
{
return new List<Token> { new(id, text, (0, text.Length)) };
}
throw new InvalidOperationException($"The special token {text} doesn't exist in the tokenizer");
}
if (_cache.TryGetValue(text, out int[]? ids))
{
tokens = new Token[ids.Length];
tokens[0] = new Token(ids[0], text, (0, text.Length));
for (int i = 1; i < ids.Length; i++)
tokens = new Token[value.Ids.Length];
tokens[0] = new Token(value.Ids[0], value.Token, (0, value.Token.Length));
for (int i = 1; i < value.Ids.Length; i++)
{
// One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width.
tokens[i] = new Token(ids[i], "", (text.Length, 0));
tokens[i] = new Token(value.Ids[i], "", (text.Length, 0));
}
return tokens;
}
// cache miss
if (_vocab.TryGetValue(text, out int mappedId))
if (_vocab.TryGetValue(text, out (int Id, string Token) mappedId))
{
return new Token[1] { new(mappedId, text, (0, text.Length)) };
}
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
int encodedLength = Helpers.GetUtf8Bytes(text.AsSpan(), arrayPoolArray);
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
Debug.Assert(encodedIds.Length > 0);
_cache.Add(text, encodedIds);
tokens = new Token[encodedIds.Length];
tokens[0] = new Token(encodedIds[0], text, (0, text.Length));
for (int i = 1; i < encodedIds.Length; i++)
{
// One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width.
tokens[i] = new Token(encodedIds[i], "", (text.Length, 0));
}
ArrayPool<byte>.Shared.Return(arrayPoolArray);
return tokens;
}
/// <summary>
/// Encode text to a list of Ids.
/// </summary>
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token.</param>
/// <param name="accumulatedIds">The list of accumulated Ids.</param>
public override void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds)
{
if (text.IsEmpty)
{
return;
}
if (isSpecialToken)
{
if (_specialTokensEncoder?.TryGetValue(text, out int id) is true)
{
accumulatedIds.Add(id);
}
return;
}
if (_cache.TryGetValue(text, out int[]? tokenIds))
{
accumulatedIds.AddRange(tokenIds);
return;
}
if (_vocab.TryGetValue(text, out int mappedId))
{
accumulatedIds.Add(mappedId);
return;
return new Token[1] { new(mappedId.Id, mappedId.Token, (0, mappedId.Token.Length)) };
}
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
int encodedLength = Helpers.GetUtf8Bytes(text, arrayPoolArray);
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
_cache.Add(text.ToString(), encodedIds);
accumulatedIds.AddRange(encodedIds);
ArrayPool<byte>.Shared.Return(arrayPoolArray);
return;
Debug.Assert(encodedIds.Length > 0);
string textAsString = text.ToString();
if (text.Length <= MaxWordLengthToCache)
{
_cache.Add(textAsString, (encodedIds, textAsString));
}
tokens = new Token[encodedIds.Length];
tokens[0] = new Token(encodedIds[0], textAsString, (0, text.Length));
for (int i = 1; i < encodedIds.Length; i++)
{
// One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width.
tokens[i] = new Token(encodedIds[i], "", (text.Length, 0));
}
return tokens;
}
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// Encode text to a list of Ids.
/// </summary>
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token.</param>
/// <param name="text">The text to encode.</param>
/// <param name="accumulatedIds">The list of accumulated Ids.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken)
public override int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue)
{
Debug.Assert(maxTokens > 0);
if (text.IsEmpty)
{
textLength = 0;
return 0;
}
if (isSpecialToken && _specialTokensEncoder is not null)
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
{
return _specialTokensEncoder.TryGetValue(text, out _) ? 1 : 0;
if (value.Ids.Length <= maxTokens)
{
accumulatedIds.AddRange(value.Ids);
textLength = text.Length;
return value.Ids.Length;
}
textLength = 0;
return 0;
}
if (_cache.TryGetValue(text, out int[] ids))
{
return ids.Length;
}
if (_vocab.TryGetValue(text, out _))
if (_vocab.TryGetValue(text, out (int Id, string Token) mappedId))
{
textLength = text.Length;
accumulatedIds.Add(mappedId.Id);
return 1;
}
@ -372,46 +333,178 @@ namespace Microsoft.ML.Tokenizers
int encodedLength = Helpers.GetUtf8Bytes(text, arrayPoolArray);
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
_cache.Add(text.ToString(), encodedIds);
if (text.Length <= MaxWordLengthToCache)
{
string textAsString = text.ToString();
_cache.Add(textAsString, (encodedIds, textAsString));
}
int result;
if (encodedIds.Length <= maxTokens)
{
accumulatedIds.AddRange(encodedIds);
textLength = text.Length;
result = encodedIds.Length;
}
else
{
textLength = 0;
result = 0;
}
ArrayPool<byte>.Shared.Return(arrayPoolArray);
return encodedIds.Length;
return result;
}
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue)
{
Debug.Assert(maxTokens > 0);
if (text.IsEmpty)
{
textLength = 0;
return 0;
}
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
{
if (value.Ids.Length <= maxTokens)
{
textLength = text.Length;
return value.Ids.Length;
}
textLength = 0;
return 0;
}
if (_vocab.TryGetValue(text, out _))
{
textLength = text.Length;
return 1;
}
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
int encodedLength = Helpers.GetUtf8Bytes(text, arrayPoolArray);
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
if (text.Length <= MaxWordLengthToCache)
{
string textAsString = text.ToString();
_cache.Add(textAsString, (encodedIds, textAsString));
}
int result;
if (encodedIds.Length <= maxTokens)
{
textLength = text.Length;
result = encodedIds.Length;
}
else
{
textLength = 0;
result = 0;
}
ArrayPool<byte>.Shared.Return(arrayPoolArray);
return result;
}
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue)
{
Debug.Assert(maxTokens > 0);
if (text.IsEmpty)
{
textIndex = 0;
return 0;
}
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
{
if (value.Ids.Length <= maxTokens)
{
textIndex = 0;
return value.Ids.Length;
}
textIndex = text.Length;
return 0;
}
if (_vocab.TryGetValue(text, out _))
{
textIndex = 0;
return 1;
}
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
int encodedLength = Helpers.GetUtf8Bytes(text, arrayPoolArray);
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
if (text.Length <= MaxWordLengthToCache)
{
string textAsString = text.ToString();
_cache.Add(textAsString, (encodedIds, textAsString));
}
int result;
if (encodedIds.Length <= maxTokens)
{
textIndex = 0;
result = encodedIds.Length;
}
else
{
textIndex = text.Length;
result = 0;
}
ArrayPool<byte>.Shared.Return(arrayPoolArray);
return result;
}
/// <summary>
/// Map the token to encoded Id.
/// </summary>
/// <param name="token">The token to map to the Id.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
/// <returns>The mapped Id of the token.</returns>
public override int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true)
public override int? MapTokenToId(ReadOnlySpan<char> token)
{
if (token.IsEmpty)
{
return 0;
}
if (considerSpecialTokens && _specialTokensEncoder is not null)
if (_cache.TryGetValue(token, out (int[] Ids, string Token) value))
{
if (_specialTokensEncoder.TryGetValue(token, out int specialTokenId))
if (value.Ids.Length == 1)
{
return specialTokenId;
}
}
if (_cache.TryGetValue(token, out int[] ids))
{
if (ids.Length == 1)
{
return ids[0];
return value.Ids[0];
}
return null;
}
if (_vocab.TryGetValue(token, out int id))
if (_vocab.TryGetValue(token, out (int Id, string Token) id))
{
return id;
return id.Id;
}
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(token.Length));
@ -420,7 +513,12 @@ namespace Microsoft.ML.Tokenizers
int encodedLength = Helpers.GetUtf8Bytes(token, arrayPoolArray);
int[] idsToCache = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
_cache.Add(token.ToString(), idsToCache);
if (token.Length <= MaxWordLengthToCache)
{
string tokenAsString = token.ToString();
_cache.Add(tokenAsString, (idsToCache, tokenAsString));
}
if (idsToCache.Length == 1)
{
@ -439,15 +537,9 @@ namespace Microsoft.ML.Tokenizers
/// Map the encoded Id to the token.
/// </summary>
/// <param name="id">The Id to map to the token.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the decoding.</param>
/// <returns>The mapped token of the Id.</returns>
public override string? MapIdToToken(int id, bool considerSpecialTokens = true)
public override string? MapIdToToken(int id)
{
if (considerSpecialTokens && _specialTokensDecoder is not null && _specialTokensDecoder.TryGetValue(id, out string? token))
{
return token;
}
if (_decoder.TryGetValue(id, out ReadOnlyMemory<byte> tokenBytes))
{
return Helpers.GetString(tokenBytes.Span);
@ -460,10 +552,9 @@ namespace Microsoft.ML.Tokenizers
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="considerSpecialTokens">Whether the special tokens should be kept in the decoded string.</param>
/// <param name="decoder">The optional Decoder to merge the given list of tokens in a string.</param>
/// <returns>The decoded string.</returns>
public override string? Decode(IEnumerable<int> ids, TokenizerDecoder? decoder = null, bool considerSpecialTokens = true)
public override string? Decode(IEnumerable<int> ids, TokenizerDecoder? decoder = null)
{
// Tiktoken doesn't guarantee a one-to-one correspondence between IDs and UTF-16 words.
// Consequently, decoding individual IDs into UTF-16 string is not supported; instead, decoding all IDs must be performed collectively.
@ -483,8 +574,6 @@ namespace Microsoft.ML.Tokenizers
Span<byte> utf8Bytes = stackalloc byte[256];
int utf8ByteCount = 0;
bool useSpecialTokens = considerSpecialTokens && _specialTokensDecoder is not null;
foreach (int id in ids)
{
if (_decoder.TryGetValue(id, out ReadOnlyMemory<byte> tokenBytes))
@ -497,19 +586,6 @@ namespace Microsoft.ML.Tokenizers
tokenBytes.Span.CopyTo(utf8Bytes.Slice(utf8ByteCount));
utf8ByteCount += tokenBytes.Length;
}
else if (useSpecialTokens && _specialTokensDecoder!.TryGetValue(id, out string? token))
{
while (true)
{
if (Helpers.TryGetUtf8Bytes(token.AsSpan(), utf8Bytes.Slice(utf8ByteCount), out int bytesWritten))
{
utf8ByteCount += bytesWritten;
break;
}
ArrayPoolGrow(ref utf8Bytes, ref arrayPoolArray, utf8ByteCount + Encoding.UTF8.GetByteCount(token));
}
}
else
{
return null;
@ -543,12 +619,12 @@ namespace Microsoft.ML.Tokenizers
/// Gets the dictionary mapping tokens to Ids.
/// </summary>
/// <remarks>This may not contain the full set of vocabulary tokens, use Encoder to get the full set of vocabulary.</remarks>
public IReadOnlyDictionary<string, int> Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
public IReadOnlyDictionary<string, int> Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value.Id);
/// <summary>
/// Gets the dictionary mapping special tokens to Ids.
/// </summary>
public IReadOnlyDictionary<string, int>? SpecialTokensEncoder => _specialTokensEncoderOriginal ??= _specialTokensEncoder?.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
public IReadOnlyDictionary<string, int>? SpecialTokens { get; }
/// <summary>
/// Gets the dictionary mapping token bytes to Ids.
@ -732,31 +808,31 @@ namespace Microsoft.ML.Tokenizers
case ModelEncoding.Cl100kBase:
var specialTokens = new Dictionary<string, int>
{ { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} };
return CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
return CreateTiktokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
case ModelEncoding.P50kBase:
specialTokens = new Dictionary<string, int> { { EndOfText, 50256 } };
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
return CreateTiktokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
case ModelEncoding.P50kEdit:
specialTokens = new Dictionary<string, int>
{ { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } };
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
return CreateTiktokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
case ModelEncoding.R50kBase:
specialTokens = new Dictionary<string, int> { { EndOfText, 50256 } };
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
return CreateTiktokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
case ModelEncoding.GPT2:
specialTokens = new Dictionary<string, int> { { EndOfText, 50256 }, };
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
return CreateTiktokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
default:
throw new NotSupportedException($"The encoder '{modelEncoding}' is not supported.");
}
}
private static readonly ConcurrentDictionary<string, (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase);
private static readonly ConcurrentDictionary<string, (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase);
/// <summary>
/// Create tokenizer based on regex pattern, BPE rank file and special tokens
@ -768,7 +844,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="normalizer">To normalize the text before tokenization</param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>The tokenizer</returns>
private static async Task<Tokenizer> CreateTikTokenTokenizerAsync(
private static async Task<Tokenizer> CreateTiktokenTokenizerAsync(
Regex regex,
string mergeableRanksFileUrl,
Dictionary<string, int> specialTokens,
@ -784,17 +860,17 @@ namespace Microsoft.ML.Tokenizers
}
}
if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) cache))
if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) cache))
{
using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false))
{
cache = await LoadTikTokenBpeAsync(stream, useAsync: true, cancellationToken).ConfigureAwait(false);
cache = await LoadTiktokenBpeAsync(stream, useAsync: true, cancellationToken).ConfigureAwait(false);
}
_tiktokenCache.TryAdd(mergeableRanksFileUrl, cache);
}
return new Tokenizer(new Tiktoken(cache.encoder, cache.decoder, cache.vocab, specialTokens), new TikTokenPreTokenizer(regex, specialTokens), normalizer);
return new Tokenizer(new Tiktoken(cache.encoder, cache.decoder, cache.vocab, specialTokens), new TiktokenPreTokenizer(regex, specialTokens), normalizer);
}
internal static Tokenizer CreateTokenizerForModel(
@ -818,17 +894,17 @@ namespace Microsoft.ML.Tokenizers
}
if (!_tiktokenCache.TryGetValue(tiktokenConfiguration.Url,
out (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) cache))
out (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, (int I, string Token)> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) cache))
{
using Stream stream = Helpers.GetStream(_httpClient, tiktokenConfiguration.Url);
cache = LoadTikTokenBpeAsync(stream, useAsync: false).GetAwaiter().GetResult();
cache = LoadTiktokenBpeAsync(stream, useAsync: false).GetAwaiter().GetResult();
_tiktokenCache.TryAdd(tiktokenConfiguration.Url, cache);
}
return new Tokenizer(
new Tiktoken(cache.encoder, cache.decoder, cache.vocab, tiktokenConfiguration.SpecialTokens, LruCache<int[]>.DefaultCacheSize),
new TikTokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
normalizer);
}
}

Просмотреть файл

@ -202,6 +202,64 @@ namespace Microsoft.ML.Tokenizers
}
}
public int PopulateIdsUpToMax(IList<int> accumulatedIds, int maxTokens, out int textLength)
{
textLength = 0;
int count = Math.Min(SymbolsCount, maxTokens);
for (int i = 0; i < count; i++)
{
accumulatedIds.Add(_symbols[i].C);
textLength += _symbols[i].Len;
}
return count;
}
public int PopulateIdsUpToMaxFromEnd(IList<int> accumulatedIds, int maxTokens, int fullTextLength, out int textIndex)
{
textIndex = fullTextLength;
int count = Math.Min(SymbolsCount, maxTokens);
for (int i = SymbolsCount - count; i < SymbolsCount; i++)
{
accumulatedIds.Add(_symbols[i].C);
textIndex -= _symbols[i].Len;
}
return count;
}
public int CountIdsUpToMax(int maxTokens, out int textLength)
{
textLength = 0;
int count = Math.Min(SymbolsCount, maxTokens);
for (int i = 0; i < count; i++)
{
textLength += _symbols[i].Len;
}
return count;
}
public int CountIdsUpToMaxFromEnd(int maxTokens, int fullTextLength, out int textIndex)
{
textIndex = fullTextLength;
int count = Math.Min(SymbolsCount, maxTokens);
for (int i = SymbolsCount - count; i < SymbolsCount; i++)
{
textIndex -= _symbols[i].Len;
}
return count;
}
public Vec<int> GetChars()
{
Vec<int> chars = new Vec<int>();

Просмотреть файл

@ -1,10 +1,10 @@
## About
Microsoft.ML.Tokenizers supports various the implmentation of the tokenization used in the NLP transforms.
Microsoft.ML.Tokenizers supports various the implementation of the tokenization used in the NLP transforms.
## Key Features
* Extensisble tokenizer architecture that allows for specialization of Normalizer, PreTokenizer, Model/Encoder, Decoder
* Extensible tokenizer architecture that allows for specialization of Normalizer, PreTokenizer, Model/Encoder, Decoder
* BPE - Byte pair encoding model
* English Roberta model
* Tiktoken model
@ -21,8 +21,8 @@ using System.IO;
// Using Tiktoken Tokenizer
//
// initialize the tokenizer for `gpt-4` model, downloading data files
Tokenizer tokenizer = await Tokenizer.CreateTiktokenForModelAsync("gpt-4");
// initialize the tokenizer for `gpt-4` model
Tokenizer tokenizer = Tokenizer.CreateTiktokenForModel("gpt-4");
string source = "Text tokenization is the process of splitting a string into a list of tokens.";
@ -68,7 +68,7 @@ The main types provided by this library are:
* `Microsoft.ML.Tokenizers.Tokenizer`
* `Microsoft.ML.Tokenizers.Bpe`
* `Microsoft.ML.Tokenizers.EnglishRoberta`
* `Microsoft.ML.Tokenizers.TikToken`
* `Microsoft.ML.Tokenizers.Tiktoken`
* `Microsoft.ML.Tokenizers.TokenizerDecoder`
* `Microsoft.ML.Tokenizers.Normalizer`
* `Microsoft.ML.Tokenizers.PreTokenizer`

Просмотреть файл

@ -40,34 +40,25 @@ namespace Microsoft.ML.Tokenizers
/// </summary>
/// <param name="token">The token string</param>
/// <param name="offset">The offset mapping to the original string</param>
/// <param name="isSpecialToken">Indicates whether the token is a special token</param>
public Split(string token, (int Index, int Length) offset, bool isSpecialToken = false)
public Split(string token, (int Index, int Length) offset)
{
_tokenString = token;
Offset = offset;
IsSpecialToken = isSpecialToken;
}
internal Split(string originalString, string? token, (int Index, int Length) offset, bool isSpecialToken = false)
internal Split(string originalString, string? token, (int Index, int Length) offset)
{
_originalString = originalString;
_tokenString = token;
Offset = offset;
IsSpecialToken = isSpecialToken;
}
/// <summary>
/// Gets if the current Split is a special token.
/// </summary>
public bool IsSpecialToken { get; }
/// <summary>
/// Indicates whether the current Split object is equal to another Split object.
/// </summary>
/// <param name="other">The Split object to compare with the current object.</param>
public bool Equals(Split other) =>
(_originalString == other._originalString || TokenString == other.TokenString) &&
IsSpecialToken == other.IsSpecialToken &&
Offset.Index == other.Offset.Index &&
Offset.Length == other.Offset.Length;
}
@ -82,9 +73,8 @@ namespace Microsoft.ML.Tokenizers
/// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string.
/// </summary>
/// <param name="text">The string to split into tokens.</param>
/// <param name="considerSpecialTokens">Indicates whether to consider the special tokens.</param>
/// <returns>The list of the splits containing the tokens and the token's offsets to the original string.</returns>
public abstract IEnumerable<Split> PreTokenize(string text, bool considerSpecialTokens = true);
public abstract IEnumerable<Split> PreTokenize(string text);
internal static IEnumerable<Split> SplitText(string text, Regex regex)
{

Просмотреть файл

@ -21,9 +21,8 @@ namespace Microsoft.ML.Tokenizers
/// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string.
/// </summary>
/// <param name="text">The string to split into tokens.</param>
/// <param name="considerSpecialTokens">Indicates whether to keep the special tokens.</param>
/// <returns>The list of the splits containing the tokens and the token's offsets to the original string.</returns>
public override IEnumerable<Split> PreTokenize(string text, bool considerSpecialTokens = true)
public override IEnumerable<Split> PreTokenize(string text)
{
if (string.IsNullOrEmpty(text))
{

Просмотреть файл

@ -21,9 +21,8 @@ namespace Microsoft.ML.Tokenizers
/// Return the whole text as one chunk.
/// </summary>
/// <param name="text">The string to split into tokens.</param>
/// <param name="considerSpecialTokens">Indicates whether to keep the special tokens.</param>
/// <returns>The original string as one chunk.</returns>
public override IEnumerable<Split> PreTokenize(string text, bool considerSpecialTokens = true)
public override IEnumerable<Split> PreTokenize(string text)
{
if (string.IsNullOrEmpty(text))
{

Просмотреть файл

@ -12,18 +12,18 @@ namespace Microsoft.ML.Tokenizers
/// <summary>
/// The pre-tokenizer for Tiktoken tokenizer.
/// </summary>
public sealed class TikTokenPreTokenizer : PreTokenizer
public sealed class TiktokenPreTokenizer : PreTokenizer
{
private readonly Regex? _specialTokensRegex;
private readonly Regex _regex;
/// <summary>
/// Initializes a new instance of the <see cref="TikTokenPreTokenizer"/> class.
/// Initializes a new instance of the <see cref="TiktokenPreTokenizer"/> class.
/// </summary>
/// <param name="regex">The regex to use for splitting the text into smaller tokens in the pre-tokenization process.</param>
/// <param name="specialTokensEncoder">Encode the special token to Id.</param>
/// <param name="specialTokensEncoder">The dictionary containing the special tokens and their corresponding ids.</param>
/// <exception cref="ArgumentNullException">When regex is null</exception>
public TikTokenPreTokenizer(Regex regex, IReadOnlyDictionary<string, int>? specialTokensEncoder)
public TiktokenPreTokenizer(Regex regex, IReadOnlyDictionary<string, int>? specialTokensEncoder)
{
if (regex is null)
{
@ -42,16 +42,15 @@ namespace Microsoft.ML.Tokenizers
/// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string.
/// </summary>
/// <param name="text">The string to split into tokens.</param>
/// <param name="considerSpecialTokens">Indicates whether to consider the special tokens.</param>
/// <returns>The list of the splits containing the tokens and the token's offsets to the original string.</returns>
public override IEnumerable<Split> PreTokenize(string text, bool considerSpecialTokens = true)
public override IEnumerable<Split> PreTokenize(string text)
{
if (string.IsNullOrEmpty(text))
{
return Array.Empty<Split>();
}
return SplitText(text, _regex, considerSpecialTokens ? _specialTokensRegex : null);
return SplitText(text, _regex, _specialTokensRegex);
static IEnumerable<Split> SplitText(string text, Regex regex, Regex? specialTokensRegex)
{
@ -74,7 +73,7 @@ namespace Microsoft.ML.Tokenizers
beginning = match.Offset + match.Length;
}
yield return new Split(text, null, (specialMatch.Offset, specialMatch.Length), isSpecialToken: true);
yield return new Split(text, null, (specialMatch.Offset, specialMatch.Length));
beginning = specialMatch.Offset + specialMatch.Length;
}
}

Просмотреть файл

@ -32,9 +32,8 @@ namespace Microsoft.ML.Tokenizers
/// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string.
/// </summary>
/// <param name="text">The string to split into tokens.</param>
/// <param name="considerSpecialTokens">Indicates whether to consider the special tokens.</param>
/// <returns>The list of the splits containing the tokens and the token's offsets to the original string.</returns>
public override IEnumerable<Split> PreTokenize(string text, bool considerSpecialTokens = false)
public override IEnumerable<Split> PreTokenize(string text)
{
if (string.IsNullOrEmpty(text))
{

Просмотреть файл

@ -58,9 +58,8 @@ namespace Microsoft.ML.Tokenizers
/// Encodes input text to object has the tokens list, tokens Ids, tokens offset mapping.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
/// <returns>The tokenization result includes the tokens list, tokens Ids, tokens offset mapping.</returns>
public EncodingResult Encode(string text, bool considerSpecialTokens = true)
public EncodingResult Encode(string text)
{
if (text is null)
{
@ -70,11 +69,11 @@ namespace Microsoft.ML.Tokenizers
string normalized = Normalizer is null ? text : Normalizer.Normalize(text);
bool offsetsMappedToOriginal = true;
EncodingResult encoding = new(text, normalized, PreTokenizer.PreTokenize(normalized, considerSpecialTokens), offsetsMappedToOriginal);
EncodingResult encoding = new(text, normalized, PreTokenizer.PreTokenize(normalized), offsetsMappedToOriginal);
foreach (Split split in encoding.Splits)
{
IReadOnlyList<Token> tokens = Model.Encode(split.TokenString, split.IsSpecialToken);
IReadOnlyList<Token> tokens = Model.Encode(split.TokenString.AsSpan());
foreach (Token token in tokens)
{
token.Offset = (token.Offset.Index + split.Offset.Index, token.Offset.Length);
@ -90,9 +89,8 @@ namespace Microsoft.ML.Tokenizers
/// Encodes input text to tokens Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
/// <returns>The tokenization result includes the tokens list, tokens Ids, tokens offset mapping.</returns>
public IReadOnlyList<int> EncodeToIds(string text, bool considerSpecialTokens = true)
public IReadOnlyList<int> EncodeToIds(string text)
{
if (text is null)
{
@ -102,9 +100,9 @@ namespace Microsoft.ML.Tokenizers
string normalized = Normalizer is not null ? Normalizer.Normalize(text) : text;
List<int> idsList = new();
foreach (Split split in PreTokenizer.PreTokenize(normalized, considerSpecialTokens))
foreach (Split split in PreTokenizer.PreTokenize(normalized))
{
Model.EncodeToIds(split.TokenSpan, split.IsSpecialToken, idsList);
Model.EncodeToIds(split.TokenSpan, idsList, out _);
}
return idsList;
@ -114,11 +112,10 @@ namespace Microsoft.ML.Tokenizers
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
/// <returns>The number of tokens Ids that the input text will be encoded to.</returns>
/// <exception cref="ArgumentNullException">The input text is null.</exception>
/// <exception cref="ArgumentException">Unable to encode the text.</exception>
public int CountTokens(string text, bool considerSpecialTokens = true)
public int CountTokens(string text)
{
if (text is null)
{
@ -128,9 +125,9 @@ namespace Microsoft.ML.Tokenizers
string normalized = Normalizer is not null ? Normalizer.Normalize(text) : text;
int idsCount = 0;
foreach (Split split in PreTokenizer.PreTokenize(normalized, considerSpecialTokens))
foreach (Split split in PreTokenizer.PreTokenize(normalized))
{
idsCount += Model.CountTokens(split.TokenSpan, split.IsSpecialToken);
idsCount += Model.CountTokens(split.TokenSpan, out _);
}
return idsCount;
@ -143,15 +140,14 @@ namespace Microsoft.ML.Tokenizers
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="processedText">If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, if all tokens fit, the result will be length of the <paramref name="processedText"/>.
/// </returns>
/// <exception cref="ArgumentNullException">The input text is null.</exception>
/// <exception cref="ArgumentOutOfRangeException">The maximum token count must be greater than 0.</exception>
public int IndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount, bool considerSpecialTokens = true)
=> IndexOf(text, maxTokenCount, fromStart: true, considerSpecialTokens, out processedText, out tokenCount);
public int IndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount)
=> IndexOf(text, maxTokenCount, out processedText, out tokenCount);
/// <summary>
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
@ -160,7 +156,6 @@ namespace Microsoft.ML.Tokenizers
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="processedText">If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
/// <returns>
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="processedText"/>; conversely, if all tokens fit, the result will be 0.
@ -170,10 +165,10 @@ namespace Microsoft.ML.Tokenizers
/// <remarks>
/// If the whole text can be encoded within the token limit, the returned index will be 0.
/// </remarks>
public int LastIndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount, bool considerSpecialTokens = true)
=> IndexOf(text, maxTokenCount, fromStart: false, considerSpecialTokens, out processedText, out tokenCount);
public int LastIndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount)
=> LastIndexOf(text, maxTokenCount, out processedText, out tokenCount);
private int IndexOf(string text, int maxTokenCount, bool fromStart, bool considerSpecialTokens, out string processedText, out int tokenCount)
private int IndexOf(string text, int maxTokenCount, out string processedText, out int tokenCount)
{
if (text is null)
{
@ -188,36 +183,60 @@ namespace Microsoft.ML.Tokenizers
processedText = Normalizer is not null ? Normalizer.Normalize(text) : text;
tokenCount = 0;
IEnumerable<Split> splits = PreTokenizer.PreTokenize(processedText, considerSpecialTokens);
foreach (Split split in (fromStart ? splits : splits.Reverse()))
IEnumerable<Split> splits = PreTokenizer.PreTokenize(processedText);
foreach (Split split in splits)
{
int count = Model.CountTokens(split.TokenSpan, split.IsSpecialToken);
if (tokenCount > maxTokenCount - count)
tokenCount += Model.CountTokens(split.TokenSpan, out int textLength, maxTokenCount - tokenCount);
if (textLength < split.Offset.Length || tokenCount >= maxTokenCount)
{
return fromStart ? split.Offset.Index : split.Offset.Index + split.Offset.Length;
return split.Offset.Index + textLength;
}
tokenCount += count;
}
return fromStart ? processedText.Length : 0;
return processedText.Length;
}
private int LastIndexOf(string text, int maxTokenCount, out string processedText, out int tokenCount)
{
if (text is null)
{
throw new ArgumentNullException(nameof(text));
}
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
}
processedText = Normalizer is not null ? Normalizer.Normalize(text) : text;
tokenCount = 0;
IEnumerable<Split> splits = PreTokenizer.PreTokenize(processedText);
foreach (Split split in splits.Reverse())
{
tokenCount += Model.CountTokensFromEnd(split.TokenSpan, out int textIndex, maxTokenCount - tokenCount);
if (textIndex > 0 || tokenCount >= maxTokenCount)
{
return split.Offset.Index + textIndex;
}
}
return 0;
}
/// <summary>
/// Decodes the Id to the mapped token.
/// </summary>
/// <param name="id">The id to map to the token.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the decoding.</param>
/// <returns>The decoded string or null if there is no token mapped to the input id.</returns>
public string? Decode(int id, bool considerSpecialTokens = true) => Model.MapIdToToken(id, considerSpecialTokens);
public string? Decode(int id) => Model.MapIdToToken(id);
/// <summary>
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="considerSpecialTokens">Whether the special tokens should be kept in the decoded string.</param>
/// <returns>The decoded string.</returns>
public string? Decode(IEnumerable<int> ids, bool considerSpecialTokens = true) => Model.Decode(ids, Decoder, considerSpecialTokens);
public string? Decode(IEnumerable<int> ids) => Model.Decode(ids, Decoder);
/// <summary>
/// Create a Tiktoken tokenizer based on model name and vocab file.
@ -252,7 +271,7 @@ namespace Microsoft.ML.Tokenizers
return new Tokenizer(
new Tiktoken(vocabStream, tiktokenConfiguration.SpecialTokens, cacheSize),
new TikTokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
normalizer);
}
@ -291,7 +310,7 @@ namespace Microsoft.ML.Tokenizers
return new Tokenizer(
await Tiktoken.CreateAsync(vocabStream, tiktokenConfiguration.SpecialTokens, cacheSize, cancellationToken).ConfigureAwait(false),
new TikTokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
normalizer);
}

Просмотреть файл

@ -193,13 +193,13 @@ namespace Microsoft.ML.Tokenizers.Tests
if (robertaModel.FilterUnsupportedChars)
{
string[]? filteredToken = p[5] as string[];
Assert.Equal(filteredToken![i], tokenizer.Model.MapIdToToken(encoding.Ids[i], considerSpecialTokens: false));
Assert.Equal(filteredToken![i], tokenizer.Model.MapIdToToken(encoding.Ids[i]));
}
else
{
Assert.Equal(encoding.Tokens[i], tokenizer.Model.MapIdToToken(encoding.Ids[i], considerSpecialTokens: false));
Assert.Equal(encoding.Tokens[i], tokenizer.Model.MapIdToToken(encoding.Ids[i]));
string[]? unfilteredToken = p[2] as string[];
Assert.Equal(unfilteredToken![i], tokenizer.Model.MapIdToToken(encoding.Ids[i], considerSpecialTokens: false));
Assert.Equal(unfilteredToken![i], tokenizer.Model.MapIdToToken(encoding.Ids[i]));
}
Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i].AsSpan()));

Просмотреть файл

@ -208,32 +208,32 @@ namespace Microsoft.ML.Tokenizers.Tests
bool isEmptyInput = string.IsNullOrEmpty(input);
IReadOnlyList<Token> bpeTokens = bpe.Encode(normalizedInput, addBeginOfSentence: false, addEndOfSentence: false);
IReadOnlyList<Token> bpeTokens = bpe.Encode(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: false);
Assert.Equal(ids.Skip(1), bpeTokens.Select(token => token.Id));
Assert.Equal(tokens.Skip(1), bpeTokens.Select(token => token.Value));
Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id)));
List<int> encodedIds = new();
bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: false, accumulatedIds: encodedIds);
bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: false, accumulatedIds: encodedIds, out _);
Assert.Equal(ids.Skip(1), encodedIds);
Assert.Equal(isEmptyInput ? 0 : ids.Length - 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: false));
Assert.Equal(isEmptyInput ? 0 : ids.Length - 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: false, out _));
bpeTokens = bpe.Encode(normalizedInput, addBeginOfSentence: false, addEndOfSentence: true);
bpeTokens = bpe.Encode(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: true);
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Skip(1).Concat(new[] { bpe.EndOfSentenceId }), bpeTokens.Select(token => token.Id));
Assert.Equal(isEmptyInput ? Array.Empty<string>() : tokens.Skip(1).Concat(new[] { bpe.EndOfSentenceToken }), bpeTokens.Select(token => token.Value));
Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id)));
encodedIds.Clear();
bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: true, accumulatedIds: encodedIds);
bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: true, accumulatedIds: encodedIds, out _);
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Skip(1).Concat(new[] { bpe.EndOfSentenceId }), encodedIds);
Assert.Equal(isEmptyInput ? 0 : ids.Length, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: true));
Assert.Equal(isEmptyInput ? 0 : ids.Length, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: true, out _));
bpeTokens = bpe.Encode(normalizedInput, addBeginOfSentence: true, addEndOfSentence: true);
bpeTokens = bpe.Encode(normalizedInput.AsSpan(), addBeginOfSentence: true, addEndOfSentence: true);
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Concat(new[] { bpe.EndOfSentenceId }), bpeTokens.Select(token => token.Id));
Assert.Equal(isEmptyInput ? Array.Empty<string>() : tokens.Concat(new[] { bpe.EndOfSentenceToken }), bpeTokens.Select(token => token.Value));
Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id)));
encodedIds.Clear();
bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: true, addEndOfSentence: true, accumulatedIds: encodedIds);
bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: true, addEndOfSentence: true, accumulatedIds: encodedIds, out _);
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Concat(new[] { bpe.EndOfSentenceId }), encodedIds);
Assert.Equal(isEmptyInput ? 0 : ids.Length + 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: true, addEndOfSentence: true));
Assert.Equal(isEmptyInput ? 0 : ids.Length + 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: true, addEndOfSentence: true, out _));
}
public static IEnumerable<object[]> LlamaTokenizersListData()
@ -250,7 +250,6 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Throws<ArgumentNullException>(() => llamaTokenizer.EncodeToIds(null!));
Assert.Throws<ArgumentNullException>(() => llamaTokenizer.CountTokens(null!));
Assert.Throws<ArgumentNullException>(() => llamaTokenizer.Decode(null!));
Assert.Throws<ArgumentNullException>(() => (llamaTokenizer.Model as SentencePieceBpe)!.Encode(null!));
}
[Theory]
@ -280,6 +279,8 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.True(bpe.AddDummyPrefix);
Assert.True(bpe.EscapeWhiteSpaces);
Assert.False(bpe.TreatWhitespaceAsSuffix);
TokenizerTests.TestTokenLimits(llamaTokenizer);
}
[Fact]

Просмотреть файл

@ -68,7 +68,7 @@ namespace Microsoft.ML.Tokenizers.Tests
public class SpacePreTokenizer : PreTokenizer
{
public override IEnumerable<Split> PreTokenize(string text, bool considerSpecialTokens = true)
public override IEnumerable<Split> PreTokenize(string text)
{
List<Split> splits = new();
if (string.IsNullOrEmpty(text))

Просмотреть файл

@ -38,7 +38,7 @@ namespace Microsoft.ML.Tokenizers.Tests
TestGPT4TokenizationEncoding(GPT4);
Assert.True(GPT4.Model is Tiktoken);
IReadOnlyDictionary<string, int>? specialTokensEncoder = (GPT4.Model as Tiktoken)!.SpecialTokensEncoder;
IReadOnlyDictionary<string, int>? specialTokensEncoder = (GPT4.Model as Tiktoken)!.SpecialTokens;
string tokenizerDataFileName = Utils.CreateTemporaryFile("tiktoken");
await Utils.DownloadFile(@"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken", tokenizerDataFileName);
@ -122,9 +122,9 @@ namespace Microsoft.ML.Tokenizers.Tests
private void TestGPT4Tokenizer(Tokenizer gpt4Tokenizer)
{
string text = ReadAndSanitizeFile("./Data/lib.rs.txt");
IReadOnlyList<int> encoded = gpt4Tokenizer.EncodeToIds(text, considerSpecialTokens: false);
IReadOnlyList<int> encoded = gpt4Tokenizer.EncodeToIds(text);
Assert.Equal(5584, encoded.Count);
int idsCount = gpt4Tokenizer.CountTokens(text, considerSpecialTokens: false);
int idsCount = gpt4Tokenizer.CountTokens(text);
Assert.Equal(encoded.Count, idsCount);
using (Stream stream = File.OpenRead("./Data/tokens.json"))

Просмотреть файл

@ -29,14 +29,26 @@ namespace Microsoft.ML.Tokenizers.Tests
int index1 = tokenizer.IndexOfTokenCount(input, maxTokenCount: i, out string processedText1, out int tokenCount1);
int index2 = tokenizer.LastIndexOfTokenCount(input, maxTokenCount: i, out string processedText2, out int tokenCount2);
IReadOnlyList<int>? prefixIds = null;
IReadOnlyList<int>? suffixIds = null;
if (tokenCount1 > 0)
// It is possible with Llama tokenizer to produce start of sentence token <s> token only if we have the maxTokenCount is 1.
// In this case, we'll get index1 equal to zero and nothing really will need to be tested.
if (tokenCount1 > 0 && index1 > 0)
{
string prefixString = processedText1.Substring(0, index1);
prefixIds = tokenizer.EncodeToIds(prefixString);
if (tokenizer.Model is SentencePieceBpe)
{
// SentencePieceBpe model normalize the text and insert more characters.
// We call the model directly to bypass the normalization step
prefixIds = new List<int>();
tokenizer.Model.EncodeToIds(prefixString.AsSpan(), (prefixIds as IList<int>)!, out _);
}
else
{
prefixIds = tokenizer.EncodeToIds(prefixString);
}
Assert.Equal(tokenCount1, prefixIds.Count);
Assert.Equal(prefixIds, fullIdsList.Take(prefixIds.Count));
}
@ -44,15 +56,45 @@ namespace Microsoft.ML.Tokenizers.Tests
if (tokenCount2 > 0)
{
string suffixString = processedText2.Substring(index2);
suffixIds = tokenizer.EncodeToIds(suffixString);
if (tokenizer.Model is SentencePieceBpe)
{
// SentencePieceBpe model normalize the text and insert more characters.
// We call the model directly to bypass the normalization step
suffixIds = new List<int>();
tokenizer.Model.EncodeToIds(suffixString.AsSpan(), (suffixIds as IList<int>)!, out _);
if (i < fullIdsList.Count)
{
suffixIds = suffixIds.Skip(1).ToList(); // Skip the start of sentence token <s>
}
}
else
{
suffixIds = tokenizer.EncodeToIds(suffixString);
}
Assert.Equal(tokenCount2, suffixIds.Count);
Assert.Equal(suffixIds, fullIdsList.Skip(fullIdsList.Count - suffixIds.Count));
}
if (i == fullIdsList.Count)
{
Assert.Equal(processedText1.Length, index1);
Assert.Equal(0, index2);
if (index1 != processedText1.Length)
{
// It's possible that the remaining text on the left doesn't produce any tokens, as in the case of BPE,
// where the pre-tokenizer removes spaces and the left text consists entirely of spaces.
Assert.True(index1 < processedText1.Length);
Assert.Equal(0, tokenizer.CountTokens(processedText1.Substring(index1)));
}
if (index2 != 0)
{
// It's possible that the remaining text on the right doesn't produce any tokens, as in the case of BPE,
// where the pre-tokenizer removes spaces and the left text consists entirely of spaces.
Assert.True(index2 > 0);
Assert.Equal(0, tokenizer.CountTokens(processedText1.Substring(0, index2)));
}
Assert.Equal(fullIdsList, prefixIds);
Assert.Equal(fullIdsList, suffixIds);
}