* ToknizersAPIsUpdate

* Address the feedback
This commit is contained in:
Tarek Mahmoud Sayed 2024-07-15 07:48:58 -07:00 коммит произвёл GitHub
Родитель f5abe6a086
Коммит 579fe03ab7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
38 изменённых файлов: 4900 добавлений и 2886 удалений

Просмотреть файл

@ -8,13 +8,13 @@ This guide provides general guidance on how to migrate from various tokenizer li
| Microsoft.DeepDev.TokenizerLib | Microsoft.ML.Tokenizers
| --- | --- |
| [TikTokenizer](https://github.com/microsoft/Tokenizer/blob/2c9ba5d343de52eb27521afef7c0c2f0f76c9c52/Tokenizer_C%23/TokenizerLib/TikTokenizer.cs#L20) | [Tokenizer](https://github.com/dotnet/machinelearning/blob/4d5317e8090e158dc7c3bc6c435926ccf1cbd8e2/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs#L41) |
| [ITokenizer](https://github.com/microsoft/Tokenizer/blob/2c9ba5d343de52eb27521afef7c0c2f0f76c9c52/Tokenizer_C%23/TokenizerLib/ITokenizer.cs#L7) | [Tokenizer](https://github.com/dotnet/machinelearning/blob/4d5317e8090e158dc7c3bc6c435926ccf1cbd8e2/src/Microsoft.ML.Tokenizers/Tokenizer.cs#L29) |
| [TokenizerBuilder](https://github.com/microsoft/Tokenizer/blob/2c9ba5d343de52eb27521afef7c0c2f0f76c9c52/Tokenizer_C%23/TokenizerLib/TokenizerBuilder.cs#L14) | [Tokenizer.CreateTiktokenForModel](https://github.com/dotnet/machinelearning/blob/4d5317e8090e158dc7c3bc6c435926ccf1cbd8e2/src/Microsoft.ML.Tokenizers/Tokenizer.cs#L324) embedded<br> [Tokenizer.CreateTiktokenForModel(Async/Stream)](https://github.com/dotnet/machinelearning/blob/4d5317e8090e158dc7c3bc6c435926ccf1cbd8e2/src/Microsoft.ML.Tokenizers/Tokenizer.cs#L241-L315) user provided file stream |
| [TikTokenizer](https://github.com/microsoft/Tokenizer/blob/2c9ba5d343de52eb27521afef7c0c2f0f76c9c52/Tokenizer_C%23/TokenizerLib/TikTokenizer.cs#L20) | Tokenizer |
| [ITokenizer](https://github.com/microsoft/Tokenizer/blob/2c9ba5d343de52eb27521afef7c0c2f0f76c9c52/Tokenizer_C%23/TokenizerLib/ITokenizer.cs#L7) | Tokenizer |
| [TokenizerBuilder](https://github.com/microsoft/Tokenizer/blob/2c9ba5d343de52eb27521afef7c0c2f0f76c9c52/Tokenizer_C%23/TokenizerLib/TokenizerBuilder.cs#L14) | TiktokenTokenizer.CreateForModel <br> TiktokenTokenizer.CreateForModel(Async/Stream) user provided file stream |
### General Guidance
- To avoid embedding the tokenizer's vocabulary files in the code assembly or downloading them at runtime when using one of the standard Tiktoken vocabulary files, utilize the [`CreateTiktokenForModel`](https://github.com/dotnet/machinelearning/blob/4d5317e8090e158dc7c3bc6c435926ccf1cbd8e2/src/Microsoft.ML.Tokenizers/Tokenizer.cs#L324) function. The [table](https://github.com/dotnet/machinelearning/blob/4d5317e8090e158dc7c3bc6c435926ccf1cbd8e2/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs#L683-L734) lists the mapping of model names to the corresponding vocabulary files used with each model. This table offers clarity regarding the vocabulary file linked with each model, alleviating users from the concern of carrying or downloading such vocabulary files if they utilize one of the models listed.
- Avoid hard-coding tiktoken regexes and special tokens. Instead use the appropriate Tiktoken.`CreateTiktokenForModel/Async` method to create the tokenizer using the model name, or a provided stream.
- Avoid doing encoding if you need the token count or encoded Ids. Instead use `Tokenizer.CountTokens` for getting the token count and `Tokenizer.EncodeToIds` for getting the encode ids.
- Avoid doing encoding if all you need is to truncate to a token budget. Instead use `Tokenizer.IndexOfCount` or `LastIndexOfCount` to find the index to truncate from the start or end of a string, respectively.
- To avoid embedding the tokenizer's vocabulary files in the code assembly or downloading them at runtime when using one of the standard Tiktoken vocabulary files, utilize the `TiktokenTokenizer.CreateForModel` function. The [table](https://github.com/dotnet/machinelearning/blob/4d5317e8090e158dc7c3bc6c435926ccf1cbd8e2/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs#L683-L734) lists the mapping of model names to the corresponding vocabulary files used with each model. This table offers clarity regarding the vocabulary file linked with each model, alleviating users from the concern of carrying or downloading such vocabulary files if they utilize one of the models listed.
- Avoid hard-coding tiktoken regexes and special tokens. Instead use the appropriate Tiktoken.`TiktokenTokenizer.CreateForModel/Async` method to create the tokenizer using the model name, or a provided stream.
- Avoid doing encoding if you need the token count or encoded Ids. Instead use `TiktokenTokenizer.CountTokens` for getting the token count and `TiktokenTokenizer.EncodeToIds` for getting the encode ids.
- Avoid doing encoding if all you need is to truncate to a token budget. Instead use `TiktokenTokenizer.GetIndexByTokenCount` or `GetIndexByTokenCountFromEnd` to find the index to truncate from the start or end of a string, respectively.

Просмотреть файл

@ -0,0 +1,30 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
namespace Microsoft.ML.Tokenizers
{
/// <summary>
/// The result of encoding a text.
/// </summary>
/// <typeparam name="T">The type of the tokens.</typeparam>
public struct EncodeResults<T>
{
/// <summary>
/// Gets or sets the list of tokens generated from the encoded text.
/// </summary>
public IReadOnlyList<T> Tokens { get; set; }
/// <summary>
/// Gets or sets the normalized text generated during the encoding process. This can be <see langword="null"/> if the encoding process does not normalize the input text.
/// </summary>
public string? NormalizedText { get; set; }
/// <summary>
/// Gets or sets the count of characters consumed from the input text.
/// </summary>
public int CharsConsumed { get; set; }
}
}

Просмотреть файл

@ -0,0 +1,29 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Tokenizers
{
/// <summary>
/// The settings used to encode a text.
/// </summary>
public struct EncodeSettings
{
public EncodeSettings() { MaxTokenCount = int.MaxValue; }
/// <summary>
/// Gets or sets a value indicating whether to consider the input normalization during encoding.
/// </summary>
public bool ConsiderNormalization { get; set; }
/// <summary>
/// Gets or sets a value indicating whether to consider the pre-tokenization during encoding.
/// </summary>
public bool ConsiderPreTokenization { get; set; }
/// <summary>
/// Gets or sets the maximum number of tokens to generate.
/// </summary>
public int MaxTokenCount { get; set; }
}
}

Просмотреть файл

@ -12,7 +12,7 @@ namespace Microsoft.ML.Tokenizers
/// Represent the token produced from the tokenization process containing the token substring,
/// the id associated to the token substring, and the offset mapping to the original string.
/// </summary>
public readonly struct Token
public readonly struct EncodedToken
{
/// <summary>
/// Gets the Id value associated to the token.
@ -35,7 +35,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="id">The Id value associated to the token.</param>
/// <param name="value">The token string value.</param>
/// <param name="offset">The offset mapping to the original string.</param>
public Token(int id, string value, (int, int) offset)
public EncodedToken(int id, string value, (int, int) offset)
{
Id = id;
Offset = offset;

Просмотреть файл

@ -1,904 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace Microsoft.ML.Tokenizers
{
/// <summary>
/// Represent the Byte Pair Encoding model.
/// </summary>
public sealed class Bpe : Tokenizer
{
/// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model.
private const int MaxWordLengthToCache = 15;
private string? _unknownToken;
private int? _unknownTokenId;
private readonly PreTokenizer? _preTokenizer;
private readonly Normalizer? _normalizer;
/// <summary>
/// Gets or Sets unknown token. The unknown token to be used when we encounter an unknown char
/// </summary>
public string? UnknownToken
{
get
{
return _unknownToken;
}
private set
{
if (value is null)
{
_unknownToken = value;
_unknownTokenId = null;
return;
}
if (!_vocab.TryGetValue(value, out int id))
{
throw new InvalidOperationException($"Unknown Token '{value}' was not present in '{nameof(Vocab)}'.");
}
_unknownTokenId = id;
_unknownToken = value;
}
}
/// <summary>
/// A prefix to be used for every subword that is not a beginning-of-word
/// </summary>
public string? ContinuingSubwordPrefix { get; }
/// <summary>
/// An optional suffix to characterize and end-of-word sub-word
/// </summary>
public string? EndOfWordSuffix { get; }
/// <summary>
/// Gets or sets whether allowing multiple unknown tokens get fused
/// </summary>
public bool FuseUnknownTokens { get; }
/// <summary>
/// Construct a new Bpe model object to use for text encoding.
/// </summary>
/// <param name="vocabFile">The JSON file path containing the dictionary of string keys and their ids.</param>
/// <param name="mergesFile">The file path containing the tokens's pairs list.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="unknownToken"> The unknown token to be used by the model.</param>
/// <param name="continuingSubwordPrefix">The prefix to attach to sub-word units that dont represent a beginning of word.</param>
/// <param name="endOfWordSuffix">The suffix to attach to sub-word units that represent an end of word.</param>
/// <param name="fuseUnknownTokens">Indicate whether allowing multiple unknown tokens get fused.</param>
public Bpe(string vocabFile, string? mergesFile, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) :
this(vocabFile is null ? throw new ArgumentNullException(nameof(vocabFile)) : File.Open(vocabFile, FileMode.Open, FileAccess.Read),
mergesFile is null ? null : File.Open(mergesFile, FileMode.Open, FileAccess.Read), preTokenizer, normalizer, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: true)
{
}
/// <summary>
/// Construct a new Bpe model object to use for text encoding.
/// </summary>
/// <param name="vocabStream">The JSON stream containing the dictionary of string keys and their ids.</param>
/// <param name="mergesStream">The stream containing the tokens's pairs list.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="unknownToken"> The unknown token to be used by the model.</param>
/// <param name="continuingSubwordPrefix">The prefix to attach to sub-word units that dont represent a beginning of word.</param>
/// <param name="endOfWordSuffix">The suffix to attach to sub-word units that represent an end of word.</param>
/// <param name="fuseUnknownTokens">Indicate whether allowing multiple unknown tokens get fused.</param>
public Bpe(Stream vocabStream, Stream? mergesStream, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) :
this(vocabStream, mergesStream, preTokenizer, normalizer, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: false)
{
}
private Bpe(Stream vocabStream, Stream? mergesStream, PreTokenizer? preTokenizer, Normalizer? normalizer, string? unknownToken, string? continuingSubwordPrefix, string? endOfWordSuffix, bool fuseUnknownTokens, bool disposeStreams)
{
try
{
if (vocabStream is null)
{
throw new ArgumentNullException(nameof(vocabStream));
}
FuseUnknownTokens = fuseUnknownTokens;
ContinuingSubwordPrefix = continuingSubwordPrefix;
EndOfWordSuffix = endOfWordSuffix;
_preTokenizer = preTokenizer ?? WhiteSpace.Instance; // Default to WhiteSpace pre-tokenizer
_normalizer = normalizer;
(Dictionary<StringSpanOrdinalKey, int>? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream);
_vocab = vocab1 ?? new Dictionary<StringSpanOrdinalKey, int>();
Cache = new StringSpanOrdinalKeyCache<Word>();
VocabReverse = new();
foreach (KeyValuePair<StringSpanOrdinalKey, int> kvp in _vocab)
{
VocabReverse.Add(kvp.Value, kvp.Key.Data!);
}
UnknownToken = unknownToken;
int prefixLen = ContinuingSubwordPrefix is null ? 0 : ContinuingSubwordPrefix.Length;
Merges = new();
for (int i = 0; i < merges.Count; i++)
{
(string a, string b) mergeValues = merges[i];
if (!_vocab.TryGetValue(mergeValues.a, out int aId))
{
throw new InvalidOperationException($"Trying to merge a token '{mergeValues.a}' which not exist in the vocabulary.");
}
if (!_vocab.TryGetValue(mergeValues.b, out int bId))
{
throw new InvalidOperationException($"Trying to merge a token '{mergeValues.b}' which not exist in the vocabulary.");
}
if (mergeValues.b.Length <= prefixLen)
{
throw new InvalidOperationException($"The merge value '{mergeValues.b}' is too short to be merged with a prefix of length {prefixLen}. This implies that the merge file is either damaged or missing the prefix in its entries.");
}
string newToken = $"{mergeValues.a}{mergeValues.b.Substring(prefixLen)}";
if (!_vocab.TryGetValue(newToken, out int newId))
{
throw new InvalidOperationException($"Trying to merge a token '{newToken}' which not exist in the vocabulary.");
}
Merges.Add(new Pair<int>(aId, bId), (i, newId));
}
}
finally
{
if (disposeStreams)
{
vocabStream.Dispose();
mergesStream?.Dispose();
}
}
}
/// <summary>
/// Gets the PreTokenizer used by the Tokenizer.
/// </summary>
public override PreTokenizer? PreTokenizer => _preTokenizer;
/// <summary>
/// Gets the Normalizer in use by the Tokenizer.
/// </summary>
public override Normalizer? Normalizer => _normalizer;
/// <summary>
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
public override IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(text, Span<char>.Empty, out normalizedString, considerPreTokenization, considerNormalization);
/// <summary>
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(null, text, out normalizedString, considerPreTokenization, considerNormalization);
private IReadOnlyList<Token> Encode(string? text, ReadOnlySpan<char> textSpan, out string? normalizedString, bool considerPreTokenization, bool considerNormalization)
{
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
normalizedString = null;
return [];
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
List<Token> tokens = new();
PriorityQueue<Merge>? priorityQueue = null;
if (splits is not null)
{
foreach ((int Offset, int Length) split in splits)
{
EncodeWithCache(textSpanToEncode.Slice(split.Offset, split.Length), tokens, split.Offset, ref priorityQueue);
}
}
else
{
EncodeWithCache(textSpanToEncode, tokens, 0, ref priorityQueue);
}
return tokens;
}
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out _, out _);
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(null, text, considerPreTokenization, considerNormalization, out _, out _);
/// <summary>
/// Encodes input text to token Ids up to maximum number of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
/// <summary>
/// Encodes input text to token Ids up to maximum number of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(null, text, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
private IReadOnlyList<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
{
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
}
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
textLength = 0;
normalizedString = null;
return [];
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
List<int> ids = new();
PriorityQueue<Merge>? priorityQueue = null;
if (splits is not null)
{
textLength = 0;
foreach ((int Offset, int Length) split in splits)
{
EncodeToIdsWithCache(textSpanToEncode.Slice(split.Offset, split.Length), ids, maxTokenCount - ids.Count, out int length, ref priorityQueue);
textLength = split.Offset + length;
if (length < split.Length || ids.Count >= maxTokenCount)
{
break;
}
}
}
else
{
EncodeToIdsWithCache(textSpanToEncode, ids, maxTokenCount, out textLength, ref priorityQueue);
}
return ids;
}
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
=> CountTokens(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out _, out _);
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
public override int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
=> CountTokens(null, text, considerPreTokenization, considerNormalization, out _, out _);
/// <summary>
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// </returns>
public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
{
tokenCount = CountTokens(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
return textLength;
}
/// <summary>
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// </returns>
public override int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
{
tokenCount = CountTokens(null, text, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
return textLength;
}
private int CountTokens(string? text, ReadOnlySpan<char> textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
{
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
}
textLength = 0;
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
normalizedString = null;
return 0;
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
PriorityQueue<Merge>? priorityQueue = null;
int count = 0;
if (splits is not null)
{
foreach ((int Offset, int Length) split in splits)
{
count += EncodeToIdsWithCache(textSpanToEncode.Slice(split.Offset, split.Length), null, maxTokenCount - count, out int length, ref priorityQueue);
textLength = split.Offset + length;
if (length < split.Length || count >= maxTokenCount)
{
break;
}
}
}
else
{
count = EncodeToIdsWithCache(textSpanToEncode, null, maxTokenCount, out textLength, ref priorityQueue);
}
return count;
}
/// <summary>
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if normalization is enabled;
/// conversely, if all tokens fit, the result will be 0.
/// </returns>
public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> LastIndexOf(text, Span<char>.Empty, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
/// <summary>
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="normalizedString"/>; conversely, if all tokens fit, the result will be 0.
/// </returns>
public override int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> LastIndexOf(null, text, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
private int LastIndexOf(string? text, ReadOnlySpan<char> textSpan, int maxTokenCount, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int tokenCount)
{
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
}
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
normalizedString = null;
tokenCount = 0;
return 0;
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
PriorityQueue<Merge>? priorityQueue = null;
if (splits is not null)
{
tokenCount = 0;
foreach ((int Offset, int Length) split in splits.Reverse())
{
tokenCount += EncodeToIdsFromEndWithCache(textSpanToEncode.Slice(split.Offset, split.Length), null, maxTokenCount - tokenCount, out int textIndex, ref priorityQueue);
if (textIndex > 0 || tokenCount >= maxTokenCount)
{
return split.Offset + textIndex;
}
}
}
else
{
tokenCount = EncodeToIdsFromEndWithCache(textSpanToEncode, null, maxTokenCount, out int textLength, ref priorityQueue);
return textLength;
}
return 0;
}
/// <summary>
/// Map the token to encoded Id.
/// </summary>
/// <param name="token">The token to map to the Id.</param>
/// <returns>The mapped Id of the token.</returns>
public override int? MapTokenToId(ReadOnlySpan<char> token) => _vocab.TryGetValue(token, out int value) ? value : null;
/// <summary>
/// Map the encoded Id to the token.
/// </summary>
/// <param name="id">The Id to map to the token.</param>
/// <returns>The mapped token of the Id.</returns>
public override string? MapIdToToken(int id)
{
if (VocabReverse.TryGetValue(id, out string? value))
{
return value;
}
return null;
}
/// <summary>
/// Gets the dictionary mapping tokens to Ids.
/// </summary>
public IReadOnlyDictionary<string, int> Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
/// <summary>
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <returns>The decoded string.</returns>
public override string? Decode(IEnumerable<int> ids) => Decode(ids, considerSpecialTokens: true);
/// <summary>
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="considerSpecialTokens">Indicate whether to consider special tokens or not.</param>
/// <returns>The decoded string.</returns>
public string? Decode(IEnumerable<int> ids, bool considerSpecialTokens)
{
if (ids is null)
{
throw new ArgumentNullException(nameof(ids));
}
ValueStringBuilder sb = new ValueStringBuilder();
bool decodeUnknownToken = _unknownTokenId.HasValue && considerSpecialTokens;
if (decodeUnknownToken)
{
foreach (int id in ids)
{
if (MapIdToToken(id) is string s)
{
sb.Append(s);
}
}
}
else
{
foreach (int id in ids)
{
if (id == _unknownTokenId)
{
continue;
}
if (MapIdToToken(id) is string s)
{
sb.Append(s);
}
}
}
if (EndOfWordSuffix is not null)
{
sb.RemoveSuffix(EndOfWordSuffix);
sb.Replace(EndOfWordSuffix, " ");
}
if (ContinuingSubwordPrefix is not null)
{
sb.Replace(ContinuingSubwordPrefix, string.Empty);
}
return sb.ToString();
}
/// Read the given files to extract the vocab and merges
internal static (Dictionary<StringSpanOrdinalKey, int>?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges)
{
JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } };
Dictionary<StringSpanOrdinalKey, int>? dic = JsonSerializer.Deserialize<Dictionary<StringSpanOrdinalKey, int>>(vocab, options) as Dictionary<StringSpanOrdinalKey, int>;
return (dic, ConvertMergesToHashmap(merges));
}
/// The vocabulary assigns a number to each token.
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;
private Dictionary<string, int>? _vocabOriginal;
/// Contains the mapping between Pairs and their (rank, newId).
internal Dictionary<Pair<int>, (int, int)> Merges { get; }
/// Contains the cache for optimizing the encoding step.
internal StringSpanOrdinalKeyCache<Word>? Cache { get; }
internal static readonly int DefaultCacheCapacity = 10_000;
/// Reversed vocabulary, to rebuild the text.
internal SortedDictionary<int, string> VocabReverse { get; }
/// Dropout probability for merges. 0 = no dropout is the default. At 1.0, tokenization will
/// perform no merges, so the result will just be characters.
internal float? Dropout { get; }
/// Converts the merges strings (for example from `merges.txt` file) with the format
/// "{pair_a} {pair_b}" into the format expected by the BPE struct
internal static Vec<(string, string)> ConvertMergesToHashmap(Stream? mergesStream)
{
if (mergesStream is null)
{
return new Vec<(string, string)>();
}
using StreamReader reader = new StreamReader(mergesStream);
Vec<(string, string)> merges = new(1000);
int lineNumber = 0;
while (true)
{
string? line = reader.ReadLine();
if (line is null)
{
break;
}
lineNumber++;
if (line.StartsWith("#version", StringComparison.Ordinal) || line.Length == 0)
{
continue;
}
int index = line.IndexOf(' ');
if (index < 0 || index == line.Length - 1 || line.IndexOf(' ', index + 1) >= 0)
{
throw new InvalidOperationException($"Invalid merger file format at line: {lineNumber}");
}
merges.Push((line.Substring(0, index), line.Substring(index + 1)));
}
return merges;
}
private readonly Dictionary<char, string> _charToString = new Dictionary<char, string>();
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal string CharToString(char c)
{
if (_charToString.TryGetValue(c, out string? v))
{
return v;
}
string s = c.ToString();
_charToString[c] = s;
return s;
}
internal Word MergeWord(ReadOnlySpan<char> w, ref PriorityQueue<Merge>? priorityQueue)
{
Word word = Word.WithCapacity(w.Length);
(int Id, int Len)? unk = null;
int i = 0;
Span<char> buffer = stackalloc char[256];
scoped ReadOnlySpan<char> s;
while (i < w.Length)
{
int length;
if (Char.IsHighSurrogate(w[i]) && i < w.Length - 1 && Char.IsLowSurrogate(w[i + 1]))
{
length = 2;
s = w.Slice(i, 2);
}
else
{
length = 1;
s = w.Slice(i, 1);
}
// Add the `continuing_subword_prefix` if relevant
if (i > 0 && ContinuingSubwordPrefix is not null)
{
if (ContinuingSubwordPrefix.Length + s.Length <= buffer.Length)
{
ContinuingSubwordPrefix.AsSpan().CopyTo(buffer);
s.CopyTo(buffer.Slice(ContinuingSubwordPrefix.Length));
s = buffer.Slice(0, ContinuingSubwordPrefix.Length + s.Length);
}
else
{
#if NETCOREAPP
s = $"{ContinuingSubwordPrefix}{s}".AsSpan();
#else
string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString();
s = $"{ContinuingSubwordPrefix}{s1}".AsSpan();
#endif
}
}
// Add the `end_of_word_suffix` if relevant
if (i + length >= w.Length && EndOfWordSuffix is not null)
{
if (s.Length + EndOfWordSuffix.Length <= buffer.Length)
{
s.CopyTo(buffer);
EndOfWordSuffix.AsSpan().CopyTo(buffer.Slice(s.Length));
s = buffer.Slice(0, s.Length + EndOfWordSuffix.Length);
}
else
{
#if NETCOREAPP
s = $"{s}{EndOfWordSuffix}".AsSpan();
#else
string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString();
s = $"{s1}{EndOfWordSuffix}".AsSpan();
#endif
}
}
if (_vocab.TryGetValue(s, out int id))
{
if (unk.HasValue)
{
word.Add(unk.Value.Id, unk.Value.Len);
unk = null;
}
word.Add(id, length);
}
else if (UnknownToken is not null)
{
if (unk.HasValue)
{
if (FuseUnknownTokens)
{
// Fuse unk
unk = (unk.Value.Id, unk.Value.Len + length);
}
else
{
// Do not fuse unk, add the previous one
word.Add(unk.Value.Id, unk.Value.Len);
if (!_vocab.TryGetValue(UnknownToken, out int value))
{
throw new InvalidOperationException($"Unknown Token Out Of Vocabulary.");
}
unk = (value, length);
}
}
else
{
if (!_vocab.TryGetValue(UnknownToken, out int value))
{
throw new InvalidOperationException($"Unknown Token Out Of Vocabulary.");
}
unk = (value, length);
}
}
i += length;
}
if (unk.HasValue)
{
word.Add(unk.Value.Id, unk.Value.Len);
}
word.MergeAll(Merges, Dropout, ref priorityQueue);
return word;
}
internal void WordToTokens(ref Word word, List<Token> tokens, int offset) => word.ToTokens(VocabReverse, tokens, offset);
internal void EncodeWithCache(ReadOnlySpan<char> text, List<Token> tokens, int offset, ref PriorityQueue<Merge>? priorityQueue)
{
Word word;
if (Cache is not null)
{
if (Cache.TryGetValue(text, out word))
{
WordToTokens(ref word, tokens, offset);
return;
}
word = MergeWord(text, ref priorityQueue);
if (text.Length <= MaxWordLengthToCache)
{
Cache.Set(text.ToString(), word);
}
}
else
{
word = MergeWord(text, ref priorityQueue);
}
WordToTokens(ref word, tokens, offset);
}
internal int WordToIds(ref Word word, IList<int>? accumulatedIds, out int textLength, int fullTextLength, int maxTokens)
{
if (word.SymbolsCount <= maxTokens)
{
textLength = fullTextLength;
if (accumulatedIds is not null)
{
word.PopulateIds(accumulatedIds);
}
return word.SymbolsCount;
}
if (accumulatedIds is not null)
{
return word.PopulateIdsUpToMax(accumulatedIds, maxTokens, out textLength);
}
return word.CountIdsUpToMax(maxTokens, out textLength);
}
internal int WordToIdsFromEnd(ref Word word, IList<int>? accumulatedIds, out int textIndex, int fullTextLength, int maxTokens)
{
if (word.SymbolsCount <= maxTokens)
{
textIndex = 0;
if (accumulatedIds is not null)
{
word.PopulateIds(accumulatedIds);
}
return word.SymbolsCount;
}
if (accumulatedIds is not null)
{
return word.PopulateIdsUpToMaxFromEnd(accumulatedIds, maxTokens, fullTextLength, out textIndex);
}
return word.CountIdsUpToMaxFromEnd(maxTokens, fullTextLength, out textIndex);
}
private int EncodeToIdsWithCache(ReadOnlySpan<char> text, List<int>? accumulatedIds, int maxTokens, out int textLength, ref PriorityQueue<Merge>? priorityQueue)
{
Word word;
if (Cache is not null)
{
if (Cache.TryGetValue(text, out Word hit))
{
return WordToIds(ref hit, accumulatedIds, out textLength, text.Length, maxTokens);
}
word = MergeWord(text, ref priorityQueue);
if (text.Length <= MaxWordLengthToCache)
{
Cache.Set(text.ToString(), word);
}
}
else
{
word = MergeWord(text, ref priorityQueue);
}
return WordToIds(ref word, accumulatedIds, out textLength, text.Length, maxTokens);
}
internal int EncodeToIdsFromEndWithCache(ReadOnlySpan<char> text, IList<int>? accumulatedIds, int maxTokens, out int textIndex, ref PriorityQueue<Merge>? priorityQueue)
{
Word word;
if (Cache is not null)
{
if (Cache.TryGetValue(text, out Word hit))
{
return WordToIdsFromEnd(ref hit, accumulatedIds, out textIndex, text.Length, maxTokens);
}
word = MergeWord(text, ref priorityQueue);
if (text.Length <= MaxWordLengthToCache)
{
Cache.Set(text.ToString(), word);
}
}
else
{
word = MergeWord(text, ref priorityQueue);
}
return WordToIdsFromEnd(ref word, accumulatedIds, out textIndex, text.Length, maxTokens);
}
}
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -13,7 +13,7 @@ namespace Microsoft.ML.Tokenizers
private readonly Dictionary<TKey, TValue> _map;
private object SyncObj => _map;
internal Cache() : this(Bpe.DefaultCacheCapacity) { }
internal Cache() : this(BpeTokenizer.DefaultCacheCapacity) { }
internal Cache(int capacity)
{

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -17,14 +17,14 @@ namespace Microsoft.ML.Tokenizers
/// <summary>
/// Represent the Byte Pair Encoding model.
/// </summary>
public sealed class EnglishRoberta : Tokenizer
public sealed class EnglishRobertaTokenizer : Tokenizer
{
private readonly HighestOccurrenceMapping _vocabIdToHighestOccurrence;
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;
private Dictionary<string, int>? _vocabOriginal;
private readonly SortedDictionary<int, StringSpanOrdinalKey> _vocabReverse;
private readonly Cache<(string, string), int> _mergeRanks;
private readonly StringSpanOrdinalKeyCache<List<Token>> _cache;
private readonly StringSpanOrdinalKeyCache<List<EncodedToken>> _cache;
private readonly PreTokenizer? _preTokenizer;
private readonly Normalizer? _normalizer;
@ -33,6 +33,67 @@ namespace Microsoft.ML.Tokenizers
/// </summary>
public bool FilterUnsupportedChars { get; }
/// <summary>
/// Create tokenizer's model object to use with the English Robert model.
/// </summary>
/// <param name="vocabularyPath">The JSON file path containing the dictionary of string keys and their ids.</param>
/// <param name="mergePath">The file path containing the tokens's pairs list.</param>
/// <param name="highestOccurrenceMappingPath">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
public static EnglishRobertaTokenizer Create(
string vocabularyPath,
string mergePath,
string highestOccurrenceMappingPath)
=> new EnglishRobertaTokenizer(vocabularyPath, mergePath, highestOccurrenceMappingPath);
/// <summary>
/// Create tokenizer's model object to use with the English Robert model.
/// </summary>
/// <param name="vocabularyPath">The JSON file path containing the dictionary of string keys and their ids.</param>
/// <param name="mergePath">The file path containing the tokens's pairs list.</param>
/// <param name="highestOccurrenceMappingPath">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
public static EnglishRobertaTokenizer Create(
string vocabularyPath,
string mergePath,
string highestOccurrenceMappingPath,
PreTokenizer? preTokenizer = null,
Normalizer? normalizer = null,
bool filterUnsupportedChars = true)
=> new EnglishRobertaTokenizer(vocabularyPath, mergePath, highestOccurrenceMappingPath, preTokenizer, normalizer, filterUnsupportedChars);
/// <summary>
/// Create tokenizer's model object to use with the English Robert model.
/// </summary>
/// <param name="vocabularyStream">The stream of a JSON file containing the dictionary of string keys and their ids.</param>
/// <param name="mergeStream">The stream of a file containing the tokens's pairs list.</param>
/// <param name="highestOccurrenceMappingStream">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
public static EnglishRobertaTokenizer Create(
Stream vocabularyStream,
Stream mergeStream,
Stream highestOccurrenceMappingStream)
=> new EnglishRobertaTokenizer(vocabularyStream, mergeStream, highestOccurrenceMappingStream);
/// <summary>
/// Create tokenizer's model object to use with the English Robert model.
/// </summary>
/// <param name="vocabularyStream">The stream of a JSON file containing the dictionary of string keys and their ids.</param>
/// <param name="mergeStream">The stream of a file containing the tokens's pairs list.</param>
/// <param name="highestOccurrenceMappingStream">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
public static EnglishRobertaTokenizer Create(
Stream vocabularyStream,
Stream mergeStream,
Stream highestOccurrenceMappingStream,
PreTokenizer? preTokenizer = null,
Normalizer? normalizer = null,
bool filterUnsupportedChars = true)
=> new EnglishRobertaTokenizer(vocabularyStream, mergeStream, highestOccurrenceMappingStream, preTokenizer, normalizer, filterUnsupportedChars);
/// <summary>
/// Construct tokenizer's model object to use with the English Robert model.
/// </summary>
@ -42,11 +103,11 @@ namespace Microsoft.ML.Tokenizers
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
public EnglishRoberta(string vocabularyPath, string mergePath, string highestOccurrenceMappingPath, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, bool filterUnsupportedChars = true) :
internal EnglishRobertaTokenizer(string vocabularyPath, string mergePath, string highestOccurrenceMappingPath, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, bool filterUnsupportedChars = true) :
this(vocabularyPath is null ? throw new ArgumentNullException(nameof(vocabularyPath)) : File.OpenRead(vocabularyPath),
mergePath is null ? throw new ArgumentNullException(nameof(mergePath)) : File.OpenRead(mergePath),
highestOccurrenceMappingPath is null ? throw new ArgumentNullException(nameof(highestOccurrenceMappingPath)) : File.OpenRead(highestOccurrenceMappingPath),
preTokenizer, normalizer, filterUnsupportedChars, true)
preTokenizer, normalizer, filterUnsupportedChars, disposeStream: true)
{
}
@ -59,12 +120,12 @@ namespace Microsoft.ML.Tokenizers
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, bool filterUnsupportedChars = true) :
this(vocabularyStream, mergeStream, highestOccurrenceMappingStream, preTokenizer, normalizer, filterUnsupportedChars, false)
internal EnglishRobertaTokenizer(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, bool filterUnsupportedChars = true) :
this(vocabularyStream, mergeStream, highestOccurrenceMappingStream, preTokenizer, normalizer, filterUnsupportedChars, disposeStream: false)
{
}
public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream, PreTokenizer? preTokenizer, Normalizer? normalizer, bool filterUnsupportedChars, bool disposeStream)
private EnglishRobertaTokenizer(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream, PreTokenizer? preTokenizer, Normalizer? normalizer, bool filterUnsupportedChars, bool disposeStream)
{
if (vocabularyStream is null)
{
@ -93,7 +154,7 @@ namespace Microsoft.ML.Tokenizers
_vocab = GetVocabulary(vocabularyStream);
_vocabReverse = _vocab.ReverseSorted();
_mergeRanks = GetMergeRanks(mergeStream);
_cache = new StringSpanOrdinalKeyCache<List<Token>>();
_cache = new StringSpanOrdinalKeyCache<List<EncodedToken>>();
if (disposeStream)
{
@ -190,7 +251,7 @@ namespace Microsoft.ML.Tokenizers
/// <summary>
/// Gets the dictionary mapping tokens to Ids.
/// </summary>
public IReadOnlyDictionary<string, int> Vocab => GetVocab();
public IReadOnlyDictionary<string, int> Vocabulary => GetVocab();
//
// Public Model interfaces implementation
@ -201,7 +262,7 @@ namespace Microsoft.ML.Tokenizers
/// </summary>
/// <param name="id">The Id to map to the string.</param>
/// <returns>The mapped token of the Id.</returns>
public override string? MapIdToToken(int id)
private string? MapIdToToken(int id)
{
if (_vocabReverse.TryGetValue(id, out var value))
{
@ -234,50 +295,45 @@ namespace Microsoft.ML.Tokenizers
}
/// <summary>
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
/// Encodes input text to a list of <see cref="EncodedToken" />s.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
public override IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(text, Span<char>.Empty, out normalizedString, considerPreTokenization, considerNormalization);
/// <summary>
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(null, text, out normalizedString, considerPreTokenization, considerNormalization);
private IReadOnlyList<Token> Encode(string? text, ReadOnlySpan<char> textSpan, out string? normalizedString, bool considerPreTokenization, bool considerNormalization)
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
protected override EncodeResults<EncodedToken> EncodeToTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
{
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
normalizedString = null;
return [];
return new EncodeResults<EncodedToken> { Tokens = [], NormalizedText = null, CharsConsumed = 0 };
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
text,
textSpan,
settings.ConsiderPreTokenization,
settings.ConsiderNormalization,
_normalizer,
_preTokenizer,
out string? normalizedString,
out ReadOnlySpan<char> textSpanToEncode,
out int charsConsumed);
if (splits is not null)
{
List<Token> tokens = new();
List<EncodedToken> tokens = new();
foreach ((int Offset, int Length) split in splits)
{
foreach (Token t in EncodeInternal(textSpanToEncode.Slice(split.Offset, split.Length)))
foreach (EncodedToken t in EncodeInternal(textSpanToEncode.Slice(split.Offset, split.Length)))
{
tokens.Add(new Token(t.Id, t.Value, (split.Offset + t.Offset.Index, t.Offset.Length)));
tokens.Add(new EncodedToken(t.Id, t.Value, (split.Offset + t.Offset.Index, t.Offset.Length)));
}
}
return tokens;
return new EncodeResults<EncodedToken> { Tokens = tokens, NormalizedText = normalizedString, CharsConsumed = charsConsumed };
}
else
{
return EncodeInternal(textSpanToEncode);
return new EncodeResults<EncodedToken> { Tokens = EncodeInternal(textSpanToEncode), NormalizedText = normalizedString, CharsConsumed = charsConsumed };
}
}
@ -286,7 +342,7 @@ namespace Microsoft.ML.Tokenizers
/// </summary>
/// <param name="text">The text to encode.</param>
/// <returns>The list of tokens generated from the text tokenization.</returns>
private IReadOnlyList<Token> EncodeInternal(ReadOnlySpan<char> text)
private IReadOnlyList<EncodedToken> EncodeInternal(ReadOnlySpan<char> text)
{
if (text.IsEmpty)
{
@ -316,14 +372,14 @@ namespace Microsoft.ML.Tokenizers
return [];
}
if (_cache.TryGetValue(text, out List<Token>? hit))
if (_cache.TryGetValue(text, out List<EncodedToken>? hit))
{
ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping);
return ModifyTokenListOffsets(hit, indexMapping);
}
List<Token> result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
List<EncodedToken> result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
_cache.Set(text.ToString(), result);
ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping);
@ -334,49 +390,13 @@ namespace Microsoft.ML.Tokenizers
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out _, out _);
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
/// <returns>The encoded results containing the list of encoded Ids.</returns>
protected override EncodeResults<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
=> EncodeToIds(text, textSpan, settings.ConsiderPreTokenization, settings.ConsiderNormalization, settings.MaxTokenCount);
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(null, text, considerPreTokenization, considerNormalization, out _, out _);
/// <summary>
/// Encodes input text to token Ids up to maximum number of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
/// <summary>
/// Encodes input text to token Ids up to maximum number of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(null, text, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
private IReadOnlyList<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
private EncodeResults<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, bool considerPreTokenization, bool considerNormalization, int maxTokenCount = int.MaxValue)
{
if (maxTokenCount <= 0)
{
@ -385,14 +405,23 @@ namespace Microsoft.ML.Tokenizers
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
textLength = 0;
normalizedString = null;
return [];
return new EncodeResults<int> { Tokens = [], NormalizedText = null, CharsConsumed = 0 };
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
text,
textSpan,
considerPreTokenization,
considerNormalization,
_normalizer,
_preTokenizer,
out string? normalizedString,
out ReadOnlySpan<char> textSpanToEncode,
out _);
List<int> ids = new();
int textLength = 0;
if (splits is not null)
{
textLength = 0;
@ -412,84 +441,70 @@ namespace Microsoft.ML.Tokenizers
EncodeToIdsInternal(textSpanToEncode, ids, out textLength, maxTokenCount);
}
return ids;
return new EncodeResults<int> { Tokens = ids, NormalizedText = normalizedString, CharsConsumed = textLength };
}
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
=> CountTokens(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out _, out _);
protected override int CountTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
=> CountTokens(text, textSpan, settings.ConsiderPreTokenization, settings.ConsiderNormalization, out _, out _, settings.MaxTokenCount);
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// Find the index of the maximum encoding capacity without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
public override int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
=> CountTokens(null, text, considerPreTokenization, considerNormalization, out _, out _);
/// <summary>
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
/// <param name="fromEnd">Indicate whether to find the index from the end of the text.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="settings" /> has <see cref="EncodeSettings.ConsiderNormalization"/> is <see langword="false"/>, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to <see langword="null"/>.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// If <paramRef name="fromEnd" /> is <see langword="false"/>, it represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the input text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// If <paramRef name="fromEnd" /> is <see langword="true"/>, it represents the index of the first character to be included. In cases where no tokens fit, the result will be the text length; conversely,
/// if all tokens fit, the result will be zero.
/// </returns>
public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
protected override int GetIndexByTokenCount(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings, bool fromEnd, out string? normalizedString, out int tokenCount)
{
tokenCount = CountTokens(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
return textLength;
if (fromEnd)
{
return LastIndexOf(text, textSpan, settings.MaxTokenCount, settings.ConsiderPreTokenization, settings.ConsiderNormalization, out normalizedString, out tokenCount);
}
tokenCount = CountTokens(text, textSpan, settings.ConsiderPreTokenization, settings.ConsiderNormalization, out normalizedString, out int charsConsumed, settings.MaxTokenCount);
return charsConsumed;
}
/// <summary>
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// </returns>
public override int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
{
tokenCount = CountTokens(null, text, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
return textLength;
}
private int CountTokens(string? text, ReadOnlySpan<char> textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
private int CountTokens(string? text, ReadOnlySpan<char> textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int charsConsumed, int maxTokenCount = int.MaxValue)
{
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
}
textLength = 0;
charsConsumed = 0;
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
normalizedString = null;
return 0;
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
text,
textSpan,
considerPreTokenization,
considerNormalization,
_normalizer,
_preTokenizer,
out normalizedString,
out ReadOnlySpan<char> textSpanToEncode,
out _);
int count = 0;
if (splits is not null)
@ -497,7 +512,7 @@ namespace Microsoft.ML.Tokenizers
foreach ((int Offset, int Length) split in splits)
{
count += EncodeToIdsInternal(textSpanToEncode.Slice(split.Offset, split.Length), null, out int length, maxTokenCount - count);
textLength = split.Offset + length;
charsConsumed = split.Offset + length;
if (length < split.Length || count >= maxTokenCount)
{
@ -507,45 +522,12 @@ namespace Microsoft.ML.Tokenizers
}
else
{
count += EncodeToIdsInternal(textSpanToEncode, null, out textLength, maxTokenCount);
count += EncodeToIdsInternal(textSpanToEncode, null, out charsConsumed, maxTokenCount);
}
return count;
}
/// <summary>
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if normalization is enabled;
/// conversely, if all tokens fit, the result will be 0.
/// </returns>
public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> LastIndexOf(text, Span<char>.Empty, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
/// <summary>
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="normalizedString"/>; conversely, if all tokens fit, the result will be 0.
/// </returns>
public override int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> LastIndexOf(null, text, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
private int LastIndexOf(string? text, ReadOnlySpan<char> textSpan, int maxTokenCount, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int tokenCount)
{
if (maxTokenCount <= 0)
@ -560,7 +542,16 @@ namespace Microsoft.ML.Tokenizers
return 0;
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
text,
textSpan,
considerPreTokenization,
considerNormalization,
_normalizer,
_preTokenizer,
out normalizedString,
out ReadOnlySpan<char> textSpanToEncode,
out _);
if (splits is not null)
{
@ -576,16 +567,16 @@ namespace Microsoft.ML.Tokenizers
}
else
{
tokenCount = EncodeToIdsFromEndInternal(textSpanToEncode, null, out int textLength, maxTokenCount);
return textLength;
tokenCount = EncodeToIdsFromEndInternal(textSpanToEncode, null, out int charsConsumed, maxTokenCount);
return charsConsumed;
}
return 0;
}
private int EncodeToIdsResult(List<Token> tokens, IList<int>? accumulatedIds, int maxTokens, int fullTextLength, out int textLength)
private int EncodeToIdsResult(List<EncodedToken> tokens, IList<int>? accumulatedIds, int maxTokens, int fullTextLength, out int charsConsumed)
{
textLength = 0;
charsConsumed = 0;
if (tokens.Count <= maxTokens)
{
@ -597,7 +588,7 @@ namespace Microsoft.ML.Tokenizers
}
}
textLength = fullTextLength;
charsConsumed = fullTextLength;
return tokens.Count;
}
@ -606,21 +597,21 @@ namespace Microsoft.ML.Tokenizers
for (int i = 0; i < maxTokens; i++)
{
accumulatedIds.Add(tokens[i].Id);
textLength += tokens[i].Offset.Length;
charsConsumed += tokens[i].Offset.Length;
}
}
else
{
for (int i = 0; i < maxTokens; i++)
{
textLength += tokens[i].Offset.Length;
charsConsumed += tokens[i].Offset.Length;
}
}
return maxTokens;
}
private int EncodeToIdsFromEndResult(List<Token> tokens, IList<int>? accumulatedIds, int maxTokens, int fullTextLength, out int textIndex)
private int EncodeToIdsFromEndResult(List<EncodedToken> tokens, IList<int>? accumulatedIds, int maxTokens, int fullTextLength, out int textIndex)
{
textIndex = fullTextLength;
@ -657,17 +648,17 @@ namespace Microsoft.ML.Tokenizers
return maxTokens;
}
private int EncodeToIdsInternal(ReadOnlySpan<char> text, IList<int>? accumulatedIds, out int textLength, int maxTokens)
private int EncodeToIdsInternal(ReadOnlySpan<char> text, IList<int>? accumulatedIds, out int charsConsumed, int maxTokens)
{
if (text.IsEmpty)
{
textLength = 0;
charsConsumed = 0;
return 0;
}
if (_cache.TryGetValue(text, out List<Token>? hit))
if (_cache.TryGetValue(text, out List<EncodedToken>? hit))
{
return EncodeToIdsResult(hit, accumulatedIds, maxTokens, text.Length, out textLength);
return EncodeToIdsResult(hit, accumulatedIds, maxTokens, text.Length, out charsConsumed);
}
char[] token = ArrayPool<char>.Shared.Rent(text.Length);
@ -690,16 +681,16 @@ namespace Microsoft.ML.Tokenizers
{
ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping);
textLength = text.Length;
charsConsumed = text.Length;
return 0;
}
List<Token> result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
List<EncodedToken> result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
_cache.Set(text.ToString(), result);
ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping);
return EncodeToIdsResult(result, accumulatedIds, maxTokens, text.Length, out textLength);
return EncodeToIdsResult(result, accumulatedIds, maxTokens, text.Length, out charsConsumed);
}
private int EncodeToIdsFromEndInternal(ReadOnlySpan<char> text, IList<int>? accumulatedIds, out int textIndex, int maxTokens)
@ -710,7 +701,7 @@ namespace Microsoft.ML.Tokenizers
return 0;
}
if (_cache.TryGetValue(text, out List<Token>? hit))
if (_cache.TryGetValue(text, out List<EncodedToken>? hit))
{
return EncodeToIdsFromEndResult(hit, accumulatedIds, maxTokens, text.Length, out textIndex);
}
@ -739,7 +730,7 @@ namespace Microsoft.ML.Tokenizers
return 0;
}
List<Token> result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
List<EncodedToken> result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
_cache.Set(text.ToString(), result);
ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping);
@ -752,7 +743,7 @@ namespace Microsoft.ML.Tokenizers
/// </summary>
/// <param name="token">The token to map to the Id.</param>
/// <returns>The mapped Id of the token.</returns>
public override int? MapTokenToId(ReadOnlySpan<char> token) => _vocab.TryGetValue(token, out int value) ? value : null;
private int? MapTokenToId(ReadOnlySpan<char> token) => _vocab.TryGetValue(token, out int value) ? value : null;
/// <summary>
/// Decode the given ids, back to a String.
@ -779,6 +770,46 @@ namespace Microsoft.ML.Tokenizers
return sb.ToString();
}
/// <summary>
/// Decode the given ids back to text and store the result in the <paramref name="destination"/> span.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="destination">The span to store the decoded text.</param>
/// <param name="idsConsumed">The number of ids consumed during the decoding.</param>
/// <param name="charsWritten">The number of characters written to the destination span.</param>
/// <returns>The operation status indicates whether all IDs were successfully decoded or if the <paramref name="destination"/> is too small to contain the entire decoded result.</returns>
public override OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, out int idsConsumed, out int charsWritten)
{
if (ids is null)
{
throw new ArgumentNullException(nameof(ids));
}
Span<char> buffer = destination;
idsConsumed = 0;
charsWritten = 0;
foreach (int id in ids)
{
if (MapIdToToken(id) is string s)
{
if (s.Length > buffer.Length)
{
return OperationStatus.DestinationTooSmall;
}
s.AsSpan().CopyTo(buffer);
buffer = buffer.Slice(s.Length);
charsWritten += s.Length;
}
idsConsumed++;
}
return OperationStatus.Done;
}
/// <summary>
/// Convert a list of token Ids to highest occurrence rankings.
/// </summary>
@ -866,7 +897,7 @@ namespace Microsoft.ML.Tokenizers
// Private & Internal methods
//
private IReadOnlyList<Token> ModifyTokenListOffsets(IReadOnlyList<Token> tokens, Span<int> indexMapping)
private IReadOnlyList<EncodedToken> ModifyTokenListOffsets(IReadOnlyList<EncodedToken> tokens, Span<int> indexMapping)
{
int index = 0;
@ -876,7 +907,7 @@ namespace Microsoft.ML.Tokenizers
if (tokens[i].Offset != (indexMapping[index], tokens[i].Value.Length))
{
List<Token> list = new List<Token>(tokens.Count);
List<EncodedToken> list = new List<EncodedToken>(tokens.Count);
for (int j = 0; j < i; j++)
{
list.Add(tokens[j]);
@ -884,7 +915,7 @@ namespace Microsoft.ML.Tokenizers
for (int j = i; j < tokens.Count; j++)
{
list.Add(new Token(tokens[j].Id, tokens[j].Value, (indexMapping[index], tokens[j].Value.Length)));
list.Add(new EncodedToken(tokens[j].Id, tokens[j].Value, (indexMapping[index], tokens[j].Value.Length)));
index += tokens[j].Value.Length;
}
@ -903,7 +934,7 @@ namespace Microsoft.ML.Tokenizers
/// <summary>
/// Encode a token into BPE-ed sub-tokens. E.g., "playing" into ["play", "ing"].
/// </summary>
private List<Token> EncodeToTokens(Span<char> token, Span<int> indexMapping)
private List<EncodedToken> EncodeToTokens(Span<char> token, Span<int> indexMapping)
{
if (token.Length == 0)
{
@ -916,7 +947,7 @@ namespace Microsoft.ML.Tokenizers
{
Debug.Assert(token[0] < charToString.Length);
string tokenValue = charToString[token[0]];
return new List<Token> { new Token(_vocab[new StringSpanOrdinalKey(tokenValue)], tokenValue, (indexMapping[0], 1)) };
return new List<EncodedToken> { new EncodedToken(_vocab[new StringSpanOrdinalKey(tokenValue)], tokenValue, (indexMapping[0], 1)) };
}
List<string> word = new(token.Length);
@ -1000,12 +1031,12 @@ namespace Microsoft.ML.Tokenizers
WordToPairs(word, pairs);
}
var tokens = new List<Token>(word.Count);
var tokens = new List<EncodedToken>(word.Count);
int index = 0;
foreach (string w in word)
{
tokens.Add(new Token(_vocab[new StringSpanOrdinalKey(w)], w, (indexMapping[index], w.Length)));
tokens.Add(new EncodedToken(_vocab[new StringSpanOrdinalKey(w)], w, (indexMapping[index], w.Length)));
index += w.Length;
}

Просмотреть файл

@ -0,0 +1,64 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Sentencepiece;
using System;
using System.Collections.Generic;
using System.IO;
namespace Microsoft.ML.Tokenizers
{
// SentencePiece is under the Apache License 2.0 https://github.com/google/sentencepiece/blob/master/LICENSE
/// <summary>
/// LlamaTokenizer is SentencePieceBpeTokenizer which is implemented based on https://github.com/google/sentencepiece.
/// </summary>
public sealed class LlamaTokenizer : SentencePieceBpeTokenizer
{
internal LlamaTokenizer(ModelProto modelProto, bool addBos, bool addEos, IReadOnlyDictionary<string, int>? addedTokens = null) : base(modelProto, addBos, addEos, addedTokens)
{
}
/// <summary>
/// Create from the given model stream a LlamaTokenizer which is based on SentencePieceBpeTokenizer. The model stream should contain the SentencePiece Bpe model according to
/// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto specification.
/// </summary>
/// <param name="modelStream">The stream containing the SentencePiece Bpe model.</param>
/// <param name="addBeginOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <param name="specialTokens">The additional tokens to add to the vocabulary.</param>
public static LlamaTokenizer Create(
Stream modelStream,
bool addBeginOfSentence = true,
bool addEndOfSentence = false,
IReadOnlyDictionary<string, int>? specialTokens = null)
{
ModelProto modelProto = ModelProto.Parser.ParseFrom(modelStream);
if (modelProto is null)
{
throw new ArgumentNullException(nameof(modelProto));
}
if (modelProto.TrainerSpec.ModelType != TrainerSpec.Types.ModelType.Bpe)
{
throw new ArgumentException("The model type is not Bpe.", nameof(modelProto));
}
if (modelProto.NormalizerSpec.Name != "identity" && !string.IsNullOrEmpty(modelProto.NormalizerSpec.Name))
{
throw new ArgumentException($"Normalization '{modelProto.NormalizerSpec.Name}' is not supported.", nameof(modelProto));
}
SentencePieceNormalizer normalizer = new(
modelProto.NormalizerSpec.RemoveExtraWhitespaces,
modelProto.NormalizerSpec.AddDummyPrefix,
modelProto.NormalizerSpec.EscapeWhitespaces,
modelProto.TrainerSpec.TreatWhitespaceAsSuffix,
specialTokens);
return new LlamaTokenizer(modelProto, addBeginOfSentence, addEndOfSentence, specialTokens);
}
}
}

Просмотреть файл

@ -0,0 +1,120 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.IO;
namespace Microsoft.ML.Tokenizers
{
/// <summary>
/// Represent the Byte Pair Encoding model.
/// Implement the Phi2 tokenizer described in https://huggingface.co/microsoft/phi-2
/// </summary>
public sealed class Phi2Tokenizer : CodeGenTokenizer
{
/// <summary>
/// Initializes a new instance of the <see cref="Phi2Tokenizer"/> class.
/// </summary>
/// <summary>
/// Construct tokenizer's model object to use with the English Robert model.
/// </summary>
/// <param name="vocabularyPath">The JSON file path containing the dictionary of string keys and their ids.</param>
/// <param name="mergePath">The file path containing the tokens's pairs list.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="addedTokens">The additional tokens to add to the vocabulary.</param>
/// <param name="addPrefixSpace">Indicate whether to include a leading space before encoding the text.</param>
/// <param name="addBeginningOfSentence">Indicate whether to include the beginning of sentence token in the encoding.</param>
/// <param name="addEndOfSentence">Indicate whether to include the end of sentence token in the encoding.</param>
/// <param name="unknownToken">The unknown token.</param>
/// <param name="beginningOfSentenceToken">The beginning of sentence token.</param>
/// <param name="endOfSentenceToken">The end of sentence token.</param>
internal Phi2Tokenizer(
string vocabularyPath,
string mergePath,
PreTokenizer? preTokenizer = null,
Normalizer? normalizer = null,
IReadOnlyDictionary<string, int>? addedTokens = null,
bool addPrefixSpace = false,
bool addBeginningOfSentence = false,
bool addEndOfSentence = false,
string? unknownToken = DefaultSpecialToken,
string? beginningOfSentenceToken = DefaultSpecialToken,
string? endOfSentenceToken = DefaultSpecialToken) :
base(vocabularyPath, mergePath, preTokenizer, normalizer, addedTokens, addPrefixSpace, addBeginningOfSentence,
addEndOfSentence, unknownToken, beginningOfSentenceToken, endOfSentenceToken)
{
}
/// <summary>
/// Construct tokenizer's model object to use with the English Robert model.
/// </summary>
/// <param name="vocabularyStream">The stream of a JSON file containing the dictionary of string keys and their ids.</param>
/// <param name="mergeStream">The stream of a file containing the tokens's pairs list.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="addedTokens">The additional tokens to add to the vocabulary.</param>
/// <param name="addPrefixSpace">Indicate whether to include a leading space before encoding the text.</param>
/// <param name="addBeginningOfSentence">Indicate whether to include the beginning of sentence token in the encoding.</param>
/// <param name="addEndOfSentence">Indicate whether to include the end of sentence token in the encoding.</param>
/// <param name="unknownToken">The unknown token.</param>
/// <param name="beginningOfSentenceToken">The beginning of sentence token.</param>
/// <param name="endOfSentenceToken">The end of sentence token.</param>
internal Phi2Tokenizer(
Stream vocabularyStream,
Stream mergeStream,
PreTokenizer? preTokenizer = null,
Normalizer? normalizer = null,
IReadOnlyDictionary<string, int>? addedTokens = null,
bool addPrefixSpace = false,
bool addBeginningOfSentence = false,
bool addEndOfSentence = false,
string? unknownToken = DefaultSpecialToken,
string? beginningOfSentenceToken = DefaultSpecialToken,
string? endOfSentenceToken = DefaultSpecialToken) :
base(vocabularyStream, mergeStream, preTokenizer, normalizer, addedTokens, addPrefixSpace, addBeginningOfSentence,
addEndOfSentence, unknownToken, beginningOfSentenceToken, endOfSentenceToken)
{
}
/// <summary>
/// Create a CodeGen Phi2 tokenizer from the given vocab and merges streams.
/// </summary>
/// <param name="vocabStream">The stream containing the vocab file.</param>
/// <param name="mergesStream">The stream containing the merges file.</param>
/// <param name="addPrefixSpace">Indicate whether to add a space before the token.</param>
/// <param name="addBeginOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <returns>The CodeGen tokenizer object.</returns>
/// <remarks>
/// The tokenizer will be created according to the configuration specified in https://huggingface.co/microsoft/phi-2/raw/main/tokenizer.json.
/// It is important to provide the similar vocab and merges files to the ones used in the training of the model.
/// The vocab and merges files can be downloaded from the following links:
/// https://huggingface.co/microsoft/phi-2/resolve/main/vocab.json?download=true
/// https://huggingface.co/microsoft/phi-2/resolve/main/merges.txt?download=true
/// </remarks>
public new static Phi2Tokenizer Create(
Stream vocabStream,
Stream mergesStream,
bool addPrefixSpace = false,
bool addBeginOfSentence = false,
bool addEndOfSentence = false)
{
if (vocabStream is null)
{
throw new ArgumentNullException(nameof(vocabStream));
}
if (mergesStream is null)
{
throw new ArgumentNullException(nameof(mergesStream));
}
return new Phi2Tokenizer(
vocabStream, mergesStream, new TiktokenPreTokenizer(TiktokenTokenizer.P50kBaseRegex(), CodeGenTokenizer.CodeGenAddedTokens), normalizer: null,
CodeGenTokenizer.CodeGenAddedTokens, addPrefixSpace: addPrefixSpace, addBeginningOfSentence: addBeginOfSentence, addEndOfSentence: addEndOfSentence);
}
}
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -202,16 +202,16 @@ namespace Microsoft.ML.Tokenizers
}
}
public int PopulateIdsUpToMax(IList<int> accumulatedIds, int maxTokens, out int textLength)
public int PopulateIdsUpToMax(IList<int> accumulatedIds, int maxTokens, out int charsConsumed)
{
textLength = 0;
charsConsumed = 0;
int count = Math.Min(SymbolsCount, maxTokens);
for (int i = 0; i < count; i++)
{
accumulatedIds.Add(_symbols[i].C);
textLength += _symbols[i].Len;
charsConsumed += _symbols[i].Len;
}
return count;
@ -232,15 +232,15 @@ namespace Microsoft.ML.Tokenizers
return count;
}
public int CountIdsUpToMax(int maxTokens, out int textLength)
public int CountIdsUpToMax(int maxTokens, out int charsConsumed)
{
textLength = 0;
charsConsumed = 0;
int count = Math.Min(SymbolsCount, maxTokens);
for (int i = 0; i < count; i++)
{
textLength += _symbols[i].Len;
charsConsumed += _symbols[i].Len;
}
return count;
@ -289,14 +289,14 @@ namespace Microsoft.ML.Tokenizers
return sb.ToString();
}
public void ToTokens(SortedDictionary<int, string> vocabReverse, List<Token> tokens, int offset)
public void ToTokens(SortedDictionary<int, string> vocabReverse, List<EncodedToken> tokens, int offset)
{
int index = 0;
for (int i = 0; i < SymbolsCount; i++)
{
int endIndex = index + _symbols[i].Len;
tokens.Add(new Token(_symbols[i].C, vocabReverse[_symbols[i].C], (index + offset, _symbols[i].Len)));
tokens.Add(new EncodedToken(_symbols[i].C, vocabReverse[_symbols[i].C], (index + offset, _symbols[i].Len)));
index += _symbols[i].Len;
}
}

Просмотреть файл

@ -18,6 +18,11 @@ namespace Microsoft.ML.Tokenizers
/// </summary>
public LowerCaseNormalizer() { }
/// <summary>
/// Gets a singleton instance of the <see cref="LowerCaseNormalizer"/>.
/// </summary>
public static LowerCaseNormalizer Instance { get; } = new LowerCaseNormalizer();
/// <summary>
/// Lowercase the original string.
/// </summary>

Просмотреть файл

@ -1,9 +1,11 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
namespace Microsoft.ML.Tokenizers
{
@ -17,12 +19,13 @@ namespace Microsoft.ML.Tokenizers
/// <summary>
/// Creates a LowerCaseNormalizer object.
/// </summary>
public SentencePieceNormalizer(bool removeExtraWhiteSpaces, bool addDummyPrefix, bool escapeWhiteSpaces, bool treatWhitespaceAsSuffix)
public SentencePieceNormalizer(bool removeExtraWhiteSpaces, bool addDummyPrefix, bool escapeWhiteSpaces, bool treatWhitespaceAsSuffix, IReadOnlyDictionary<string, int>? specialTokens)
{
RemoveExtraWhiteSpaces = removeExtraWhiteSpaces;
AddDummyPrefix = addDummyPrefix;
EscapeWhiteSpaces = escapeWhiteSpaces;
TreatWhitespaceAsSuffix = treatWhitespaceAsSuffix;
SpecialTokens = specialTokens;
}
/// <summary>
@ -35,9 +38,20 @@ namespace Microsoft.ML.Tokenizers
/// </summary>
public bool AddDummyPrefix { get; }
/// <summary>
/// Indicate escaping white spaces by adding the dummy prefix character U+2581.
/// </summary>
public bool EscapeWhiteSpaces { get; }
public bool TreatWhitespaceAsSuffix { get; }
/// <summary>
/// Indicate treating white space as suffix.
/// </summary>
public bool TreatWhitespaceAsSuffix { get; private set; }
/// <summary>
/// Indicate the added tokens.
/// </summary>
public IReadOnlyDictionary<string, int>? SpecialTokens { get; }
/// <summary>
/// Normalize the original string according to SentencePiece normalization.
@ -87,19 +101,32 @@ namespace Microsoft.ML.Tokenizers
Span<char> span = stackalloc char[512];
char[]? buffer = null;
if (span.Length < length + 1)
int spanLength = AddDummyPrefix ? length + 1 : length;
if (span.Length < spanLength)
{
// Add dummy prefix if needed
buffer = ArrayPool<char>.Shared.Rent(AddDummyPrefix ? length + 1 : length);
buffer = ArrayPool<char>.Shared.Rent(spanLength);
span = buffer;
}
span = span.Slice(0, spanLength);
int bufferIndex = 0;
if (AddDummyPrefix && !TreatWhitespaceAsSuffix)
{
span[bufferIndex++] = EscapeWhiteSpaces ? DummyPrefix : ' ';
if (SpecialTokens is not null)
{
InsertDummyPrefix(original, ref startIndex, endIndex, span, ref bufferIndex);
}
else
{
span[bufferIndex++] = EscapeWhiteSpaces ? DummyPrefix : ' ';
}
}
int originalStart = startIndex;
while (startIndex <= endIndex)
{
char c = original[startIndex++];
@ -123,7 +150,15 @@ namespace Microsoft.ML.Tokenizers
if (AddDummyPrefix && TreatWhitespaceAsSuffix)
{
span[bufferIndex++] = EscapeWhiteSpaces ? DummyPrefix : ' ';
if (SpecialTokens is not null)
{
InsertDummyPrefixAtEnd(span, ref bufferIndex);
}
else
{
// Add dummy prefix if needed
span[bufferIndex++] = EscapeWhiteSpaces ? DummyPrefix : ' ';
}
}
string result = span.Slice(0, bufferIndex).ToString();
@ -134,5 +169,75 @@ namespace Microsoft.ML.Tokenizers
}
return result;
}
private void InsertDummyPrefix(ReadOnlySpan<char> original, ref int startIndex, int endIndex, Span<char> span, ref int bufferIndex)
{
int currentStartIndex;
endIndex++;
do
{
currentStartIndex = startIndex;
foreach (var kvp in SpecialTokens!)
{
var token = kvp.Key;
var tokenLength = token.Length;
if (startIndex + tokenLength <= endIndex && original.Slice(startIndex, tokenLength).SequenceEqual(token.AsSpan()))
{
token.AsSpan().CopyTo(span.Slice(bufferIndex));
bufferIndex += tokenLength;
startIndex += tokenLength;
break;
}
}
} while (currentStartIndex < startIndex);
if (startIndex < endIndex)
{
// prefix should be followed with more characters, otherwise startIndex should be greater endIndex
Debug.Assert(bufferIndex < span.Length - 1);
span[bufferIndex++] = EscapeWhiteSpaces ? DummyPrefix : ' ';
}
}
private void InsertDummyPrefixAtEnd(Span<char> span, ref int bufferIndex)
{
int currentIndex;
int currentBufferIndex = bufferIndex - 1;
if (currentBufferIndex < 0)
{
return;
}
do
{
currentIndex = currentBufferIndex;
foreach (var kvp in SpecialTokens!)
{
var token = kvp.Key;
var tokenLength = token.Length;
if (currentIndex >= tokenLength - 1 && span.Slice(currentIndex - tokenLength + 1, tokenLength).SequenceEqual(token.AsSpan()))
{
currentBufferIndex -= tokenLength;
break;
}
}
} while (currentBufferIndex > 0 && currentBufferIndex < currentIndex);
if (currentBufferIndex > 0)
{
// prefix should be proceeded with more characters, otherwise currentBufferIndex should be 0 or less
Debug.Assert(bufferIndex < span.Length);
int i = bufferIndex;
while (i > currentBufferIndex + 1)
{
span[i] = span[i - 1];
i--;
}
span[currentBufferIndex + 1] = EscapeWhiteSpaces ? DummyPrefix : ' ';
bufferIndex++;
}
}
}
}

Просмотреть файл

@ -18,6 +18,11 @@ namespace Microsoft.ML.Tokenizers
/// </summary>
public UpperCaseNormalizer() { }
/// <summary>
/// Gets a singleton instance of the <see cref="UpperCaseNormalizer"/>.
/// </summary>
public static UpperCaseNormalizer Instance { get; } = new UpperCaseNormalizer();
/// <summary>
/// Uppercase the original string.
/// </summary>

Просмотреть файл

@ -9,6 +9,7 @@ Microsoft.ML.Tokenizers supports various the implementation of the tokenization
* English Roberta model
* Tiktoken model
* Llama model
* Phi2 model
## How to Use
@ -22,18 +23,18 @@ using System.IO;
//
// initialize the tokenizer for `gpt-4` model
Tokenizer tokenizer = Tokenizer.CreateTiktokenForModel("gpt-4");
Tokenizer tokenizer = TiktokenTokenizer.CreateForModel("gpt-4");
string source = "Text tokenization is the process of splitting a string into a list of tokens.";
Console.WriteLine($"Tokens: {tokenizer.CountTokens(source)}");
// print: Tokens: 16
var trimIndex = tokenizer.LastIndexOfTokenCount(source, 5, out string processedText, out _);
var trimIndex = tokenizer.GetIndexByTokenCountFromEnd(source, 5, out string processedText, out _);
Console.WriteLine($"5 tokens from end: {processedText.Substring(trimIndex)}");
// 5 tokens from end: a list of tokens.
trimIndex = tokenizer.IndexOfTokenCount(source, 5, out processedText, out _);
trimIndex = tokenizer.GetIndexByTokenCount(source, 5, out processedText, out _);
Console.WriteLine($"5 tokens from start: {processedText.Substring(0, trimIndex)}");
// 5 tokens from start: Text tokenization is the
@ -51,7 +52,7 @@ const string modelUrl = @"https://huggingface.co/hf-internal-testing/llama-token
using Stream remoteStream = await httpClient.GetStreamAsync(modelUrl);
// Create the Llama tokenizer using the remote stream
Tokenizer llamaTokenizer = Tokenizer.CreateLlama(remoteStream);
Tokenizer llamaTokenizer = LlamaTokenizer.Create(remoteStream);
string input = "Hello, world!";
ids = llamaTokenizer.EncodeToIds(input);
Console.WriteLine(string.Join(", ", ids));
@ -66,9 +67,9 @@ Console.WriteLine($"Tokens: {llamaTokenizer.CountTokens(input)}");
The main types provided by this library are:
* `Microsoft.ML.Tokenizers.Tokenizer`
* `Microsoft.ML.Tokenizers.Bpe`
* `Microsoft.ML.Tokenizers.EnglishRoberta`
* `Microsoft.ML.Tokenizers.Tiktoken`
* `Microsoft.ML.Tokenizers.BpeTokenizer`
* `Microsoft.ML.Tokenizers.EnglishRobertaTokenizer`
* `Microsoft.ML.Tokenizers.TiktokenTokenizer`
* `Microsoft.ML.Tokenizers.Normalizer`
* `Microsoft.ML.Tokenizers.PreTokenizer`

Просмотреть файл

@ -29,7 +29,7 @@ namespace Microsoft.ML.Tokenizers
return [];
}
return SplitText(text, Tiktoken.P50kBaseRegex());
return SplitText(text, TiktokenTokenizer.P50kBaseRegex());
}
/// <summary>
@ -44,7 +44,7 @@ namespace Microsoft.ML.Tokenizers
return [];
}
return SplitText(text, Tiktoken.P50kBaseRegex());
return SplitText(text, TiktokenTokenizer.P50kBaseRegex());
}
}
}

Просмотреть файл

@ -12,12 +12,12 @@ namespace Microsoft.ML.Tokenizers
/// The pre-tokenizer which split the text at the word boundary.
/// The word is a set of alphabet, numeric, and underscore characters.
/// </summary>
public sealed partial class WhiteSpace : PreTokenizer
public sealed partial class WhiteSpacePreTokenizer : PreTokenizer
{
/// <summary>
/// Gets a singleton instance of the WhiteSpace pre-tokenizer..
/// </summary>
public static WhiteSpace Instance { get; } = new WhiteSpace();
public static WhiteSpacePreTokenizer Instance { get; } = new WhiteSpacePreTokenizer();
private const string PretokenizePattern = /*lang=regex*/ @"\w+|[^\w\s]+";
#if NET7_0_OR_GREATER

Просмотреть файл

@ -2,16 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Sentencepiece;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.ML.Tokenizers
{
@ -20,6 +13,11 @@ namespace Microsoft.ML.Tokenizers
/// </summary>
public abstract class Tokenizer
{
/// <summary>
/// Initializes a new instance of the <see cref="Tokenizer"/> class.
/// </summary>
protected Tokenizer() { }
/// <summary>
/// Gets the PreTokenizer used by the Tokenizer.
/// </summary>
@ -31,24 +29,13 @@ namespace Microsoft.ML.Tokenizers
public virtual Normalizer? Normalizer => null;
/// <summary>
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
public virtual IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(text.AsSpan(), out normalizedString, considerPreTokenization, considerNormalization);
/// <summary>
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
public abstract IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true);
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
/// <returns>The encoded results containing the list of encoded Ids.</returns>
protected abstract EncodeResults<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);
/// <summary>
/// Encodes input text to token Ids.
@ -57,7 +44,8 @@ namespace Microsoft.ML.Tokenizers
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public virtual IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true) => EncodeToIds(text.AsSpan(), considerPreTokenization, considerNormalization);
public IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(text, ReadOnlySpan<char>.Empty, new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization }).Tokens;
/// <summary>
/// Encodes input text to token Ids.
@ -66,466 +54,256 @@ namespace Microsoft.ML.Tokenizers
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public abstract IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true);
public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(null, text, new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization }).Tokens;
/// <summary>
/// Encodes input text to token Ids up to maximum number of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// </summary>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedText">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="normalizedText">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is <see langword="false"/>, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to <see langword="null"/>.</param>
/// <param name="charsConsumed">The characters count of the text that encompasses the maximum encoded tokens.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public virtual IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(text.AsSpan(), maxTokenCount, out normalizedText, out textLength, considerPreTokenization, considerNormalization);
/// <summary>
/// Encodes input text to token Ids up to maximum number of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedText">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public abstract IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true);
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
public virtual int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
=> CountTokens(text.AsSpan(), considerPreTokenization, considerNormalization);
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
public abstract int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true);
/// <summary>
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// </returns>
public virtual int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> IndexOfTokenCount(text.AsSpan(), maxTokenCount, out normalizedString, out tokenCount, considerPreTokenization, considerNormalization);
/// <summary>
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// </returns>
public abstract int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true);
/// <summary>
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="processedText">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="processedText"/>; conversely, if all tokens fit, the result will be 0.
/// </returns>
public virtual int LastIndexOfTokenCount(string text, int maxTokenCount, out string? processedText, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> LastIndexOfTokenCount(text.AsSpan(), maxTokenCount, out processedText, out tokenCount, considerPreTokenization, considerNormalization);
/// <summary>
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="processedText">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="processedText"/>; conversely, if all tokens fit, the result will be 0.
/// </returns>
public abstract int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? processedText, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true);
/// <summary>
/// Map the token to encoded Id.
/// </summary>
/// <param name="token">The token to map to the Id.</param>
/// <returns>The mapped Id of the token.</returns>
public virtual int? MapTokenToId(string token)
public IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedText, out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true)
{
if (token is null)
{
throw new ArgumentNullException(nameof(token));
}
EncodeResults<int> result = EncodeToIds(text, ReadOnlySpan<char>.Empty,
new EncodeSettings
{
ConsiderPreTokenization = considerPreTokenization,
ConsiderNormalization = considerNormalization,
MaxTokenCount = maxTokenCount
});
return MapTokenToId(token.AsSpan());
normalizedText = result.NormalizedText;
charsConsumed = result.CharsConsumed;
return result.Tokens;
}
/// <summary>
/// Map the token to encoded Id.
/// Encodes input text to token Ids up to maximum number of tokens.
/// </summary>
/// <param name="token">The token to map to the Id.</param>
/// <returns>The mapped Id of the token.</returns>
public abstract int? MapTokenToId(ReadOnlySpan<char> token);
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedText">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is <see langword="false"/>, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to <see langword="null"/>.</param>
/// <param name="charsConsumed">The characters count of the text that encompasses the maximum encoded tokens.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedText, out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true)
{
EncodeResults<int> result = EncodeToIds(null, text,
new EncodeSettings
{
ConsiderPreTokenization = considerPreTokenization,
ConsiderNormalization = considerNormalization,
MaxTokenCount = maxTokenCount
});
normalizedText = result.NormalizedText;
charsConsumed = result.CharsConsumed;
return result.Tokens;
}
/// <summary>
/// Decodes the Id to the mapped token.
/// Encodes input text to a list of <see cref="EncodedToken" />s.
/// </summary>
/// <param name="id">The id to map to the token.</param>
/// <returns>The decoded string or null if there is no token mapped to the input id.</returns>
public abstract string? MapIdToToken(int id);
/// <param name="text">The text to encode.</param>
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
protected abstract EncodeResults<EncodedToken> EncodeToTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);
/// <summary>
/// Encodes input text to a list of <see cref="EncodedToken" />s.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is <see langword="false"/>, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to <see langword="null"/>.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded <see cref="EncodedToken" />s.</returns>
public IReadOnlyList<EncodedToken> EncodeToTokens(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true)
{
EncodeResults<EncodedToken> result = EncodeToTokens(text, ReadOnlySpan<char>.Empty, new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization });
normalizedString = result.NormalizedText;
return result.Tokens;
}
/// <summary>
/// Encodes input text to a list of <see cref="EncodedToken" />s.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is <see langword="false"/>, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to <see langword="null"/>.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded <see cref="EncodedToken" />s.</returns>
public IReadOnlyList<EncodedToken> EncodeToTokens(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true)
{
EncodeResults<EncodedToken> result = EncodeToTokens(null, text, new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization });
normalizedString = result.NormalizedText;
return result.Tokens;
}
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
protected abstract int CountTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
public int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
=> CountTokens(text, ReadOnlySpan<char>.Empty, new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization });
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
public int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
=> CountTokens(null, text, new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization });
/// <summary>
/// Find the index of the maximum encoding capacity without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
/// <param name="fromEnd">Indicate whether to find the index from the end of the text.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="settings" /> has <see cref="EncodeSettings.ConsiderNormalization"/> is <see langword="false"/>, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to <see langword="null"/>.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// If <paramRef name="fromEnd" /> is <see langword="false"/>, it represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the input text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// If <paramRef name="fromEnd" /> is <see langword="true"/>, it represents the index of the first character to be included. In cases where no tokens fit, the result will be the text length; conversely,
/// if all tokens fit, the result will be zero.
/// </returns>
protected abstract int GetIndexByTokenCount(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings, bool fromEnd, out string? normalizedString, out int tokenCount);
/// <summary>
/// Find the index of the maximum encoding capacity without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is <see langword="false"/>, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to <see langword="null"/>.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the input text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// </returns>
public int GetIndexByTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> GetIndexByTokenCount(
text,
ReadOnlySpan<char>.Empty,
new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization, MaxTokenCount = maxTokenCount },
fromEnd: false,
out normalizedString,
out tokenCount);
/// <summary>
/// Find the index of the maximum encoding capacity without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerPreTokenization" /> is <see langword="false"/>, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to <see langword="null"/>.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the input text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// </returns>
public int GetIndexByTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> GetIndexByTokenCount(
null,
text,
new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization, MaxTokenCount = maxTokenCount },
fromEnd: false,
out normalizedString,
out tokenCount);
/// <summary>
/// Find the index of the maximum encoding capacity without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerPreTokenization" /> is <see langword="false"/>, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to <see langword="null"/>.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index of the first character to be included. In cases where no tokens fit, the result will be the text length; conversely,
/// if all tokens fit, the result will be zero.
/// </returns>
public int GetIndexByTokenCountFromEnd(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> GetIndexByTokenCount(
text,
ReadOnlySpan<char>.Empty,
new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization, MaxTokenCount = maxTokenCount },
fromEnd: true,
out normalizedString,
out tokenCount);
/// <summary>
/// Find the index of the maximum encoding capacity without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerPreTokenization" /> is <see langword="false"/>, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to <see langword="null"/>.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index of the first character to be included. In cases where no tokens fit, the result will be the text length; conversely,
/// if all tokens fit, the result will be zero.
/// </returns>
public int GetIndexByTokenCountFromEnd(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> GetIndexByTokenCount(
null,
text,
new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization, MaxTokenCount = maxTokenCount },
fromEnd: true,
out normalizedString,
out tokenCount);
/// <summary>
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <returns>The decoded string.</returns>
public virtual string? Decode(IEnumerable<int> ids)
{
if (ids is null)
{
throw new ArgumentNullException(nameof(ids));
}
ValueStringBuilder sb = new ValueStringBuilder();
foreach (int id in ids)
{
if (MapIdToToken(id) is string s)
{
sb.Append(s);
}
}
return sb.ToString();
}
//
// Factory Methods
//
public abstract string? Decode(IEnumerable<int> ids);
/// <summary>
/// Create a new Tiktoken tokenizer's object asynchronously.
/// Decode the given ids back to text and store the result in the <paramref name="destination"/> span.
/// </summary>
/// <param name="vocabStream">The stream to the BPE vocab file.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
/// <param name="cacheSize">The size of the cache to use.</param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>The tokenizer's object.</returns>
public static async Task<Tokenizer> CreateTiktokenAsync(
Stream vocabStream,
PreTokenizer? preTokenizer,
Normalizer? normalizer,
IReadOnlyDictionary<string, int>? specialTokens = null,
int cacheSize = LruCache<int[]>.DefaultCacheSize,
CancellationToken cancellationToken = default)
{
if (vocabStream is null)
{
throw new ArgumentNullException(nameof(vocabStream));
}
(Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) =
await Tiktoken.LoadTiktokenBpeAsync(vocabStream, useAsync: true, cancellationToken).ConfigureAwait(false);
return new Tiktoken(encoder, decoder, vocab, preTokenizer, specialTokens, normalizer, cacheSize);
}
/// <summary>
/// Create a new Tiktoken tokenizer's object asynchronously.
/// </summary>
/// <param name="vocabFilePath">The BPE vocab file.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="specialTokensEncoder">The dictionary mapping special tokens to Ids.</param>
/// <param name="cacheSize">The size of the cache to use.</param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>The tokenizer's object.</returns>
public static async Task<Tokenizer> CreateTiktokenAsync(
string vocabFilePath,
PreTokenizer? preTokenizer,
Normalizer? normalizer,
IReadOnlyDictionary<string, int>? specialTokensEncoder = null,
int cacheSize = LruCache<int[]>.DefaultCacheSize,
CancellationToken cancellationToken = default)
{
if (vocabFilePath is null)
{
throw new ArgumentNullException(nameof(vocabFilePath));
}
using Stream vocabStream = File.OpenRead(vocabFilePath);
return await CreateTiktokenAsync(vocabStream, preTokenizer, normalizer, specialTokensEncoder, cacheSize, cancellationToken).ConfigureAwait(false);
}
/// <summary>
/// Create a Tiktoken tokenizer based on model name and vocab file.
/// </summary>
/// <param name="modelName">Model name</param>
/// <param name="vocabStream">The stream to the BPE vocab file.</param>
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the model</param>
/// <param name="cacheSize">The size of the cache to use.</param>
/// <param name="normalizer">To normalize the text before tokenization</param>
/// <returns>The tokenizer</returns>
public static Tokenizer CreateTiktokenForModel(
string modelName,
Stream vocabStream,
IReadOnlyDictionary<string, int>? extraSpecialTokens = null,
int cacheSize = LruCache<int[]>.DefaultCacheSize,
Normalizer? normalizer = null)
{
if (string.IsNullOrEmpty(modelName))
{
throw new ArgumentNullException(nameof(modelName));
}
(Dictionary<string, int> SpecialTokens, Regex Regex, string _) tiktokenConfiguration = Tiktoken.GetTiktokenConfigurations(modelName);
if (extraSpecialTokens is not null)
{
foreach (var extraSpecialToken in extraSpecialTokens)
{
tiktokenConfiguration.SpecialTokens.Add(extraSpecialToken.Key, extraSpecialToken.Value);
}
}
return new Tiktoken(vocabStream,
new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
tiktokenConfiguration.SpecialTokens,
normalizer,
cacheSize);
}
/// <summary>
/// Create a Tiktoken tokenizer based on model name and vocab file.
/// </summary>
/// <param name="modelName">Model name</param>
/// <param name="vocabStream">The stream to the BPE vocab file.</param>
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the model</param>
/// <param name="cacheSize">The size of the cache to use.</param>
/// <param name="normalizer">To normalize the text before tokenization</param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>The tokenizer</returns>
public static async Task<Tokenizer> CreateTiktokenForModelAsync(
string modelName,
Stream vocabStream,
IReadOnlyDictionary<string, int>? extraSpecialTokens = null,
int cacheSize = LruCache<int[]>.DefaultCacheSize,
Normalizer? normalizer = null,
CancellationToken cancellationToken = default)
{
if (string.IsNullOrEmpty(modelName))
{
throw new ArgumentNullException(nameof(modelName));
}
(Dictionary<string, int> SpecialTokens, Regex Regex, string _) tiktokenConfiguration = Tiktoken.GetTiktokenConfigurations(modelName);
if (extraSpecialTokens is not null)
{
foreach (var extraSpecialToken in extraSpecialTokens)
{
tiktokenConfiguration.SpecialTokens.Add(extraSpecialToken.Key, extraSpecialToken.Value);
}
}
return await CreateTiktokenAsync(vocabStream,
new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
normalizer,
tiktokenConfiguration.SpecialTokens,
cacheSize, cancellationToken).ConfigureAwait(false);
}
/// <summary>
/// Create tokenizer based on model name
/// </summary>
/// <param name="modelName">Model name</param>
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the model</param>
/// <param name="normalizer">To normalize the text before tokenization</param>
/// <returns>The tokenizer</returns>
public static Tokenizer CreateTiktokenForModel(string modelName, IReadOnlyDictionary<string, int>? extraSpecialTokens = null, Normalizer? normalizer = null)
=> Tiktoken.CreateForModel(Tiktoken.GetModelEncoding(modelName), modelName, extraSpecialTokens, normalizer);
/// <summary>
/// Create tokenizer based on encoding name
/// </summary>
/// <param name="encodingName">Encoding name</param>
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the encoding</param>
/// <param name="normalizer">To normalize the text before tokenization</param>
/// <returns>The tokenizer</returns>
public static Tokenizer CreateTiktokenForEncoding(string encodingName, IReadOnlyDictionary<string, int>? extraSpecialTokens = null, Normalizer? normalizer = null)
{
if (string.IsNullOrEmpty(encodingName))
{
throw new ArgumentNullException(nameof(encodingName));
}
Tiktoken.ModelEncoding modelEncoding;
if (encodingName.Equals(Tiktoken.Cl100kBaseEncodingName, StringComparison.OrdinalIgnoreCase))
{
modelEncoding = Tiktoken.ModelEncoding.Cl100kBase;
}
else if (encodingName.Equals(Tiktoken.O200kBaseEncodingName, StringComparison.OrdinalIgnoreCase))
{
modelEncoding = Tiktoken.ModelEncoding.O200kBase;
}
else if (encodingName.Equals(Tiktoken.P50kBaseEncodingName, StringComparison.OrdinalIgnoreCase))
{
modelEncoding = Tiktoken.ModelEncoding.P50kBase;
}
else if (encodingName.Equals(Tiktoken.P50kEditEncodingName, StringComparison.OrdinalIgnoreCase))
{
modelEncoding = Tiktoken.ModelEncoding.P50kEdit;
}
else if (encodingName.Equals(Tiktoken.R50kBaseEncodingName, StringComparison.OrdinalIgnoreCase))
{
modelEncoding = Tiktoken.ModelEncoding.R50kBase;
}
else
{
throw new ArgumentException($"The encoding name '{encodingName}' is not supported. The only supported encoding names are: {Tiktoken.Cl100kBaseEncodingName}, {Tiktoken.P50kBaseEncodingName}, {Tiktoken.P50kEditEncodingName}, and {Tiktoken.R50kBaseEncodingName}.", nameof(encodingName));
}
return Tiktoken.CreateForModel(modelEncoding, modelName: null, extraSpecialTokens, normalizer);
}
/// <summary>
/// Create a SentencePieceBpe tokenizer from the given model stream. The model stream should contain the SentencePiece Bpe model according to
/// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto specification.
/// </summary>
/// <param name="modelStream">The stream containing the SentencePiece Bpe model.</param>
/// <param name="addBeginOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
public static Tokenizer CreateLlama(
Stream modelStream,
bool addBeginOfSentence = true,
bool addEndOfSentence = false)
{
ModelProto modelProto = ModelProto.Parser.ParseFrom(modelStream);
if (modelProto is null)
{
throw new ArgumentNullException(nameof(modelProto));
}
if (modelProto.TrainerSpec.ModelType != TrainerSpec.Types.ModelType.Bpe)
{
throw new ArgumentException("The model type is not Bpe.", nameof(modelProto));
}
if (modelProto.NormalizerSpec.Name != "identity" && !string.IsNullOrEmpty(modelProto.NormalizerSpec.Name))
{
throw new ArgumentException($"Normalization '{modelProto.NormalizerSpec.Name}' is not supported.", nameof(modelProto));
}
SentencePieceNormalizer normalizer = new(
modelProto.NormalizerSpec.RemoveExtraWhitespaces,
modelProto.NormalizerSpec.AddDummyPrefix,
modelProto.NormalizerSpec.EscapeWhitespaces,
modelProto.TrainerSpec.TreatWhitespaceAsSuffix);
return new SentencePieceBpe(modelProto, addBeginOfSentence, addEndOfSentence);
}
/// <summary>
/// Create a CodeGen tokenizer from the given vocab and merges streams.
/// </summary>
/// <param name="vocabStream">The stream containing the vocab file.</param>
/// <param name="mergesStream">The stream containing the merges file.</param>
/// <param name="addPrefixSpace">Indicate whether to add a space before the token.</param>
/// <param name="addBeginOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <returns>The CodeGen tokenizer object.</returns>
/// <remarks>
/// The tokenizer will be created according to the configuration specified in https://huggingface.co/Salesforce/codegen-350M-mono/raw/main/tokenizer.json.
/// It is important to provide the similar vocab and merges files to the ones used in the training of the model.
/// The vocab and merges files can be downloaded from the following links:
/// https://huggingface.co/Salesforce/codegen-350M-mono/resolve/main/vocab.json?download=true
/// https://huggingface.co/Salesforce/codegen-350M-mono/resolve/main/merges.txt?download=true
/// </remarks>
public static Tokenizer CreateCodeGen(
Stream vocabStream,
Stream mergesStream,
bool addPrefixSpace = false,
bool addBeginOfSentence = false,
bool addEndOfSentence = false)
{
if (vocabStream is null)
{
throw new ArgumentNullException(nameof(vocabStream));
}
if (mergesStream is null)
{
throw new ArgumentNullException(nameof(mergesStream));
}
return new CodeGen(
vocabStream,
mergesStream,
new TiktokenPreTokenizer(Tiktoken.P50kBaseRegex(), CodeGen.CodeGenAddedTokens),
normalizer: null,
CodeGen.CodeGenAddedTokens,
addPrefixSpace: addPrefixSpace,
addBeginningOfSentence: addBeginOfSentence,
addEndOfSentence: addEndOfSentence);
}
/// <summary>
/// Create a CodeGen Phi2 tokenizer from the given vocab and merges streams.
/// </summary>
/// <param name="vocabStream">The stream containing the vocab file.</param>
/// <param name="mergesStream">The stream containing the merges file.</param>
/// <param name="addPrefixSpace">Indicate whether to add a space before the token.</param>
/// <param name="addBeginOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <returns>The CodeGen tokenizer object.</returns>
/// <remarks>
/// The tokenizer will be created according to the configuration specified in https://huggingface.co/microsoft/phi-2/raw/main/tokenizer.json.
/// It is important to provide the similar vocab and merges files to the ones used in the training of the model.
/// The vocab and merges files can be downloaded from the following links:
/// https://huggingface.co/microsoft/phi-2/resolve/main/vocab.json?download=true
/// https://huggingface.co/microsoft/phi-2/resolve/main/merges.txt?download=true
/// </remarks>
public static Tokenizer CreatePhi2(
Stream vocabStream,
Stream mergesStream,
bool addPrefixSpace = false,
bool addBeginOfSentence = false,
bool addEndOfSentence = false)
=> CreateCodeGen(vocabStream, mergesStream, addPrefixSpace, addBeginOfSentence, addEndOfSentence);
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="destination">The span to store the decoded text.</param>
/// <param name="idsConsumed">The number of ids consumed during the decoding.</param>
/// <param name="charsWritten">The number of characters written to the destination span.</param>
/// <returns>The operation status indicates whether all IDs were successfully decoded or if the <paramref name="destination"/> is too small to contain the entire decoded result.</returns>
public abstract OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, out int idsConsumed, out int charsWritten);
internal static IEnumerable<(int Offset, int Length)>? InitializeForEncoding(
string? text,
@ -535,7 +313,8 @@ namespace Microsoft.ML.Tokenizers
Normalizer? normalizer,
PreTokenizer? preTokenizer,
out string? normalizedString,
out ReadOnlySpan<char> textSpanToEncode)
out ReadOnlySpan<char> textSpanToEncode,
out int fullTextLength)
{
normalizedString = null;
IEnumerable<(int Offset, int Length)>? splits = null;
@ -546,6 +325,7 @@ namespace Microsoft.ML.Tokenizers
{
normalizedString = normalizer.Normalize(textSpan.ToString());
textSpanToEncode = normalizedString.AsSpan();
fullTextLength = normalizedString.Length;
if (considerPreTokenization && preTokenizer is not null)
{
splits = preTokenizer.PreTokenize(normalizedString);
@ -554,6 +334,7 @@ namespace Microsoft.ML.Tokenizers
else
{
textSpanToEncode = textSpan;
fullTextLength = textSpan.Length;
if (considerPreTokenization && preTokenizer is not null)
{
splits = preTokenizer.PreTokenize(textSpan);
@ -566,6 +347,7 @@ namespace Microsoft.ML.Tokenizers
{
normalizedString = normalizer.Normalize(text);
textSpanToEncode = normalizedString.AsSpan();
fullTextLength = normalizedString.Length;
if (considerPreTokenization && preTokenizer is not null)
{
splits = preTokenizer.PreTokenize(normalizedString);
@ -574,6 +356,7 @@ namespace Microsoft.ML.Tokenizers
else
{
textSpanToEncode = text.AsSpan();
fullTextLength = text.Length;
if (considerPreTokenization && preTokenizer is not null)
{
splits = preTokenizer.PreTokenize(text);

Просмотреть файл

@ -127,5 +127,88 @@ namespace Microsoft.ML.Tokenizers
return targetIndex;
}
public static bool ConvertUtf8ToUtf16(ReadOnlySpan<byte> utf8Bytes, Span<char> utf16Chars, out int bytesConsumed, out int charsWritten)
{
Debug.Assert(utf16Chars.Length >= Encoding.UTF8.GetMaxCharCount(utf8Bytes.Length));
int byteIndex = 0;
int charIndex = 0;
bytesConsumed = 0;
charsWritten = 0;
while (byteIndex < utf8Bytes.Length)
{
uint codePoint;
int additionalBytes;
byte firstByte = utf8Bytes[byteIndex];
if ((firstByte & 0x80) == 0)
{
// 1-byte sequence (ASCII)
codePoint = firstByte;
utf16Chars[charIndex++] = (char)firstByte;
charsWritten++;
bytesConsumed = ++byteIndex;
continue;
}
else if ((firstByte & 0xE0) == 0xC0)
{
// 2-byte sequence
codePoint = (uint)(firstByte & 0x1F);
additionalBytes = 1;
}
else if ((firstByte & 0xF0) == 0xE0)
{
// 3-byte sequence
codePoint = (uint)(firstByte & 0x0F);
additionalBytes = 2;
}
else if ((firstByte & 0xF8) == 0xF0)
{
// 4-byte sequence
codePoint = (uint)(firstByte & 0x07);
additionalBytes = 3;
}
else
{
return false;
}
if (byteIndex + additionalBytes >= utf8Bytes.Length)
{
return true; // incomplete utf-8 sequence
}
for (int i = 1; i <= additionalBytes; i++)
{
byte nextByte = utf8Bytes[byteIndex + i];
if ((nextByte & 0xC0) != 0x80)
{
return false;
}
codePoint = (codePoint << 6) | (uint)(nextByte & 0x3F);
}
byteIndex += additionalBytes + 1;
bytesConsumed = byteIndex;
if (codePoint <= 0xFFFF)
{
utf16Chars[charIndex++] = (char)codePoint;
}
else
{
codePoint -= 0x10000;
utf16Chars[charIndex++] = (char)((codePoint >> 10) + 0xD800);
utf16Chars[charIndex++] = (char)((codePoint & 0x3FF) + 0xDC00);
}
charsWritten = charIndex;
}
return true;
}
}
}

Просмотреть файл

@ -77,7 +77,7 @@ namespace Microsoft.ML.Tokenizers
private object SyncObj => _map;
internal StringSpanOrdinalKeyCache() : this(Bpe.DefaultCacheCapacity) { }
internal StringSpanOrdinalKeyCache() : this(BpeTokenizer.DefaultCacheCapacity) { }
internal StringSpanOrdinalKeyCache(int capacity)
{

Просмотреть файл

@ -161,7 +161,7 @@ namespace System.Text
while (index <= _pos - oldLength)
{
ReadOnlySpan<char> buffer = _chars.Slice(index);
ReadOnlySpan<char> buffer = _chars.Slice(index, _pos - index);
int subIndex = buffer.IndexOf(oldValue.AsSpan(), StringComparison.Ordinal);
if (subIndex < 0)
{
@ -175,7 +175,8 @@ namespace System.Text
newValue.AsSpan().CopyTo(_chars.Slice(index));
if (oldLength > newLength)
{
_chars.Slice(index + oldLength).CopyTo(_chars.Slice(index + newLength));
int newIndex = index + oldLength;
_chars.Slice(newIndex, _pos - newIndex).CopyTo(_chars.Slice(index + newLength));
_pos -= oldLength - newLength;
}
}
@ -183,7 +184,8 @@ namespace System.Text
{
Insert(index, newValue);
_chars.Slice(index + newLength + oldLength).CopyTo(_chars.Slice(index + newLength));
int newIndex = index + newLength + oldLength;
_chars.Slice(newIndex, _pos - newIndex).CopyTo(_chars.Slice(index + newLength));
_pos -= oldLength;
}
@ -202,6 +204,16 @@ namespace System.Text
return false;
}
public void Remove(int start, int length)
{
if (length > 0 && start + length <= _pos)
{
int remaining = _pos - start - length;
_chars.Slice(start + length, remaining).CopyTo(_chars.Slice(start));
_pos -= length;
}
}
public bool EndsWith(string value)
{
int valueLength = value.Length;

Просмотреть файл

@ -26,21 +26,21 @@ namespace Microsoft.ML.TorchSharp.Extensions
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt"
Assembly assembly = typeof(TokenizerExtensions).Assembly;
_instance = new EnglishRoberta(
_instance = EnglishRobertaTokenizer.Create(
assembly.GetManifestResourceStream("encoder.json"),
assembly.GetManifestResourceStream("vocab.bpe"),
assembly.GetManifestResourceStream("dict.txt"),
RobertaPreTokenizer.Instance);
(_instance as EnglishRoberta).AddMaskSymbol();
(_instance as EnglishRobertaTokenizer).AddMaskSymbol();
}
return _instance;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static EnglishRoberta RobertaModel(this Tokenizer tokenizer)
internal static EnglishRobertaTokenizer RobertaModel(this Tokenizer tokenizer)
{
EnglishRoberta model = tokenizer as EnglishRoberta;
EnglishRobertaTokenizer model = tokenizer as EnglishRobertaTokenizer;
if (model is null)
{
throw new InvalidOperationException($"The input tokenizer is not using the EnglishRoberta model.");

Просмотреть файл

@ -201,7 +201,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
private protected override Module CreateModule(IChannel ch, IDataView input)
{
Tokenizer = TokenizerExtensions.GetInstance(ch);
EnglishRoberta tokenizerModel = Tokenizer.RobertaModel();
EnglishRobertaTokenizer tokenizerModel = Tokenizer.RobertaModel();
NasBertModel model;
if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)

Просмотреть файл

@ -167,7 +167,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
Sentence1Getter(ref sentenceRom);
var sentence = sentenceRom.ToString();
Tensor t;
IReadOnlyList<Token> encoding = Tokenizer.Encode(sentence, out string normalizedString);
IReadOnlyList<EncodedToken> encoding = Tokenizer.EncodeToTokens(sentence, out string normalizedString);
if (target.Length != encoding.Count)
{
@ -327,7 +327,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
var ch = env.Start("Load Model");
var tokenizer = TokenizerExtensions.GetInstance(ch);
EnglishRoberta tokenizerModel = tokenizer.RobertaModel();
EnglishRobertaTokenizer tokenizerModel = tokenizer.RobertaModel();
var model = new NerModel(options, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, options.NumberOfClasses);
if (!ctx.TryLoadBinaryStream("TSModel", r => model.load(r)))
@ -377,7 +377,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
private void CondenseOutput(ref VBuffer<UInt32> dst, string sentence, Tokenizer tokenizer, TensorCacher outputCacher)
{
var pre = tokenizer.PreTokenizer.PreTokenize(sentence);
IReadOnlyList<Token> encoding = tokenizer.Encode(sentence, out string normalizedString);
IReadOnlyList<EncodedToken> encoding = tokenizer.EncodeToTokens(sentence, out string normalizedString);
var argmax = (outputCacher as BertTensorCacher).Result.argmax(-1);
var prediction = argmax.ToArray<long>();

Просмотреть файл

@ -239,7 +239,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
var ch = env.Start("Load Model");
var tokenizer = TokenizerExtensions.GetInstance(ch);
EnglishRoberta tokenizerModel = tokenizer.RobertaModel();
EnglishRobertaTokenizer tokenizerModel = tokenizer.RobertaModel();
var model = new ModelForPrediction(options, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, options.NumberOfClasses);
if (!ctx.TryLoadBinaryStream("TSModel", r => model.load(r)))

Просмотреть файл

@ -263,7 +263,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
var ch = env.Start("Load Model");
var tokenizer = TokenizerExtensions.GetInstance(ch);
EnglishRoberta tokenizerModel = tokenizer.RobertaModel();
EnglishRobertaTokenizer tokenizerModel = tokenizer.RobertaModel();
var model = new ModelForPrediction(options, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, options.NumberOfClasses);
if (!ctx.TryLoadBinaryStream("TSModel", r => model.load(r)))

Просмотреть файл

@ -401,7 +401,7 @@ namespace Microsoft.ML.TorchSharp.Roberta
answerIndexGetter(ref answerIndex);
var contextString = context.ToString();
var contextTokens = Tokenizer.Encode(contextString, out string normalized);
var contextTokens = Tokenizer.EncodeToTokens(contextString, out string normalized);
var contextToken = contextTokens.Select(t => t.Value).ToArray();
var contextTokenId = Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(contextTokens.Select(t => t.Id).ToArray());
@ -437,7 +437,7 @@ namespace Microsoft.ML.TorchSharp.Roberta
private Dictionary<int, int> AlignAnswerPosition(IReadOnlyList<string> tokens, string text)
{
EnglishRoberta robertaModel = Tokenizer as EnglishRoberta;
EnglishRobertaTokenizer robertaModel = Tokenizer as EnglishRobertaTokenizer;
Debug.Assert(robertaModel is not null);
var mapping = new Dictionary<int, int>();

Просмотреть файл

@ -3,10 +3,13 @@
// See the LICENSE file in the project root for more information.
using System;
using System.Buffers;
using System.Buffers.Binary;
using System.Collections.Generic;
using System.Net.Http;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Net;
using System.Text;
using System.Text.Json;
using Xunit;
@ -247,9 +250,10 @@ namespace Microsoft.ML.Tokenizers.Tests
try
{
Bpe bpe = new Bpe(vocabFile, mergesFile, unknownToken: unknownToken, continuingSubwordPrefix: continuingSubwordPrefix, endOfWordSuffix: endOfWordSuffix, fuseUnknownTokens: fuseUnknownToken);
BpeTokenizer bpe = BpeTokenizer.Create(vocabFile: vocabFile, mergesFile: mergesFile, preTokenizer: WhiteSpacePreTokenizer.Instance, normalizer: null, unknownToken: unknownToken,
continuingSubwordPrefix: continuingSubwordPrefix, endOfWordSuffix: endOfWordSuffix, fuseUnknownTokens: fuseUnknownToken);
Tokenizer tokenizer = bpe;
IReadOnlyList<Token> encoding = tokenizer.Encode(sentence, out _);
IReadOnlyList<EncodedToken> encoding = tokenizer.EncodeToTokens(sentence, out _);
int[] encodingIds = encoding.Select(t => t.Id).ToArray();
IReadOnlyList<int> idsList = tokenizer.EncodeToIds(sentence);
@ -261,14 +265,19 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(decodedTokens, tokenizer.Decode(encodingIds));
Assert.Equal(decodedTokensWithoutUnknownToken, bpe.Decode(encodingIds, considerSpecialTokens: false));
TestDecodingWithSpan(bpe, encodingIds, considerSpecialTokens: true, decodedTokens);
TestDecodingWithSpan(bpe, encodingIds, considerSpecialTokens: false, decodedTokensWithoutUnknownToken);
var reverseVocabulary = bpe.Vocabulary.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
for (int i = 0; i < encoding.Count; i++)
{
Assert.Equal(expectedTokens[i], encoding[i].Value);
Assert.Equal(offsets[i], encoding[i].Offset);
Assert.Equal(ids[i], encoding[i].Id);
Assert.Equal(ids[i], idsList[i]);
Assert.Equal(encoding[i].Value, tokenizer.MapIdToToken(encodingIds[i]));
Assert.Equal(encodingIds[i], tokenizer.MapTokenToId(encoding[i].Value.AsSpan()));
Assert.Equal(encoding[i].Value, reverseVocabulary[encodingIds[i]]);
Assert.Equal(encodingIds[i], bpe.Vocabulary[encoding[i].Value]);
}
}
finally
@ -281,6 +290,35 @@ namespace Microsoft.ML.Tokenizers.Tests
}
}
private void TestDecodingWithSpan(BpeTokenizer bpe, int[] ids, bool considerSpecialTokens, string expectedDecoded)
{
char[] destinationBuffer = new char[expectedDecoded.Length];
OperationStatus status;
int lastIdsConsumed = 0;
int lastCharactersWritten = 0;
int idsConsumed;
int charactersWritten;
for (int i = 1; i < destinationBuffer.Length - 1; i += Math.Max(1, destinationBuffer.Length - 3)) // enough to test length 1, and destinationBuffer.Length - 2 only.
{
status = bpe.Decode(ids, destinationBuffer.AsSpan().Slice(0, i), considerSpecialTokens, out idsConsumed, out charactersWritten);
Assert.Equal(OperationStatus.DestinationTooSmall, status);
Assert.True(idsConsumed < ids.Length);
Assert.True(idsConsumed >= lastIdsConsumed);
Assert.True(charactersWritten < expectedDecoded.Length);
Assert.True(charactersWritten >= lastCharactersWritten);
lastIdsConsumed = idsConsumed;
lastCharactersWritten = charactersWritten;
}
status = bpe.Decode(ids, destinationBuffer.AsSpan(), considerSpecialTokens, out idsConsumed, out charactersWritten);
Assert.Equal(OperationStatus.Done, status);
Assert.Equal(ids.Length, idsConsumed);
Assert.Equal(expectedDecoded.Length, charactersWritten);
Assert.Equal(expectedDecoded, destinationBuffer.AsSpan().ToString());
}
private static Tokenizer? _gpt2Tokenizer = null;
private static Tokenizer GetGpt2Tokenizer()
@ -292,20 +330,49 @@ namespace Microsoft.ML.Tokenizers.Tests
using Stream vocabStream = File.OpenRead(Path.Combine(@"Gpt-2", "vocab.json"));
using Stream mergesStream = File.OpenRead(Path.Combine(@"Gpt-2", "merges.txt"));
_gpt2Tokenizer = new Bpe(vocabStream, mergesStream);
_gpt2Tokenizer = BpeTokenizer.Create(vocabStream, mergesStream);
}
return _gpt2Tokenizer;
}
[Fact]
public async void TestBpeCreation()
{
// "https://huggingface.co/openai-community/gpt2/raw/main/vocab.json";
// "https://huggingface.co/openai-community/gpt2/raw/main/merges.txt";
string vocabFile = Path.Combine(@"Gpt-2", "vocab.json");
string mergesFile = Path.Combine(@"Gpt-2", "merges.txt");
BpeTokenizer bpe = BpeTokenizer.Create(vocabFile, mergesFile);
ValidateTokenizer(bpe);
using Stream vocabStream = File.OpenRead(vocabFile);
using Stream mergesStream = File.OpenRead(mergesFile);
bpe = BpeTokenizer.Create(vocabStream, mergesStream);
ValidateTokenizer(bpe);
// Reset the streams for reusing and ensuring the stream are not disposed too.
vocabStream.Position = 0;
mergesStream.Position = 0;
bpe = await BpeTokenizer.CreateAsync(vocabStream, mergesStream);
ValidateTokenizer(bpe);
}
[Fact]
public void TestGpt2Vocab()
{
Tokenizer tokenizer = GetGpt2Tokenizer();
ValidateTokenizer(tokenizer);
}
private void ValidateTokenizer(Tokenizer tokenizer)
{
string text = "The quick brown fox jumps over the lazy dog!";
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
IReadOnlyList<EncodedToken> encoding = tokenizer.EncodeToTokens(text, out _);
IReadOnlyList<int> ids = tokenizer.EncodeToIds(text);
Assert.Equal(12, encoding.Count);
@ -358,8 +425,8 @@ namespace Microsoft.ML.Tokenizers.Tests
{
Tokenizer tokenizer = GetGpt2Tokenizer();
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
IReadOnlyList<Token> encoding1 = tokenizer.Encode(text.AsSpan(), out _);
IReadOnlyList<EncodedToken> encoding = tokenizer.EncodeToTokens(text, out _);
IReadOnlyList<EncodedToken> encoding1 = tokenizer.EncodeToTokens(text.AsSpan(), out _);
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray());
@ -389,17 +456,17 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text));
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text.AsSpan()));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text, expectedIds.Length - 3, out normalizedString, out int tokenCount));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.GetIndexByTokenCount(text, expectedIds.Length - 3, out normalizedString, out int tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIds.Length - 3, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text.AsSpan(), expectedIds.Length - 3, out normalizedString, out tokenCount));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.GetIndexByTokenCount(text.AsSpan(), expectedIds.Length - 3, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIds.Length - 3, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text, 3, out normalizedString, out tokenCount));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.GetIndexByTokenCountFromEnd(text, 3, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(3, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text.AsSpan(), 3, out normalizedString, out tokenCount));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.GetIndexByTokenCountFromEnd(text.AsSpan(), 3, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(3, tokenCount);
}
@ -423,7 +490,7 @@ namespace Microsoft.ML.Tokenizers.Tests
return fileName;
}
internal static Bpe CreateEmptyBpe(PreTokenizer? preTokenizer = null, Normalizer? normalizer = null)
internal static BpeTokenizer CreateEmptyBpe(PreTokenizer? preTokenizer = null, Normalizer? normalizer = null)
{
using MemoryStream emptyVocabStream = new MemoryStream();
using StreamWriter writer = new StreamWriter(emptyVocabStream);
@ -431,7 +498,8 @@ namespace Microsoft.ML.Tokenizers.Tests
writer.Flush();
emptyVocabStream.Position = 0;
return new Bpe(vocabStream: emptyVocabStream, mergesStream: null, preTokenizer: preTokenizer ?? WhiteSpace.Instance, normalizer: normalizer, unknownToken: "Ukn");
return BpeTokenizer.Create(
vocabStream: emptyVocabStream, mergesStream: null, preTokenizer: preTokenizer ?? WhiteSpacePreTokenizer.Instance, normalizer: normalizer, unknownToken: "Ukn");
}
}
}

Просмотреть файл

@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Linq;
@ -27,7 +28,7 @@ namespace Microsoft.ML.Tokenizers.Tests
using Stream vocabStream = File.OpenRead(Path.Combine(@"Codegen-350M-mono", "vocab.json"));
using Stream mergesStream = File.OpenRead(Path.Combine(@"Codegen-350M-mono", "merges.txt"));
return Tokenizer.CreateCodeGen(vocabStream, mergesStream, addPrefixSpace, bos, eos);
return CodeGenTokenizer.Create(vocabStream, mergesStream, addPrefixSpace, bos, eos);
}
private static Tokenizer CreateCodegenPhi2Tokenizer()
@ -38,7 +39,7 @@ namespace Microsoft.ML.Tokenizers.Tests
using Stream vocabStream = File.OpenRead(Path.Combine(@"Phi-2", "vocab.json"));
using Stream mergesStream = File.OpenRead(Path.Combine(@"Phi-2", "merges.txt"));
return Tokenizer.CreateCodeGen(vocabStream, mergesStream);
return CodeGenTokenizer.Create(vocabStream, mergesStream);
}
public static IEnumerable<object?[]> CodeGenTestData
@ -227,7 +228,7 @@ namespace Microsoft.ML.Tokenizers.Tests
TestDecoding(phi2Tokenizer, text);
}
private void ValidateEncoding(IReadOnlyList<Token> encoding, bool addPrefixSpace, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds,
private void ValidateEncoding(IReadOnlyList<EncodedToken> encoding, bool addPrefixSpace, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds,
string[] expectedTokensWithSpace, (int Index, int Length)[] expectedOffsetsWithSpace, int[] expectedIdsWithSpace)
{
if (addPrefixSpace)
@ -246,38 +247,85 @@ namespace Microsoft.ML.Tokenizers.Tests
private void TestDecoding(Tokenizer tokenizer, string text)
{
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
Assert.Equal(text, tokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
CodeGenTokenizer codeGenTokenizer = (tokenizer as CodeGenTokenizer)!;
encoding = tokenizer.Encode(text.AsSpan(), out _);
Assert.Equal(text, tokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
IReadOnlyList<EncodedToken> encoding = tokenizer.EncodeToTokens(text, out _);
int[] ids = encoding.Select(t => t.Id).ToArray();
Assert.Equal(text, tokenizer.Decode(ids));
encoding = tokenizer.EncodeToTokens(text.AsSpan(), out _);
ids = encoding.Select(t => t.Id).ToArray();
Assert.Equal(text, tokenizer.Decode(ids));
TestDecodingWithSpan(codeGenTokenizer, ids, codeGenTokenizer.AddPrefixSpace, considerSpecialTokens: false, text);
CodeGen codeGenTokenizer = (tokenizer as CodeGen)!;
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: true, addEndOfSentence: false, out _);
ids = encoding.Select(t => t.Id).ToArray();
Assert.Equal(text, codeGenTokenizer.Decode(ids));
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: true, addEndOfSentence: false, out _);
ids = encoding.Select(t => t.Id).ToArray();
Assert.Equal(text, codeGenTokenizer.Decode(ids));
TestDecodingWithSpan(codeGenTokenizer, ids, codeGenTokenizer.AddPrefixSpace, considerSpecialTokens: false, text);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: true, addEndOfSentence: false, out _);
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: true, addEndOfSentence: false, out _);
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: false, addEndOfSentence: true, out _);
ids = encoding.Select(t => t.Id).ToArray();
Assert.Equal(text, codeGenTokenizer.Decode(ids));
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: false, addEndOfSentence: true, out _);
ids = encoding.Select(t => t.Id).ToArray();
Assert.Equal(text, codeGenTokenizer.Decode(ids));
TestDecodingWithSpan(codeGenTokenizer, ids, codeGenTokenizer.AddPrefixSpace, considerSpecialTokens: false, text);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: false, addEndOfSentence: true, out _);
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: false, addEndOfSentence: true, out _);
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: true, addEndOfSentence: true, out _);
ids = encoding.Select(t => t.Id).ToArray();
Assert.Equal(text, codeGenTokenizer.Decode(ids));
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: true, addEndOfSentence: true, out _);
ids = encoding.Select(t => t.Id).ToArray();
Assert.Equal(text, codeGenTokenizer.Decode(ids));
TestDecodingWithSpan(codeGenTokenizer, ids, codeGenTokenizer.AddPrefixSpace, considerSpecialTokens: false, text);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
ids = encoding.Select(t => t.Id).ToArray();
Assert.Equal(text, codeGenTokenizer.Decode(ids, hasPrefixSpace: true, considerSpecialTokens: false));
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
ids = encoding.Select(t => t.Id).ToArray();
Assert.Equal(text, codeGenTokenizer.Decode(ids, hasPrefixSpace: true, considerSpecialTokens: false));
TestDecodingWithSpan(codeGenTokenizer, ids, hasPrefixSpace: true, considerSpecialTokens: false, text);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray(), hasPrefixSpace: true, considerSpecialTokens: false));
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray(), hasPrefixSpace: true, considerSpecialTokens: false));
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
ids = encoding.Select(t => t.Id).ToArray();
string targetText = $"{codeGenTokenizer.BeginningOfSentenceToken}{text}{codeGenTokenizer.EndOfSentenceToken}";
Assert.Equal(targetText, codeGenTokenizer.Decode(ids, hasPrefixSpace: true, considerSpecialTokens: true));
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
ids = encoding.Select(t => t.Id).ToArray();
Assert.Equal(targetText, codeGenTokenizer.Decode(ids, hasPrefixSpace: true, considerSpecialTokens: true));
TestDecodingWithSpan(codeGenTokenizer, ids, hasPrefixSpace: true, considerSpecialTokens: true, targetText);
}
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal($"{codeGenTokenizer.BeginningOfSentenceToken}{text}{codeGenTokenizer.EndOfSentenceToken}", codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray(), hasPrefixSpace: true, considerSpecialTokens: true));
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal($"{codeGenTokenizer.BeginningOfSentenceToken}{text}{codeGenTokenizer.EndOfSentenceToken}", codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray(), hasPrefixSpace: true, considerSpecialTokens: true));
private void TestDecodingWithSpan(CodeGenTokenizer tokenizer, int[] ids, bool hasPrefixSpace, bool considerSpecialTokens, string expectedDecoded)
{
char[] destinationBuffer = new char[expectedDecoded.Length];
OperationStatus status;
int lastIdsConsumed = 0;
int lastCharactersWritten = 0;
int idsConsumed;
int charactersWritten;
for (int i = 1; i < destinationBuffer.Length - 1; i += Math.Max(1, destinationBuffer.Length - 3)) // enough to test length 1, and destinationBuffer.Length - 2 only.
{
status = tokenizer.Decode(ids, destinationBuffer.AsSpan().Slice(0, i), hasPrefixSpace, considerSpecialTokens, out idsConsumed, out charactersWritten);
Assert.Equal(OperationStatus.DestinationTooSmall, status);
Assert.True(idsConsumed < ids.Length);
Assert.True(idsConsumed >= lastIdsConsumed);
Assert.True(charactersWritten < expectedDecoded.Length);
Assert.True(charactersWritten >= lastCharactersWritten);
lastIdsConsumed = idsConsumed;
lastCharactersWritten = charactersWritten;
}
status = tokenizer.Decode(ids, destinationBuffer.AsSpan(), hasPrefixSpace, considerSpecialTokens, out idsConsumed, out charactersWritten);
Assert.Equal(OperationStatus.Done, status);
Assert.Equal(ids.Length, idsConsumed);
Assert.Equal(expectedDecoded.Length, charactersWritten);
Assert.Equal(expectedDecoded, destinationBuffer.AsSpan().ToString());
}
private void TestTokenizer(
@ -290,28 +338,28 @@ namespace Microsoft.ML.Tokenizers.Tests
(int Index, int Length)[] expectedOffsetsWithSpace,
int[] expectedIdsWithSpace)
{
CodeGen codeGenTokenizer = (tokenizer as CodeGen)!;
CodeGenTokenizer codeGenTokenizer = (tokenizer as CodeGenTokenizer)!;
//
// Full Encoding
//
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
IReadOnlyList<EncodedToken> encoding = tokenizer.EncodeToTokens(text, out _);
ValidateEncoding(encoding, codeGenTokenizer.AddPrefixSpace, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
encoding = tokenizer.Encode(text.AsSpan(), out _);
encoding = tokenizer.EncodeToTokens(text.AsSpan(), out _);
ValidateEncoding(encoding, codeGenTokenizer.AddPrefixSpace, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
ValidateEncoding(encoding, addPrefixSpace: false, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
ValidateEncoding(encoding, addPrefixSpace: false, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
ValidateEncoding(encoding, addPrefixSpace: true, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
ValidateEncoding(encoding, addPrefixSpace: true, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
//
@ -392,24 +440,24 @@ namespace Microsoft.ML.Tokenizers.Tests
offsets = codeGenTokenizer.AddPrefixSpace ? expectedOffsetsWithSpace : expectedOffsets;
Assert.Equal(offsets[offsets.Length - 1].Index + offsets[offsets.Length - 1].Length, codeGenTokenizer.IndexOfTokenCount(text, ids.Length, out normalizedString, out int tokenCount));
Assert.Equal(offsets[offsets.Length - 1].Index + offsets[offsets.Length - 1].Length, codeGenTokenizer.GetIndexByTokenCount(text, ids.Length, out normalizedString, out int tokenCount));
Assert.Null(normalizedString);
Assert.Equal(ids.Length, tokenCount);
Assert.Equal(offsets[offsets.Length - 1].Index + offsets[offsets.Length - 1].Length, codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), ids.Length, out normalizedString, out tokenCount));
Assert.Equal(offsets[offsets.Length - 1].Index + offsets[offsets.Length - 1].Length, codeGenTokenizer.GetIndexByTokenCount(text.AsSpan(), ids.Length, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(ids.Length, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 1].Index + expectedOffsets[expectedOffsets.Length - 1].Length, codeGenTokenizer.IndexOfTokenCount(text, expectedIds.Length, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 1].Index + expectedOffsets[expectedOffsets.Length - 1].Length, codeGenTokenizer.GetIndexByTokenCount(text, expectedIds.Length, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIds.Length, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 1].Index + expectedOffsets[expectedOffsets.Length - 1].Length, codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), expectedIds.Length, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 1].Index + expectedOffsets[expectedOffsets.Length - 1].Length, codeGenTokenizer.GetIndexByTokenCount(text.AsSpan(), expectedIds.Length, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIds.Length, tokenCount);
Assert.Equal(expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Index + expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Length, codeGenTokenizer.IndexOfTokenCount(text, expectedIdsWithSpace.Length, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Equal(expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Index + expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Length, codeGenTokenizer.GetIndexByTokenCount(text, expectedIdsWithSpace.Length, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIdsWithSpace.Length, tokenCount);
Assert.Equal(expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Index + expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Length, codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), expectedIdsWithSpace.Length, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Equal(expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Index + expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Length, codeGenTokenizer.GetIndexByTokenCount(text.AsSpan(), expectedIdsWithSpace.Length, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIdsWithSpace.Length, tokenCount);
@ -419,26 +467,26 @@ namespace Microsoft.ML.Tokenizers.Tests
int expectedIndex = offsets.Length > 1 && offsets[offsets.Length - 1].Index == offsets[offsets.Length - 2].Index ? text.Length : offsets[offsets.Length - 1].Index;
int expectedTokenCount = expectedIndex == text.Length ? 0 : 1;
Assert.Equal(expectedIndex, codeGenTokenizer.LastIndexOfTokenCount(text, 1, out normalizedString, out tokenCount));
Assert.Equal(expectedIndex, codeGenTokenizer.GetIndexByTokenCountFromEnd(text, 1, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedTokenCount, tokenCount);
Assert.Equal(expectedIndex, codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), 1, out normalizedString, out tokenCount));
Assert.Equal(expectedIndex, codeGenTokenizer.GetIndexByTokenCountFromEnd(text.AsSpan(), 1, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedTokenCount, tokenCount);
Assert.Equal(expectedIndex, codeGenTokenizer.LastIndexOfTokenCount(text, 1, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Equal(expectedIndex, codeGenTokenizer.GetIndexByTokenCountFromEnd(text, 1, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedTokenCount, tokenCount);
Assert.Equal(expectedIndex, codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), 1, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Equal(expectedIndex, codeGenTokenizer.GetIndexByTokenCountFromEnd(text.AsSpan(), 1, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedTokenCount, tokenCount);
expectedIndex = offsets.Length > 1 && expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Index == expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 2].Index ? text.Length : expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Index;
expectedTokenCount = expectedIndex == text.Length ? 0 : 1;
Assert.Equal(expectedIndex, codeGenTokenizer.LastIndexOfTokenCount(text, 1, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Equal(expectedIndex, codeGenTokenizer.GetIndexByTokenCountFromEnd(text, 1, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedTokenCount, tokenCount);
Assert.Equal(expectedIndex, codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), 1, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Equal(expectedIndex, codeGenTokenizer.GetIndexByTokenCountFromEnd(text.AsSpan(), 1, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedTokenCount, tokenCount);
@ -447,10 +495,33 @@ namespace Microsoft.ML.Tokenizers.Tests
//
var tokens = codeGenTokenizer.AddPrefixSpace ? expectedTokensWithSpace : expectedTokens;
var reverseVocab = codeGenTokenizer.Vocabulary.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
var reverseAddedTokens = codeGenTokenizer.AddedTokens?.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
for (int i = 0; i < tokens.Length; i++)
{
Assert.Equal(tokens[i], codeGenTokenizer.MapIdToToken(ids[i]));
Assert.Equal(ids[i], codeGenTokenizer.MapTokenToId(tokens[i]));
Assert.Equal(tokens[i], MapIdToToken(ids[i]));
Assert.Equal(ids[i], MapTokenId(tokens[i]));
}
string MapIdToToken(int id)
{
if (reverseVocab.TryGetValue(id, out string? token))
{
return token;
}
return reverseAddedTokens![id];
}
int MapTokenId(string token)
{
if (codeGenTokenizer.Vocabulary.TryGetValue(token, out int id))
{
return id;
}
return codeGenTokenizer.AddedTokens![token];
}
}
@ -473,9 +544,9 @@ namespace Microsoft.ML.Tokenizers.Tests
// Beginning of Sentence
//
CodeGen codeGenTokenizer = (_codegen350MMonoTokenizerWithBeginningOfSentence as CodeGen)!;
CodeGenTokenizer codeGenTokenizer = (_codegen350MMonoTokenizerWithBeginningOfSentence as CodeGenTokenizer)!;
IReadOnlyList<Token> encoding = codeGenTokenizer.Encode(text, out _);
IReadOnlyList<EncodedToken> encoding = codeGenTokenizer.EncodeToTokens(text, out _);
Assert.True(codeGenTokenizer.BeginningOfSentenceToken is not null);
Assert.True(codeGenTokenizer.BeginningOfSentenceId.HasValue);
var idList = new List<int>(expectedIds);
@ -486,17 +557,17 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), out _);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
@ -505,32 +576,32 @@ namespace Microsoft.ML.Tokenizers.Tests
idList.Insert(0, codeGenTokenizer.BeginningOfSentenceId!.Value);
tokensList = new List<string>(expectedTokensWithSpace);
tokensList.Insert(0, codeGenTokenizer.BeginningOfSentenceToken!);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: false, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: false, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
@ -547,9 +618,9 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.NotEqual(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
Assert.NotEqual(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
ids = codeGenTokenizer.EncodeToIds(text, maxTokenCount: 5, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out string? normalizedString, out int textLength);
ids = codeGenTokenizer.EncodeToIds(text, maxTokenCount: 5, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out string? normalizedString, out int charsConsumed);
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), maxTokenCount: 5, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out textLength);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), maxTokenCount: 5, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out charsConsumed);
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
int tokenCount = codeGenTokenizer.CountTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
@ -564,41 +635,41 @@ namespace Microsoft.ML.Tokenizers.Tests
count = codeGenTokenizer.CountTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false);
Assert.Equal(tokenCount + 1, count);
int length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
int length = codeGenTokenizer.GetIndexByTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
length = codeGenTokenizer.GetIndexByTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out count);
length = codeGenTokenizer.GetIndexByTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out count);
length = codeGenTokenizer.GetIndexByTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
length = codeGenTokenizer.GetIndexByTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
length = codeGenTokenizer.GetIndexByTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(text.Length, length);
int index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
int index = codeGenTokenizer.GetIndexByTokenCountFromEnd(text, maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
index = codeGenTokenizer.GetIndexByTokenCountFromEnd(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out count);
index = codeGenTokenizer.GetIndexByTokenCountFromEnd(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out count);
index = codeGenTokenizer.GetIndexByTokenCountFromEnd(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
index = codeGenTokenizer.GetIndexByTokenCountFromEnd(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
index = codeGenTokenizer.GetIndexByTokenCountFromEnd(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(0, index);
@ -606,9 +677,9 @@ namespace Microsoft.ML.Tokenizers.Tests
// End of Sentence
//
codeGenTokenizer = (_codegen350MMonoTokenizerWithEndOfSentence as CodeGen)!;
codeGenTokenizer = (_codegen350MMonoTokenizerWithEndOfSentence as CodeGenTokenizer)!;
encoding = codeGenTokenizer.Encode(text, out _);
encoding = codeGenTokenizer.EncodeToTokens(text, out _);
Assert.True(codeGenTokenizer.EndOfSentenceToken is not null);
Assert.True(codeGenTokenizer.EndOfSentenceId.HasValue);
idList = new List<int>(expectedIds);
@ -619,17 +690,17 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), out _);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out _);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out _);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
@ -638,32 +709,32 @@ namespace Microsoft.ML.Tokenizers.Tests
idList.Add(codeGenTokenizer.EndOfSentenceId!.Value);
tokensList = new List<string>(expectedTokensWithSpace);
tokensList.Add(codeGenTokenizer.EndOfSentenceToken!);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: true, out _);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: true, out _);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
@ -680,9 +751,9 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.NotEqual(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
Assert.NotEqual(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
ids = codeGenTokenizer.EncodeToIds(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out textLength);
ids = codeGenTokenizer.EncodeToIds(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out charsConsumed);
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out textLength);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out charsConsumed);
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
tokenCount = codeGenTokenizer.CountTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
@ -697,41 +768,41 @@ namespace Microsoft.ML.Tokenizers.Tests
count = codeGenTokenizer.CountTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true);
Assert.Equal(tokenCount + 1, count);
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
length = codeGenTokenizer.GetIndexByTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
length = codeGenTokenizer.GetIndexByTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out count);
length = codeGenTokenizer.GetIndexByTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out count);
length = codeGenTokenizer.GetIndexByTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
length = codeGenTokenizer.GetIndexByTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
length = codeGenTokenizer.GetIndexByTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(text.Length, length);
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
index = codeGenTokenizer.GetIndexByTokenCountFromEnd(text, maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
index = codeGenTokenizer.GetIndexByTokenCountFromEnd(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out count);
index = codeGenTokenizer.GetIndexByTokenCountFromEnd(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out count);
index = codeGenTokenizer.GetIndexByTokenCountFromEnd(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
index = codeGenTokenizer.GetIndexByTokenCountFromEnd(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
index = codeGenTokenizer.GetIndexByTokenCountFromEnd(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(0, index);
@ -739,9 +810,9 @@ namespace Microsoft.ML.Tokenizers.Tests
// Beginning & End of Sentence
//
codeGenTokenizer = (_codegen350MMonoTokenizerWithBeginningAndEndOfSentence as CodeGen)!;
codeGenTokenizer = (_codegen350MMonoTokenizerWithBeginningAndEndOfSentence as CodeGenTokenizer)!;
encoding = codeGenTokenizer.Encode(text, out _);
encoding = codeGenTokenizer.EncodeToTokens(text, out _);
Assert.True(codeGenTokenizer.BeginningOfSentenceToken is not null);
Assert.True(codeGenTokenizer.BeginningOfSentenceId.HasValue);
idList = new List<int>(expectedIds);
@ -755,19 +826,19 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal((0, 0), encoding[0].Offset);
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), out _);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out _);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out _);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
@ -779,37 +850,37 @@ namespace Microsoft.ML.Tokenizers.Tests
tokensList = new List<string>(expectedTokensWithSpace);
tokensList.Insert(0, codeGenTokenizer.BeginningOfSentenceToken!);
tokensList.Add(codeGenTokenizer.EndOfSentenceToken!);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
@ -833,10 +904,10 @@ namespace Microsoft.ML.Tokenizers.Tests
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
Assert.NotEqual(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
Assert.NotEqual(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
ids = codeGenTokenizer.EncodeToIds(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out textLength);
ids = codeGenTokenizer.EncodeToIds(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out charsConsumed);
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out textLength);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out charsConsumed);
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
@ -851,41 +922,41 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(tokenCount + 2, count);
count = codeGenTokenizer.CountTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true);
Assert.Equal(tokenCount + 2, count);
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
length = codeGenTokenizer.GetIndexByTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 2, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
length = codeGenTokenizer.GetIndexByTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 2, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out count);
length = codeGenTokenizer.GetIndexByTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out count);
Assert.Equal(tokenCount + 2, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out count);
length = codeGenTokenizer.GetIndexByTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out count);
Assert.Equal(tokenCount + 2, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
length = codeGenTokenizer.GetIndexByTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
length = codeGenTokenizer.GetIndexByTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(text.Length, length);
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
index = codeGenTokenizer.GetIndexByTokenCountFromEnd(text, maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 2, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
index = codeGenTokenizer.GetIndexByTokenCountFromEnd(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 2, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out count);
index = codeGenTokenizer.GetIndexByTokenCountFromEnd(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out count);
Assert.Equal(tokenCount + 2, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out count);
index = codeGenTokenizer.GetIndexByTokenCountFromEnd(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out count);
Assert.Equal(tokenCount + 2, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
index = codeGenTokenizer.GetIndexByTokenCountFromEnd(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
index = codeGenTokenizer.GetIndexByTokenCountFromEnd(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(0, index);
}
@ -895,14 +966,14 @@ namespace Microsoft.ML.Tokenizers.Tests
[Fact]
public void TestDefaultValues()
{
CodeGen codeGenTokenizer = (_codegen350MMonoTokenizer as CodeGen)!;
CodeGenTokenizer codeGenTokenizer = (_codegen350MMonoTokenizer as CodeGenTokenizer)!;
Assert.False(codeGenTokenizer.AddPrefixSpace);
Assert.False(codeGenTokenizer.AddBeginningOfSentence);
Assert.False(codeGenTokenizer.AddEndOfSentence);
Assert.Equal(codeGenTokenizer.MapTokenToId(DefaultSpecialToken), codeGenTokenizer.BeginningOfSentenceId!.Value);
Assert.Equal(codeGenTokenizer.MapTokenToId(DefaultSpecialToken), codeGenTokenizer.EndOfSentenceId!.Value);
Assert.Equal(codeGenTokenizer.MapTokenToId(DefaultSpecialToken), codeGenTokenizer.UnknownTokenId!.Value);
Assert.Equal(codeGenTokenizer.EncodeToIds(DefaultSpecialToken)[0], codeGenTokenizer.BeginningOfSentenceId!.Value);
Assert.Equal(codeGenTokenizer.EncodeToIds(DefaultSpecialToken)[0], codeGenTokenizer.EndOfSentenceId!.Value);
Assert.Equal(codeGenTokenizer.EncodeToIds(DefaultSpecialToken)[0], codeGenTokenizer.UnknownTokenId!.Value);
Assert.Equal(DefaultSpecialToken, codeGenTokenizer.BeginningOfSentenceToken);
Assert.Equal(DefaultSpecialToken, codeGenTokenizer.EndOfSentenceToken);
@ -923,35 +994,35 @@ namespace Microsoft.ML.Tokenizers.Tests
(int Index, int Length)[] offsets = [(0, 0), (0, 1), (1, 0), (1, 2)];
int calculatedLengthUsingOffsets = expectedTokenCount > 0 ? offsets[expectedTokenCount - 1].Index + offsets[expectedTokenCount - 1].Length : 0;
IReadOnlyList<int> ids = _codegen350MMonoTokenizer.EncodeToIds(input, maxTokenCount, out _, out int textLength);
IReadOnlyList<int> ids = _codegen350MMonoTokenizer.EncodeToIds(input, maxTokenCount, out _, out int charsConsumed);
Assert.Equal(expectedTokenCount, ids.Count);
Assert.Equal(expectedTextLength, textLength);
Assert.Equal(expectedTextLength, charsConsumed);
Assert.Equal(encodingIds.Take(expectedTokenCount), ids);
Assert.Equal(calculatedLengthUsingOffsets, textLength);
ids = _codegen350MMonoTokenizer.EncodeToIds(input.AsSpan(), maxTokenCount, out _, out textLength);
Assert.Equal(calculatedLengthUsingOffsets, charsConsumed);
ids = _codegen350MMonoTokenizer.EncodeToIds(input.AsSpan(), maxTokenCount, out _, out charsConsumed);
Assert.Equal(expectedTokenCount, ids.Count);
Assert.Equal(expectedTextLength, textLength);
Assert.Equal(expectedTextLength, charsConsumed);
Assert.Equal(encodingIds.Take(expectedTokenCount), ids);
Assert.Equal(calculatedLengthUsingOffsets, textLength);
Assert.Equal(calculatedLengthUsingOffsets, charsConsumed);
textLength = _codegen350MMonoTokenizer.IndexOfTokenCount(input, maxTokenCount, out _, out int tokenCount);
charsConsumed = _codegen350MMonoTokenizer.GetIndexByTokenCount(input, maxTokenCount, out _, out int tokenCount);
Assert.Equal(expectedTokenCount, tokenCount);
Assert.Equal(expectedTextLength, textLength);
Assert.Equal(calculatedLengthUsingOffsets, textLength);
textLength = _codegen350MMonoTokenizer.IndexOfTokenCount(input.AsSpan(), maxTokenCount, out _, out tokenCount);
Assert.Equal(expectedTextLength, charsConsumed);
Assert.Equal(calculatedLengthUsingOffsets, charsConsumed);
charsConsumed = _codegen350MMonoTokenizer.GetIndexByTokenCount(input.AsSpan(), maxTokenCount, out _, out tokenCount);
Assert.Equal(expectedTokenCount, tokenCount);
Assert.Equal(expectedTextLength, textLength);
Assert.Equal(calculatedLengthUsingOffsets, textLength);
Assert.Equal(expectedTextLength, charsConsumed);
Assert.Equal(calculatedLengthUsingOffsets, charsConsumed);
calculatedLengthUsingOffsets = expectedTokenCountFromEnd > 0 ? offsets[offsets.Length - expectedTokenCountFromEnd].Index : input.Length;
textLength = _codegen350MMonoTokenizer.LastIndexOfTokenCount(input, maxTokenCount, out _, out tokenCount);
charsConsumed = _codegen350MMonoTokenizer.GetIndexByTokenCountFromEnd(input, maxTokenCount, out _, out tokenCount);
Assert.Equal(expectedTokenCountFromEnd, tokenCount);
Assert.Equal(expectedTextIndexFromEnd, textLength);
Assert.Equal(calculatedLengthUsingOffsets, textLength);
textLength = _codegen350MMonoTokenizer.LastIndexOfTokenCount(input.AsSpan(), maxTokenCount, out _, out tokenCount);
Assert.Equal(expectedTextIndexFromEnd, charsConsumed);
Assert.Equal(calculatedLengthUsingOffsets, charsConsumed);
charsConsumed = _codegen350MMonoTokenizer.GetIndexByTokenCountFromEnd(input.AsSpan(), maxTokenCount, out _, out tokenCount);
Assert.Equal(expectedTokenCountFromEnd, tokenCount);
Assert.Equal(expectedTextIndexFromEnd, textLength);
Assert.Equal(calculatedLengthUsingOffsets, textLength);
Assert.Equal(expectedTextIndexFromEnd, charsConsumed);
Assert.Equal(calculatedLengthUsingOffsets, charsConsumed);
}
}
}

Просмотреть файл

@ -8,6 +8,7 @@ using System.Collections.Generic;
using System.Linq;
using Xunit;
using System.Buffers;
namespace Microsoft.ML.Tokenizers.Tests
{
@ -86,7 +87,7 @@ namespace Microsoft.ML.Tokenizers.Tests
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe";
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt";
_robertaTokenizer = new EnglishRoberta(
_robertaTokenizer = EnglishRobertaTokenizer.Create(
Path.Combine(@"Gpt-2", "vocab.json"),
Path.Combine(@"Gpt-2", "merges.txt"),
Path.Combine(@"Gpt-2", "dict.txt"),
@ -109,28 +110,28 @@ namespace Microsoft.ML.Tokenizers.Tests
string mergeFile = Path.Combine(@"Gpt-2", "merges.txt");
string translationFile = Path.Combine(@"Gpt-2", "dict.txt");
Tokenizer tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
Tokenizer tokenizer = EnglishRobertaTokenizer.Create(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer);
TokenizerTests.TestTokenLimits(tokenizer);
tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance, filterUnsupportedChars: false);
tokenizer = EnglishRobertaTokenizer.Create(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance, filterUnsupportedChars: false);
TestTokenizer(tokenizer);
using Stream vocabStream = File.OpenRead(vocabFile);
using Stream mergeStream = File.OpenRead(mergeFile);
using Stream translationStream = File.OpenRead(translationFile);
tokenizer = new EnglishRoberta(vocabStream, mergeStream, translationStream, RobertaPreTokenizer.Instance);
tokenizer = EnglishRobertaTokenizer.Create(vocabStream, mergeStream, translationStream, RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer);
// Ensure caching works regardless of which method is called first.
for (CallingOrder order = CallingOrder.Encode; order <= CallingOrder.CountTokens; order++)
{
tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
tokenizer = EnglishRobertaTokenizer.Create(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer, order);
tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance, filterUnsupportedChars: false);
tokenizer = EnglishRobertaTokenizer.Create(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance, filterUnsupportedChars: false);
TestTokenizer(tokenizer, order);
}
}
@ -177,8 +178,8 @@ namespace Microsoft.ML.Tokenizers.Tests
{
Tokenizer tokenizer = GetRobertaTokenizer();
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
IReadOnlyList<Token> encoding1 = tokenizer.Encode(text.AsSpan(), out _);
IReadOnlyList<EncodedToken> encoding = tokenizer.EncodeToTokens(text, out _);
IReadOnlyList<EncodedToken> encoding1 = tokenizer.EncodeToTokens(text.AsSpan(), out _);
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray());
@ -208,17 +209,17 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text));
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text.AsSpan()));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text, expectedIds.Length - 3, out normalizedString, out int tokenCount));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.GetIndexByTokenCount(text, expectedIds.Length - 3, out normalizedString, out int tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIds.Length - 3, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text.AsSpan(), expectedIds.Length - 3, out normalizedString, out tokenCount));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.GetIndexByTokenCount(text.AsSpan(), expectedIds.Length - 3, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIds.Length - 3, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text, 3, out normalizedString, out tokenCount));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.GetIndexByTokenCountFromEnd(text, 3, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(3, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text.AsSpan(), 3, out normalizedString, out tokenCount));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.GetIndexByTokenCountFromEnd(text.AsSpan(), 3, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(3, tokenCount);
}
@ -234,32 +235,32 @@ namespace Microsoft.ML.Tokenizers.Tests
// Calling with callIdsFirst = true will test the other way around.
private void TestTokenizer(Tokenizer tokenizer, CallingOrder callingOrder = CallingOrder.Encode)
{
Assert.True(tokenizer is EnglishRoberta);
Assert.True(tokenizer is EnglishRobertaTokenizer);
Assert.True(tokenizer.PreTokenizer is RobertaPreTokenizer);
foreach (object[] p in BertaData)
{
IReadOnlyList<int> ids;
IReadOnlyList<Token> encoding;
IReadOnlyList<EncodedToken> encoding;
int idsCount;
if (callingOrder == CallingOrder.Encode)
{
encoding = tokenizer.Encode((string)p[0], out _);
encoding = tokenizer.EncodeToTokens((string)p[0], out _);
ids = tokenizer.EncodeToIds((string)p[0]);
idsCount = tokenizer.CountTokens((string)p[0]);
}
else if (callingOrder == CallingOrder.EncodeToIds)
{
ids = tokenizer.EncodeToIds((string)p[0]);
encoding = tokenizer.Encode((string)p[0], out _);
encoding = tokenizer.EncodeToTokens((string)p[0], out _);
idsCount = tokenizer.CountTokens((string)p[0]);
}
else // CountTokens
{
idsCount = tokenizer.CountTokens((string)p[0]);
ids = tokenizer.EncodeToIds((string)p[0]);
encoding = tokenizer.Encode((string)p[0], out _);
encoding = tokenizer.EncodeToTokens((string)p[0], out _);
}
int[] encodingIds = encoding.Select(t => t.Id).ToArray();
@ -271,32 +272,63 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(((int[])p[1]).Length, idsCount);
Assert.Equal(p[3], offsets);
EnglishRoberta? robertaModel = tokenizer as EnglishRoberta;
EnglishRobertaTokenizer? robertaModel = tokenizer as EnglishRobertaTokenizer;
Assert.Equal(p[2], tokens);
Assert.Equal(string.Concat((string[])(p[robertaModel!.FilterUnsupportedChars ? 5 : 2])), tokenizer.Decode(encodingIds));
string expectedDecodedString = string.Concat((string[])(p[robertaModel!.FilterUnsupportedChars ? 5 : 2]));
Assert.Equal(expectedDecodedString, tokenizer.Decode(encodingIds));
TestDecodingWithSpan(robertaModel, encodingIds, expectedDecodedString);
Assert.NotNull(robertaModel);
Assert.Equal(encodingIds, robertaModel!.ConvertOccurrenceRanksToIds(robertaModel!.ConvertIdsToOccurrenceRanks(encodingIds)));
Assert.Equal(p[4], robertaModel.ConvertIdsToOccurrenceValues(encodingIds));
var reverseVocab = robertaModel.Vocabulary.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
for (int i = 0; i < tokens.Length; i++)
{
if (robertaModel.FilterUnsupportedChars)
{
string[]? filteredToken = p[5] as string[];
Assert.Equal(filteredToken![i], tokenizer.MapIdToToken(encodingIds[i]));
Assert.Equal(filteredToken![i], reverseVocab[encodingIds[i]].Replace("\u0120", " "));
}
else
{
Assert.Equal(tokens[i], tokenizer.MapIdToToken(encodingIds[i]));
Assert.Equal(tokens[i], reverseVocab[encodingIds[i]]);
string[]? unfilteredToken = p[2] as string[];
Assert.Equal(unfilteredToken![i], tokenizer.MapIdToToken(encodingIds[i]));
Assert.Equal(unfilteredToken![i], reverseVocab[encodingIds[i]]);
}
Assert.Equal(encodingIds[i], tokenizer.MapTokenToId(tokens[i].AsSpan()));
Assert.Equal(encodingIds[i], robertaModel.Vocabulary[tokens[i]]);
}
}
}
private void TestDecodingWithSpan(EnglishRobertaTokenizer tokenizer, int[] ids, string expectedDecoded)
{
char[] destinationBuffer = new char[expectedDecoded.Length];
OperationStatus status;
int lastIdsConsumed = 0;
int lastCharactersWritten = 0;
int idsConsumed;
int charactersWritten;
for (int i = 1; i < destinationBuffer.Length - 1; i += Math.Max(1, destinationBuffer.Length - 3)) // enough to test length 1, and destinationBuffer.Length - 2 only.
{
status = tokenizer.Decode(ids, destinationBuffer.AsSpan().Slice(0, i), out idsConsumed, out charactersWritten);
Assert.Equal(OperationStatus.DestinationTooSmall, status);
Assert.True(idsConsumed < ids.Length);
Assert.True(idsConsumed >= lastIdsConsumed);
Assert.True(charactersWritten < expectedDecoded.Length);
Assert.True(charactersWritten >= lastCharactersWritten);
}
status = tokenizer.Decode(ids, destinationBuffer.AsSpan(), out idsConsumed, out charactersWritten);
Assert.Equal(OperationStatus.Done, status);
Assert.Equal(ids.Length, idsConsumed);
Assert.Equal(expectedDecoded.Length, charactersWritten);
Assert.Equal(expectedDecoded, destinationBuffer.AsSpan().ToString());
}
}
}

Просмотреть файл

@ -1,15 +1,17 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Tokenizers;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Net.Http;
using System.Text.Json;
using System.Linq;
using System.IO;
using System.Threading.Tasks;
using System.Linq;
using System.Net.Http;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
using Xunit;
namespace Microsoft.ML.Tokenizers.Tests
@ -19,20 +21,65 @@ namespace Microsoft.ML.Tokenizers.Tests
private static readonly HttpClient _httpClient = new HttpClient() { Timeout = TimeSpan.FromMinutes(5) };
private static Tokenizer _llamaTokenizer = CreateLlamaTokenizer();
private static Tokenizer _llamaMistralTokenizer = CreateLMistralTokenizer();
private static Tokenizer _llamaPhi3Tokenizer = CreateLPhi3Tokenizer();
private static Tokenizer _llamaPhi3TokenizerWithTreatSpaceSuffix = CreateLPhi3Tokenizer(treatWhitespaceAsSuffix: true);
internal const string DummyPrefix = "\u2581"; // '▁' (LOWER ONE EIGHT BLOCK)
private static Tokenizer CreateLlamaTokenizer()
{
// @"https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.model?download=true";
// @"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model";
using Stream remoteStream = File.OpenRead(Path.Combine(@"Llama", "tokenizer.model"));
return Tokenizer.CreateLlama(remoteStream);
return LlamaTokenizer.Create(remoteStream);
}
private static Tokenizer CreateLMistralTokenizer()
{
// @"https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/tokenizer.model?download=true";
using Stream remoteStream = File.OpenRead(Path.Combine(@"Mistral", "tokenizer.model"));
return Tokenizer.CreateLlama(remoteStream);
return LlamaTokenizer.Create(remoteStream);
}
private static Tokenizer CreateLPhi3Tokenizer(bool treatWhitespaceAsSuffix = false)
{
// Phi3 is using the same tokenizer.model used by Llama. https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/tree/main
using Stream remoteStream = File.OpenRead(Path.Combine(@"Llama", "tokenizer.model"));
LlamaTokenizer tokenizer = LlamaTokenizer.Create(remoteStream, addBeginOfSentence: true, addEndOfSentence: false,
specialTokens: new Dictionary<string, int>
{
// added tokens are picked up from https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/tokenizer_config.json
{ "<unk>", 0 },
{ "<s>", 1 },
{ "</s>", 2 },
{ "<|endoftext|>" , 32000 },
{ "<|assistant|>", 32001 },
{ "<|placeholder1|>", 32002 },
{ "<|placeholder2|>", 32003 },
{ "<|placeholder3|>", 32004 },
{ "<|placeholder4|>", 32005 },
{ "<|system|>", 32006 },
{ "<|end|>", 32007 },
{ "<|placeholder5|>", 32008 },
{ "<|placeholder6|>", 32009 },
{ "<|user|>", 32010 },
});
if (treatWhitespaceAsSuffix)
{
PropertyInfo? propertyInfo = typeof(SentencePieceBpeTokenizer).GetProperty("TreatWhitespaceAsSuffix", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public);
if (propertyInfo != null)
{
propertyInfo.SetValue(tokenizer, true);
}
propertyInfo = typeof(SentencePieceNormalizer).GetProperty("TreatWhitespaceAsSuffix", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public);
if (propertyInfo != null)
{
propertyInfo.SetValue(tokenizer.Normalizer, true);
}
}
return tokenizer;
}
public static IEnumerable<object[]> LlamaTestData()
@ -184,68 +231,112 @@ namespace Microsoft.ML.Tokenizers.Tests
[Theory]
[MemberData(nameof(LlamaTestData))]
public void TestLlamaTokenizer(Tokenizer llamaTokenizer, string input, int[] ids, string[] tokens, (int Index, int Length)[] offsets)
public void TestLlamaTokenizer(Tokenizer tokenizer, string input, int[] ids, string[] tokens, (int Index, int Length)[] offsets)
{
SentencePieceBpe? bpe = llamaTokenizer as SentencePieceBpe;
Assert.NotNull(bpe);
// Phi-3 and Llama are using the same tokenizer.model, so we can test both with the same data as long as we are not using added tokens which behave differently for Phi-3.
Tokenizer[] tokenizers = tokenizer == _llamaTokenizer ? new[] { tokenizer, _llamaPhi3Tokenizer } : new[] { tokenizer };
IReadOnlyList<Token> result = llamaTokenizer.Encode(input, out _);
Assert.Equal(ids, result.Select(t => t.Id).ToArray());
Assert.Equal(tokens, result.Select(t => t.Value).ToArray());
Assert.Equal(offsets, result.Select(t => t.Offset).ToArray());
Assert.Equal(input, llamaTokenizer.Decode(ids));
Assert.Equal(ids, llamaTokenizer.EncodeToIds(input));
Assert.Equal(ids.Length, llamaTokenizer.CountTokens(input));
for (int i = 0; i < tokens.Length; i++)
foreach (Tokenizer llamaTokenizer in tokenizers)
{
Assert.Equal(tokens[i], bpe!.MapIdToToken(ids[i]));
Assert.Equal(ids[i], bpe!.MapTokenToId(tokens[i].AsSpan()));
Assert.Equal(ids[i], bpe!.Vocab[tokens[i]]);
LlamaTokenizer bpe = (llamaTokenizer as LlamaTokenizer)!;
Assert.NotNull(bpe);
IReadOnlyList<EncodedToken> result = llamaTokenizer.EncodeToTokens(input, out _);
Assert.Equal(ids, result.Select(t => t.Id).ToArray());
Assert.Equal(tokens, result.Select(t => t.Value).ToArray());
Assert.Equal(offsets, result.Select(t => t.Offset).ToArray());
Assert.Equal(input, llamaTokenizer.Decode(ids));
TestDecodingWithSpan(bpe, ids, input);
Assert.Equal(ids, llamaTokenizer.EncodeToIds(input));
Assert.Equal(ids.Length, llamaTokenizer.CountTokens(input));
var reverseVocabulary = bpe.Vocabulary.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
for (int i = 0; i < tokens.Length; i++)
{
Assert.Equal(tokens[i], reverseVocabulary[ids[i]]);
Assert.Equal(ids[i], bpe.Vocabulary[tokens[i]]);
}
Assert.NotNull(llamaTokenizer.Normalizer);
string normalizedInput = llamaTokenizer.Normalizer!.Normalize(input);
bool isEmptyInput = string.IsNullOrEmpty(input);
IReadOnlyList<EncodedToken> bpeTokens = bpe.EncodeToTokens(normalizedInput.AsSpan(), out _, addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false);
Assert.Equal(ids.Skip(1), bpeTokens.Select(token => token.Id));
Assert.Equal(tokens.Skip(1), bpeTokens.Select(token => token.Value));
int[] extractedIds = bpeTokens.Select(token => token.Id).ToArray();
Assert.Equal(input, llamaTokenizer.Decode(extractedIds));
TestDecodingWithSpan(bpe, extractedIds, input);
IReadOnlyList<int> encodedIds = bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false);
Assert.Equal(ids.Skip(1), encodedIds);
Assert.Equal(isEmptyInput ? 0 : ids.Length - 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false));
bpeTokens = bpe.EncodeToTokens(normalizedInput.AsSpan(), out _, addBeginningOfSentence: false, addEndOfSentence: true, considerNormalization: false);
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Skip(1).Concat(new[] { bpe.EndOfSentenceId }), bpeTokens.Select(token => token.Id));
Assert.Equal(isEmptyInput ? Array.Empty<string>() : tokens.Skip(1).Concat(new[] { bpe.EndOfSentenceToken }), bpeTokens.Select(token => token.Value));
extractedIds = bpeTokens.Select(token => token.Id).ToArray();
Assert.Equal(input, llamaTokenizer.Decode(extractedIds));
TestDecodingWithSpan(bpe, extractedIds, input);
encodedIds = bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: true, considerNormalization: false);
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Skip(1).Concat(new[] { bpe.EndOfSentenceId }), encodedIds);
Assert.Equal(isEmptyInput ? 0 : ids.Length, bpe.CountTokens(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: true, considerNormalization: false));
bpeTokens = bpe.EncodeToTokens(normalizedInput.AsSpan(), out _, addBeginningOfSentence: true, addEndOfSentence: true, considerNormalization: false);
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Concat(new[] { bpe.EndOfSentenceId }), bpeTokens.Select(token => token.Id));
Assert.Equal(isEmptyInput ? Array.Empty<string>() : tokens.Concat(new[] { bpe.EndOfSentenceToken }), bpeTokens.Select(token => token.Value));
extractedIds = bpeTokens.Select(token => token.Id).ToArray();
Assert.Equal(input, llamaTokenizer.Decode(extractedIds));
TestDecodingWithSpan(bpe, extractedIds, input);
encodedIds = bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginningOfSentence: true, addEndOfSentence: true, considerNormalization: false);
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Concat(new[] { bpe.EndOfSentenceId }), encodedIds);
Assert.Equal(isEmptyInput ? 0 : ids.Length + 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginningOfSentence: true, addEndOfSentence: true, considerNormalization: false));
}
}
private void TestDecodingWithSpan(LlamaTokenizer tokenizer, int[] ids, string expectedDecoded)
{
char[] destinationBuffer = new char[expectedDecoded.Length];
OperationStatus status;
int lastIdsConsumed = 0;
int lastCharactersWritten = 0;
int idsConsumed;
int charactersWritten;
for (int i = 1; i < destinationBuffer.Length - 1; i += Math.Max(1, destinationBuffer.Length - 3)) // enough to test length 1, and destinationBuffer.Length - 2 only.
{
status = tokenizer.Decode(ids, destinationBuffer.AsSpan().Slice(0, i), out idsConsumed, out charactersWritten);
Assert.Equal(OperationStatus.DestinationTooSmall, status);
Assert.True(idsConsumed < ids.Length);
Assert.True(idsConsumed >= lastIdsConsumed);
Assert.True(charactersWritten < expectedDecoded.Length);
Assert.True(charactersWritten >= lastCharactersWritten);
lastIdsConsumed = idsConsumed;
lastCharactersWritten = charactersWritten;
}
Assert.NotNull(llamaTokenizer.Normalizer);
string normalizedInput = llamaTokenizer.Normalizer!.Normalize(input);
bool isEmptyInput = string.IsNullOrEmpty(input);
IReadOnlyList<Token> bpeTokens = bpe.Encode(normalizedInput.AsSpan(), out _, addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false);
Assert.Equal(ids.Skip(1), bpeTokens.Select(token => token.Id));
Assert.Equal(tokens.Skip(1), bpeTokens.Select(token => token.Value));
Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id)));
IReadOnlyList<int> encodedIds = bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false);
Assert.Equal(ids.Skip(1), encodedIds);
Assert.Equal(isEmptyInput ? 0 : ids.Length - 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false));
bpeTokens = bpe.Encode(normalizedInput.AsSpan(), out _, addBeginningOfSentence: false, addEndOfSentence: true, considerNormalization: false);
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Skip(1).Concat(new[] { bpe.EndOfSentenceId }), bpeTokens.Select(token => token.Id));
Assert.Equal(isEmptyInput ? Array.Empty<string>() : tokens.Skip(1).Concat(new[] { bpe.EndOfSentenceToken }), bpeTokens.Select(token => token.Value));
Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id)));
encodedIds = bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: true, considerNormalization: false);
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Skip(1).Concat(new[] { bpe.EndOfSentenceId }), encodedIds);
Assert.Equal(isEmptyInput ? 0 : ids.Length, bpe.CountTokens(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: true, considerNormalization: false));
bpeTokens = bpe.Encode(normalizedInput.AsSpan(), out _, addBeginningOfSentence: true, addEndOfSentence: true, considerNormalization: false);
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Concat(new[] { bpe.EndOfSentenceId }), bpeTokens.Select(token => token.Id));
Assert.Equal(isEmptyInput ? Array.Empty<string>() : tokens.Concat(new[] { bpe.EndOfSentenceToken }), bpeTokens.Select(token => token.Value));
Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id)));
encodedIds = bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginningOfSentence: true, addEndOfSentence: true, considerNormalization: false);
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Concat(new[] { bpe.EndOfSentenceId }), encodedIds);
Assert.Equal(isEmptyInput ? 0 : ids.Length + 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginningOfSentence: true, addEndOfSentence: true, considerNormalization: false));
status = tokenizer.Decode(ids, destinationBuffer.AsSpan(), out idsConsumed, out charactersWritten);
Assert.Equal(OperationStatus.Done, status);
Assert.Equal(ids.Length, idsConsumed);
Assert.Equal(expectedDecoded.Length, charactersWritten);
Assert.Equal(expectedDecoded, destinationBuffer.AsSpan().ToString());
}
public static IEnumerable<object[]> LlamaTokenizersListData()
{
yield return new object[] { _llamaTokenizer };
yield return new object[] { _llamaMistralTokenizer };
yield return new object[] { _llamaPhi3Tokenizer };
}
[Theory]
[MemberData(nameof(LlamaTokenizersListData))]
public void TestLlamaTokenizerWithEmptyInput(Tokenizer llamaTokenizer)
{
Assert.Equal([], llamaTokenizer.Encode((string)null!, out _));
Assert.Equal([], llamaTokenizer.Encode(Span<char>.Empty, out _));
Assert.Equal([], llamaTokenizer.EncodeToTokens((string)null!, out _));
Assert.Equal([], llamaTokenizer.EncodeToTokens(Span<char>.Empty, out _));
Assert.Equal([], llamaTokenizer.EncodeToIds((string)null!));
Assert.Equal([], llamaTokenizer.EncodeToIds(Span<char>.Empty));
@ -260,14 +351,14 @@ namespace Microsoft.ML.Tokenizers.Tests
[MemberData(nameof(LlamaTokenizersListData))]
public void TestLlamaTokenizerProperties(Tokenizer llamaTokenizer)
{
SentencePieceBpe? bpe = llamaTokenizer as SentencePieceBpe;
LlamaTokenizer? bpe = llamaTokenizer as LlamaTokenizer;
Assert.NotNull(bpe);
Assert.NotNull(llamaTokenizer.Normalizer);
Assert.Equal("▁Hello,▁World!", llamaTokenizer.Normalizer.Normalize("Hello, World!"));
Assert.True(bpe.Vocab.Count > 0);
Assert.True(bpe.Vocab.TryGetValue("▁", out _));
Assert.True(bpe.Vocabulary.Count > 0);
Assert.True(bpe.Vocabulary.TryGetValue("▁", out _));
Assert.Equal(0, bpe.UnknownId);
Assert.Equal("<unk>", bpe.UnknownToken);
@ -276,8 +367,6 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(2, bpe.EndOfSentenceId);
Assert.Equal("</s>", bpe.EndOfSentenceToken);
Assert.Equal(bpe.Vocab["▁"], bpe.MapTokenToId("▁".AsSpan()));
Assert.Equal("▁", bpe.MapIdToToken(bpe.Vocab["▁"]));
Assert.True(bpe.ByteFallback);
Assert.True(bpe.AddDummyPrefix);
@ -290,41 +379,58 @@ namespace Microsoft.ML.Tokenizers.Tests
[Fact]
public void TestSentencePieceNormalizer()
{
SentencePieceNormalizer normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: false, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false);
SentencePieceNormalizer normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: false, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false, specialTokens: null);
Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!"));
Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!".AsSpan()));
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: false, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false);
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: false, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false, specialTokens: null);
Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!"));
Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!".AsSpan()));
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false);
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false, specialTokens: null);
Assert.Equal(" Hello, World!", normalizer.Normalize("Hello, World!"));
Assert.Equal(" Hello, World!", normalizer.Normalize("Hello, World!".AsSpan()));
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: false);
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: false, specialTokens: null);
Assert.Equal("▁Hello,▁World!", normalizer.Normalize("Hello, World!"));
Assert.Equal("▁Hello,▁World!", normalizer.Normalize("Hello, World!".AsSpan()));
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: false);
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: false, specialTokens: null);
Assert.Equal("▁Hello,▁▁▁▁▁▁World!", normalizer.Normalize("Hello, World!"));
Assert.Equal("▁Hello,▁▁▁▁▁▁World!", normalizer.Normalize("Hello, World!".AsSpan()));
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true);
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true, specialTokens: null);
Assert.Equal("Hello,▁World!▁", normalizer.Normalize("Hello, World!"));
Assert.Equal("Hello,▁World!▁", normalizer.Normalize("Hello, World!".AsSpan()));
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: false, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true);
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: false, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true, specialTokens: null);
Assert.Equal("Hello,▁World!", normalizer.Normalize("Hello, World!"));
Assert.Equal("Hello,▁World!", normalizer.Normalize("Hello, World!".AsSpan()));
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true);
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true, specialTokens: null);
Assert.Equal("Hello,▁▁▁▁▁▁World!▁", normalizer.Normalize("Hello, World!"));
Assert.Equal("Hello,▁▁▁▁▁▁World!▁", normalizer.Normalize("Hello, World!".AsSpan()));
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: true);
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: true, specialTokens: null);
Assert.Equal("Hello, World! ", normalizer.Normalize("Hello, World!"));
Assert.Equal("Hello, World! ", normalizer.Normalize("Hello, World!".AsSpan()));
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: false, specialTokens: (_llamaPhi3Tokenizer as LlamaTokenizer)!.SpecialTokens);
Assert.Equal("<|user|>", normalizer.Normalize("<|user|>"));
Assert.Equal("<|user|><|system|><|assistant|><|endoftext|>", normalizer.Normalize("<|user|><|system|><|assistant|><|endoftext|>"));
Assert.Equal("▁Hello<|user|>", normalizer.Normalize("Hello<|user|>"));
Assert.Equal("▁Hello,▁<|user|>World", normalizer.Normalize("Hello, <|user|>World"));
Assert.Equal("<|endoftext|>▁Hello<|user|>", normalizer.Normalize("<|endoftext|>Hello<|user|>"));
Assert.Equal("", normalizer.Normalize(""));
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true, specialTokens: (_llamaPhi3Tokenizer as LlamaTokenizer)!.SpecialTokens);
Assert.Equal("<|user|>", normalizer.Normalize("<|user|>"));
Assert.Equal("<|user|><|system|><|assistant|><|endoftext|>", normalizer.Normalize("<|user|><|system|><|assistant|><|endoftext|>"));
Assert.Equal("Hello▁<|user|>", normalizer.Normalize("Hello<|user|>"));
Assert.Equal("Hello,▁<|user|>World▁", normalizer.Normalize("Hello, <|user|>World"));
Assert.Equal("<|endoftext|>Hello▁<|user|>", normalizer.Normalize("<|endoftext|>Hello<|user|>"));
Assert.Equal("", normalizer.Normalize(""));
}
public static IEnumerable<object?[]> TokenizerTestData
@ -376,8 +482,8 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.NotNull(tokenizer.Normalizer);
Assert.Null(tokenizer.PreTokenizer);
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
IReadOnlyList<Token> encoding1 = tokenizer.Encode(text.AsSpan(), out _);
IReadOnlyList<EncodedToken> encoding = tokenizer.EncodeToTokens(text, out _);
IReadOnlyList<EncodedToken> encoding1 = tokenizer.EncodeToTokens(text.AsSpan(), out _);
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray());
@ -387,12 +493,12 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray());
Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray());
SentencePieceBpe sentencePieceBpe = (tokenizer as SentencePieceBpe)!;
SentencePieceBpeTokenizer sentencePieceBpe = (tokenizer as SentencePieceBpeTokenizer)!;
foreach (bool considerNormalization in new[] { true, false })
foreach (bool addBeginningOfSentence in new[] { true, false })
foreach (bool addEndOfSentence in new[] { true, false })
{
encoding = sentencePieceBpe.Encode(
encoding = sentencePieceBpe.EncodeToTokens(
considerNormalization ? text : normalizedText,
out _,
addBeginningOfSentence: addBeginningOfSentence,
@ -400,7 +506,7 @@ namespace Microsoft.ML.Tokenizers.Tests
considerPreTokenization: false,
considerNormalization: considerNormalization);
encoding1 = sentencePieceBpe.Encode(
encoding1 = sentencePieceBpe.EncodeToTokens(
considerNormalization ? text.AsSpan() : normalizedText.AsSpan(),
out _,
addBeginningOfSentence: addBeginningOfSentence,
@ -441,12 +547,12 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(normalizedText, normalizedString);
Assert.Equal(normalizedText.Length, length);
SentencePieceBpe sentencePieceBpe = (tokenizer as SentencePieceBpe)!;
SentencePieceBpeTokenizer sentencePieceBpe = (tokenizer as SentencePieceBpeTokenizer)!;
foreach (bool considerNormalization in new[] { true, false })
foreach (bool addBeginningOfSentence in new[] { true, false })
foreach (bool addEndOfSentence in new[] { true, false })
{
// (string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
// (string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true)
int[] expectedIds1 = addBeginningOfSentence ? expectedIds : expectedIds.Skip(1).ToArray();
expectedIds1 = addEndOfSentence ? expectedIds1.Concat(new[] { sentencePieceBpe.EndOfSentenceId }).ToArray() : expectedIds1;
@ -511,19 +617,286 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text));
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text.AsSpan()));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index + expectedOffsets[expectedOffsets.Length - 7].Length, tokenizer.IndexOfTokenCount(text, expectedIds.Length - 6, out string? normalizedString, out int tokenCount));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index + expectedOffsets[expectedOffsets.Length - 7].Length, tokenizer.GetIndexByTokenCount(text, expectedIds.Length - 6, out string? normalizedString, out int tokenCount));
Assert.Equal(normalizedText, normalizedString);
Assert.Equal(expectedIds.Length - 6, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index + expectedOffsets[expectedOffsets.Length - 7].Length, tokenizer.IndexOfTokenCount(text.AsSpan(), expectedIds.Length - 6, out normalizedString, out tokenCount));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index + expectedOffsets[expectedOffsets.Length - 7].Length, tokenizer.GetIndexByTokenCount(text.AsSpan(), expectedIds.Length - 6, out normalizedString, out tokenCount));
Assert.Equal(normalizedText, normalizedString);
Assert.Equal(expectedIds.Length - 6, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index, tokenizer.LastIndexOfTokenCount(text, 7, out normalizedString, out tokenCount));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index, tokenizer.GetIndexByTokenCountFromEnd(text, 7, out normalizedString, out tokenCount));
Assert.Equal(normalizedText, normalizedString);
Assert.Equal(7, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index, tokenizer.LastIndexOfTokenCount(text.AsSpan(), 7, out normalizedString, out tokenCount));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index, tokenizer.GetIndexByTokenCountFromEnd(text.AsSpan(), 7, out normalizedString, out tokenCount));
Assert.Equal(normalizedText, normalizedString);
Assert.Equal(7, tokenCount);
}
[Fact]
public void TestPhi3Tokenizer()
{
LlamaTokenizer tokenizer = (_llamaPhi3Tokenizer as LlamaTokenizer)!;
Assert.True(tokenizer.SpecialTokens is not null);
StringBuilder sb = new(); // Create bigger string containing all Added Tokens
IReadOnlyList<EncodedToken> encodedTokens;
IReadOnlyList<int> encodedIds;
int tokenCount;
string? normalizedString;
foreach (var kvp in tokenizer.SpecialTokens)
{
encodedTokens = tokenizer.EncodeToTokens(kvp.Key, out normalizedString);
Assert.Equal(new[] { tokenizer.BeginningOfSentenceToken, kvp.Key }, encodedTokens.Select(et => et.Value).ToArray());
Assert.Equal(new[] { tokenizer.BeginningOfSentenceId, kvp.Value }, encodedTokens.Select(et => et.Id).ToArray());
Assert.Equal($"{kvp.Key}", normalizedString);
encodedIds = tokenizer.EncodeToIds(kvp.Key);
Assert.Equal(encodedIds, encodedTokens.Select(et => et.Id).ToArray());
tokenCount = tokenizer.CountTokens(kvp.Key);
Assert.Equal(tokenCount, encodedTokens.Count);
sb.Append($" Hello{kvp.Key}");
}
string s = sb.ToString();
string expectedNormalizedString = $"{DummyPrefix}{s.Replace(' ', DummyPrefix[0])}";
encodedTokens = tokenizer.EncodeToTokens(s, out normalizedString, addBeginningOfSentence: false, addEndOfSentence: false);
Assert.Equal(expectedNormalizedString, normalizedString);
string[] specialTokens = tokenizer.SpecialTokens.Keys.ToArray();
string accumulatedString = DummyPrefix;
string accumulatedStringFromEnd = "";
for (int i = 1; i <= encodedTokens.Count; i++)
{
int index = tokenizer.GetIndexByTokenCount(s, addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: i, out normalizedString, out tokenCount);
Assert.Equal(index, accumulatedString.Length);
Assert.Equal(i, tokenCount);
accumulatedString += i % 2 != 0 ? $"Hello{DummyPrefix}" : specialTokens[i / 2 - 1];
accumulatedStringFromEnd = (encodedTokens.Count == i ? DummyPrefix : (i % 2 == 0 ? $"{DummyPrefix}Hello" : specialTokens[specialTokens.Length - 1 - (i / 2)])) + accumulatedStringFromEnd;
index = tokenizer.GetIndexByTokenCountFromEnd(s, addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: i, considerNormalization: true, out normalizedString, out tokenCount);
Assert.Equal(i, tokenCount);
Assert.Equal(index, normalizedString!.Length - accumulatedStringFromEnd.Length);
}
}
public static IEnumerable<object[]> Phi3TestData()
{
// text to tokenize,
// Decode text without special tokens,
// expected ids
// expected ids when using space suffix
yield return new object[]
{
"Can you provide ways to eat combinations of bananas and dragonfruits?",
"Can you provide ways to eat combinations of bananas and dragonfruits?",
new int[]
{
1, 1815, 366, 3867, 5837, 304, 17545, 18240, 310, 9892, 16397, 322, 8338, 265, 29888, 21211, 29973
},
new int[]
{
1, 6028, 366, 3867, 5837, 304, 17545, 18240, 310, 9892, 16397, 322, 8338, 265, 29888, 21211, 29973, 29871
}
};
yield return new object[]
{
"Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2." +
" Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey.",
"Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2." +
" Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey.",
new int[]
{
1, 18585, 29991, 2266, 526, 777, 5837, 304, 17545, 9892, 16397, 322, 8338, 265, 29888, 21211, 4208, 29901, 29871, 29896, 29889, 10765, 1648, 322, 8338, 265, 29888,
9216, 10597, 347, 29901, 3164, 355, 9892, 16397, 322, 8338, 265, 29888, 21211, 4208, 411, 777, 27274, 322, 298, 4992, 29889, 29871, 29906, 29889, 10765, 1648, 322,
8338, 265, 29888, 9216, 4497, 328, 29901, 23478, 269, 506, 287, 9892, 16397, 322, 8338, 265, 29888, 21211, 4208, 411, 777, 454, 3712, 3623, 625, 322, 298, 4992, 29889
},
new int[]
{
1, 29903, 545, 29991, 2266, 526, 777, 5837, 304, 17545, 9892, 16397, 322, 8338, 265, 29888, 21211, 4208, 29901, 29871, 29896, 29889, 10765, 1648, 322, 8338, 265, 29888,
9216, 10597, 347, 29901, 3164, 355, 9892, 16397, 322, 8338, 265, 29888, 21211, 4208, 411, 777, 27274, 322, 298, 4992, 29889, 29871, 29906, 29889, 10765, 1648, 322, 8338,
265, 29888, 9216, 4497, 328, 29901, 23478, 269, 506, 287, 9892, 16397, 322, 8338, 265, 29888, 21211, 4208, 411, 777, 454, 3712, 3623, 625, 322, 298, 4992, 29889, 29871
}
};
yield return new object[]
{
"What about solving an 2x + 3 = 7 equation?",
"What about solving an 2x + 3 = 7 equation?",
new int[]
{
1, 1724, 1048, 17069, 385, 29871, 29906, 29916, 718, 29871, 29941, 353, 29871, 29955, 6306, 29973
},
new int[]
{
1, 5618, 1048, 17069, 385, 29871, 29906, 29916, 718, 29871, 29941, 353, 29871, 29955, 6306, 29973, 29871
}
};
yield return new object[]
{
"\nCount to 3\n",
"\nCount to 3\n",
new int[]
{
1, 29871, 13, 3981, 304, 29871, 29941, 13
},
new int[]
{
1, 13, 3981, 304, 29871, 29941, 13, 29871
}
};
yield return new object[]
{
"<|user|>",
"",
new int[]
{
1, 32010
},
new int[]
{
1, 32010
}
};
yield return new object[]
{
"<|end|>",
"",
new int[]
{
1, 32007
},
new int[]
{
1, 32007
}
};
yield return new object[]
{
"<|assistant|>",
"",
new int[]
{
1, 32001
},
new int[]
{
1, 32001
}
};
yield return new object[]
{
"<|user|>\nCount to 3<|end|>\n<|assistant|>",
"\nCount to 3\n",
new int[]
{
1, 32010, 29871, 13, 3981, 304, 29871, 29941, 32007, 13, 32001
},
new int[]
{
1, 32010, 13, 3981, 304, 29871, 29941, 32007, 13, 29871, 32001
}
};
}
[Theory]
[MemberData(nameof(Phi3TestData))]
public void TestPhi3TokenizerIdEncoding(string text, string decodedWithNoSpecialTokens, int[] expectedIds, int[] expectedIdsWithSuffix)
{
LlamaTokenizer tokenizer = (_llamaPhi3Tokenizer as LlamaTokenizer)!;
var ids = tokenizer.EncodeToIds(text);
Assert.Equal(expectedIds, ids);
Assert.Equal(decodedWithNoSpecialTokens, tokenizer.Decode(expectedIds));
string textWithSpecialTokens = $"{tokenizer.BeginningOfSentenceToken}{text}";
Assert.Equal(textWithSpecialTokens, tokenizer.Decode(expectedIds, considerSpecialTokens: true));
char[] destinationBuffer = new char[decodedWithNoSpecialTokens.Length];
int idsConsumed;
int charactersWritten;
for (int i = 1; i < destinationBuffer.Length - 1; i += Math.Max(1, destinationBuffer.Length - 3)) // enough to test length 1, and destinationBuffer.Length - 2 only.
{
Assert.Equal(OperationStatus.DestinationTooSmall, tokenizer.Decode(ids, destinationBuffer.AsSpan().Slice(0, i), out idsConsumed, out charactersWritten));
Assert.True(idsConsumed < ids.Count);
Assert.True(decodedWithNoSpecialTokens.AsSpan().StartsWith(destinationBuffer.AsSpan().Slice(0, charactersWritten)));
}
Assert.Equal(OperationStatus.Done, tokenizer.Decode(ids, destinationBuffer.AsSpan(), out idsConsumed, out charactersWritten));
Assert.Equal(ids.Count, idsConsumed);
Assert.Equal(decodedWithNoSpecialTokens.Length, charactersWritten);
Assert.Equal(decodedWithNoSpecialTokens, destinationBuffer.AsSpan().ToString());
destinationBuffer = new char[textWithSpecialTokens.Length];
for (int i = 1; i < destinationBuffer.Length - 1; i += Math.Max(1, destinationBuffer.Length - 3)) // enough to test length 1, and destinationBuffer.Length - 2 only.
{
Assert.Equal(OperationStatus.DestinationTooSmall, tokenizer.Decode(ids, destinationBuffer.AsSpan().Slice(0, i), considerSpecialTokens: true, out idsConsumed, out charactersWritten));
Assert.True(idsConsumed < ids.Count);
Assert.True(textWithSpecialTokens.AsSpan().StartsWith(destinationBuffer.AsSpan().Slice(0, charactersWritten)));
}
Assert.Equal(OperationStatus.Done, tokenizer.Decode(ids, destinationBuffer.AsSpan(), considerSpecialTokens: true, out idsConsumed, out charactersWritten));
Assert.Equal(ids.Count, idsConsumed);
Assert.Equal(textWithSpecialTokens.Length, charactersWritten);
Assert.Equal(textWithSpecialTokens, destinationBuffer.AsSpan().ToString());
LlamaTokenizer tokenizerWithSuffix = (_llamaPhi3TokenizerWithTreatSpaceSuffix as LlamaTokenizer)!;
Assert.True(tokenizerWithSuffix.TreatWhitespaceAsSuffix);
ids = tokenizerWithSuffix.EncodeToIds(text);
Assert.Equal(expectedIdsWithSuffix, ids);
Assert.Equal(decodedWithNoSpecialTokens, tokenizerWithSuffix.Decode(expectedIdsWithSuffix));
Assert.Equal(textWithSpecialTokens, tokenizerWithSuffix.Decode(expectedIdsWithSuffix, considerSpecialTokens: true));
//
// Test with suffix instead of prefix
//
destinationBuffer = new char[decodedWithNoSpecialTokens.Length + 1]; // one extra for suffix
for (int i = 1; i < destinationBuffer.Length - 1; i += Math.Max(1, destinationBuffer.Length - 3)) // enough to test length 1, and destinationBuffer.Length - 2 only.
{
Assert.Equal(OperationStatus.DestinationTooSmall, tokenizerWithSuffix.Decode(ids, destinationBuffer.AsSpan().Slice(0, i), out idsConsumed, out charactersWritten));
Assert.True(idsConsumed < ids.Count);
Assert.True(decodedWithNoSpecialTokens.AsSpan().StartsWith(destinationBuffer.AsSpan().Slice(0, charactersWritten)));
}
Assert.Equal(OperationStatus.Done, tokenizerWithSuffix.Decode(ids, destinationBuffer.AsSpan(), out idsConsumed, out charactersWritten));
Assert.Equal(ids.Count, idsConsumed);
Assert.Equal(decodedWithNoSpecialTokens.Length, charactersWritten);
Assert.Equal(decodedWithNoSpecialTokens, destinationBuffer.AsSpan().Slice(0, charactersWritten).ToString());
destinationBuffer = new char[textWithSpecialTokens.Length + 1];
for (int i = 1; i < destinationBuffer.Length - 1; i += Math.Max(1, destinationBuffer.Length - 3)) // enough to test length 1, and destinationBuffer.Length - 2 only.
{
Assert.Equal(OperationStatus.DestinationTooSmall, tokenizerWithSuffix.Decode(ids, destinationBuffer.AsSpan().Slice(0, i), considerSpecialTokens: true, out idsConsumed, out charactersWritten));
Assert.True(idsConsumed < ids.Count);
var sp = destinationBuffer.AsSpan().Slice(0, charactersWritten);
if (sp.Length > 0 && sp[sp.Length - 1] == ' ')
{
sp = sp.Slice(0, sp.Length - 1);
}
Assert.True(textWithSpecialTokens.AsSpan().StartsWith(sp));
}
Assert.Equal(OperationStatus.Done, tokenizerWithSuffix.Decode(ids, destinationBuffer.AsSpan(), considerSpecialTokens: true, out idsConsumed, out charactersWritten));
Assert.Equal(ids.Count, idsConsumed);
Assert.Equal(textWithSpecialTokens.Length, charactersWritten);
Assert.Equal(textWithSpecialTokens, destinationBuffer.AsSpan(0, charactersWritten).ToString());
}
}
}

Просмотреть файл

@ -19,14 +19,14 @@ namespace Microsoft.ML.Tokenizers.Tests
{
yield return new object?[]
{
new LowerCaseNormalizer(),
LowerCaseNormalizer.Instance,
"How Are You Doing?",
"how are you doing?",
};
yield return new object?[]
{
new UpperCaseNormalizer(),
UpperCaseNormalizer.Instance,
"How Are You Doing?",
"HOW ARE YOU DOING?",
};
@ -62,7 +62,7 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(normalized, normalizedText);
Tokenizer tokenizer = BpeTests.CreateEmptyBpe(preTokenizer: null, normalizer);
IReadOnlyList<Token> tokens = tokenizer.Encode(text, out string? normalizedString);
IReadOnlyList<EncodedToken> tokens = tokenizer.EncodeToTokens(text, out string? normalizedString);
Assert.Equal(normalized, normalizedString);
}

Просмотреть файл

@ -18,14 +18,14 @@ namespace Microsoft.ML.Tokenizers.Tests
{
yield return new object[]
{
WhiteSpace.Instance,
WhiteSpacePreTokenizer.Instance,
"How are you doing?",
new (int Offset, int Length)[] { (0, 3), (4, 3), (8, 3), (12, 5), (17, 1), }
};
yield return new object[]
{
WhiteSpace.Instance,
WhiteSpacePreTokenizer.Instance,
"I_am_Just_Fine!",
new (int Offset, int Length)[] { (0, 14), (14, 1) }
};
@ -56,14 +56,14 @@ namespace Microsoft.ML.Tokenizers.Tests
// Empty tokenizer which tokenize all parts as unknown tokens.
Tokenizer tokenizer = BpeTests.CreateEmptyBpe(normalizer: null, preTokenizer: preTokenizer);
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
IReadOnlyList<EncodedToken> encoding = tokenizer.EncodeToTokens(text, out _);
Assert.True(encoding.Count >= splitParts.Length, $"Expected to have {encoding.Count} >= {splitParts.Length}");
}
[Fact]
public void TestWhiteSpacePreTokenizer()
{
Assert.Empty(WhiteSpace.Instance.PreTokenize((string)null!));
Assert.Empty(WhiteSpacePreTokenizer.Instance.PreTokenize((string)null!));
}
public class SpacePreTokenizer : PreTokenizer

Просмотреть файл

@ -3,12 +3,13 @@
// See the LICENSE file in the project root for more information.
using Microsoft.DotNet.RemoteExecutor;
using Microsoft.ML.Tokenizers;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.IO.Compression;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
@ -27,20 +28,20 @@ namespace Microsoft.ML.Tokenizers.Tests
{ IMEnd, 100265},
};
public static Tokenizer GPT4 { get; } = Tokenizer.CreateTiktokenForModel("gpt-4", _specialTokens);
public static Tokenizer GPT2 { get; } = Tokenizer.CreateTiktokenForModel("gpt2");
public static Tokenizer P50kBase { get; } = Tokenizer.CreateTiktokenForModel("text-davinci-003");
public static Tokenizer R50kBase { get; } = Tokenizer.CreateTiktokenForModel("ada");
public static Tokenizer P50kEdit { get; } = Tokenizer.CreateTiktokenForModel("text-davinci-edit-001");
public static Tokenizer GPT4o { get; } = Tokenizer.CreateTiktokenForModel("gpt-4o");
public static Tokenizer GPT4 { get; } = TiktokenTokenizer.CreateForModel("gpt-4", _specialTokens);
public static Tokenizer GPT2 { get; } = TiktokenTokenizer.CreateForModel("gpt2");
public static Tokenizer P50kBase { get; } = TiktokenTokenizer.CreateForModel("text-davinci-003");
public static Tokenizer R50kBase { get; } = TiktokenTokenizer.CreateForModel("ada");
public static Tokenizer P50kEdit { get; } = TiktokenTokenizer.CreateForModel("text-davinci-edit-001");
public static Tokenizer GPT4o { get; } = TiktokenTokenizer.CreateForModel("gpt-4o");
[Fact]
public async void TestTokenizerCreation()
{
TestGPT4TokenizationEncoding(GPT4);
Assert.True(GPT4 is Tiktoken);
IReadOnlyDictionary<string, int>? specialTokensEncoder = (GPT4 as Tiktoken)!.SpecialTokens;
Assert.True(GPT4 is TiktokenTokenizer);
IReadOnlyDictionary<string, int>? specialTokensEncoder = (GPT4 as TiktokenTokenizer)!.SpecialTokens;
string tokenizerDataFileName = Utils.CreateTemporaryFile("tiktoken");
@ -54,37 +55,37 @@ namespace Microsoft.ML.Tokenizers.Tests
try
{
Tokenizer tokenizer = new Tiktoken(tokenizerDataFileName, GPT4.PreTokenizer, specialTokensEncoder);
Tokenizer tokenizer = TiktokenTokenizer.Create(tokenizerDataFileName, GPT4.PreTokenizer, null, specialTokensEncoder);
TestGPT4TokenizationEncoding(tokenizer);
using (Stream stream = File.OpenRead(tokenizerDataFileName))
{
tokenizer = new Tiktoken(stream, GPT4.PreTokenizer, specialTokensEncoder);
tokenizer = TiktokenTokenizer.Create(stream, GPT4.PreTokenizer, null, specialTokensEncoder);
}
TestGPT4TokenizationEncoding(tokenizer);
tokenizer = await Tokenizer.CreateTiktokenAsync(tokenizerDataFileName, GPT4.PreTokenizer, normalizer: null, specialTokensEncoder);
tokenizer = await TiktokenTokenizer.CreateAsync(tokenizerDataFileName, GPT4.PreTokenizer, normalizer: null, specialTokensEncoder);
TestGPT4TokenizationEncoding(tokenizer);
using (Stream stream = File.OpenRead(tokenizerDataFileName))
{
tokenizer = await Tokenizer.CreateTiktokenAsync(stream, GPT4.PreTokenizer, normalizer: null, specialTokensEncoder);
tokenizer = await TiktokenTokenizer.CreateAsync(stream, GPT4.PreTokenizer, normalizer: null, specialTokensEncoder);
}
TestGPT4TokenizationEncoding(tokenizer);
using (Stream stream = File.OpenRead(tokenizerDataFileName))
{
tokenizer = Tokenizer.CreateTiktokenForModel("gpt-4", stream);
tokenizer = TiktokenTokenizer.CreateForModel("gpt-4", stream);
}
TestGPT4TokenizationEncoding(tokenizer);
using (Stream stream = File.OpenRead(tokenizerDataFileName))
{
tokenizer = await Tokenizer.CreateTiktokenForModelAsync("gpt-3.5-turbo", stream);
tokenizer = await TiktokenTokenizer.CreateForModelAsync("gpt-3.5-turbo", stream);
}
TestGPT4TokenizationEncoding(tokenizer);
tokenizer = Tokenizer.CreateTiktokenForModel("gpt-4");
tokenizer = TiktokenTokenizer.CreateForModel("gpt-4");
TestGPT4TokenizationEncoding(tokenizer);
}
finally
@ -111,11 +112,11 @@ namespace Microsoft.ML.Tokenizers.Tests
try
{
Tiktoken tiktoken = (tokenizer as Tiktoken)!;
Tokenizer externalTokenizer = new Tiktoken(tokenizerDataFileName, tokenizer.PreTokenizer, tiktoken.SpecialTokens);
TiktokenTokenizer tiktoken = (tokenizer as TiktokenTokenizer)!;
TiktokenTokenizer externalTokenizer = TiktokenTokenizer.Create(tokenizerDataFileName, tokenizer.PreTokenizer, null, tiktoken.SpecialTokens);
IReadOnlyDictionary<ReadOnlyMemory<byte>, int> encoder = tiktoken.Encoder;
IReadOnlyDictionary<ReadOnlyMemory<byte>, int> externalEncoder = (externalTokenizer as Tiktoken)!.Encoder;
IReadOnlyDictionary<ReadOnlyMemory<byte>, int> encoder = GetEncoder(tiktoken)!;
IReadOnlyDictionary<ReadOnlyMemory<byte>, int> externalEncoder = GetEncoder(externalTokenizer)!;
Assert.Equal(externalEncoder.Count, encoder.Count);
foreach (KeyValuePair<ReadOnlyMemory<byte>, int> kvp in encoder)
@ -136,8 +137,9 @@ namespace Microsoft.ML.Tokenizers.Tests
IReadOnlyList<int> encoded = tokenizer.EncodeToIds(text);
Assert.Equal(new List<int>() { 9906, 4435 }, encoded);
Assert.Equal(text, tokenizer.Decode(encoded)!);
TestDecodingWithSpan((tokenizer as TiktokenTokenizer)!, encoded.ToArray(), text);
IReadOnlyList<Token> result = tokenizer.Encode(text, out string? normalizedString);
IReadOnlyList<EncodedToken> result = tokenizer.EncodeToTokens(text, out string? normalizedString);
int idsCount = tokenizer.CountTokens(text);
int[] ids = result.Select(token => token.Id).ToArray();
@ -152,6 +154,35 @@ namespace Microsoft.ML.Tokenizers.Tests
TestGPT4Tokenizer(tokenizer);
}
private void TestDecodingWithSpan(TiktokenTokenizer tokenizer, int[] ids, string expectedDecoded)
{
char[] destinationBuffer = new char[expectedDecoded.Length];
OperationStatus status;
int lastIdsConsumed = 0;
int lastCharactersWritten = 0;
int idsConsumed;
int charactersWritten;
for (int i = 1; i < destinationBuffer.Length - 1; i += Math.Max(1, destinationBuffer.Length - 3)) // enough to test length 1, and destinationBuffer.Length - 2 only.
{
status = tokenizer.Decode(ids, destinationBuffer.AsSpan().Slice(0, i), out idsConsumed, out charactersWritten);
Assert.Equal(OperationStatus.DestinationTooSmall, status);
Assert.True(idsConsumed < ids.Length);
Assert.True(idsConsumed >= lastIdsConsumed);
Assert.True(charactersWritten < expectedDecoded.Length);
Assert.True(charactersWritten >= lastCharactersWritten);
lastIdsConsumed = idsConsumed;
lastCharactersWritten = charactersWritten;
}
status = tokenizer.Decode(ids, destinationBuffer.AsSpan(), out idsConsumed, out charactersWritten);
Assert.Equal(OperationStatus.Done, status);
Assert.Equal(ids.Length, idsConsumed);
Assert.Equal(expectedDecoded.Length, charactersWritten);
Assert.Equal(expectedDecoded, destinationBuffer.AsSpan().ToString());
}
[Fact]
public void TestEncode1()
{
@ -159,8 +190,9 @@ namespace Microsoft.ML.Tokenizers.Tests
IReadOnlyList<int> encoded = GPT4.EncodeToIds(text);
Assert.Equal(new List<int>() { 100264, 9906, 4435, 100265 }, encoded);
Assert.Equal(text, GPT4.Decode(encoded));
TestDecodingWithSpan((GPT4 as TiktokenTokenizer)!, encoded.ToArray(), text);
IReadOnlyList<Token> result = GPT4.Encode(text, out string? normalizedString);
IReadOnlyList<EncodedToken> result = GPT4.EncodeToTokens(text, out string? normalizedString);
int idsCount = GPT4.CountTokens(text);
int[] ids = result.Select(token => token.Id).ToArray();
@ -188,8 +220,8 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(expected!, encoded);
}
string? decoded = gpt4Tokenizer.Decode(encoded);
Assert.Equal(text, decoded!);
Assert.Equal(text, gpt4Tokenizer.Decode(encoded));
TestDecodingWithSpan((gpt4Tokenizer as TiktokenTokenizer)!, encoded.ToArray(), text);
TokenizerTests.TestTokenLimits(gpt4Tokenizer);
}
@ -200,10 +232,10 @@ namespace Microsoft.ML.Tokenizers.Tests
string text = "<|im_start|>Hello<|im_end|> World";
IReadOnlyList<int> encoded = GPT4.EncodeToIds(text);
Assert.Equal(new List<int>() { 100264, 9906, 100265, 4435 }, encoded);
string? decoded = GPT4.Decode(encoded);
Assert.Equal(text, decoded);
Assert.Equal(text, GPT4.Decode(encoded));
TestDecodingWithSpan((GPT4 as TiktokenTokenizer)!, encoded.ToArray(), text);
IReadOnlyList<Token> result = GPT4.Encode(text, out string? normalizedString);
IReadOnlyList<EncodedToken> result = GPT4.EncodeToTokens(text, out string? normalizedString);
int[] ids = result.Select(token => token.Id).ToArray();
string[] tokens = result.Select(token => token.Value).ToArray();
(int, int)[] offsets = result.Select(token => token.Offset).ToArray();
@ -222,7 +254,7 @@ namespace Microsoft.ML.Tokenizers.Tests
IReadOnlyList<int> encoded = GPT4.EncodeToIds(text);
Assert.Empty(encoded);
IReadOnlyList<Token> result = GPT4.Encode(text, out string? normalizedString);
IReadOnlyList<EncodedToken> result = GPT4.EncodeToTokens(text, out string? normalizedString);
int idsCount = GPT4.CountTokens(text);
Assert.Empty(result);
Assert.Equal(0, idsCount);
@ -236,8 +268,9 @@ namespace Microsoft.ML.Tokenizers.Tests
int idsCount = GPT4.CountTokens(text);
Assert.Equal(new List<int>() { 100264, 9906, 2928, 99834, 4435, 100265 }, encoded);
Assert.Equal(text, GPT4.Decode(encoded));
TestDecodingWithSpan((GPT4 as TiktokenTokenizer)!, encoded.ToArray(), text);
IReadOnlyList<Token> result = GPT4.Encode(text, out string? normalizedString);
IReadOnlyList<EncodedToken> result = GPT4.EncodeToTokens(text, out string? normalizedString);
Assert.Equal(encoded, result.Select(token => token.Id).ToArray());
Assert.Equal(encoded.Count, idsCount);
Assert.Equal(new string[] { "<|im_start|>", "Hello", " ⭐", "⭐", " World", "<|im_end|>" }, result.Select(token => token.Value).ToArray());
@ -260,8 +293,8 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(expected!, encoded);
}
string? decoded = GPT4o.Decode(encoded);
Assert.Equal(text, decoded);
Assert.Equal(text, GPT4o.Decode(encoded));
TestDecodingWithSpan((GPT4o as TiktokenTokenizer)!, encoded.ToArray(), text);
text = "<|endoftext|>Hello ⭐ World<|endofprompt|>";
@ -269,8 +302,9 @@ namespace Microsoft.ML.Tokenizers.Tests
idsCount = GPT4o.CountTokens(text);
Assert.Equal(new List<int>() { 199999, 13225, 161181, 5922, 200018 }, encoded);
Assert.Equal(text, GPT4o.Decode(encoded));
TestDecodingWithSpan((GPT4o as TiktokenTokenizer)!, encoded.ToArray(), text);
IReadOnlyList<Token> result = GPT4o.Encode(text, out string? normalizedString);
IReadOnlyList<EncodedToken> result = GPT4o.EncodeToTokens(text, out string? normalizedString);
Assert.Equal(encoded, result.Select(token => token.Id).ToArray());
Assert.Equal(encoded.Count, idsCount);
@ -295,8 +329,8 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(expected!, encoded);
}
string? decoded = GPT2.Decode(encoded);
Assert.Equal(text, decoded);
Assert.Equal(text, GPT2.Decode(encoded));
TestDecodingWithSpan((GPT2 as TiktokenTokenizer)!, encoded.ToArray(), text);
}
[Fact]
@ -314,8 +348,8 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(expected!, encoded);
}
string? decoded = P50kBase.Decode(encoded);
Assert.Equal(text, decoded);
Assert.Equal(text, P50kBase.Decode(encoded));
TestDecodingWithSpan((P50kBase as TiktokenTokenizer)!, encoded.ToArray(), text);
}
[Fact]
@ -333,8 +367,8 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(expected!, encoded);
}
string? decoded = P50kEdit.Decode(encoded);
Assert.Equal(text, decoded);
Assert.Equal(text, P50kEdit.Decode(encoded));
TestDecodingWithSpan((P50kEdit as TiktokenTokenizer)!, encoded.ToArray(), text);
}
[Fact]
@ -352,8 +386,8 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(expected!, encoded);
}
string? decoded = R50kBase.Decode(encoded);
Assert.Equal(text, decoded);
Assert.Equal(text, R50kBase.Decode(encoded));
TestDecodingWithSpan((R50kBase as TiktokenTokenizer)!, encoded.ToArray(), text);
}
[Theory]
@ -404,8 +438,8 @@ namespace Microsoft.ML.Tokenizers.Tests
[InlineData("gpt2")]
public void TestAllSupportedModelNames(string modelName)
{
Tokenizer tokenizer = Tokenizer.CreateTiktokenForModel(modelName);
Assert.True(tokenizer is Tiktoken);
Tokenizer tokenizer = TiktokenTokenizer.CreateForModel(modelName);
Assert.True(tokenizer is TiktokenTokenizer);
Assert.NotNull(tokenizer.PreTokenizer);
}
@ -417,8 +451,8 @@ namespace Microsoft.ML.Tokenizers.Tests
[InlineData("o200k_base")]
public void TestAllSupportedEncodingNames(string encodingName)
{
Tokenizer tokenizer = Tokenizer.CreateTiktokenForEncoding(encodingName);
Assert.True(tokenizer is Tiktoken);
Tokenizer tokenizer = TiktokenTokenizer.CreateForEncoding(encodingName);
Assert.True(tokenizer is TiktokenTokenizer);
Assert.NotNull(tokenizer.PreTokenizer);
string modelName = encodingName.ToLowerInvariant() switch
@ -431,29 +465,29 @@ namespace Microsoft.ML.Tokenizers.Tests
_ => throw new ArgumentException("Invalid encoding name"),
};
Tokenizer tokenizer1 = Tokenizer.CreateTiktokenForModel(modelName);
Tokenizer tokenizer1 = TiktokenTokenizer.CreateForModel(modelName);
Assert.True(tokenizer is Tiktoken);
Assert.True(tokenizer1 is Tiktoken);
Assert.True(tokenizer is TiktokenTokenizer);
Assert.True(tokenizer1 is TiktokenTokenizer);
Tiktoken tiktoken = (tokenizer as Tiktoken)!;
Tiktoken tiktoken1 = (tokenizer1 as Tiktoken)!;
TiktokenTokenizer tiktoken = (tokenizer as TiktokenTokenizer)!;
TiktokenTokenizer tiktoken1 = (tokenizer1 as TiktokenTokenizer)!;
Assert.Equal(tiktoken1.Encoder, tiktoken.Encoder);
Assert.Equal(tiktoken1.Decoder, tiktoken.Decoder);
Assert.Equal(GetEncoder(tiktoken1), GetEncoder(tiktoken));
Assert.Equal(GetDecoder(tiktoken1), GetDecoder(tiktoken));
Assert.Equal(tiktoken1.SpecialTokens, tiktoken.SpecialTokens);
Assert.Equal(tiktoken1.Vocab, tiktoken.Vocab);
Assert.Equal(GetVocabulary(tiktoken1), GetVocabulary(tiktoken));
}
[Fact]
public void TestEncodingNamesNegativeCases()
{
Assert.Throws<ArgumentNullException>(() => Tokenizer.CreateTiktokenForEncoding(null!));
Assert.Throws<ArgumentException>(() => Tokenizer.CreateTiktokenForEncoding("r50k_base_"));
Assert.Throws<ArgumentException>(() => Tokenizer.CreateTiktokenForEncoding("p50k_base_"));
Assert.Throws<ArgumentException>(() => Tokenizer.CreateTiktokenForEncoding("p50k_edit_"));
Assert.Throws<ArgumentException>(() => Tokenizer.CreateTiktokenForEncoding("cl100k_base_"));
Assert.Throws<ArgumentException>(() => Tokenizer.CreateTiktokenForEncoding("o200k_base_"));
Assert.Throws<ArgumentNullException>(() => TiktokenTokenizer.CreateForEncoding(null!));
Assert.Throws<ArgumentException>(() => TiktokenTokenizer.CreateForEncoding("r50k_base_"));
Assert.Throws<ArgumentException>(() => TiktokenTokenizer.CreateForEncoding("p50k_base_"));
Assert.Throws<ArgumentException>(() => TiktokenTokenizer.CreateForEncoding("p50k_edit_"));
Assert.Throws<ArgumentException>(() => TiktokenTokenizer.CreateForEncoding("cl100k_base_"));
Assert.Throws<ArgumentException>(() => TiktokenTokenizer.CreateForEncoding("o200k_base_"));
}
[InlineData("gpt-4")]
@ -466,8 +500,8 @@ namespace Microsoft.ML.Tokenizers.Tests
{
RemoteExecutor.Invoke(static (name) =>
{
Tokenizer tokenizer = Tokenizer.CreateTiktokenForModel(name);
Assert.True(tokenizer is Tiktoken);
Tokenizer tokenizer = TiktokenTokenizer.CreateForModel(name);
Assert.True(tokenizer is TiktokenTokenizer);
Assert.NotNull(tokenizer.PreTokenizer);
}, modelName).Dispose();
}
@ -514,8 +548,8 @@ namespace Microsoft.ML.Tokenizers.Tests
{
Tokenizer tokenizer = GPT4;
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
IReadOnlyList<Token> encoding1 = tokenizer.Encode(text.AsSpan(), out _);
IReadOnlyList<EncodedToken> encoding = tokenizer.EncodeToTokens(text, out _);
IReadOnlyList<EncodedToken> encoding1 = tokenizer.EncodeToTokens(text.AsSpan(), out _);
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray());
@ -545,17 +579,17 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text));
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text.AsSpan()));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text, expectedIds.Length - 3, out normalizedString, out int tokenCount));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.GetIndexByTokenCount(text, expectedIds.Length - 3, out normalizedString, out int tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIds.Length - 3, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text.AsSpan(), expectedIds.Length - 3, out normalizedString, out tokenCount));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.GetIndexByTokenCount(text.AsSpan(), expectedIds.Length - 3, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIds.Length - 3, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text, 3, out normalizedString, out tokenCount));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.GetIndexByTokenCountFromEnd(text, 3, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(3, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text.AsSpan(), 3, out normalizedString, out tokenCount));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.GetIndexByTokenCountFromEnd(text.AsSpan(), 3, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(3, tokenCount);
}
@ -629,7 +663,7 @@ namespace Microsoft.ML.Tokenizers.Tests
[MemberData(nameof(TokenizerLimitsTestData))]
public void TestPreciseTokenLimits(string text, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds)
{
IReadOnlyList<Token> result = GPT4.Encode(text, out _);
IReadOnlyList<EncodedToken> result = GPT4.EncodeToTokens(text, out _);
int[] ids = result.Select(r => r.Id).ToArray();
(int Index, int Length)[] offsets = result.Select(r => r.Offset).ToArray();
Assert.Equal(expectedTokens, result.Select(r => r.Value));
@ -640,7 +674,7 @@ namespace Microsoft.ML.Tokenizers.Tests
for (int tokenCount = 1; tokenCount <= ids.Length; tokenCount++)
{
int length = GPT4.IndexOfTokenCount(text, tokenCount, out _, out int count);
int length = GPT4.GetIndexByTokenCount(text, tokenCount, out _, out int count);
Assert.True(count <= ids.Length);
if (count < tokenCount)
@ -658,7 +692,7 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(0, length);
}
int index = GPT4.LastIndexOfTokenCount(text, tokenCount, out _, out count);
int index = GPT4.GetIndexByTokenCountFromEnd(text, tokenCount, out _, out count);
Assert.True(count <= ids.Length);
if (count < tokenCount)
@ -677,6 +711,16 @@ namespace Microsoft.ML.Tokenizers.Tests
}
}
}
// We are not exposing the Encoder, Decoder, or Vocabulary so far. For now, use reflection to test it.
private static IReadOnlyDictionary<ReadOnlyMemory<byte>, int>? GetEncoder(TiktokenTokenizer tiktoken)
=> typeof(TiktokenTokenizer).GetProperty("Encoder", BindingFlags.Instance | BindingFlags.NonPublic)?.GetValue(tiktoken) as IReadOnlyDictionary<ReadOnlyMemory<byte>, int>;
private static IReadOnlyDictionary<int, ReadOnlyMemory<byte>>? GetDecoder(TiktokenTokenizer tiktoken)
=> typeof(TiktokenTokenizer).GetProperty("Decoder", BindingFlags.Instance | BindingFlags.NonPublic)?.GetValue(tiktoken) as IReadOnlyDictionary<int, ReadOnlyMemory<byte>>;
private static IReadOnlyDictionary<string, int>? GetVocabulary(TiktokenTokenizer tiktoken)
=> typeof(TiktokenTokenizer).GetProperty("Vocabulary", BindingFlags.Instance | BindingFlags.NonPublic)?.GetValue(tiktoken) as IReadOnlyDictionary<string, int>;
}
}

Просмотреть файл

@ -26,11 +26,11 @@ namespace Microsoft.ML.Tokenizers.Tests
for (int i = 1; i <= fullIdsList.Count; i++)
{
int index1 = tokenizer.IndexOfTokenCount(input, maxTokenCount: i, out string? processedText1, out int tokenCount1);
int index2 = tokenizer.LastIndexOfTokenCount(input, maxTokenCount: i, out string? processedText2, out int tokenCount2);
IReadOnlyList<int> partialIdsList = tokenizer.EncodeToIds(input, maxTokenCount: i, out string? processedText, out int textLength);
int index1 = tokenizer.GetIndexByTokenCount(input, maxTokenCount: i, out string? processedText1, out int tokenCount1);
int index2 = tokenizer.GetIndexByTokenCountFromEnd(input, maxTokenCount: i, out string? processedText2, out int tokenCount2);
IReadOnlyList<int> partialIdsList = tokenizer.EncodeToIds(input, maxTokenCount: i, out string? processedText, out int charsConsumed);
Assert.True(processedText is null || textLength <= processedText.Length);
Assert.True(processedText is null || charsConsumed <= processedText.Length);
Assert.True(tokenizer.Normalizer is not null || processedText is null);
Assert.Equal(fullIdsList.Take(partialIdsList.Count), partialIdsList);
@ -44,7 +44,7 @@ namespace Microsoft.ML.Tokenizers.Tests
{
string prefixString = (processedText1 ?? input).Substring(0, index1);
if (tokenizer is SentencePieceBpe)
if (tokenizer is SentencePieceBpeTokenizer)
{
// SentencePieceBpe model normalize the text and insert more characters.
// We call the model directly to bypass the normalization step
@ -62,7 +62,7 @@ namespace Microsoft.ML.Tokenizers.Tests
{
string suffixString = (processedText2 ?? input).Substring(index2);
if (tokenizer is SentencePieceBpe)
if (tokenizer is SentencePieceBpeTokenizer)
{
// SentencePieceBpe model normalize the text and insert more characters.
// We call the model directly to bypass the normalization step
@ -105,15 +105,15 @@ namespace Microsoft.ML.Tokenizers.Tests
}
}
Assert.Equal(0, tokenizer.IndexOfTokenCount((string)null!, maxTokenCount: 10, out _, out _));
Assert.Equal(0, tokenizer.LastIndexOfTokenCount((string)null!, maxTokenCount: 10, out _, out _));
Assert.Equal(0, tokenizer.IndexOfTokenCount(Span<char>.Empty, maxTokenCount: 10, out _, out _));
Assert.Equal(0, tokenizer.LastIndexOfTokenCount(Span<char>.Empty, maxTokenCount: 10, out _, out _));
Assert.Equal(0, tokenizer.GetIndexByTokenCount((string)null!, maxTokenCount: 10, out _, out _));
Assert.Equal(0, tokenizer.GetIndexByTokenCountFromEnd((string)null!, maxTokenCount: 10, out _, out _));
Assert.Equal(0, tokenizer.GetIndexByTokenCount(Span<char>.Empty, maxTokenCount: 10, out _, out _));
Assert.Equal(0, tokenizer.GetIndexByTokenCountFromEnd(Span<char>.Empty, maxTokenCount: 10, out _, out _));
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.IndexOfTokenCount(input, maxTokenCount: 0, out _, out _));
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.IndexOfTokenCount(input, maxTokenCount: -1, out _, out _));
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.LastIndexOfTokenCount(input, maxTokenCount: 0, out _, out _));
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.LastIndexOfTokenCount(input, maxTokenCount: -1, out _, out _));
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.GetIndexByTokenCount(input, maxTokenCount: 0, out _, out _));
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.GetIndexByTokenCount(input, maxTokenCount: -1, out _, out _));
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.GetIndexByTokenCountFromEnd(input, maxTokenCount: 0, out _, out _));
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.GetIndexByTokenCountFromEnd(input, maxTokenCount: -1, out _, out _));
}
}
}