Tokenizer's APIs Polishing (#7108)
* Remove teh confusing specialTokens flag parameter from all APIs * Normalize casing of the Tiktoken name * Make Model.Encode work with span instead of string * Support granular Last/IndexOf. * Remove wrong assert. * Address the feedback * Update the package doc.
This commit is contained in:
Родитель
214e12aefc
Коммит
c980eaf964
|
@ -20,6 +20,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
{
|
||||
/// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model.
|
||||
|
||||
private const int MaxWordLengthToCache = 15;
|
||||
private string? _unknownToken;
|
||||
|
||||
/// <summary>
|
||||
|
@ -176,10 +177,9 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <summary>
|
||||
/// Encode a text string to a list of tokens.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
|
||||
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token. This parameter is ignored in this model.</param>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <returns>The list of tokens generated from the text tokenization.</returns>
|
||||
public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = false)
|
||||
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text)
|
||||
{
|
||||
if (text.Length == 0)
|
||||
{
|
||||
|
@ -192,34 +192,44 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <summary>
|
||||
/// Encode a split text string to a list of Ids and add them to the accumulatedIds list.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
|
||||
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token. This parameter is ignored in this model.</param>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
|
||||
public override void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds);
|
||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
public override int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsWithCache(text, accumulatedIds, maxTokens, out textLength);
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
|
||||
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token. This parameter is ignored in this model.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to. This parameter is ignored in this model.</returns>
|
||||
public override int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken) => EncodeToIdsWithCache(text, null);
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
public override int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsWithCache(text, null, maxTokens, out textLength);
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
|
||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
public override int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue) => EncodeToIdsFromEndWithCache(text, null, maxTokens, out textIndex);
|
||||
|
||||
/// <summary>
|
||||
/// Map the token to encoded Id.
|
||||
/// </summary>
|
||||
/// <param name="token">The token to map to the Id.</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
|
||||
/// <returns>The mapped Id of the token.</returns>
|
||||
public override int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out int value) ? value : null;
|
||||
public override int? MapTokenToId(ReadOnlySpan<char> token) => _vocab.TryGetValue(token, out int value) ? value : null;
|
||||
|
||||
/// <summary>
|
||||
/// Map the encoded Id to the token.
|
||||
/// </summary>
|
||||
/// <param name="id">The Id to map to the token.</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the decoding.</param>
|
||||
/// <returns>The mapped token of the Id.</returns>
|
||||
public override string? MapIdToToken(int id, bool considerSpecialTokens = true)
|
||||
public override string? MapIdToToken(int id)
|
||||
{
|
||||
if (VocabReverse.TryGetValue(id, out string? value))
|
||||
{
|
||||
|
@ -434,7 +444,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
internal List<Token> WordToTokens(ref Word word) => word.ToTokens(VocabReverse);
|
||||
|
||||
internal List<Token> EncodeWithCache(string text)
|
||||
internal List<Token> EncodeWithCache(ReadOnlySpan<char> text)
|
||||
{
|
||||
Word word;
|
||||
if (Cache is not null)
|
||||
|
@ -444,28 +454,64 @@ namespace Microsoft.ML.Tokenizers
|
|||
return WordToTokens(ref word);
|
||||
}
|
||||
|
||||
word = MergeWord(text.AsSpan());
|
||||
Cache.Set(text, word);
|
||||
word = MergeWord(text);
|
||||
|
||||
if (text.Length <= MaxWordLengthToCache)
|
||||
{
|
||||
Cache.Set(text.ToString(), word);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
word = MergeWord(text.AsSpan());
|
||||
word = MergeWord(text);
|
||||
}
|
||||
|
||||
return WordToTokens(ref word);
|
||||
}
|
||||
|
||||
internal int WordToIds(ref Word word, IList<int>? accumulatedIds)
|
||||
internal int WordToIds(ref Word word, IList<int>? accumulatedIds, out int textLength, int fullTextLength, int maxTokens)
|
||||
{
|
||||
if (accumulatedIds is not null)
|
||||
if (word.SymbolsCount < maxTokens)
|
||||
{
|
||||
word.PopulateIds(accumulatedIds);
|
||||
textLength = fullTextLength;
|
||||
if (accumulatedIds is not null)
|
||||
{
|
||||
word.PopulateIds(accumulatedIds);
|
||||
}
|
||||
|
||||
return word.SymbolsCount;
|
||||
}
|
||||
|
||||
return word.SymbolsCount;
|
||||
if (accumulatedIds is not null)
|
||||
{
|
||||
return word.PopulateIdsUpToMax(accumulatedIds, maxTokens, out textLength);
|
||||
}
|
||||
|
||||
return word.CountIdsUpToMax(maxTokens, out textLength);
|
||||
}
|
||||
|
||||
internal int EncodeToIdsWithCache(ReadOnlySpan<char> text, IList<int>? accumulatedIds)
|
||||
internal int WordToIdsFromEnd(ref Word word, IList<int>? accumulatedIds, out int textIndex, int fullTextLength, int maxTokens)
|
||||
{
|
||||
if (word.SymbolsCount < maxTokens)
|
||||
{
|
||||
textIndex = 0;
|
||||
if (accumulatedIds is not null)
|
||||
{
|
||||
word.PopulateIds(accumulatedIds);
|
||||
}
|
||||
|
||||
return word.SymbolsCount;
|
||||
}
|
||||
|
||||
if (accumulatedIds is not null)
|
||||
{
|
||||
return word.PopulateIdsUpToMaxFromEnd(accumulatedIds, maxTokens, fullTextLength, out textIndex);
|
||||
}
|
||||
|
||||
return word.CountIdsUpToMaxFromEnd(maxTokens, fullTextLength, out textIndex);
|
||||
}
|
||||
|
||||
internal int EncodeToIdsWithCache(ReadOnlySpan<char> text, IList<int>? accumulatedIds, int maxTokens, out int textLength)
|
||||
{
|
||||
Word word;
|
||||
|
||||
|
@ -473,18 +519,48 @@ namespace Microsoft.ML.Tokenizers
|
|||
{
|
||||
if (Cache.TryGetValue(text, out Word hit))
|
||||
{
|
||||
return WordToIds(ref hit, accumulatedIds);
|
||||
return WordToIds(ref hit, accumulatedIds, out textLength, text.Length, maxTokens);
|
||||
}
|
||||
|
||||
word = MergeWord(text);
|
||||
Cache.Set(text.ToString(), word);
|
||||
|
||||
if (text.Length <= MaxWordLengthToCache)
|
||||
{
|
||||
Cache.Set(text.ToString(), word);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
word = MergeWord(text);
|
||||
}
|
||||
|
||||
return WordToIds(ref word, accumulatedIds);
|
||||
return WordToIds(ref word, accumulatedIds, out textLength, text.Length, maxTokens);
|
||||
}
|
||||
|
||||
internal int EncodeToIdsFromEndWithCache(ReadOnlySpan<char> text, IList<int>? accumulatedIds, int maxTokens, out int textIndex)
|
||||
{
|
||||
Word word;
|
||||
|
||||
if (Cache is not null)
|
||||
{
|
||||
if (Cache.TryGetValue(text, out Word hit))
|
||||
{
|
||||
return WordToIdsFromEnd(ref hit, accumulatedIds, out textIndex, text.Length, maxTokens);
|
||||
}
|
||||
|
||||
word = MergeWord(text);
|
||||
|
||||
if (text.Length <= MaxWordLengthToCache)
|
||||
{
|
||||
Cache.Set(text.ToString(), word);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
word = MergeWord(text);
|
||||
}
|
||||
|
||||
return WordToIdsFromEnd(ref word, accumulatedIds, out textIndex, text.Length, maxTokens);
|
||||
}
|
||||
|
||||
internal static readonly List<Token> EmptyTokensList = new();
|
||||
|
|
|
@ -135,15 +135,9 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// Map the encoded Id to the token.
|
||||
/// </summary>
|
||||
/// <param name="id">The Id to map to the string.</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the decoding.</param>
|
||||
/// <returns>The mapped token of the Id.</returns>
|
||||
public override string? MapIdToToken(int id, bool considerSpecialTokens = true)
|
||||
public override string? MapIdToToken(int id)
|
||||
{
|
||||
if (!considerSpecialTokens && id < 0)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
if (_vocabReverse.TryGetValue(id, out var value))
|
||||
{
|
||||
string v = value.Data!;
|
||||
|
@ -176,12 +170,11 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <summary>
|
||||
/// Encode a text string to a list of tokens.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
|
||||
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token. This parameter is ignored in this model.</param>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <returns>The list of tokens generated from the text tokenization.</returns>
|
||||
public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = false)
|
||||
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text)
|
||||
{
|
||||
if (string.IsNullOrEmpty(text))
|
||||
if (text.IsEmpty)
|
||||
{
|
||||
return Bpe.EmptyTokensList;
|
||||
}
|
||||
|
@ -215,7 +208,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
|
||||
List<Token> result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
|
||||
_cache.Set(text, result);
|
||||
_cache.Set(text.ToString(), result);
|
||||
ArrayPool<char>.Shared.Return(token);
|
||||
ArrayPool<int>.Shared.Return(indexMapping);
|
||||
return result;
|
||||
|
@ -224,37 +217,116 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <summary>
|
||||
/// Encode a split text string to a list of Ids and add them to the accumulatedIds list.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
|
||||
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token. This parameter is ignored in this model.</param>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
|
||||
public override void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIds(text, accumulatedIds);
|
||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
public override int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsInternal(text, accumulatedIds, out textLength, maxTokens);
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
|
||||
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token. This parameter is ignored in this model.</param>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
public override int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken) => EncodeToIds(text, null);
|
||||
public override int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsInternal(text, null, out textLength, maxTokens);
|
||||
|
||||
private int EncodeToIds(ReadOnlySpan<char> text, IList<int>? accumulatedIds)
|
||||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
|
||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
public override int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue) => EncodeToIdsFromEndInternal(text, null, out textIndex, maxTokens);
|
||||
|
||||
private int EncodeToIdsResult(List<Token> tokens, IList<int>? accumulatedIds, int maxTokens, int fullTextLength, out int textLength)
|
||||
{
|
||||
if (text.IsEmpty)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
textLength = 0;
|
||||
|
||||
if (_cache.TryGetValue(text, out List<Token>? hit))
|
||||
if (tokens.Count <= maxTokens)
|
||||
{
|
||||
if (accumulatedIds is not null)
|
||||
{
|
||||
foreach (var t in hit)
|
||||
foreach (var t in tokens)
|
||||
{
|
||||
accumulatedIds.Add(t.Id);
|
||||
}
|
||||
}
|
||||
|
||||
return hit.Count;
|
||||
textLength = fullTextLength;
|
||||
return tokens.Count;
|
||||
}
|
||||
|
||||
if (accumulatedIds is not null)
|
||||
{
|
||||
for (int i = 0; i < maxTokens; i++)
|
||||
{
|
||||
accumulatedIds.Add(tokens[i].Id);
|
||||
textLength += tokens[i].Offset.Length;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < maxTokens; i++)
|
||||
{
|
||||
textLength += tokens[i].Offset.Length;
|
||||
}
|
||||
}
|
||||
|
||||
return maxTokens;
|
||||
}
|
||||
|
||||
private int EncodeToIdsFromEndResult(List<Token> tokens, IList<int>? accumulatedIds, int maxTokens, int fullTextLength, out int textIndex)
|
||||
{
|
||||
textIndex = fullTextLength;
|
||||
|
||||
if (tokens.Count <= maxTokens)
|
||||
{
|
||||
if (accumulatedIds is not null)
|
||||
{
|
||||
foreach (var t in tokens)
|
||||
{
|
||||
accumulatedIds.Add(t.Id);
|
||||
}
|
||||
}
|
||||
|
||||
textIndex = 0;
|
||||
return tokens.Count;
|
||||
}
|
||||
|
||||
if (accumulatedIds is not null)
|
||||
{
|
||||
for (int i = tokens.Count - maxTokens; i < tokens.Count; i++)
|
||||
{
|
||||
accumulatedIds.Add(tokens[i].Id);
|
||||
textIndex -= tokens[i].Offset.Length;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = tokens.Count - maxTokens; i < tokens.Count; i++)
|
||||
{
|
||||
textIndex -= tokens[i].Offset.Length;
|
||||
}
|
||||
}
|
||||
|
||||
return maxTokens;
|
||||
}
|
||||
|
||||
private int EncodeToIdsInternal(ReadOnlySpan<char> text, IList<int>? accumulatedIds, out int textLength, int maxTokens)
|
||||
{
|
||||
if (text.IsEmpty)
|
||||
{
|
||||
textLength = 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (_cache.TryGetValue(text, out List<Token>? hit))
|
||||
{
|
||||
return EncodeToIdsResult(hit, accumulatedIds, maxTokens, text.Length, out textLength);
|
||||
}
|
||||
|
||||
char[] token = ArrayPool<char>.Shared.Rent(text.Length);
|
||||
|
@ -275,6 +347,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
{
|
||||
ArrayPool<char>.Shared.Return(token);
|
||||
ArrayPool<int>.Shared.Return(indexMapping);
|
||||
textLength = 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -283,24 +356,58 @@ namespace Microsoft.ML.Tokenizers
|
|||
ArrayPool<char>.Shared.Return(token);
|
||||
ArrayPool<int>.Shared.Return(indexMapping);
|
||||
|
||||
if (accumulatedIds is not null)
|
||||
return EncodeToIdsResult(result, accumulatedIds, maxTokens, text.Length, out textLength);
|
||||
}
|
||||
|
||||
private int EncodeToIdsFromEndInternal(ReadOnlySpan<char> text, IList<int>? accumulatedIds, out int textIndex, int maxTokens)
|
||||
{
|
||||
if (text.IsEmpty)
|
||||
{
|
||||
foreach (var t in result)
|
||||
textIndex = text.Length;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (_cache.TryGetValue(text, out List<Token>? hit))
|
||||
{
|
||||
return EncodeToIdsFromEndResult(hit, accumulatedIds, maxTokens, text.Length, out textIndex);
|
||||
}
|
||||
|
||||
char[] token = ArrayPool<char>.Shared.Rent(text.Length);
|
||||
int[] indexMapping = ArrayPool<int>.Shared.Rent(text.Length);
|
||||
|
||||
int newTokenIndex = 0;
|
||||
for (int i = 0; i < text.Length; i++)
|
||||
{
|
||||
if (_byteToUnicode.TryGetValue(text[i], out var value))
|
||||
{
|
||||
accumulatedIds.Add(t.Id);
|
||||
token[newTokenIndex] = value;
|
||||
indexMapping[newTokenIndex] = i;
|
||||
newTokenIndex++;
|
||||
}
|
||||
}
|
||||
|
||||
return result.Count;
|
||||
if (newTokenIndex == 0)
|
||||
{
|
||||
ArrayPool<char>.Shared.Return(token);
|
||||
ArrayPool<int>.Shared.Return(indexMapping);
|
||||
textIndex = text.Length;
|
||||
return 0;
|
||||
}
|
||||
|
||||
List<Token> result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
|
||||
_cache.Set(text.ToString(), result);
|
||||
ArrayPool<char>.Shared.Return(token);
|
||||
ArrayPool<int>.Shared.Return(indexMapping);
|
||||
|
||||
return EncodeToIdsFromEndResult(result, accumulatedIds, maxTokens, text.Length, out textIndex);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Map the token to encoded Id.
|
||||
/// </summary>
|
||||
/// <param name="token">The token to map to the Id.</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
|
||||
/// <returns>The mapped Id of the token.</returns>
|
||||
public override int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out int value) ? value : null;
|
||||
public override int? MapTokenToId(ReadOnlySpan<char> token) => _vocab.TryGetValue(token, out int value) ? value : null;
|
||||
|
||||
/// <summary>
|
||||
/// Convert a list of tokens Ids to highest occurrence rankings.
|
||||
|
|
|
@ -16,22 +16,23 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <summary>
|
||||
/// Encode a text to a list of tokens.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
|
||||
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token.</param>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <returns>The list of tokens generated from the text tokenization.</returns>
|
||||
public abstract IReadOnlyList<Token> Encode(string text, bool isSpecialToken = false);
|
||||
public abstract IReadOnlyList<Token> Encode(ReadOnlySpan<char> text);
|
||||
|
||||
/// <summary>
|
||||
/// Encode a text to a list of Ids and add them to the accumulatedIds list.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
|
||||
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token.</param>
|
||||
/// <param name="text">The text to encode. </param>
|
||||
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
|
||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
/// <remarks>
|
||||
/// This method does the default implementation that uses the Encode method to get the token's Ids.
|
||||
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
|
||||
/// </remarks>
|
||||
public virtual void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds)
|
||||
public virtual int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue)
|
||||
{
|
||||
if (accumulatedIds is null)
|
||||
{
|
||||
|
@ -39,60 +40,126 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
|
||||
// Default implementation is not optimized for memory allocation. It is recommended to override this method for the sake of the performance.
|
||||
var tokens = Encode(text.ToString());
|
||||
foreach (var token in tokens)
|
||||
textLength = 0;
|
||||
var tokens = Encode(text);
|
||||
|
||||
int count = Math.Min(tokens.Count, maxTokens);
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
accumulatedIds.Add(token.Id);
|
||||
textLength += tokens[i].Offset.Length;
|
||||
accumulatedIds.Add(tokens[i].Id);
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
|
||||
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token.</param>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
/// <remarks>
|
||||
/// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids.
|
||||
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
|
||||
/// </remarks>
|
||||
public virtual int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken)
|
||||
public virtual int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue)
|
||||
{
|
||||
if (maxTokens <= 0)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(maxTokens), "The maximum number of tokens must be greater than 0.");
|
||||
}
|
||||
|
||||
var ids = new List<int>();
|
||||
EncodeToIds(text, isSpecialToken, ids);
|
||||
return ids.Count;
|
||||
|
||||
if (maxTokens == int.MaxValue)
|
||||
{
|
||||
EncodeToIds(text, ids, out _);
|
||||
textLength = text.Length;
|
||||
return ids.Count;
|
||||
}
|
||||
|
||||
IReadOnlyList<Token> tokens = Encode(text);
|
||||
textLength = 0;
|
||||
int count = Math.Min(tokens.Count, maxTokens);
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
textLength += tokens[i].Offset.Length;
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
|
||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
/// <remarks>
|
||||
/// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids.
|
||||
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
|
||||
/// </remarks>
|
||||
public virtual int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue)
|
||||
{
|
||||
if (maxTokens <= 0)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(maxTokens), "The maximum number of tokens must be greater than 0.");
|
||||
}
|
||||
|
||||
var ids = new List<int>();
|
||||
|
||||
if (maxTokens == int.MaxValue)
|
||||
{
|
||||
EncodeToIds(text, ids, out _);
|
||||
textIndex = 0;
|
||||
return ids.Count;
|
||||
}
|
||||
|
||||
IReadOnlyList<Token> tokens = Encode(text);
|
||||
textIndex = text.Length;
|
||||
int count = Math.Min(tokens.Count, maxTokens);
|
||||
|
||||
int tokensCount = tokens.Count;
|
||||
int end = tokensCount - count;
|
||||
for (int i = tokensCount - 1; i >= end; i--)
|
||||
{
|
||||
textIndex -= tokens[i].Offset.Length;
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Map the token to encoded id with the option to skip the special tokens.
|
||||
/// </summary>
|
||||
/// <param name="token">The token to map to Id</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
|
||||
/// <returns>The mapped Id of the token.</returns>
|
||||
public abstract int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true);
|
||||
public abstract int? MapTokenToId(ReadOnlySpan<char> token);
|
||||
|
||||
/// <summary>
|
||||
/// Map the encoded Id to the token.
|
||||
/// </summary>
|
||||
/// <param name="id">The Id to map to the token.</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the decoding.</param>
|
||||
/// <returns>The mapped token of the Id.</returns>
|
||||
public abstract string? MapIdToToken(int id, bool considerSpecialTokens = true);
|
||||
public abstract string? MapIdToToken(int id);
|
||||
|
||||
/// <summary>
|
||||
/// Decode the given ids, back to a String.
|
||||
/// </summary>
|
||||
/// <param name="ids">The list of ids that we want to decode.</param>
|
||||
/// <param name="considerSpecialTokens">Whether the special tokens should be kept in the decoded string.</param>
|
||||
/// <param name="decoder">The optional Decoder to merge the given list of tokens in a string.</param>
|
||||
/// <returns>The decoded string.</returns>
|
||||
public virtual string? Decode(IEnumerable<int> ids, TokenizerDecoder? decoder = null, bool considerSpecialTokens = true)
|
||||
public virtual string? Decode(IEnumerable<int> ids, TokenizerDecoder? decoder = null)
|
||||
{
|
||||
List<string> tokens = new List<string>();
|
||||
|
||||
foreach (int id in ids)
|
||||
{
|
||||
if (MapIdToToken(id, considerSpecialTokens) is string s)
|
||||
if (MapIdToToken(id) is string s)
|
||||
{
|
||||
tokens.Add(s);
|
||||
}
|
||||
|
|
|
@ -154,11 +154,10 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <summary>
|
||||
/// Encode a text to a list of tokens.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
|
||||
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token.</param>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <returns>The list of tokens generated from the text tokenization.</returns>
|
||||
/// <remarks>The input text has to be normalized before calling this method.</remarks>
|
||||
public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = false) => Encode(text, AddBeginningOfSentence, AddEndOfSentence);
|
||||
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text) => Encode(text, AddBeginningOfSentence, AddEndOfSentence);
|
||||
|
||||
/// <summary>
|
||||
/// Encode a text to a list of tokens.
|
||||
|
@ -168,13 +167,8 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||
/// <returns>The list of tokens generated from the text tokenization.</returns>
|
||||
/// <remarks>The input text has to be normalized before calling this method.</remarks>
|
||||
public IReadOnlyList<Token> Encode(string text, bool addBeginOfSentence, bool addEndOfSentence)
|
||||
public IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence)
|
||||
{
|
||||
if (text is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(text));
|
||||
}
|
||||
|
||||
if (text.Length == 0)
|
||||
{
|
||||
return Array.Empty<Token>();
|
||||
|
@ -182,7 +176,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
BpeSymbol[] symbols = ArrayPool<BpeSymbol>.Shared.Rent(text.Length);
|
||||
|
||||
Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text.AsSpan(), symbols);
|
||||
Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols);
|
||||
|
||||
List<Token> tokens = new();
|
||||
|
||||
|
@ -198,7 +192,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
if (id == UninitializedId)
|
||||
{
|
||||
if (_vocab.TryGetValue(text.AsSpan().Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo))
|
||||
if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo))
|
||||
{
|
||||
id = tokenInfo.Id;
|
||||
type = tokenInfo.Type;
|
||||
|
@ -214,19 +208,19 @@ namespace Microsoft.ML.Tokenizers
|
|||
{
|
||||
if (id == UnknownId && ByteFallback)
|
||||
{
|
||||
EncodeAsBytes(text.AsSpan().Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index);
|
||||
EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index);
|
||||
}
|
||||
else
|
||||
{
|
||||
tokens.Add(new Token(
|
||||
id,
|
||||
GetTokenString(id, symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length, text.AsSpan()),
|
||||
GetTokenString(id, symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length, text),
|
||||
(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length)));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
Segment(symbols[index].pieceSpan, text.AsSpan());
|
||||
Segment(symbols[index].pieceSpan, text);
|
||||
}
|
||||
|
||||
ArrayPool<BpeSymbol>.Shared.Return(symbols);
|
||||
|
@ -314,11 +308,14 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <summary>
|
||||
/// Encode a text to a list of Ids and add them to the accumulatedIds list.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
|
||||
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token.</param>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
|
||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
/// <remarks>The input text has to be normalized before calling this method.</remarks>
|
||||
public override void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIds(text, AddBeginningOfSentence, AddEndOfSentence, accumulatedIds);
|
||||
public override int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue)
|
||||
=> EncodeToIds(text, AddBeginningOfSentence, AddEndOfSentence, accumulatedIds, out textLength, maxTokens);
|
||||
|
||||
/// <summary>
|
||||
/// Encode a text to a list of Ids and add them to the accumulatedIds list.
|
||||
|
@ -327,26 +324,29 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="addBeginOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
|
||||
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
|
||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
/// <remarks>The input text has to be normalized before calling this method.</remarks>
|
||||
public void EncodeToIds(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, IList<int> accumulatedIds)
|
||||
public int EncodeToIds(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue)
|
||||
{
|
||||
if (maxTokens <= 0)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(maxTokens), "The maximum number of tokens must be greater than 0.");
|
||||
}
|
||||
|
||||
textLength = 0;
|
||||
if (text.IsEmpty)
|
||||
{
|
||||
return;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int idsCount = 0;
|
||||
|
||||
if (addBeginOfSentence)
|
||||
{
|
||||
accumulatedIds.Add(BeginningOfSentenceId);
|
||||
}
|
||||
|
||||
if (text.IsEmpty)
|
||||
{
|
||||
if (addEndOfSentence)
|
||||
{
|
||||
accumulatedIds.Add(EndOfSentenceId);
|
||||
}
|
||||
return;
|
||||
idsCount++;
|
||||
}
|
||||
|
||||
BpeSymbol[] symbols = ArrayPool<BpeSymbol>.Shared.Rent(text.Length);
|
||||
|
@ -376,34 +376,64 @@ namespace Microsoft.ML.Tokenizers
|
|||
{
|
||||
if (id == UnknownId && ByteFallback)
|
||||
{
|
||||
EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index);
|
||||
if (!EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref textLength))
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
accumulatedIds.Add(id);
|
||||
if (idsCount < maxTokens)
|
||||
{
|
||||
accumulatedIds.Add(id);
|
||||
textLength += symbols[index].pieceSpan.Length;
|
||||
idsCount++;
|
||||
}
|
||||
else
|
||||
{
|
||||
return idsCount;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
Segment(symbols[index].pieceSpan, text);
|
||||
if (!Segment(symbols[index].pieceSpan, text, ref textLength))
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
ArrayPool<BpeSymbol>.Shared.Return(symbols);
|
||||
|
||||
if (addEndOfSentence)
|
||||
{
|
||||
accumulatedIds.Add(EndOfSentenceId);
|
||||
if (idsCount < maxTokens)
|
||||
{
|
||||
accumulatedIds.Add(EndOfSentenceId);
|
||||
idsCount++;
|
||||
}
|
||||
}
|
||||
|
||||
return idsCount;
|
||||
|
||||
// Encode the Unknown token to bytes.
|
||||
void EncodeAsBytes(ReadOnlySpan<char> text, int index)
|
||||
bool EncodeAsBytes(ReadOnlySpan<char> text, int index, ref int textLength)
|
||||
{
|
||||
for (int i = 0; i < text.Length; i++)
|
||||
{
|
||||
char c = text[i];
|
||||
if (c <= 0x7F)
|
||||
{
|
||||
accumulatedIds.Add((int)c + _byteCodeToIdOffset); // byte code is mapped to the to the Ids starting from 4.
|
||||
if (idsCount < maxTokens)
|
||||
{
|
||||
textLength++;
|
||||
accumulatedIds.Add((int)c + _byteCodeToIdOffset); // byte code is mapped to the to the Ids starting from 4.
|
||||
idsCount++;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -419,9 +449,21 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
// Need to convert the text into UTF-8 bytes and then encode the bytes.
|
||||
int bytesWritten = Helpers.GetUtf8Bytes(text.Slice(i), utf8Bytes);
|
||||
for (int j = 0; j < bytesWritten; j++)
|
||||
|
||||
bool ret;
|
||||
if (idsCount + bytesWritten <= maxTokens)
|
||||
{
|
||||
accumulatedIds.Add((int)utf8Bytes[j] + _byteCodeToIdOffset); // byte code is mapped to the to the Ids starting from 4.
|
||||
for (int j = 0; j < bytesWritten; j++)
|
||||
{
|
||||
accumulatedIds.Add((int)utf8Bytes[j] + _byteCodeToIdOffset); // byte code is mapped to the to the Ids starting from 4.
|
||||
}
|
||||
|
||||
textLength += text.Length - i;
|
||||
ret = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
ret = false;
|
||||
}
|
||||
|
||||
if (arrayPoolArray is not null)
|
||||
|
@ -429,40 +471,60 @@ namespace Microsoft.ML.Tokenizers
|
|||
ArrayPool<byte>.Shared.Return(arrayPoolArray);
|
||||
}
|
||||
|
||||
break;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void Segment((int Index, int Length) pieceSpan, ReadOnlySpan<char> text)
|
||||
bool Segment((int Index, int Length) pieceSpan, ReadOnlySpan<char> text, ref int textLength)
|
||||
{
|
||||
if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id))
|
||||
{
|
||||
EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index);
|
||||
return;
|
||||
return EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index, ref textLength);
|
||||
}
|
||||
|
||||
if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused ||
|
||||
revMerge is null ||
|
||||
!revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge))
|
||||
{
|
||||
accumulatedIds.Add(id.Id);
|
||||
return;
|
||||
if (idsCount < maxTokens)
|
||||
{
|
||||
accumulatedIds.Add(id.Id);
|
||||
textLength += pieceSpan.Length;
|
||||
idsCount++;
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
Segment((merge.LeftIndex, merge.LeftLen), text);
|
||||
Segment((merge.RightIndex, merge.RightLen), text);
|
||||
return Segment((merge.LeftIndex, merge.LeftLen), text, ref textLength) && Segment((merge.RightIndex, merge.RightLen), text, ref textLength);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
|
||||
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token.</param>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
/// <remarks>The input text has to be normalized before calling this method.</remarks>
|
||||
public override int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken) => CountTokens(text, AddBeginningOfSentence, AddEndOfSentence);
|
||||
public override int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue) => CountTokens(text, AddBeginningOfSentence, AddEndOfSentence, out textLength, maxTokens);
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
|
||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
public override int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue) => CountTokensFromEnd(text, AddBeginningOfSentence, AddEndOfSentence, out textIndex, maxTokens);
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
|
@ -470,10 +532,13 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="addBeginOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
|
||||
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
/// <remarks>The input text has to be normalized before calling this method.</remarks>
|
||||
public int CountTokens(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence)
|
||||
public int CountTokens(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, out int textLength, int maxTokens = int.MaxValue)
|
||||
{
|
||||
textLength = 0;
|
||||
if (text.IsEmpty)
|
||||
{
|
||||
return 0;
|
||||
|
@ -508,36 +573,61 @@ namespace Microsoft.ML.Tokenizers
|
|||
{
|
||||
if (id == UnknownId && ByteFallback)
|
||||
{
|
||||
EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index);
|
||||
if (!EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref textLength))
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
tokenCount++;
|
||||
if (tokenCount < maxTokens)
|
||||
{
|
||||
tokenCount++;
|
||||
textLength += symbols[index].pieceSpan.Length;
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
Segment(symbols[index].pieceSpan, text);
|
||||
if (!Segment(symbols[index].pieceSpan, text, ref textLength))
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
ArrayPool<BpeSymbol>.Shared.Return(symbols);
|
||||
|
||||
if (addEndOfSentence)
|
||||
{
|
||||
tokenCount++;
|
||||
if (tokenCount < maxTokens)
|
||||
{
|
||||
tokenCount++;
|
||||
}
|
||||
}
|
||||
|
||||
return tokenCount;
|
||||
|
||||
// Encode the Unknown token to bytes.
|
||||
void EncodeAsBytes(ReadOnlySpan<char> text, int index)
|
||||
bool EncodeAsBytes(ReadOnlySpan<char> text, int index, ref int textLength)
|
||||
{
|
||||
for (int i = 0; i < text.Length; i++)
|
||||
{
|
||||
char c = text[i];
|
||||
if (c <= 0x7F)
|
||||
{
|
||||
tokenCount++;
|
||||
if (tokenCount < maxTokens)
|
||||
{
|
||||
tokenCount++;
|
||||
textLength++;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -552,36 +642,233 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
|
||||
// Need to convert the text into UTF-8 bytes and then encode the bytes.
|
||||
tokenCount += Helpers.GetUtf8Bytes(text.Slice(i), utf8Bytes);
|
||||
int encodedCount = Helpers.GetUtf8Bytes(text.Slice(i), utf8Bytes);
|
||||
bool ret;
|
||||
|
||||
if (tokenCount + encodedCount <= maxTokens)
|
||||
{
|
||||
tokenCount += encodedCount;
|
||||
textLength += text.Length - i;
|
||||
ret = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
ret = false;
|
||||
}
|
||||
|
||||
if (arrayPoolArray is not null)
|
||||
{
|
||||
ArrayPool<byte>.Shared.Return(arrayPoolArray);
|
||||
}
|
||||
|
||||
break;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void Segment((int Index, int Length) pieceSpan, ReadOnlySpan<char> text)
|
||||
bool Segment((int Index, int Length) pieceSpan, ReadOnlySpan<char> text, ref int textLength)
|
||||
{
|
||||
if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id))
|
||||
{
|
||||
EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index);
|
||||
return;
|
||||
return EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index, ref textLength);
|
||||
}
|
||||
|
||||
if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused ||
|
||||
revMerge is null ||
|
||||
!revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge))
|
||||
{
|
||||
tokenCount++;
|
||||
return;
|
||||
if (tokenCount < maxTokens)
|
||||
{
|
||||
tokenCount++;
|
||||
textLength += pieceSpan.Length;
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
Segment((merge.LeftIndex, merge.LeftLen), text);
|
||||
Segment((merge.RightIndex, merge.RightLen), text);
|
||||
return Segment((merge.LeftIndex, merge.LeftLen), text, ref textLength) && Segment((merge.RightIndex, merge.RightLen), text, ref textLength);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="addBeginOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
|
||||
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
|
||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
/// <remarks>The input text has to be normalized before calling this method.</remarks>
|
||||
public int CountTokensFromEnd(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, out int textIndex, int maxTokens = int.MaxValue)
|
||||
{
|
||||
textIndex = text.Length;
|
||||
if (text.IsEmpty)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
int tokenCount = addEndOfSentence ? 1 : 0;
|
||||
|
||||
BpeSymbol[] symbols = ArrayPool<BpeSymbol>.Shared.Rent(text.Length);
|
||||
|
||||
Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols);
|
||||
|
||||
// Move to the last symbol.
|
||||
int lastSymbolIndex = 0;
|
||||
while (symbols[lastSymbolIndex].next != -1 && lastSymbolIndex < symbols.Length)
|
||||
{
|
||||
lastSymbolIndex = symbols[lastSymbolIndex].next;
|
||||
}
|
||||
|
||||
for (int index = lastSymbolIndex; index >= 0; index = symbols[index].prev)
|
||||
{
|
||||
int id = symbols[index].id;
|
||||
byte type = symbols[index].type;
|
||||
|
||||
if (id == UninitializedId)
|
||||
{
|
||||
if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo))
|
||||
{
|
||||
id = tokenInfo.Id;
|
||||
type = tokenInfo.Type;
|
||||
}
|
||||
else
|
||||
{
|
||||
id = UnknownId;
|
||||
type = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused)
|
||||
{
|
||||
if (id == UnknownId && ByteFallback)
|
||||
{
|
||||
if (!EncodeAsBytesFromEnd(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref textIndex))
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (tokenCount < maxTokens)
|
||||
{
|
||||
tokenCount++;
|
||||
textIndex -= symbols[index].pieceSpan.Length;
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!SegmentFromEnd(symbols[index].pieceSpan, text, ref textIndex))
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
ArrayPool<BpeSymbol>.Shared.Return(symbols);
|
||||
|
||||
if (AddBeginningOfSentence)
|
||||
{
|
||||
if (tokenCount < maxTokens)
|
||||
{
|
||||
tokenCount++;
|
||||
}
|
||||
}
|
||||
|
||||
return tokenCount;
|
||||
|
||||
// Encode the Unknown token to bytes.
|
||||
bool EncodeAsBytesFromEnd(ReadOnlySpan<char> text, int index, ref int textIndex)
|
||||
{
|
||||
for (int i = text.Length - 1; i >= 0; i--)
|
||||
{
|
||||
char c = text[i];
|
||||
if (c <= 0x7F)
|
||||
{
|
||||
if (tokenCount < maxTokens)
|
||||
{
|
||||
tokenCount++;
|
||||
textIndex--;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
Span<byte> utf8Bytes = stackalloc byte[100];
|
||||
byte[]? arrayPoolArray = null;
|
||||
|
||||
int len = Encoding.UTF8.GetMaxByteCount(text.Length - i);
|
||||
if (len > utf8Bytes.Length)
|
||||
{
|
||||
arrayPoolArray = ArrayPool<byte>.Shared.Rent(len);
|
||||
utf8Bytes = arrayPoolArray;
|
||||
}
|
||||
|
||||
// Need to convert the text into UTF-8 bytes and then encode the bytes.
|
||||
int encodedCount = Helpers.GetUtf8Bytes(text.Slice(0, i + 1), utf8Bytes);
|
||||
bool ret;
|
||||
|
||||
if (tokenCount + encodedCount <= maxTokens)
|
||||
{
|
||||
tokenCount += encodedCount;
|
||||
textIndex -= i + 1;
|
||||
ret = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
ret = false;
|
||||
}
|
||||
|
||||
if (arrayPoolArray is not null)
|
||||
{
|
||||
ArrayPool<byte>.Shared.Return(arrayPoolArray);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SegmentFromEnd((int Index, int Length) pieceSpan, ReadOnlySpan<char> text, ref int textIndex)
|
||||
{
|
||||
if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id))
|
||||
{
|
||||
return EncodeAsBytesFromEnd(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index, ref textIndex);
|
||||
}
|
||||
|
||||
if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused ||
|
||||
revMerge is null ||
|
||||
!revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge))
|
||||
{
|
||||
if (tokenCount < maxTokens)
|
||||
{
|
||||
tokenCount++;
|
||||
textIndex -= pieceSpan.Length;
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Segment the right part first.
|
||||
return SegmentFromEnd((merge.RightIndex, merge.RightLen), text, ref textIndex) && SegmentFromEnd((merge.LeftIndex, merge.LeftLen), text, ref textIndex);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -589,32 +876,28 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// Map the token to encoded id with the option to skip the special tokens.
|
||||
/// </summary>
|
||||
/// <param name="token">The token to map to Id</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
|
||||
/// <returns>The mapped Id of the token.</returns>
|
||||
public override int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true)
|
||||
public override int? MapTokenToId(ReadOnlySpan<char> token)
|
||||
=> _vocab.TryGetValue(token, out (int Id, float Score, byte Type) value) ? value.Id : null;
|
||||
|
||||
/// <summary>
|
||||
/// Map the encoded Id to the token.
|
||||
/// </summary>
|
||||
/// <param name="id">The Id to map to the token.</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the decoding.</param>
|
||||
/// <returns>The mapped token of the Id.</returns>
|
||||
public override string? MapIdToToken(int id, bool considerSpecialTokens = true)
|
||||
public override string? MapIdToToken(int id)
|
||||
=> _vocabReverse.TryGetValue(id, out string? value) ? value : null;
|
||||
|
||||
/// <summary>
|
||||
/// Decode the given ids, back to a String.
|
||||
/// </summary>
|
||||
/// <param name="ids">The list of ids that we want to decode.</param>
|
||||
/// <param name="considerSpecialTokens">Whether the special tokens should be kept in the decoded string.</param>
|
||||
/// <param name="decoder">The optional Decoder to merge the given list of tokens in a string.</param>
|
||||
/// <returns>The decoded string.</returns>
|
||||
/// <remarks>
|
||||
/// The decoder is not used here because the SentencePiece Bpe model knows how to decode the ids in additions to avoid any performance overhead.
|
||||
/// considerSpecialTokens is not used here because the SentencePiece Bpe model always remove unknown or control tokens during the decoding.
|
||||
/// </remarks>
|
||||
public override string? Decode(IEnumerable<int> ids, TokenizerDecoder? decoder = null, bool considerSpecialTokens = true)
|
||||
public override string? Decode(IEnumerable<int> ids, TokenizerDecoder? decoder = null)
|
||||
{
|
||||
if (ids is null)
|
||||
{
|
||||
|
|
|
@ -24,12 +24,10 @@ namespace Microsoft.ML.Tokenizers
|
|||
{
|
||||
private readonly Dictionary<ReadOnlyMemory<byte>, int> _encoder;
|
||||
private readonly Dictionary<int, ReadOnlyMemory<byte>> _decoder;
|
||||
private readonly LruCache<int[]> _cache;
|
||||
private readonly Dictionary<StringSpanOrdinalKey, int>? _specialTokensEncoder;
|
||||
private Dictionary<string, int>? _specialTokensEncoderOriginal;
|
||||
private readonly Dictionary<int, string>? _specialTokensDecoder;
|
||||
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;
|
||||
private readonly LruCache<(int[] Bytes, string Token)> _cache;
|
||||
private readonly Dictionary<StringSpanOrdinalKey, (int Id, string Token)> _vocab;
|
||||
private IReadOnlyDictionary<string, int>? _vocabOriginal;
|
||||
private const int MaxWordLengthToCache = 15;
|
||||
|
||||
/// <summary>
|
||||
/// Create a new Tiktoken tokenizer's model object.
|
||||
|
@ -68,7 +66,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
internal Tiktoken(
|
||||
Dictionary<ReadOnlyMemory<byte>, int> encoder,
|
||||
Dictionary<int, ReadOnlyMemory<byte>> decoder,
|
||||
Dictionary<StringSpanOrdinalKey, int> vocab,
|
||||
Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab,
|
||||
IReadOnlyDictionary<string, int>? specialTokens,
|
||||
int cacheSize = LruCache<int[]>.DefaultCacheSize)
|
||||
{
|
||||
|
@ -76,23 +74,24 @@ namespace Microsoft.ML.Tokenizers
|
|||
_decoder = decoder ?? throw new ArgumentNullException(nameof(decoder));
|
||||
_vocab = vocab ?? throw new ArgumentNullException(nameof(vocab));
|
||||
|
||||
Debug.Assert(encoder.Count == decoder.Count);
|
||||
|
||||
_encoder = encoder!;
|
||||
_decoder = decoder!;
|
||||
_vocab = vocab!;
|
||||
_cache = new LruCache<int[]>(cacheSize);
|
||||
_cache = new LruCache<(int[] Bytes, string Token)>(cacheSize);
|
||||
|
||||
(_specialTokensEncoder, _specialTokensDecoder) = CreateEncoderDecoder(specialTokens);
|
||||
SpecialTokens = specialTokens;
|
||||
CacheSpecialTokensEncoding(specialTokens);
|
||||
}
|
||||
|
||||
private Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTokens, int cacheSize, bool disposeStream)
|
||||
{
|
||||
try
|
||||
{
|
||||
_cache = new LruCache<int[]>(cacheSize);
|
||||
(_encoder, _vocab, _decoder) = LoadTikTokenBpeAsync(vocabStream, useAsync: false).GetAwaiter().GetResult();
|
||||
(_specialTokensEncoder, _specialTokensDecoder) = CreateEncoderDecoder(specialTokens);
|
||||
_cache = new LruCache<(int[] Bytes, string Token)>(cacheSize);
|
||||
(_encoder, _vocab, _decoder) = LoadTiktokenBpeAsync(vocabStream, useAsync: false).GetAwaiter().GetResult();
|
||||
|
||||
SpecialTokens = specialTokens;
|
||||
CacheSpecialTokensEncoding(specialTokens);
|
||||
}
|
||||
finally
|
||||
{
|
||||
|
@ -103,15 +102,19 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
}
|
||||
|
||||
private static (Dictionary<StringSpanOrdinalKey, int>?, Dictionary<int, string>?) CreateEncoderDecoder(IReadOnlyDictionary<string, int>? specialTokens)
|
||||
private void CacheSpecialTokensEncoding(IReadOnlyDictionary<string, int>? specialTokens)
|
||||
{
|
||||
Debug.Assert(_cache is not null);
|
||||
Debug.Assert(_decoder is not null);
|
||||
|
||||
if (specialTokens is not null)
|
||||
{
|
||||
var encoder = specialTokens.ToDictionary(e => new StringSpanOrdinalKey(e.Key), e => e.Value);
|
||||
return (encoder, encoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key.Data!));
|
||||
foreach (KeyValuePair<string, int> specialToken in specialTokens)
|
||||
{
|
||||
_decoder![specialToken.Value] = Encoding.UTF8.GetBytes(specialToken.Key);
|
||||
_cache!.Add(specialToken.Key, (new[] { specialToken.Value }, specialToken.Key));
|
||||
}
|
||||
}
|
||||
|
||||
return (null, null);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -133,8 +136,8 @@ namespace Microsoft.ML.Tokenizers
|
|||
throw new ArgumentNullException(nameof(vocabStream));
|
||||
}
|
||||
|
||||
(Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) =
|
||||
await LoadTikTokenBpeAsync(vocabStream, useAsync: true, cancellationToken).ConfigureAwait(false);
|
||||
(Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) =
|
||||
await LoadTiktokenBpeAsync(vocabStream, useAsync: true, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
return new Tiktoken(encoder, decoder, vocab, specialTokens, cacheSize);
|
||||
}
|
||||
|
@ -170,11 +173,11 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
|
||||
/// <returns>Map of byte[] to integer token id</returns>
|
||||
/// <exception cref="InvalidOperationException"></exception>
|
||||
internal static async ValueTask<(Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<StringSpanOrdinalKey, int>, Dictionary<int, ReadOnlyMemory<byte>>)> LoadTikTokenBpeAsync(
|
||||
internal static async ValueTask<(Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<StringSpanOrdinalKey, (int Id, string Token)>, Dictionary<int, ReadOnlyMemory<byte>>)> LoadTiktokenBpeAsync(
|
||||
Stream vocabStream, bool useAsync, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var encoder = new Dictionary<ReadOnlyMemory<byte>, int>(ReadOnlyMemoryByteComparer.Instance);
|
||||
var vocab = new Dictionary<StringSpanOrdinalKey, int>();
|
||||
var vocab = new Dictionary<StringSpanOrdinalKey, (int Id, string Token)>();
|
||||
var decoder = new Dictionary<int, ReadOnlyMemory<byte>>();
|
||||
|
||||
try
|
||||
|
@ -212,7 +215,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
if (decodedToken.IndexOf('\uFFFD') < 0)
|
||||
{
|
||||
vocab[new StringSpanOrdinalKey(decodedToken)] = rank;
|
||||
vocab[new StringSpanOrdinalKey(decodedToken)] = (rank, decodedToken);
|
||||
}
|
||||
}
|
||||
else
|
||||
|
@ -231,140 +234,98 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Encode a split text string to a list of tokens.
|
||||
/// Encode text to a list of tokens.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
|
||||
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token.</param>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <returns>The list of tokens generated from the text tokenization.</returns>
|
||||
public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = false)
|
||||
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text)
|
||||
{
|
||||
Token[] tokens;
|
||||
|
||||
if (string.IsNullOrEmpty(text))
|
||||
if (text.IsEmpty)
|
||||
{
|
||||
return Array.Empty<Token>();
|
||||
}
|
||||
|
||||
if (isSpecialToken)
|
||||
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
|
||||
{
|
||||
if (_specialTokensEncoder?.TryGetValue(text, out int id) is true)
|
||||
{
|
||||
return new List<Token> { new(id, text, (0, text.Length)) };
|
||||
}
|
||||
|
||||
throw new InvalidOperationException($"The special token {text} doesn't exist in the tokenizer");
|
||||
}
|
||||
|
||||
if (_cache.TryGetValue(text, out int[]? ids))
|
||||
{
|
||||
tokens = new Token[ids.Length];
|
||||
tokens[0] = new Token(ids[0], text, (0, text.Length));
|
||||
for (int i = 1; i < ids.Length; i++)
|
||||
tokens = new Token[value.Ids.Length];
|
||||
tokens[0] = new Token(value.Ids[0], value.Token, (0, value.Token.Length));
|
||||
for (int i = 1; i < value.Ids.Length; i++)
|
||||
{
|
||||
// One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width.
|
||||
tokens[i] = new Token(ids[i], "", (text.Length, 0));
|
||||
tokens[i] = new Token(value.Ids[i], "", (text.Length, 0));
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
// cache miss
|
||||
if (_vocab.TryGetValue(text, out int mappedId))
|
||||
if (_vocab.TryGetValue(text, out (int Id, string Token) mappedId))
|
||||
{
|
||||
return new Token[1] { new(mappedId, text, (0, text.Length)) };
|
||||
}
|
||||
|
||||
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
|
||||
int encodedLength = Helpers.GetUtf8Bytes(text.AsSpan(), arrayPoolArray);
|
||||
|
||||
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
|
||||
Debug.Assert(encodedIds.Length > 0);
|
||||
_cache.Add(text, encodedIds);
|
||||
|
||||
tokens = new Token[encodedIds.Length];
|
||||
tokens[0] = new Token(encodedIds[0], text, (0, text.Length));
|
||||
for (int i = 1; i < encodedIds.Length; i++)
|
||||
{
|
||||
// One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width.
|
||||
tokens[i] = new Token(encodedIds[i], "", (text.Length, 0));
|
||||
}
|
||||
|
||||
ArrayPool<byte>.Shared.Return(arrayPoolArray);
|
||||
return tokens;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Encode text to a list of Ids.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
|
||||
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token.</param>
|
||||
/// <param name="accumulatedIds">The list of accumulated Ids.</param>
|
||||
public override void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds)
|
||||
{
|
||||
if (text.IsEmpty)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if (isSpecialToken)
|
||||
{
|
||||
if (_specialTokensEncoder?.TryGetValue(text, out int id) is true)
|
||||
{
|
||||
accumulatedIds.Add(id);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (_cache.TryGetValue(text, out int[]? tokenIds))
|
||||
{
|
||||
accumulatedIds.AddRange(tokenIds);
|
||||
return;
|
||||
}
|
||||
|
||||
if (_vocab.TryGetValue(text, out int mappedId))
|
||||
{
|
||||
accumulatedIds.Add(mappedId);
|
||||
return;
|
||||
return new Token[1] { new(mappedId.Id, mappedId.Token, (0, mappedId.Token.Length)) };
|
||||
}
|
||||
|
||||
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
|
||||
int encodedLength = Helpers.GetUtf8Bytes(text, arrayPoolArray);
|
||||
|
||||
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
|
||||
_cache.Add(text.ToString(), encodedIds);
|
||||
|
||||
accumulatedIds.AddRange(encodedIds);
|
||||
|
||||
ArrayPool<byte>.Shared.Return(arrayPoolArray);
|
||||
return;
|
||||
|
||||
Debug.Assert(encodedIds.Length > 0);
|
||||
string textAsString = text.ToString();
|
||||
|
||||
if (text.Length <= MaxWordLengthToCache)
|
||||
{
|
||||
_cache.Add(textAsString, (encodedIds, textAsString));
|
||||
}
|
||||
|
||||
tokens = new Token[encodedIds.Length];
|
||||
tokens[0] = new Token(encodedIds[0], textAsString, (0, text.Length));
|
||||
for (int i = 1; i < encodedIds.Length; i++)
|
||||
{
|
||||
// One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width.
|
||||
tokens[i] = new Token(encodedIds[i], "", (text.Length, 0));
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
/// Encode text to a list of Ids.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode. If the value of the parameter <paramref name="isSpecialToken"/> is true, the entire text will be treated as a special token.</param>
|
||||
/// <param name="isSpecialToken">Specifies whether the entire <paramref name="text"/> is considered a special token.</param>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="accumulatedIds">The list of accumulated Ids.</param>
|
||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
public override int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken)
|
||||
public override int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue)
|
||||
{
|
||||
Debug.Assert(maxTokens > 0);
|
||||
|
||||
if (text.IsEmpty)
|
||||
{
|
||||
textLength = 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (isSpecialToken && _specialTokensEncoder is not null)
|
||||
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
|
||||
{
|
||||
return _specialTokensEncoder.TryGetValue(text, out _) ? 1 : 0;
|
||||
if (value.Ids.Length <= maxTokens)
|
||||
{
|
||||
accumulatedIds.AddRange(value.Ids);
|
||||
textLength = text.Length;
|
||||
return value.Ids.Length;
|
||||
}
|
||||
|
||||
textLength = 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (_cache.TryGetValue(text, out int[] ids))
|
||||
{
|
||||
return ids.Length;
|
||||
}
|
||||
|
||||
if (_vocab.TryGetValue(text, out _))
|
||||
if (_vocab.TryGetValue(text, out (int Id, string Token) mappedId))
|
||||
{
|
||||
textLength = text.Length;
|
||||
accumulatedIds.Add(mappedId.Id);
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
@ -372,46 +333,178 @@ namespace Microsoft.ML.Tokenizers
|
|||
int encodedLength = Helpers.GetUtf8Bytes(text, arrayPoolArray);
|
||||
|
||||
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
|
||||
_cache.Add(text.ToString(), encodedIds);
|
||||
|
||||
if (text.Length <= MaxWordLengthToCache)
|
||||
{
|
||||
string textAsString = text.ToString();
|
||||
_cache.Add(textAsString, (encodedIds, textAsString));
|
||||
}
|
||||
|
||||
int result;
|
||||
if (encodedIds.Length <= maxTokens)
|
||||
{
|
||||
accumulatedIds.AddRange(encodedIds);
|
||||
textLength = text.Length;
|
||||
result = encodedIds.Length;
|
||||
}
|
||||
else
|
||||
{
|
||||
textLength = 0;
|
||||
result = 0;
|
||||
}
|
||||
|
||||
ArrayPool<byte>.Shared.Return(arrayPoolArray);
|
||||
return encodedIds.Length;
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
public override int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue)
|
||||
{
|
||||
Debug.Assert(maxTokens > 0);
|
||||
|
||||
if (text.IsEmpty)
|
||||
{
|
||||
textLength = 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
|
||||
{
|
||||
if (value.Ids.Length <= maxTokens)
|
||||
{
|
||||
textLength = text.Length;
|
||||
return value.Ids.Length;
|
||||
}
|
||||
|
||||
textLength = 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (_vocab.TryGetValue(text, out _))
|
||||
{
|
||||
textLength = text.Length;
|
||||
return 1;
|
||||
}
|
||||
|
||||
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
|
||||
int encodedLength = Helpers.GetUtf8Bytes(text, arrayPoolArray);
|
||||
|
||||
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
|
||||
if (text.Length <= MaxWordLengthToCache)
|
||||
{
|
||||
string textAsString = text.ToString();
|
||||
_cache.Add(textAsString, (encodedIds, textAsString));
|
||||
}
|
||||
|
||||
int result;
|
||||
if (encodedIds.Length <= maxTokens)
|
||||
{
|
||||
textLength = text.Length;
|
||||
result = encodedIds.Length;
|
||||
}
|
||||
else
|
||||
{
|
||||
textLength = 0;
|
||||
result = 0;
|
||||
}
|
||||
|
||||
ArrayPool<byte>.Shared.Return(arrayPoolArray);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
|
||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
public override int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue)
|
||||
{
|
||||
Debug.Assert(maxTokens > 0);
|
||||
|
||||
if (text.IsEmpty)
|
||||
{
|
||||
textIndex = 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
|
||||
{
|
||||
if (value.Ids.Length <= maxTokens)
|
||||
{
|
||||
textIndex = 0;
|
||||
return value.Ids.Length;
|
||||
}
|
||||
|
||||
textIndex = text.Length;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (_vocab.TryGetValue(text, out _))
|
||||
{
|
||||
textIndex = 0;
|
||||
return 1;
|
||||
}
|
||||
|
||||
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
|
||||
int encodedLength = Helpers.GetUtf8Bytes(text, arrayPoolArray);
|
||||
|
||||
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
|
||||
|
||||
if (text.Length <= MaxWordLengthToCache)
|
||||
{
|
||||
string textAsString = text.ToString();
|
||||
_cache.Add(textAsString, (encodedIds, textAsString));
|
||||
}
|
||||
|
||||
int result;
|
||||
if (encodedIds.Length <= maxTokens)
|
||||
{
|
||||
textIndex = 0;
|
||||
result = encodedIds.Length;
|
||||
}
|
||||
else
|
||||
{
|
||||
textIndex = text.Length;
|
||||
result = 0;
|
||||
}
|
||||
|
||||
ArrayPool<byte>.Shared.Return(arrayPoolArray);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Map the token to encoded Id.
|
||||
/// </summary>
|
||||
/// <param name="token">The token to map to the Id.</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
|
||||
/// <returns>The mapped Id of the token.</returns>
|
||||
public override int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true)
|
||||
public override int? MapTokenToId(ReadOnlySpan<char> token)
|
||||
{
|
||||
if (token.IsEmpty)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (considerSpecialTokens && _specialTokensEncoder is not null)
|
||||
if (_cache.TryGetValue(token, out (int[] Ids, string Token) value))
|
||||
{
|
||||
if (_specialTokensEncoder.TryGetValue(token, out int specialTokenId))
|
||||
if (value.Ids.Length == 1)
|
||||
{
|
||||
return specialTokenId;
|
||||
}
|
||||
}
|
||||
|
||||
if (_cache.TryGetValue(token, out int[] ids))
|
||||
{
|
||||
if (ids.Length == 1)
|
||||
{
|
||||
return ids[0];
|
||||
return value.Ids[0];
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
if (_vocab.TryGetValue(token, out int id))
|
||||
if (_vocab.TryGetValue(token, out (int Id, string Token) id))
|
||||
{
|
||||
return id;
|
||||
return id.Id;
|
||||
}
|
||||
|
||||
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(token.Length));
|
||||
|
@ -420,7 +513,12 @@ namespace Microsoft.ML.Tokenizers
|
|||
int encodedLength = Helpers.GetUtf8Bytes(token, arrayPoolArray);
|
||||
|
||||
int[] idsToCache = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
|
||||
_cache.Add(token.ToString(), idsToCache);
|
||||
|
||||
if (token.Length <= MaxWordLengthToCache)
|
||||
{
|
||||
string tokenAsString = token.ToString();
|
||||
_cache.Add(tokenAsString, (idsToCache, tokenAsString));
|
||||
}
|
||||
|
||||
if (idsToCache.Length == 1)
|
||||
{
|
||||
|
@ -439,15 +537,9 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// Map the encoded Id to the token.
|
||||
/// </summary>
|
||||
/// <param name="id">The Id to map to the token.</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the decoding.</param>
|
||||
/// <returns>The mapped token of the Id.</returns>
|
||||
public override string? MapIdToToken(int id, bool considerSpecialTokens = true)
|
||||
public override string? MapIdToToken(int id)
|
||||
{
|
||||
if (considerSpecialTokens && _specialTokensDecoder is not null && _specialTokensDecoder.TryGetValue(id, out string? token))
|
||||
{
|
||||
return token;
|
||||
}
|
||||
|
||||
if (_decoder.TryGetValue(id, out ReadOnlyMemory<byte> tokenBytes))
|
||||
{
|
||||
return Helpers.GetString(tokenBytes.Span);
|
||||
|
@ -460,10 +552,9 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// Decode the given ids, back to a String.
|
||||
/// </summary>
|
||||
/// <param name="ids">The list of ids that we want to decode.</param>
|
||||
/// <param name="considerSpecialTokens">Whether the special tokens should be kept in the decoded string.</param>
|
||||
/// <param name="decoder">The optional Decoder to merge the given list of tokens in a string.</param>
|
||||
/// <returns>The decoded string.</returns>
|
||||
public override string? Decode(IEnumerable<int> ids, TokenizerDecoder? decoder = null, bool considerSpecialTokens = true)
|
||||
public override string? Decode(IEnumerable<int> ids, TokenizerDecoder? decoder = null)
|
||||
{
|
||||
// Tiktoken doesn't guarantee a one-to-one correspondence between IDs and UTF-16 words.
|
||||
// Consequently, decoding individual IDs into UTF-16 string is not supported; instead, decoding all IDs must be performed collectively.
|
||||
|
@ -483,8 +574,6 @@ namespace Microsoft.ML.Tokenizers
|
|||
Span<byte> utf8Bytes = stackalloc byte[256];
|
||||
int utf8ByteCount = 0;
|
||||
|
||||
bool useSpecialTokens = considerSpecialTokens && _specialTokensDecoder is not null;
|
||||
|
||||
foreach (int id in ids)
|
||||
{
|
||||
if (_decoder.TryGetValue(id, out ReadOnlyMemory<byte> tokenBytes))
|
||||
|
@ -497,19 +586,6 @@ namespace Microsoft.ML.Tokenizers
|
|||
tokenBytes.Span.CopyTo(utf8Bytes.Slice(utf8ByteCount));
|
||||
utf8ByteCount += tokenBytes.Length;
|
||||
}
|
||||
else if (useSpecialTokens && _specialTokensDecoder!.TryGetValue(id, out string? token))
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
if (Helpers.TryGetUtf8Bytes(token.AsSpan(), utf8Bytes.Slice(utf8ByteCount), out int bytesWritten))
|
||||
{
|
||||
utf8ByteCount += bytesWritten;
|
||||
break;
|
||||
}
|
||||
|
||||
ArrayPoolGrow(ref utf8Bytes, ref arrayPoolArray, utf8ByteCount + Encoding.UTF8.GetByteCount(token));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return null;
|
||||
|
@ -543,12 +619,12 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// Gets the dictionary mapping tokens to Ids.
|
||||
/// </summary>
|
||||
/// <remarks>This may not contain the full set of vocabulary tokens, use Encoder to get the full set of vocabulary.</remarks>
|
||||
public IReadOnlyDictionary<string, int> Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
|
||||
public IReadOnlyDictionary<string, int> Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value.Id);
|
||||
|
||||
/// <summary>
|
||||
/// Gets the dictionary mapping special tokens to Ids.
|
||||
/// </summary>
|
||||
public IReadOnlyDictionary<string, int>? SpecialTokensEncoder => _specialTokensEncoderOriginal ??= _specialTokensEncoder?.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
|
||||
public IReadOnlyDictionary<string, int>? SpecialTokens { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the dictionary mapping token bytes to Ids.
|
||||
|
@ -732,31 +808,31 @@ namespace Microsoft.ML.Tokenizers
|
|||
case ModelEncoding.Cl100kBase:
|
||||
var specialTokens = new Dictionary<string, int>
|
||||
{ { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} };
|
||||
return CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
|
||||
return CreateTiktokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
|
||||
|
||||
case ModelEncoding.P50kBase:
|
||||
specialTokens = new Dictionary<string, int> { { EndOfText, 50256 } };
|
||||
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
|
||||
return CreateTiktokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
|
||||
|
||||
case ModelEncoding.P50kEdit:
|
||||
specialTokens = new Dictionary<string, int>
|
||||
{ { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } };
|
||||
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
|
||||
return CreateTiktokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
|
||||
|
||||
case ModelEncoding.R50kBase:
|
||||
specialTokens = new Dictionary<string, int> { { EndOfText, 50256 } };
|
||||
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
|
||||
return CreateTiktokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
|
||||
|
||||
case ModelEncoding.GPT2:
|
||||
specialTokens = new Dictionary<string, int> { { EndOfText, 50256 }, };
|
||||
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
|
||||
return CreateTiktokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
|
||||
|
||||
default:
|
||||
throw new NotSupportedException($"The encoder '{modelEncoding}' is not supported.");
|
||||
}
|
||||
}
|
||||
|
||||
private static readonly ConcurrentDictionary<string, (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase);
|
||||
private static readonly ConcurrentDictionary<string, (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase);
|
||||
|
||||
/// <summary>
|
||||
/// Create tokenizer based on regex pattern, BPE rank file and special tokens
|
||||
|
@ -768,7 +844,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="normalizer">To normalize the text before tokenization</param>
|
||||
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
|
||||
/// <returns>The tokenizer</returns>
|
||||
private static async Task<Tokenizer> CreateTikTokenTokenizerAsync(
|
||||
private static async Task<Tokenizer> CreateTiktokenTokenizerAsync(
|
||||
Regex regex,
|
||||
string mergeableRanksFileUrl,
|
||||
Dictionary<string, int> specialTokens,
|
||||
|
@ -784,17 +860,17 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
}
|
||||
|
||||
if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) cache))
|
||||
if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) cache))
|
||||
{
|
||||
using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false))
|
||||
{
|
||||
cache = await LoadTikTokenBpeAsync(stream, useAsync: true, cancellationToken).ConfigureAwait(false);
|
||||
cache = await LoadTiktokenBpeAsync(stream, useAsync: true, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
_tiktokenCache.TryAdd(mergeableRanksFileUrl, cache);
|
||||
}
|
||||
|
||||
return new Tokenizer(new Tiktoken(cache.encoder, cache.decoder, cache.vocab, specialTokens), new TikTokenPreTokenizer(regex, specialTokens), normalizer);
|
||||
return new Tokenizer(new Tiktoken(cache.encoder, cache.decoder, cache.vocab, specialTokens), new TiktokenPreTokenizer(regex, specialTokens), normalizer);
|
||||
}
|
||||
|
||||
internal static Tokenizer CreateTokenizerForModel(
|
||||
|
@ -818,17 +894,17 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
|
||||
if (!_tiktokenCache.TryGetValue(tiktokenConfiguration.Url,
|
||||
out (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) cache))
|
||||
out (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, (int I, string Token)> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) cache))
|
||||
{
|
||||
using Stream stream = Helpers.GetStream(_httpClient, tiktokenConfiguration.Url);
|
||||
cache = LoadTikTokenBpeAsync(stream, useAsync: false).GetAwaiter().GetResult();
|
||||
cache = LoadTiktokenBpeAsync(stream, useAsync: false).GetAwaiter().GetResult();
|
||||
|
||||
_tiktokenCache.TryAdd(tiktokenConfiguration.Url, cache);
|
||||
}
|
||||
|
||||
return new Tokenizer(
|
||||
new Tiktoken(cache.encoder, cache.decoder, cache.vocab, tiktokenConfiguration.SpecialTokens, LruCache<int[]>.DefaultCacheSize),
|
||||
new TikTokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
|
||||
new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
|
||||
normalizer);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -202,6 +202,64 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
}
|
||||
|
||||
public int PopulateIdsUpToMax(IList<int> accumulatedIds, int maxTokens, out int textLength)
|
||||
{
|
||||
textLength = 0;
|
||||
|
||||
int count = Math.Min(SymbolsCount, maxTokens);
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
accumulatedIds.Add(_symbols[i].C);
|
||||
textLength += _symbols[i].Len;
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
public int PopulateIdsUpToMaxFromEnd(IList<int> accumulatedIds, int maxTokens, int fullTextLength, out int textIndex)
|
||||
{
|
||||
textIndex = fullTextLength;
|
||||
|
||||
int count = Math.Min(SymbolsCount, maxTokens);
|
||||
|
||||
for (int i = SymbolsCount - count; i < SymbolsCount; i++)
|
||||
{
|
||||
accumulatedIds.Add(_symbols[i].C);
|
||||
textIndex -= _symbols[i].Len;
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
public int CountIdsUpToMax(int maxTokens, out int textLength)
|
||||
{
|
||||
textLength = 0;
|
||||
|
||||
int count = Math.Min(SymbolsCount, maxTokens);
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
textLength += _symbols[i].Len;
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
public int CountIdsUpToMaxFromEnd(int maxTokens, int fullTextLength, out int textIndex)
|
||||
{
|
||||
textIndex = fullTextLength;
|
||||
|
||||
int count = Math.Min(SymbolsCount, maxTokens);
|
||||
|
||||
for (int i = SymbolsCount - count; i < SymbolsCount; i++)
|
||||
{
|
||||
textIndex -= _symbols[i].Len;
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
public Vec<int> GetChars()
|
||||
{
|
||||
Vec<int> chars = new Vec<int>();
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
## About
|
||||
|
||||
Microsoft.ML.Tokenizers supports various the implmentation of the tokenization used in the NLP transforms.
|
||||
Microsoft.ML.Tokenizers supports various the implementation of the tokenization used in the NLP transforms.
|
||||
|
||||
## Key Features
|
||||
|
||||
* Extensisble tokenizer architecture that allows for specialization of Normalizer, PreTokenizer, Model/Encoder, Decoder
|
||||
* Extensible tokenizer architecture that allows for specialization of Normalizer, PreTokenizer, Model/Encoder, Decoder
|
||||
* BPE - Byte pair encoding model
|
||||
* English Roberta model
|
||||
* Tiktoken model
|
||||
|
@ -21,8 +21,8 @@ using System.IO;
|
|||
// Using Tiktoken Tokenizer
|
||||
//
|
||||
|
||||
// initialize the tokenizer for `gpt-4` model, downloading data files
|
||||
Tokenizer tokenizer = await Tokenizer.CreateTiktokenForModelAsync("gpt-4");
|
||||
// initialize the tokenizer for `gpt-4` model
|
||||
Tokenizer tokenizer = Tokenizer.CreateTiktokenForModel("gpt-4");
|
||||
|
||||
string source = "Text tokenization is the process of splitting a string into a list of tokens.";
|
||||
|
||||
|
@ -68,7 +68,7 @@ The main types provided by this library are:
|
|||
* `Microsoft.ML.Tokenizers.Tokenizer`
|
||||
* `Microsoft.ML.Tokenizers.Bpe`
|
||||
* `Microsoft.ML.Tokenizers.EnglishRoberta`
|
||||
* `Microsoft.ML.Tokenizers.TikToken`
|
||||
* `Microsoft.ML.Tokenizers.Tiktoken`
|
||||
* `Microsoft.ML.Tokenizers.TokenizerDecoder`
|
||||
* `Microsoft.ML.Tokenizers.Normalizer`
|
||||
* `Microsoft.ML.Tokenizers.PreTokenizer`
|
||||
|
|
|
@ -40,34 +40,25 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// </summary>
|
||||
/// <param name="token">The token string</param>
|
||||
/// <param name="offset">The offset mapping to the original string</param>
|
||||
/// <param name="isSpecialToken">Indicates whether the token is a special token</param>
|
||||
public Split(string token, (int Index, int Length) offset, bool isSpecialToken = false)
|
||||
public Split(string token, (int Index, int Length) offset)
|
||||
{
|
||||
_tokenString = token;
|
||||
Offset = offset;
|
||||
IsSpecialToken = isSpecialToken;
|
||||
}
|
||||
|
||||
internal Split(string originalString, string? token, (int Index, int Length) offset, bool isSpecialToken = false)
|
||||
internal Split(string originalString, string? token, (int Index, int Length) offset)
|
||||
{
|
||||
_originalString = originalString;
|
||||
_tokenString = token;
|
||||
Offset = offset;
|
||||
IsSpecialToken = isSpecialToken;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets if the current Split is a special token.
|
||||
/// </summary>
|
||||
public bool IsSpecialToken { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Indicates whether the current Split object is equal to another Split object.
|
||||
/// </summary>
|
||||
/// <param name="other">The Split object to compare with the current object.</param>
|
||||
public bool Equals(Split other) =>
|
||||
(_originalString == other._originalString || TokenString == other.TokenString) &&
|
||||
IsSpecialToken == other.IsSpecialToken &&
|
||||
Offset.Index == other.Offset.Index &&
|
||||
Offset.Length == other.Offset.Length;
|
||||
}
|
||||
|
@ -82,9 +73,8 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string.
|
||||
/// </summary>
|
||||
/// <param name="text">The string to split into tokens.</param>
|
||||
/// <param name="considerSpecialTokens">Indicates whether to consider the special tokens.</param>
|
||||
/// <returns>The list of the splits containing the tokens and the token's offsets to the original string.</returns>
|
||||
public abstract IEnumerable<Split> PreTokenize(string text, bool considerSpecialTokens = true);
|
||||
public abstract IEnumerable<Split> PreTokenize(string text);
|
||||
|
||||
internal static IEnumerable<Split> SplitText(string text, Regex regex)
|
||||
{
|
||||
|
|
|
@ -21,9 +21,8 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string.
|
||||
/// </summary>
|
||||
/// <param name="text">The string to split into tokens.</param>
|
||||
/// <param name="considerSpecialTokens">Indicates whether to keep the special tokens.</param>
|
||||
/// <returns>The list of the splits containing the tokens and the token's offsets to the original string.</returns>
|
||||
public override IEnumerable<Split> PreTokenize(string text, bool considerSpecialTokens = true)
|
||||
public override IEnumerable<Split> PreTokenize(string text)
|
||||
{
|
||||
if (string.IsNullOrEmpty(text))
|
||||
{
|
||||
|
|
|
@ -21,9 +21,8 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// Return the whole text as one chunk.
|
||||
/// </summary>
|
||||
/// <param name="text">The string to split into tokens.</param>
|
||||
/// <param name="considerSpecialTokens">Indicates whether to keep the special tokens.</param>
|
||||
/// <returns>The original string as one chunk.</returns>
|
||||
public override IEnumerable<Split> PreTokenize(string text, bool considerSpecialTokens = true)
|
||||
public override IEnumerable<Split> PreTokenize(string text)
|
||||
{
|
||||
if (string.IsNullOrEmpty(text))
|
||||
{
|
||||
|
|
|
@ -12,18 +12,18 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <summary>
|
||||
/// The pre-tokenizer for Tiktoken tokenizer.
|
||||
/// </summary>
|
||||
public sealed class TikTokenPreTokenizer : PreTokenizer
|
||||
public sealed class TiktokenPreTokenizer : PreTokenizer
|
||||
{
|
||||
private readonly Regex? _specialTokensRegex;
|
||||
private readonly Regex _regex;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="TikTokenPreTokenizer"/> class.
|
||||
/// Initializes a new instance of the <see cref="TiktokenPreTokenizer"/> class.
|
||||
/// </summary>
|
||||
/// <param name="regex">The regex to use for splitting the text into smaller tokens in the pre-tokenization process.</param>
|
||||
/// <param name="specialTokensEncoder">Encode the special token to Id.</param>
|
||||
/// <param name="specialTokensEncoder">The dictionary containing the special tokens and their corresponding ids.</param>
|
||||
/// <exception cref="ArgumentNullException">When regex is null</exception>
|
||||
public TikTokenPreTokenizer(Regex regex, IReadOnlyDictionary<string, int>? specialTokensEncoder)
|
||||
public TiktokenPreTokenizer(Regex regex, IReadOnlyDictionary<string, int>? specialTokensEncoder)
|
||||
{
|
||||
if (regex is null)
|
||||
{
|
||||
|
@ -42,16 +42,15 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string.
|
||||
/// </summary>
|
||||
/// <param name="text">The string to split into tokens.</param>
|
||||
/// <param name="considerSpecialTokens">Indicates whether to consider the special tokens.</param>
|
||||
/// <returns>The list of the splits containing the tokens and the token's offsets to the original string.</returns>
|
||||
public override IEnumerable<Split> PreTokenize(string text, bool considerSpecialTokens = true)
|
||||
public override IEnumerable<Split> PreTokenize(string text)
|
||||
{
|
||||
if (string.IsNullOrEmpty(text))
|
||||
{
|
||||
return Array.Empty<Split>();
|
||||
}
|
||||
|
||||
return SplitText(text, _regex, considerSpecialTokens ? _specialTokensRegex : null);
|
||||
return SplitText(text, _regex, _specialTokensRegex);
|
||||
|
||||
static IEnumerable<Split> SplitText(string text, Regex regex, Regex? specialTokensRegex)
|
||||
{
|
||||
|
@ -74,7 +73,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
beginning = match.Offset + match.Length;
|
||||
}
|
||||
|
||||
yield return new Split(text, null, (specialMatch.Offset, specialMatch.Length), isSpecialToken: true);
|
||||
yield return new Split(text, null, (specialMatch.Offset, specialMatch.Length));
|
||||
beginning = specialMatch.Offset + specialMatch.Length;
|
||||
}
|
||||
}
|
|
@ -32,9 +32,8 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string.
|
||||
/// </summary>
|
||||
/// <param name="text">The string to split into tokens.</param>
|
||||
/// <param name="considerSpecialTokens">Indicates whether to consider the special tokens.</param>
|
||||
/// <returns>The list of the splits containing the tokens and the token's offsets to the original string.</returns>
|
||||
public override IEnumerable<Split> PreTokenize(string text, bool considerSpecialTokens = false)
|
||||
public override IEnumerable<Split> PreTokenize(string text)
|
||||
{
|
||||
if (string.IsNullOrEmpty(text))
|
||||
{
|
||||
|
|
|
@ -58,9 +58,8 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// Encodes input text to object has the tokens list, tokens Ids, tokens offset mapping.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
|
||||
/// <returns>The tokenization result includes the tokens list, tokens Ids, tokens offset mapping.</returns>
|
||||
public EncodingResult Encode(string text, bool considerSpecialTokens = true)
|
||||
public EncodingResult Encode(string text)
|
||||
{
|
||||
if (text is null)
|
||||
{
|
||||
|
@ -70,11 +69,11 @@ namespace Microsoft.ML.Tokenizers
|
|||
string normalized = Normalizer is null ? text : Normalizer.Normalize(text);
|
||||
bool offsetsMappedToOriginal = true;
|
||||
|
||||
EncodingResult encoding = new(text, normalized, PreTokenizer.PreTokenize(normalized, considerSpecialTokens), offsetsMappedToOriginal);
|
||||
EncodingResult encoding = new(text, normalized, PreTokenizer.PreTokenize(normalized), offsetsMappedToOriginal);
|
||||
|
||||
foreach (Split split in encoding.Splits)
|
||||
{
|
||||
IReadOnlyList<Token> tokens = Model.Encode(split.TokenString, split.IsSpecialToken);
|
||||
IReadOnlyList<Token> tokens = Model.Encode(split.TokenString.AsSpan());
|
||||
foreach (Token token in tokens)
|
||||
{
|
||||
token.Offset = (token.Offset.Index + split.Offset.Index, token.Offset.Length);
|
||||
|
@ -90,9 +89,8 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// Encodes input text to tokens Ids.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
|
||||
/// <returns>The tokenization result includes the tokens list, tokens Ids, tokens offset mapping.</returns>
|
||||
public IReadOnlyList<int> EncodeToIds(string text, bool considerSpecialTokens = true)
|
||||
public IReadOnlyList<int> EncodeToIds(string text)
|
||||
{
|
||||
if (text is null)
|
||||
{
|
||||
|
@ -102,9 +100,9 @@ namespace Microsoft.ML.Tokenizers
|
|||
string normalized = Normalizer is not null ? Normalizer.Normalize(text) : text;
|
||||
List<int> idsList = new();
|
||||
|
||||
foreach (Split split in PreTokenizer.PreTokenize(normalized, considerSpecialTokens))
|
||||
foreach (Split split in PreTokenizer.PreTokenize(normalized))
|
||||
{
|
||||
Model.EncodeToIds(split.TokenSpan, split.IsSpecialToken, idsList);
|
||||
Model.EncodeToIds(split.TokenSpan, idsList, out _);
|
||||
}
|
||||
|
||||
return idsList;
|
||||
|
@ -114,11 +112,10 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// Get the number of tokens that the input text will be encoded to.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
|
||||
/// <returns>The number of tokens Ids that the input text will be encoded to.</returns>
|
||||
/// <exception cref="ArgumentNullException">The input text is null.</exception>
|
||||
/// <exception cref="ArgumentException">Unable to encode the text.</exception>
|
||||
public int CountTokens(string text, bool considerSpecialTokens = true)
|
||||
public int CountTokens(string text)
|
||||
{
|
||||
if (text is null)
|
||||
{
|
||||
|
@ -128,9 +125,9 @@ namespace Microsoft.ML.Tokenizers
|
|||
string normalized = Normalizer is not null ? Normalizer.Normalize(text) : text;
|
||||
|
||||
int idsCount = 0;
|
||||
foreach (Split split in PreTokenizer.PreTokenize(normalized, considerSpecialTokens))
|
||||
foreach (Split split in PreTokenizer.PreTokenize(normalized))
|
||||
{
|
||||
idsCount += Model.CountTokens(split.TokenSpan, split.IsSpecialToken);
|
||||
idsCount += Model.CountTokens(split.TokenSpan, out _);
|
||||
}
|
||||
|
||||
return idsCount;
|
||||
|
@ -143,15 +140,14 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||
/// <param name="processedText">If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text.</param>
|
||||
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
|
||||
/// <returns>
|
||||
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, if all tokens fit, the result will be length of the <paramref name="processedText"/>.
|
||||
/// </returns>
|
||||
/// <exception cref="ArgumentNullException">The input text is null.</exception>
|
||||
/// <exception cref="ArgumentOutOfRangeException">The maximum token count must be greater than 0.</exception>
|
||||
public int IndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount, bool considerSpecialTokens = true)
|
||||
=> IndexOf(text, maxTokenCount, fromStart: true, considerSpecialTokens, out processedText, out tokenCount);
|
||||
public int IndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount)
|
||||
=> IndexOf(text, maxTokenCount, out processedText, out tokenCount);
|
||||
|
||||
/// <summary>
|
||||
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
|
||||
|
@ -160,7 +156,6 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||
/// <param name="processedText">If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text.</param>
|
||||
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
|
||||
/// <returns>
|
||||
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="processedText"/>; conversely, if all tokens fit, the result will be 0.
|
||||
|
@ -170,10 +165,10 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <remarks>
|
||||
/// If the whole text can be encoded within the token limit, the returned index will be 0.
|
||||
/// </remarks>
|
||||
public int LastIndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount, bool considerSpecialTokens = true)
|
||||
=> IndexOf(text, maxTokenCount, fromStart: false, considerSpecialTokens, out processedText, out tokenCount);
|
||||
public int LastIndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount)
|
||||
=> LastIndexOf(text, maxTokenCount, out processedText, out tokenCount);
|
||||
|
||||
private int IndexOf(string text, int maxTokenCount, bool fromStart, bool considerSpecialTokens, out string processedText, out int tokenCount)
|
||||
private int IndexOf(string text, int maxTokenCount, out string processedText, out int tokenCount)
|
||||
{
|
||||
if (text is null)
|
||||
{
|
||||
|
@ -188,36 +183,60 @@ namespace Microsoft.ML.Tokenizers
|
|||
processedText = Normalizer is not null ? Normalizer.Normalize(text) : text;
|
||||
tokenCount = 0;
|
||||
|
||||
IEnumerable<Split> splits = PreTokenizer.PreTokenize(processedText, considerSpecialTokens);
|
||||
foreach (Split split in (fromStart ? splits : splits.Reverse()))
|
||||
IEnumerable<Split> splits = PreTokenizer.PreTokenize(processedText);
|
||||
foreach (Split split in splits)
|
||||
{
|
||||
int count = Model.CountTokens(split.TokenSpan, split.IsSpecialToken);
|
||||
if (tokenCount > maxTokenCount - count)
|
||||
tokenCount += Model.CountTokens(split.TokenSpan, out int textLength, maxTokenCount - tokenCount);
|
||||
if (textLength < split.Offset.Length || tokenCount >= maxTokenCount)
|
||||
{
|
||||
return fromStart ? split.Offset.Index : split.Offset.Index + split.Offset.Length;
|
||||
return split.Offset.Index + textLength;
|
||||
}
|
||||
|
||||
tokenCount += count;
|
||||
}
|
||||
|
||||
return fromStart ? processedText.Length : 0;
|
||||
return processedText.Length;
|
||||
}
|
||||
|
||||
private int LastIndexOf(string text, int maxTokenCount, out string processedText, out int tokenCount)
|
||||
{
|
||||
if (text is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(text));
|
||||
}
|
||||
|
||||
if (maxTokenCount <= 0)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
|
||||
}
|
||||
|
||||
processedText = Normalizer is not null ? Normalizer.Normalize(text) : text;
|
||||
tokenCount = 0;
|
||||
|
||||
IEnumerable<Split> splits = PreTokenizer.PreTokenize(processedText);
|
||||
foreach (Split split in splits.Reverse())
|
||||
{
|
||||
tokenCount += Model.CountTokensFromEnd(split.TokenSpan, out int textIndex, maxTokenCount - tokenCount);
|
||||
if (textIndex > 0 || tokenCount >= maxTokenCount)
|
||||
{
|
||||
return split.Offset.Index + textIndex;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Decodes the Id to the mapped token.
|
||||
/// </summary>
|
||||
/// <param name="id">The id to map to the token.</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the decoding.</param>
|
||||
/// <returns>The decoded string or null if there is no token mapped to the input id.</returns>
|
||||
public string? Decode(int id, bool considerSpecialTokens = true) => Model.MapIdToToken(id, considerSpecialTokens);
|
||||
public string? Decode(int id) => Model.MapIdToToken(id);
|
||||
|
||||
/// <summary>
|
||||
/// Decode the given ids, back to a String.
|
||||
/// </summary>
|
||||
/// <param name="ids">The list of ids that we want to decode.</param>
|
||||
/// <param name="considerSpecialTokens">Whether the special tokens should be kept in the decoded string.</param>
|
||||
/// <returns>The decoded string.</returns>
|
||||
public string? Decode(IEnumerable<int> ids, bool considerSpecialTokens = true) => Model.Decode(ids, Decoder, considerSpecialTokens);
|
||||
public string? Decode(IEnumerable<int> ids) => Model.Decode(ids, Decoder);
|
||||
|
||||
/// <summary>
|
||||
/// Create a Tiktoken tokenizer based on model name and vocab file.
|
||||
|
@ -252,7 +271,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
return new Tokenizer(
|
||||
new Tiktoken(vocabStream, tiktokenConfiguration.SpecialTokens, cacheSize),
|
||||
new TikTokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
|
||||
new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
|
||||
normalizer);
|
||||
}
|
||||
|
||||
|
@ -291,7 +310,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
return new Tokenizer(
|
||||
await Tiktoken.CreateAsync(vocabStream, tiktokenConfiguration.SpecialTokens, cacheSize, cancellationToken).ConfigureAwait(false),
|
||||
new TikTokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
|
||||
new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
|
||||
normalizer);
|
||||
}
|
||||
|
||||
|
|
|
@ -193,13 +193,13 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
if (robertaModel.FilterUnsupportedChars)
|
||||
{
|
||||
string[]? filteredToken = p[5] as string[];
|
||||
Assert.Equal(filteredToken![i], tokenizer.Model.MapIdToToken(encoding.Ids[i], considerSpecialTokens: false));
|
||||
Assert.Equal(filteredToken![i], tokenizer.Model.MapIdToToken(encoding.Ids[i]));
|
||||
}
|
||||
else
|
||||
{
|
||||
Assert.Equal(encoding.Tokens[i], tokenizer.Model.MapIdToToken(encoding.Ids[i], considerSpecialTokens: false));
|
||||
Assert.Equal(encoding.Tokens[i], tokenizer.Model.MapIdToToken(encoding.Ids[i]));
|
||||
string[]? unfilteredToken = p[2] as string[];
|
||||
Assert.Equal(unfilteredToken![i], tokenizer.Model.MapIdToToken(encoding.Ids[i], considerSpecialTokens: false));
|
||||
Assert.Equal(unfilteredToken![i], tokenizer.Model.MapIdToToken(encoding.Ids[i]));
|
||||
}
|
||||
|
||||
Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i].AsSpan()));
|
||||
|
|
|
@ -208,32 +208,32 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
|
||||
bool isEmptyInput = string.IsNullOrEmpty(input);
|
||||
|
||||
IReadOnlyList<Token> bpeTokens = bpe.Encode(normalizedInput, addBeginOfSentence: false, addEndOfSentence: false);
|
||||
IReadOnlyList<Token> bpeTokens = bpe.Encode(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: false);
|
||||
Assert.Equal(ids.Skip(1), bpeTokens.Select(token => token.Id));
|
||||
Assert.Equal(tokens.Skip(1), bpeTokens.Select(token => token.Value));
|
||||
Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id)));
|
||||
List<int> encodedIds = new();
|
||||
bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: false, accumulatedIds: encodedIds);
|
||||
bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: false, accumulatedIds: encodedIds, out _);
|
||||
Assert.Equal(ids.Skip(1), encodedIds);
|
||||
Assert.Equal(isEmptyInput ? 0 : ids.Length - 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: false));
|
||||
Assert.Equal(isEmptyInput ? 0 : ids.Length - 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: false, out _));
|
||||
|
||||
bpeTokens = bpe.Encode(normalizedInput, addBeginOfSentence: false, addEndOfSentence: true);
|
||||
bpeTokens = bpe.Encode(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: true);
|
||||
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Skip(1).Concat(new[] { bpe.EndOfSentenceId }), bpeTokens.Select(token => token.Id));
|
||||
Assert.Equal(isEmptyInput ? Array.Empty<string>() : tokens.Skip(1).Concat(new[] { bpe.EndOfSentenceToken }), bpeTokens.Select(token => token.Value));
|
||||
Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id)));
|
||||
encodedIds.Clear();
|
||||
bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: true, accumulatedIds: encodedIds);
|
||||
bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: true, accumulatedIds: encodedIds, out _);
|
||||
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Skip(1).Concat(new[] { bpe.EndOfSentenceId }), encodedIds);
|
||||
Assert.Equal(isEmptyInput ? 0 : ids.Length, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: true));
|
||||
Assert.Equal(isEmptyInput ? 0 : ids.Length, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: true, out _));
|
||||
|
||||
bpeTokens = bpe.Encode(normalizedInput, addBeginOfSentence: true, addEndOfSentence: true);
|
||||
bpeTokens = bpe.Encode(normalizedInput.AsSpan(), addBeginOfSentence: true, addEndOfSentence: true);
|
||||
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Concat(new[] { bpe.EndOfSentenceId }), bpeTokens.Select(token => token.Id));
|
||||
Assert.Equal(isEmptyInput ? Array.Empty<string>() : tokens.Concat(new[] { bpe.EndOfSentenceToken }), bpeTokens.Select(token => token.Value));
|
||||
Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id)));
|
||||
encodedIds.Clear();
|
||||
bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: true, addEndOfSentence: true, accumulatedIds: encodedIds);
|
||||
bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: true, addEndOfSentence: true, accumulatedIds: encodedIds, out _);
|
||||
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Concat(new[] { bpe.EndOfSentenceId }), encodedIds);
|
||||
Assert.Equal(isEmptyInput ? 0 : ids.Length + 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: true, addEndOfSentence: true));
|
||||
Assert.Equal(isEmptyInput ? 0 : ids.Length + 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: true, addEndOfSentence: true, out _));
|
||||
}
|
||||
|
||||
public static IEnumerable<object[]> LlamaTokenizersListData()
|
||||
|
@ -250,7 +250,6 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
Assert.Throws<ArgumentNullException>(() => llamaTokenizer.EncodeToIds(null!));
|
||||
Assert.Throws<ArgumentNullException>(() => llamaTokenizer.CountTokens(null!));
|
||||
Assert.Throws<ArgumentNullException>(() => llamaTokenizer.Decode(null!));
|
||||
Assert.Throws<ArgumentNullException>(() => (llamaTokenizer.Model as SentencePieceBpe)!.Encode(null!));
|
||||
}
|
||||
|
||||
[Theory]
|
||||
|
@ -280,6 +279,8 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
Assert.True(bpe.AddDummyPrefix);
|
||||
Assert.True(bpe.EscapeWhiteSpaces);
|
||||
Assert.False(bpe.TreatWhitespaceAsSuffix);
|
||||
|
||||
TokenizerTests.TestTokenLimits(llamaTokenizer);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
|
|
|
@ -68,7 +68,7 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
|
||||
public class SpacePreTokenizer : PreTokenizer
|
||||
{
|
||||
public override IEnumerable<Split> PreTokenize(string text, bool considerSpecialTokens = true)
|
||||
public override IEnumerable<Split> PreTokenize(string text)
|
||||
{
|
||||
List<Split> splits = new();
|
||||
if (string.IsNullOrEmpty(text))
|
||||
|
|
|
@ -38,7 +38,7 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
TestGPT4TokenizationEncoding(GPT4);
|
||||
|
||||
Assert.True(GPT4.Model is Tiktoken);
|
||||
IReadOnlyDictionary<string, int>? specialTokensEncoder = (GPT4.Model as Tiktoken)!.SpecialTokensEncoder;
|
||||
IReadOnlyDictionary<string, int>? specialTokensEncoder = (GPT4.Model as Tiktoken)!.SpecialTokens;
|
||||
|
||||
string tokenizerDataFileName = Utils.CreateTemporaryFile("tiktoken");
|
||||
await Utils.DownloadFile(@"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken", tokenizerDataFileName);
|
||||
|
@ -122,9 +122,9 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
private void TestGPT4Tokenizer(Tokenizer gpt4Tokenizer)
|
||||
{
|
||||
string text = ReadAndSanitizeFile("./Data/lib.rs.txt");
|
||||
IReadOnlyList<int> encoded = gpt4Tokenizer.EncodeToIds(text, considerSpecialTokens: false);
|
||||
IReadOnlyList<int> encoded = gpt4Tokenizer.EncodeToIds(text);
|
||||
Assert.Equal(5584, encoded.Count);
|
||||
int idsCount = gpt4Tokenizer.CountTokens(text, considerSpecialTokens: false);
|
||||
int idsCount = gpt4Tokenizer.CountTokens(text);
|
||||
Assert.Equal(encoded.Count, idsCount);
|
||||
|
||||
using (Stream stream = File.OpenRead("./Data/tokens.json"))
|
||||
|
|
|
@ -29,14 +29,26 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
int index1 = tokenizer.IndexOfTokenCount(input, maxTokenCount: i, out string processedText1, out int tokenCount1);
|
||||
int index2 = tokenizer.LastIndexOfTokenCount(input, maxTokenCount: i, out string processedText2, out int tokenCount2);
|
||||
|
||||
|
||||
IReadOnlyList<int>? prefixIds = null;
|
||||
IReadOnlyList<int>? suffixIds = null;
|
||||
|
||||
if (tokenCount1 > 0)
|
||||
// It is possible with Llama tokenizer to produce start of sentence token <s> token only if we have the maxTokenCount is 1.
|
||||
// In this case, we'll get index1 equal to zero and nothing really will need to be tested.
|
||||
if (tokenCount1 > 0 && index1 > 0)
|
||||
{
|
||||
string prefixString = processedText1.Substring(0, index1);
|
||||
prefixIds = tokenizer.EncodeToIds(prefixString);
|
||||
|
||||
if (tokenizer.Model is SentencePieceBpe)
|
||||
{
|
||||
// SentencePieceBpe model normalize the text and insert more characters.
|
||||
// We call the model directly to bypass the normalization step
|
||||
prefixIds = new List<int>();
|
||||
tokenizer.Model.EncodeToIds(prefixString.AsSpan(), (prefixIds as IList<int>)!, out _);
|
||||
}
|
||||
else
|
||||
{
|
||||
prefixIds = tokenizer.EncodeToIds(prefixString);
|
||||
}
|
||||
Assert.Equal(tokenCount1, prefixIds.Count);
|
||||
Assert.Equal(prefixIds, fullIdsList.Take(prefixIds.Count));
|
||||
}
|
||||
|
@ -44,15 +56,45 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
if (tokenCount2 > 0)
|
||||
{
|
||||
string suffixString = processedText2.Substring(index2);
|
||||
suffixIds = tokenizer.EncodeToIds(suffixString);
|
||||
|
||||
if (tokenizer.Model is SentencePieceBpe)
|
||||
{
|
||||
// SentencePieceBpe model normalize the text and insert more characters.
|
||||
// We call the model directly to bypass the normalization step
|
||||
suffixIds = new List<int>();
|
||||
tokenizer.Model.EncodeToIds(suffixString.AsSpan(), (suffixIds as IList<int>)!, out _);
|
||||
if (i < fullIdsList.Count)
|
||||
{
|
||||
suffixIds = suffixIds.Skip(1).ToList(); // Skip the start of sentence token <s>
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
suffixIds = tokenizer.EncodeToIds(suffixString);
|
||||
}
|
||||
|
||||
Assert.Equal(tokenCount2, suffixIds.Count);
|
||||
Assert.Equal(suffixIds, fullIdsList.Skip(fullIdsList.Count - suffixIds.Count));
|
||||
}
|
||||
|
||||
if (i == fullIdsList.Count)
|
||||
{
|
||||
Assert.Equal(processedText1.Length, index1);
|
||||
Assert.Equal(0, index2);
|
||||
if (index1 != processedText1.Length)
|
||||
{
|
||||
// It's possible that the remaining text on the left doesn't produce any tokens, as in the case of BPE,
|
||||
// where the pre-tokenizer removes spaces and the left text consists entirely of spaces.
|
||||
Assert.True(index1 < processedText1.Length);
|
||||
Assert.Equal(0, tokenizer.CountTokens(processedText1.Substring(index1)));
|
||||
}
|
||||
|
||||
if (index2 != 0)
|
||||
{
|
||||
// It's possible that the remaining text on the right doesn't produce any tokens, as in the case of BPE,
|
||||
// where the pre-tokenizer removes spaces and the left text consists entirely of spaces.
|
||||
Assert.True(index2 > 0);
|
||||
Assert.Equal(0, tokenizer.CountTokens(processedText1.Substring(0, index2)));
|
||||
}
|
||||
|
||||
Assert.Equal(fullIdsList, prefixIds);
|
||||
Assert.Equal(fullIdsList, suffixIds);
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче