* Tokenizer's APIs Update

* Address the feedback

* Address the feedback and use the new TestTokenizers package
This commit is contained in:
Tarek Mahmoud Sayed 2024-04-19 09:21:27 -07:00 коммит произвёл GitHub
Родитель fac1e1018b
Коммит 72cfdf611a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
31 изменённых файлов: 2826 добавлений и 1137 удалений

Просмотреть файл

@ -41,7 +41,7 @@
<MicrosoftDotNetInteractiveVersion>1.0.0-beta.23509.3</MicrosoftDotNetInteractiveVersion> <MicrosoftDotNetInteractiveVersion>1.0.0-beta.23509.3</MicrosoftDotNetInteractiveVersion>
<MicrosoftMLOnnxRuntimeVersion>1.16.3</MicrosoftMLOnnxRuntimeVersion> <MicrosoftMLOnnxRuntimeVersion>1.16.3</MicrosoftMLOnnxRuntimeVersion>
<MlNetMklDepsVersion>0.0.0.12</MlNetMklDepsVersion> <MlNetMklDepsVersion>0.0.0.12</MlNetMklDepsVersion>
<!-- <!--
@("inteltbb.devel", "win", "2021.7.1.15305") @("inteltbb.devel", "win", "2021.7.1.15305")
--> -->
<OneDalPkgVersion Condition="'$(OS)' == 'Windows_NT'">2023.0.0.23189</OneDalPkgVersion> <OneDalPkgVersion Condition="'$(OS)' == 'Windows_NT'">2023.0.0.23189</OneDalPkgVersion>
@ -87,6 +87,7 @@
<MicrosoftMLTensorFlowTestModelsVersion>0.0.13-test</MicrosoftMLTensorFlowTestModelsVersion> <MicrosoftMLTensorFlowTestModelsVersion>0.0.13-test</MicrosoftMLTensorFlowTestModelsVersion>
<MicrosoftMLTestDatabasesVersion>0.0.6-test</MicrosoftMLTestDatabasesVersion> <MicrosoftMLTestDatabasesVersion>0.0.6-test</MicrosoftMLTestDatabasesVersion>
<MicrosoftMLTestModelsVersion>0.0.7-test</MicrosoftMLTestModelsVersion> <MicrosoftMLTestModelsVersion>0.0.7-test</MicrosoftMLTestModelsVersion>
<MicrosoftMLTestTokenizersVersion>2.0.0-beta.24218.2</MicrosoftMLTestTokenizersVersion>
<SystemDataSqlClientVersion>4.8.6</SystemDataSqlClientVersion> <SystemDataSqlClientVersion>4.8.6</SystemDataSqlClientVersion>
<SystemDataSQLiteCoreVersion>1.0.118</SystemDataSQLiteCoreVersion> <SystemDataSQLiteCoreVersion>1.0.118</SystemDataSQLiteCoreVersion>
<XunitCombinatorialVersion>1.2.7</XunitCombinatorialVersion> <XunitCombinatorialVersion>1.2.7</XunitCombinatorialVersion>

Просмотреть файл

@ -1,152 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Text;
namespace Microsoft.ML.Tokenizers
{
/// <summary>
/// The Encoding represents the output of a Tokenizer.
/// </summary>
public sealed class EncodingResult
{
/// <summary>
/// Create a new object of the EncodingResult object.
/// </summary>
/// <param name="originalString">The list of tokens to merge.</param>
/// <param name="normalizedString">The list of tokens to merge.</param>
/// <param name="splits">The list of tokens to merge.</param>
/// <param name="offsetsMappedToOriginalString">Indicate whether the offsets is mapped to the original string or the normalized string.</param>
public EncodingResult(string originalString, string normalizedString, IEnumerable<Split> splits, bool offsetsMappedToOriginalString)
{
OriginalString = originalString;
NormalizedString = normalizedString;
Splits = splits;
OffsetsMappedToOriginalString = offsetsMappedToOriginalString;
}
/// <summary>
/// Gets the original tokenized string.
/// </summary>
public string? OriginalString { get; }
/// <summary>
/// Gets the normalized form of the original string.
/// </summary>
public string? NormalizedString { get; }
/// <summary>
/// Gets the normalized form of the original string.
/// </summary>
public bool OffsetsMappedToOriginalString { get; }
internal IEnumerable<Split> Splits { get; }
private List<Token>? _tokens;
private List<string>? _tokensWords;
private List<int>? _ids;
private List<(int Index, int Length)>? _offsets;
internal void AddTokens(IReadOnlyList<Token> addedTokens)
{
if (_tokens is null)
{
_tokens = new(addedTokens);
return;
}
foreach (var token in addedTokens)
{
_tokens.Add(token);
}
}
/// <summary>
/// Gets list of the tokens Ids.
/// The Ids are the main input to a Language Model. They are the token indices, the numerical representations that a LM understands.
/// </summary>
public IReadOnlyList<int> Ids
{
get
{
if (_ids is not null)
{
return _ids;
}
if (_tokens is null)
{
return Array.Empty<int>();
}
_ids = new List<int>(_tokens.Count);
foreach (var token in _tokens)
{
_ids.Add(token.Id);
}
return _ids;
}
}
/// <summary>
/// Gets the generated tokens. They are the string representation of the Ids.
/// </summary>
public IReadOnlyList<string> Tokens
{
get
{
if (_tokensWords is not null)
{
return _tokensWords;
}
if (_tokens is null)
{
return Array.Empty<string>();
}
_tokensWords = new List<string>(_tokens.Count);
foreach (var token in _tokens)
{
_tokensWords.Add(token.Value);
}
return _tokensWords;
}
}
/// <summary>
/// Gets The list of offsets. These offsets let's you slice the input string, and thus retrieve
/// the original part that led to producing the corresponding token.
/// </summary>
public IReadOnlyList<(int Index, int Length)> Offsets
{
get
{
if (_offsets is not null)
{
return _offsets;
}
if (_tokens is null)
{
return Array.Empty<(int, int)>();
}
_offsets = new List<(int Index, int Length)>(_tokens.Count);
foreach (var token in _tokens)
{
_offsets.Add(token.Offset);
}
return _offsets;
}
}
}
}

Просмотреть файл

@ -17,13 +17,15 @@ namespace Microsoft.ML.Tokenizers
/// <summary> /// <summary>
/// Represent the Byte Pair Encoding model. /// Represent the Byte Pair Encoding model.
/// </summary> /// </summary>
public sealed class Bpe : Model public sealed class Bpe : Tokenizer
{ {
/// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model. /// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model.
private const int MaxWordLengthToCache = 15; private const int MaxWordLengthToCache = 15;
private string? _unknownToken; private string? _unknownToken;
private int? _unknownTokenId; private int? _unknownTokenId;
private readonly PreTokenizer? _preTokenizer;
private readonly Normalizer? _normalizer;
/// <summary> /// <summary>
/// Gets or Sets unknown token. The unknown token to be used when we encounter an unknown char /// Gets or Sets unknown token. The unknown token to be used when we encounter an unknown char
@ -74,13 +76,15 @@ namespace Microsoft.ML.Tokenizers
/// </summary> /// </summary>
/// <param name="vocabFile">The JSON file path containing the dictionary of string keys and their ids.</param> /// <param name="vocabFile">The JSON file path containing the dictionary of string keys and their ids.</param>
/// <param name="mergesFile">The file path containing the tokens's pairs list.</param> /// <param name="mergesFile">The file path containing the tokens's pairs list.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="unknownToken"> The unknown token to be used by the model.</param> /// <param name="unknownToken"> The unknown token to be used by the model.</param>
/// <param name="continuingSubwordPrefix">The prefix to attach to sub-word units that dont represent a beginning of word.</param> /// <param name="continuingSubwordPrefix">The prefix to attach to sub-word units that dont represent a beginning of word.</param>
/// <param name="endOfWordSuffix">The suffix to attach to sub-word units that represent an end of word.</param> /// <param name="endOfWordSuffix">The suffix to attach to sub-word units that represent an end of word.</param>
/// <param name="fuseUnknownTokens">Indicate whether allowing multiple unknown tokens get fused.</param> /// <param name="fuseUnknownTokens">Indicate whether allowing multiple unknown tokens get fused.</param>
public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) : public Bpe(string vocabFile, string? mergesFile, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) :
this(vocabFile is null ? throw new ArgumentNullException(nameof(vocabFile)) : File.Open(vocabFile, FileMode.Open, FileAccess.Read), this(vocabFile is null ? throw new ArgumentNullException(nameof(vocabFile)) : File.Open(vocabFile, FileMode.Open, FileAccess.Read),
mergesFile is null ? null : File.Open(mergesFile, FileMode.Open, FileAccess.Read), unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: true) mergesFile is null ? null : File.Open(mergesFile, FileMode.Open, FileAccess.Read), preTokenizer, normalizer, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: true)
{ {
} }
@ -89,16 +93,18 @@ namespace Microsoft.ML.Tokenizers
/// </summary> /// </summary>
/// <param name="vocabStream">The JSON stream containing the dictionary of string keys and their ids.</param> /// <param name="vocabStream">The JSON stream containing the dictionary of string keys and their ids.</param>
/// <param name="mergesStream">The stream containing the tokens's pairs list.</param> /// <param name="mergesStream">The stream containing the tokens's pairs list.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="unknownToken"> The unknown token to be used by the model.</param> /// <param name="unknownToken"> The unknown token to be used by the model.</param>
/// <param name="continuingSubwordPrefix">The prefix to attach to sub-word units that dont represent a beginning of word.</param> /// <param name="continuingSubwordPrefix">The prefix to attach to sub-word units that dont represent a beginning of word.</param>
/// <param name="endOfWordSuffix">The suffix to attach to sub-word units that represent an end of word.</param> /// <param name="endOfWordSuffix">The suffix to attach to sub-word units that represent an end of word.</param>
/// <param name="fuseUnknownTokens">Indicate whether allowing multiple unknown tokens get fused.</param> /// <param name="fuseUnknownTokens">Indicate whether allowing multiple unknown tokens get fused.</param>
public Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) : public Bpe(Stream vocabStream, Stream? mergesStream, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) :
this(vocabStream, mergesStream, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: false) this(vocabStream, mergesStream, preTokenizer, normalizer, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: false)
{ {
} }
private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, string? continuingSubwordPrefix, string? endOfWordSuffix, bool fuseUnknownTokens, bool disposeStreams) private Bpe(Stream vocabStream, Stream? mergesStream, PreTokenizer? preTokenizer, Normalizer? normalizer, string? unknownToken, string? continuingSubwordPrefix, string? endOfWordSuffix, bool fuseUnknownTokens, bool disposeStreams)
{ {
try try
{ {
@ -110,6 +116,8 @@ namespace Microsoft.ML.Tokenizers
FuseUnknownTokens = fuseUnknownTokens; FuseUnknownTokens = fuseUnknownTokens;
ContinuingSubwordPrefix = continuingSubwordPrefix; ContinuingSubwordPrefix = continuingSubwordPrefix;
EndOfWordSuffix = endOfWordSuffix; EndOfWordSuffix = endOfWordSuffix;
_preTokenizer = preTokenizer ?? WhiteSpace.Instance; // Default to WhiteSpace pre-tokenizer
_normalizer = normalizer;
(Dictionary<StringSpanOrdinalKey, int>? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream); (Dictionary<StringSpanOrdinalKey, int>? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream);
_vocab = vocab1 ?? new Dictionary<StringSpanOrdinalKey, int>(); _vocab = vocab1 ?? new Dictionary<StringSpanOrdinalKey, int>();
@ -166,47 +174,320 @@ namespace Microsoft.ML.Tokenizers
} }
/// <summary> /// <summary>
/// Encode a text string to a list of tokens. /// Gets the PreTokenizer used by the Tokenizer.
/// </summary>
public override PreTokenizer? PreTokenizer => _preTokenizer;
/// <summary>
/// Gets the Normalizer in use by the Tokenizer.
/// </summary>
public override Normalizer? Normalizer => _normalizer;
/// <summary>
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
/// </summary> /// </summary>
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <returns>The list of tokens generated from the text tokenization.</returns> /// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text) /// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
public override IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(text, Span<char>.Empty, out normalizedString, considerPreTokenization, considerNormalization);
/// <summary>
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(null, text, out normalizedString, considerPreTokenization, considerNormalization);
private IReadOnlyList<Token> Encode(string? text, ReadOnlySpan<char> textSpan, out string? normalizedString, bool considerPreTokenization, bool considerNormalization)
{ {
if (text.Length == 0) if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{ {
return EmptyTokensList; normalizedString = null;
return [];
} }
return EncodeWithCache(text); IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
List<Token> tokens = new();
PriorityQueue<Merge>? priorityQueue = null;
if (splits is not null)
{
foreach ((int Offset, int Length) split in splits)
{
EncodeWithCache(textSpanToEncode.Slice(split.Offset, split.Length), tokens, split.Offset, ref priorityQueue);
}
}
else
{
EncodeWithCache(textSpanToEncode, tokens, 0, ref priorityQueue);
}
return tokens;
} }
/// <summary> /// <summary>
/// Encode a split text string to a list of Ids and add them to the accumulatedIds list. /// Encodes input text to token Ids.
/// </summary> /// </summary>
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param> /// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out _, out _);
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(null, text, considerPreTokenization, considerNormalization, out _, out _);
/// <summary>
/// Encodes input text to token Ids up to maximum number of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param> /// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param> /// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns> /// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
public override int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsWithCache(text, accumulatedIds, maxTokens, out textLength); /// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
/// <summary>
/// Encodes input text to token Ids up to maximum number of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(null, text, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
private IReadOnlyList<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
{
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
}
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
textLength = 0;
normalizedString = null;
return [];
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
List<int> ids = new();
PriorityQueue<Merge>? priorityQueue = null;
if (splits is not null)
{
textLength = 0;
foreach ((int Offset, int Length) split in splits)
{
EncodeToIdsWithCache(textSpanToEncode.Slice(split.Offset, split.Length), ids, maxTokenCount - ids.Count, out int length, ref priorityQueue);
textLength = split.Offset + length;
if (length < split.Length || ids.Count >= maxTokenCount)
{
break;
}
}
}
else
{
EncodeToIdsWithCache(textSpanToEncode, ids, maxTokenCount, out textLength, ref priorityQueue);
}
return ids;
}
/// <summary> /// <summary>
/// Get the number of tokens that the input text will be encoded to. /// Get the number of tokens that the input text will be encoded to.
/// </summary> /// </summary>
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param> /// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param> /// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns> /// <returns>The number of token Ids that the input text will be encoded to.</returns>
public override int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsWithCache(text, null, maxTokens, out textLength); public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
=> CountTokens(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out _, out _);
/// <summary> /// <summary>
/// Get the number of tokens that the input text will be encoded to. /// Get the number of tokens that the input text will be encoded to.
/// </summary> /// </summary>
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param> /// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param> /// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns> /// <returns>The number of token Ids that the input text will be encoded to.</returns>
public override int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue) => EncodeToIdsFromEndWithCache(text, null, maxTokens, out textIndex); public override int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
=> CountTokens(null, text, considerPreTokenization, considerNormalization, out _, out _);
/// <summary>
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// </returns>
public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
{
tokenCount = CountTokens(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
return textLength;
}
/// <summary>
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// </returns>
public override int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
{
tokenCount = CountTokens(null, text, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
return textLength;
}
private int CountTokens(string? text, ReadOnlySpan<char> textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
{
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
}
textLength = 0;
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
normalizedString = null;
return 0;
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
PriorityQueue<Merge>? priorityQueue = null;
int count = 0;
if (splits is not null)
{
foreach ((int Offset, int Length) split in splits)
{
count += EncodeToIdsWithCache(textSpanToEncode.Slice(split.Offset, split.Length), null, maxTokenCount - count, out int length, ref priorityQueue);
textLength = split.Offset + length;
if (length < split.Length || count >= maxTokenCount)
{
break;
}
}
}
else
{
count = EncodeToIdsWithCache(textSpanToEncode, null, maxTokenCount, out textLength, ref priorityQueue);
}
return count;
}
/// <summary>
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if normalization is enabled;
/// conversely, if all tokens fit, the result will be 0.
/// </returns>
public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> LastIndexOf(text, Span<char>.Empty, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
/// <summary>
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="normalizedString"/>; conversely, if all tokens fit, the result will be 0.
/// </returns>
public override int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> LastIndexOf(null, text, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
private int LastIndexOf(string? text, ReadOnlySpan<char> textSpan, int maxTokenCount, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int tokenCount)
{
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
}
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
normalizedString = null;
tokenCount = 0;
return 0;
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
PriorityQueue<Merge>? priorityQueue = null;
if (splits is not null)
{
tokenCount = 0;
foreach ((int Offset, int Length) split in splits.Reverse())
{
tokenCount += EncodeToIdsFromEndWithCache(textSpanToEncode.Slice(split.Offset, split.Length), null, maxTokenCount - tokenCount, out int textIndex, ref priorityQueue);
if (textIndex > 0 || tokenCount >= maxTokenCount)
{
return split.Offset + textIndex;
}
}
}
else
{
tokenCount = EncodeToIdsFromEndWithCache(textSpanToEncode, null, maxTokenCount, out int textLength, ref priorityQueue);
return textLength;
}
return 0;
}
/// <summary> /// <summary>
/// Map the token to encoded Id. /// Map the token to encoded Id.
@ -383,7 +664,7 @@ namespace Microsoft.ML.Tokenizers
return s; return s;
} }
internal Word MergeWord(ReadOnlySpan<char> w) internal Word MergeWord(ReadOnlySpan<char> w, ref PriorityQueue<Merge>? priorityQueue)
{ {
Word word = Word.WithCapacity(w.Length); Word word = Word.WithCapacity(w.Length);
(int Id, int Len)? unk = null; (int Id, int Len)? unk = null;
@ -494,23 +775,24 @@ namespace Microsoft.ML.Tokenizers
word.Add(unk.Value.Id, unk.Value.Len); word.Add(unk.Value.Id, unk.Value.Len);
} }
word.MergeAll(Merges, Dropout); word.MergeAll(Merges, Dropout, ref priorityQueue);
return word; return word;
} }
internal List<Token> WordToTokens(ref Word word) => word.ToTokens(VocabReverse); internal void WordToTokens(ref Word word, List<Token> tokens, int offset) => word.ToTokens(VocabReverse, tokens, offset);
internal List<Token> EncodeWithCache(ReadOnlySpan<char> text) internal void EncodeWithCache(ReadOnlySpan<char> text, List<Token> tokens, int offset, ref PriorityQueue<Merge>? priorityQueue)
{ {
Word word; Word word;
if (Cache is not null) if (Cache is not null)
{ {
if (Cache.TryGetValue(text, out word)) if (Cache.TryGetValue(text, out word))
{ {
return WordToTokens(ref word); WordToTokens(ref word, tokens, offset);
return;
} }
word = MergeWord(text); word = MergeWord(text, ref priorityQueue);
if (text.Length <= MaxWordLengthToCache) if (text.Length <= MaxWordLengthToCache)
{ {
@ -519,15 +801,15 @@ namespace Microsoft.ML.Tokenizers
} }
else else
{ {
word = MergeWord(text); word = MergeWord(text, ref priorityQueue);
} }
return WordToTokens(ref word); WordToTokens(ref word, tokens, offset);
} }
internal int WordToIds(ref Word word, IList<int>? accumulatedIds, out int textLength, int fullTextLength, int maxTokens) internal int WordToIds(ref Word word, IList<int>? accumulatedIds, out int textLength, int fullTextLength, int maxTokens)
{ {
if (word.SymbolsCount < maxTokens) if (word.SymbolsCount <= maxTokens)
{ {
textLength = fullTextLength; textLength = fullTextLength;
if (accumulatedIds is not null) if (accumulatedIds is not null)
@ -548,7 +830,7 @@ namespace Microsoft.ML.Tokenizers
internal int WordToIdsFromEnd(ref Word word, IList<int>? accumulatedIds, out int textIndex, int fullTextLength, int maxTokens) internal int WordToIdsFromEnd(ref Word word, IList<int>? accumulatedIds, out int textIndex, int fullTextLength, int maxTokens)
{ {
if (word.SymbolsCount < maxTokens) if (word.SymbolsCount <= maxTokens)
{ {
textIndex = 0; textIndex = 0;
if (accumulatedIds is not null) if (accumulatedIds is not null)
@ -567,7 +849,7 @@ namespace Microsoft.ML.Tokenizers
return word.CountIdsUpToMaxFromEnd(maxTokens, fullTextLength, out textIndex); return word.CountIdsUpToMaxFromEnd(maxTokens, fullTextLength, out textIndex);
} }
internal int EncodeToIdsWithCache(ReadOnlySpan<char> text, IList<int>? accumulatedIds, int maxTokens, out int textLength) private int EncodeToIdsWithCache(ReadOnlySpan<char> text, List<int>? accumulatedIds, int maxTokens, out int textLength, ref PriorityQueue<Merge>? priorityQueue)
{ {
Word word; Word word;
@ -578,7 +860,7 @@ namespace Microsoft.ML.Tokenizers
return WordToIds(ref hit, accumulatedIds, out textLength, text.Length, maxTokens); return WordToIds(ref hit, accumulatedIds, out textLength, text.Length, maxTokens);
} }
word = MergeWord(text); word = MergeWord(text, ref priorityQueue);
if (text.Length <= MaxWordLengthToCache) if (text.Length <= MaxWordLengthToCache)
{ {
@ -587,13 +869,13 @@ namespace Microsoft.ML.Tokenizers
} }
else else
{ {
word = MergeWord(text); word = MergeWord(text, ref priorityQueue);
} }
return WordToIds(ref word, accumulatedIds, out textLength, text.Length, maxTokens); return WordToIds(ref word, accumulatedIds, out textLength, text.Length, maxTokens);
} }
internal int EncodeToIdsFromEndWithCache(ReadOnlySpan<char> text, IList<int>? accumulatedIds, int maxTokens, out int textIndex) internal int EncodeToIdsFromEndWithCache(ReadOnlySpan<char> text, IList<int>? accumulatedIds, int maxTokens, out int textIndex, ref PriorityQueue<Merge>? priorityQueue)
{ {
Word word; Word word;
@ -604,7 +886,7 @@ namespace Microsoft.ML.Tokenizers
return WordToIdsFromEnd(ref hit, accumulatedIds, out textIndex, text.Length, maxTokens); return WordToIdsFromEnd(ref hit, accumulatedIds, out textIndex, text.Length, maxTokens);
} }
word = MergeWord(text); word = MergeWord(text, ref priorityQueue);
if (text.Length <= MaxWordLengthToCache) if (text.Length <= MaxWordLengthToCache)
{ {
@ -613,12 +895,10 @@ namespace Microsoft.ML.Tokenizers
} }
else else
{ {
word = MergeWord(text); word = MergeWord(text, ref priorityQueue);
} }
return WordToIdsFromEnd(ref word, accumulatedIds, out textIndex, text.Length, maxTokens); return WordToIdsFromEnd(ref word, accumulatedIds, out textIndex, text.Length, maxTokens);
} }
internal static readonly List<Token> EmptyTokensList = new();
} }
} }

Просмотреть файл

@ -8,6 +8,7 @@ using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Text;
using System.Text.Json; using System.Text.Json;
namespace Microsoft.ML.Tokenizers namespace Microsoft.ML.Tokenizers
@ -15,7 +16,7 @@ namespace Microsoft.ML.Tokenizers
/// <summary> /// <summary>
/// Represent the Byte Pair Encoding model. /// Represent the Byte Pair Encoding model.
/// </summary> /// </summary>
public sealed class EnglishRoberta : Model public sealed class EnglishRoberta : Tokenizer
{ {
private readonly HighestOccurrenceMapping _vocabIdToHighestOccurrence; private readonly HighestOccurrenceMapping _vocabIdToHighestOccurrence;
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab; private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;
@ -26,6 +27,8 @@ namespace Microsoft.ML.Tokenizers
private readonly IReadOnlyDictionary<char, char> _unicodeToByte; private readonly IReadOnlyDictionary<char, char> _unicodeToByte;
private readonly string[] _charToString; private readonly string[] _charToString;
private readonly StringSpanOrdinalKeyCache<List<Token>> _cache; private readonly StringSpanOrdinalKeyCache<List<Token>> _cache;
private readonly PreTokenizer? _preTokenizer;
private readonly Normalizer? _normalizer;
/// <summary> /// <summary>
/// Indicate if want to filter the unsupported characters during the decoding. /// Indicate if want to filter the unsupported characters during the decoding.
@ -38,29 +41,51 @@ namespace Microsoft.ML.Tokenizers
/// <param name="vocabularyPath">The JSON file path containing the dictionary of string keys and their ids.</param> /// <param name="vocabularyPath">The JSON file path containing the dictionary of string keys and their ids.</param>
/// <param name="mergePath">The file path containing the tokens's pairs list.</param> /// <param name="mergePath">The file path containing the tokens's pairs list.</param>
/// <param name="highestOccurrenceMappingPath">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param> /// <param name="highestOccurrenceMappingPath">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param> /// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
public EnglishRoberta(string vocabularyPath, string mergePath, string highestOccurrenceMappingPath, bool filterUnsupportedChars = true) public EnglishRoberta(string vocabularyPath, string mergePath, string highestOccurrenceMappingPath, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, bool filterUnsupportedChars = true) :
this(vocabularyPath is null ? throw new ArgumentNullException(nameof(vocabularyPath)) : File.OpenRead(vocabularyPath),
mergePath is null ? throw new ArgumentNullException(nameof(mergePath)) : File.OpenRead(mergePath),
highestOccurrenceMappingPath is null ? throw new ArgumentNullException(nameof(highestOccurrenceMappingPath)) : File.OpenRead(highestOccurrenceMappingPath),
preTokenizer, normalizer, filterUnsupportedChars, true)
{ {
if (vocabularyPath is null) }
/// <summary>
/// Construct tokenizer's model object to use with the English Robert model.
/// </summary>
/// <param name="vocabularyStream">The stream of a JSON file containing the dictionary of string keys and their ids.</param>
/// <param name="mergeStream">The stream of a file containing the tokens's pairs list.</param>
/// <param name="highestOccurrenceMappingStream">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, bool filterUnsupportedChars = true) :
this(vocabularyStream, mergeStream, highestOccurrenceMappingStream, preTokenizer, normalizer, filterUnsupportedChars, false)
{
}
public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream, PreTokenizer? preTokenizer, Normalizer? normalizer, bool filterUnsupportedChars, bool disposeStream)
{
if (vocabularyStream is null)
{ {
throw new ArgumentNullException(nameof(vocabularyPath)); throw new ArgumentNullException(nameof(vocabularyStream));
} }
if (mergePath is null) if (mergeStream is null)
{ {
throw new ArgumentNullException(nameof(mergePath)); throw new ArgumentNullException(nameof(mergeStream));
} }
if (highestOccurrenceMappingPath is null) if (highestOccurrenceMappingStream is null)
{ {
throw new ArgumentNullException(nameof(highestOccurrenceMappingPath)); throw new ArgumentNullException(nameof(highestOccurrenceMappingStream));
} }
FilterUnsupportedChars = filterUnsupportedChars; FilterUnsupportedChars = filterUnsupportedChars;
_preTokenizer = preTokenizer;
using Stream vocabularyStream = File.OpenRead(vocabularyPath); _normalizer = normalizer;
using Stream mergeStream = File.OpenRead(mergePath);
using Stream highestOccurrenceMappingStream = File.OpenRead(highestOccurrenceMappingPath);
// vocabularyPath like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json" // vocabularyPath like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
// merge file like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe" // merge file like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
@ -79,48 +104,24 @@ namespace Microsoft.ML.Tokenizers
_unicodeToByte = _byteToUnicode.Reverse(); _unicodeToByte = _byteToUnicode.Reverse();
_cache = new StringSpanOrdinalKeyCache<List<Token>>(); _cache = new StringSpanOrdinalKeyCache<List<Token>>();
if (disposeStream)
{
vocabularyStream.Dispose();
mergeStream.Dispose();
highestOccurrenceMappingStream.Dispose();
}
} }
/// <summary> /// <summary>
/// Construct tokenizer's model object to use with the English Robert model. /// Gets the PreTokenizer used by the Tokenizer.
/// </summary> /// </summary>
/// <param name="vocabularyStream">The stream of a JSON file containing the dictionary of string keys and their ids.</param> public override PreTokenizer? PreTokenizer => _preTokenizer;
/// <param name="mergeStream">The stream of a file containing the tokens's pairs list.</param>
/// <param name="highestOccurrenceMappingStream">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream, bool filterUnsupportedChars = true)
{
if (vocabularyStream is null)
{
throw new ArgumentNullException(nameof(vocabularyStream));
}
if (mergeStream is null) /// <summary>
{ /// Gets the Normalizer in use by the Tokenizer.
throw new ArgumentNullException(nameof(mergeStream)); /// </summary>
} public override Normalizer? Normalizer => _normalizer;
if (highestOccurrenceMappingStream is null)
{
throw new ArgumentNullException(nameof(highestOccurrenceMappingStream));
}
FilterUnsupportedChars = filterUnsupportedChars;
_vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingStream);
_vocab = GetVocabulary(vocabularyStream);
_vocabReverse = _vocab.ReverseSorted();
_mergeRanks = GetMergeRanks(mergeStream);
int maxCharValue = GetByteToUnicode(out _byteToUnicode);
_charToString = new string[maxCharValue];
for (char c = (char)0; c < (char)maxCharValue; c++)
{
_charToString[c] = c.ToString();
}
_unicodeToByte = _byteToUnicode.Reverse();
_cache = new StringSpanOrdinalKeyCache<List<Token>>();
}
/// <summary> /// <summary>
/// Gets the dictionary mapping tokens to Ids. /// Gets the dictionary mapping tokens to Ids.
@ -167,16 +168,64 @@ namespace Microsoft.ML.Tokenizers
return null; return null;
} }
/// <summary>
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
public override IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(text, Span<char>.Empty, out normalizedString, considerPreTokenization, considerNormalization);
/// <summary>
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(null, text, out normalizedString, considerPreTokenization, considerNormalization);
private IReadOnlyList<Token> Encode(string? text, ReadOnlySpan<char> textSpan, out string? normalizedString, bool considerPreTokenization, bool considerNormalization)
{
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
normalizedString = null;
return [];
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
if (splits is not null)
{
List<Token> tokens = new();
foreach ((int Offset, int Length) split in splits)
{
foreach (Token t in EncodeInternal(textSpanToEncode.Slice(split.Offset, split.Length)))
{
tokens.Add(new Token(t.Id, t.Value, (split.Offset + t.Offset.Index, t.Offset.Length)));
}
}
return tokens;
}
else
{
return EncodeInternal(textSpanToEncode);
}
}
/// <summary> /// <summary>
/// Encode a text string to a list of tokens. /// Encode a text string to a list of tokens.
/// </summary> /// </summary>
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <returns>The list of tokens generated from the text tokenization.</returns> /// <returns>The list of tokens generated from the text tokenization.</returns>
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text) private IReadOnlyList<Token> EncodeInternal(ReadOnlySpan<char> text)
{ {
if (text.IsEmpty) if (text.IsEmpty)
{ {
return Bpe.EmptyTokensList; return [];
} }
char[] token = ArrayPool<char>.Shared.Rent(text.Length); char[] token = ArrayPool<char>.Shared.Rent(text.Length);
@ -197,7 +246,7 @@ namespace Microsoft.ML.Tokenizers
{ {
ArrayPool<char>.Shared.Return(token); ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping); ArrayPool<int>.Shared.Return(indexMapping);
return Array.Empty<Token>(); return [];
} }
if (_cache.TryGetValue(text, out List<Token>? hit)) if (_cache.TryGetValue(text, out List<Token>? hit))
@ -215,32 +264,257 @@ namespace Microsoft.ML.Tokenizers
} }
/// <summary> /// <summary>
/// Encode a split text string to a list of Ids and add them to the accumulatedIds list. /// Encodes input text to token Ids.
/// </summary> /// </summary>
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param> /// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out _, out _);
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(null, text, considerPreTokenization, considerNormalization, out _, out _);
/// <summary>
/// Encodes input text to token Ids up to maximum number of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param> /// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param> /// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns> /// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
public override int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsInternal(text, accumulatedIds, out textLength, maxTokens); /// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
/// <summary>
/// Encodes input text to token Ids up to maximum number of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(null, text, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
private IReadOnlyList<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
{
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
}
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
textLength = 0;
normalizedString = null;
return [];
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
List<int> ids = new();
if (splits is not null)
{
textLength = 0;
foreach ((int Offset, int Length) split in splits)
{
EncodeToIdsInternal(textSpanToEncode.Slice(split.Offset, split.Length), ids, out int length, maxTokenCount - ids.Count);
textLength = split.Offset + length;
if (length < split.Length || ids.Count >= maxTokenCount)
{
break;
}
}
}
else
{
EncodeToIdsInternal(textSpanToEncode, ids, out textLength, maxTokenCount);
}
return ids;
}
/// <summary> /// <summary>
/// Get the number of tokens that the input text will be encoded to. /// Get the number of tokens that the input text will be encoded to.
/// </summary> /// </summary>
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param> /// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param> /// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns> /// <returns>The number of token Ids that the input text will be encoded to.</returns>
public override int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsInternal(text, null, out textLength, maxTokens); public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
=> CountTokens(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out _, out _);
/// <summary> /// <summary>
/// Get the number of tokens that the input text will be encoded to. /// Get the number of tokens that the input text will be encoded to.
/// </summary> /// </summary>
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param> /// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param> /// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns> /// <returns>The number of token Ids that the input text will be encoded to.</returns>
public override int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue) => EncodeToIdsFromEndInternal(text, null, out textIndex, maxTokens); public override int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
=> CountTokens(null, text, considerPreTokenization, considerNormalization, out _, out _);
/// <summary>
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// </returns>
public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
{
tokenCount = CountTokens(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
return textLength;
}
/// <summary>
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// </returns>
public override int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
{
tokenCount = CountTokens(null, text, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
return textLength;
}
private int CountTokens(string? text, ReadOnlySpan<char> textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
{
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
}
textLength = 0;
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
normalizedString = null;
return 0;
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
int count = 0;
if (splits is not null)
{
foreach ((int Offset, int Length) split in splits)
{
count += EncodeToIdsInternal(textSpanToEncode.Slice(split.Offset, split.Length), null, out int length, maxTokenCount - count);
textLength = split.Offset + length;
if (length < split.Length || count >= maxTokenCount)
{
break;
}
}
}
else
{
count += EncodeToIdsInternal(textSpanToEncode, null, out textLength, maxTokenCount);
}
return count;
}
/// <summary>
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if normalization is enabled;
/// conversely, if all tokens fit, the result will be 0.
/// </returns>
public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> LastIndexOf(text, Span<char>.Empty, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
/// <summary>
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="normalizedString"/>; conversely, if all tokens fit, the result will be 0.
/// </returns>
public override int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> LastIndexOf(null, text, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
private int LastIndexOf(string? text, ReadOnlySpan<char> textSpan, int maxTokenCount, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int tokenCount)
{
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
}
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
normalizedString = null;
tokenCount = 0;
return 0;
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
if (splits is not null)
{
tokenCount = 0;
foreach ((int Offset, int Length) split in splits.Reverse())
{
tokenCount += EncodeToIdsFromEndInternal(textSpanToEncode.Slice(split.Offset, split.Length), null, out int textIndex, maxTokenCount - tokenCount);
if (textIndex > 0 || tokenCount >= maxTokenCount)
{
return split.Offset + textIndex;
}
}
}
else
{
tokenCount = EncodeToIdsFromEndInternal(textSpanToEncode, null, out int textLength, maxTokenCount);
return textLength;
}
return 0;
}
private int EncodeToIdsResult(List<Token> tokens, IList<int>? accumulatedIds, int maxTokens, int fullTextLength, out int textLength) private int EncodeToIdsResult(List<Token> tokens, IList<int>? accumulatedIds, int maxTokens, int fullTextLength, out int textLength)
{ {
@ -347,7 +621,7 @@ namespace Microsoft.ML.Tokenizers
{ {
ArrayPool<char>.Shared.Return(token); ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping); ArrayPool<int>.Shared.Return(indexMapping);
textLength = 0; textLength = text.Length;
return 0; return 0;
} }
@ -390,7 +664,7 @@ namespace Microsoft.ML.Tokenizers
{ {
ArrayPool<char>.Shared.Return(token); ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping); ArrayPool<int>.Shared.Return(indexMapping);
textIndex = text.Length; textIndex = 0;
return 0; return 0;
} }
@ -410,7 +684,32 @@ namespace Microsoft.ML.Tokenizers
public override int? MapTokenToId(ReadOnlySpan<char> token) => _vocab.TryGetValue(token, out int value) ? value : null; public override int? MapTokenToId(ReadOnlySpan<char> token) => _vocab.TryGetValue(token, out int value) ? value : null;
/// <summary> /// <summary>
/// Convert a list of tokens Ids to highest occurrence rankings. /// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <returns>The decoded string.</returns>
public override string? Decode(IEnumerable<int> ids)
{
if (ids is null)
{
throw new ArgumentNullException(nameof(ids));
}
ValueStringBuilder sb = new ValueStringBuilder();
foreach (int id in ids)
{
if (MapIdToToken(id) is string s)
{
sb.Append(s);
}
}
return sb.ToString();
}
/// <summary>
/// Convert a list of token Ids to highest occurrence rankings.
/// </summary> /// </summary>
/// <param name="ids">The Ids list to map to the high occurrence rank.</param> /// <param name="ids">The Ids list to map to the high occurrence rank.</param>
/// <returns>The list of ranks mapped from the list of Ids.</returns> /// <returns>The list of ranks mapped from the list of Ids.</returns>
@ -432,7 +731,7 @@ namespace Microsoft.ML.Tokenizers
} }
/// <summary> /// <summary>
/// Convert a list of tokens Ids to highest occurrence values. /// Convert a list of token Ids to highest occurrence values.
/// </summary> /// </summary>
/// <param name="ids">The Ids list to map to the high occurrence values.</param> /// <param name="ids">The Ids list to map to the high occurrence values.</param>
/// <returns>The list of occurrence values mapped from the list of Ids.</returns> /// <returns>The list of occurrence values mapped from the list of Ids.</returns>
@ -454,7 +753,7 @@ namespace Microsoft.ML.Tokenizers
} }
/// <summary> /// <summary>
/// Convert a list of highest occurrence rankings to tokens Ids list . /// Convert a list of highest occurrence rankings to token Ids list .
/// </summary> /// </summary>
/// <param name="ranks">The high occurrence ranks list to map to the Ids list.</param> /// <param name="ranks">The high occurrence ranks list to map to the Ids list.</param>
/// <returns>The list of Ids mapped from the list of ranks.</returns> /// <returns>The list of Ids mapped from the list of ranks.</returns>
@ -638,7 +937,7 @@ namespace Microsoft.ML.Tokenizers
{ {
if (token.Length == 0) if (token.Length == 0)
{ {
return Bpe.EmptyTokensList; return [];
} }
if (token.Length == 1) if (token.Length == 1)

Просмотреть файл

@ -1,180 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Text;
namespace Microsoft.ML.Tokenizers
{
/// <summary>
/// Represents a model used during Tokenization (like BPE or Word Piece or Unigram).
/// </summary>
public abstract class Model
{
/// <summary>
/// Encode a text to a list of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <returns>The list of tokens generated from the text tokenization.</returns>
public abstract IReadOnlyList<Token> Encode(ReadOnlySpan<char> text);
/// <summary>
/// Encode a text to a list of Ids and add them to the accumulatedIds list.
/// </summary>
/// <param name="text">The text to encode. </param>
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
/// <remarks>
/// This method does the default implementation that uses the Encode method to get the token's Ids.
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
/// </remarks>
public virtual int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue)
{
if (accumulatedIds is null)
{
throw new ArgumentNullException(nameof(accumulatedIds));
}
// Default implementation is not optimized for memory allocation. It is recommended to override this method for the sake of the performance.
textLength = 0;
var tokens = Encode(text);
int count = Math.Min(tokens.Count, maxTokens);
for (int i = 0; i < count; i++)
{
textLength += tokens[i].Offset.Length;
accumulatedIds.Add(tokens[i].Id);
}
return count;
}
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
/// <remarks>
/// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids.
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
/// </remarks>
public virtual int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue)
{
if (maxTokens <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokens), "The maximum number of tokens must be greater than 0.");
}
var ids = new List<int>();
if (maxTokens == int.MaxValue)
{
EncodeToIds(text, ids, out _);
textLength = text.Length;
return ids.Count;
}
IReadOnlyList<Token> tokens = Encode(text);
textLength = 0;
int count = Math.Min(tokens.Count, maxTokens);
for (int i = 0; i < count; i++)
{
textLength += tokens[i].Offset.Length;
}
return count;
}
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
/// <remarks>
/// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids.
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
/// </remarks>
public virtual int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue)
{
if (maxTokens <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokens), "The maximum number of tokens must be greater than 0.");
}
var ids = new List<int>();
if (maxTokens == int.MaxValue)
{
EncodeToIds(text, ids, out _);
textIndex = 0;
return ids.Count;
}
IReadOnlyList<Token> tokens = Encode(text);
textIndex = text.Length;
int count = Math.Min(tokens.Count, maxTokens);
int tokensCount = tokens.Count;
int end = tokensCount - count;
for (int i = tokensCount - 1; i >= end; i--)
{
textIndex -= tokens[i].Offset.Length;
}
return count;
}
/// <summary>
/// Map the token to encoded id with the option to skip the special tokens.
/// </summary>
/// <param name="token">The token to map to Id</param>
/// <returns>The mapped Id of the token.</returns>
public abstract int? MapTokenToId(ReadOnlySpan<char> token);
/// <summary>
/// Map the encoded Id to the token.
/// </summary>
/// <param name="id">The Id to map to the token.</param>
/// <returns>The mapped token of the Id.</returns>
public abstract string? MapIdToToken(int id);
/// <summary>
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <returns>The decoded string.</returns>
/// <remarks>
/// This method does the default implementation that uses the MapIdToToken method to get the token.
/// Tokenizer models may opt to override this method to ensure accurate results if the default implementation
/// provided here proves insufficient for the model's specific scenario.
/// </remarks>
public virtual string? Decode(IEnumerable<int> ids)
{
if (ids is null)
{
throw new ArgumentNullException(nameof(ids));
}
ValueStringBuilder sb = new ValueStringBuilder();
foreach (int id in ids)
{
if (MapIdToToken(id) is string s)
{
sb.Append(s);
}
}
return sb.ToString();
}
}
}

Просмотреть файл

@ -19,7 +19,7 @@ namespace Microsoft.ML.Tokenizers
/// <summary> /// <summary>
/// SentencePieceBpe is a tokenizer that splits the input into tokens using the SentencePiece Bpe model. /// SentencePieceBpe is a tokenizer that splits the input into tokens using the SentencePiece Bpe model.
/// </summary> /// </summary>
public sealed class SentencePieceBpe : Model public sealed class SentencePieceBpe : Tokenizer
{ {
private const int UninitializedId = -2; // indicate if the symbol contains uninitialized id. private const int UninitializedId = -2; // indicate if the symbol contains uninitialized id.
@ -29,9 +29,10 @@ namespace Microsoft.ML.Tokenizers
private readonly int _maxByteId; private readonly int _maxByteId;
private readonly int _byteCodeToIdOffset; // offset of mapping byte code to the to the Ids. private readonly int _byteCodeToIdOffset; // offset of mapping byte code to the to the Ids.
private readonly int _oneByteUtf8EncodingMaxId; // the maximum value of the one byte UTF-8 character. private readonly int _oneByteUtf8EncodingMaxId; // the maximum value of the one byte UTF-8 character.
private readonly Normalizer? _normalizer;
internal SentencePieceBpe(ModelProto modelProto, bool addBos, bool addEos) : internal SentencePieceBpe(ModelProto modelProto, bool addBos, bool addEos) :
this(modelProto is null ? throw new ArgumentNullException(nameof(modelProto)) : modelProto) this(modelProto is null ? throw new ArgumentNullException(nameof(modelProto)) : modelProto)
{ {
AddBeginningOfSentence = addBos; AddBeginningOfSentence = addBos;
AddEndOfSentence = addEos; AddEndOfSentence = addEos;
@ -65,6 +66,8 @@ namespace Microsoft.ML.Tokenizers
EscapeWhiteSpaces = modelProto.NormalizerSpec.EscapeWhitespaces; EscapeWhiteSpaces = modelProto.NormalizerSpec.EscapeWhitespaces;
TreatWhitespaceAsSuffix = modelProto.TrainerSpec.TreatWhitespaceAsSuffix; TreatWhitespaceAsSuffix = modelProto.TrainerSpec.TreatWhitespaceAsSuffix;
ByteFallback = modelProto.TrainerSpec.ByteFallback; ByteFallback = modelProto.TrainerSpec.ByteFallback;
_normalizer = new SentencePieceNormalizer(modelProto.NormalizerSpec.RemoveExtraWhitespaces, AddDummyPrefix, EscapeWhiteSpaces, modelProto.TrainerSpec.TreatWhitespaceAsSuffix);
} }
/// <summary> /// <summary>
@ -127,6 +130,16 @@ namespace Microsoft.ML.Tokenizers
/// </summary> /// </summary>
public int UnknownId { get; } public int UnknownId { get; }
/// <summary>
/// Gets the PreTokenizer used by the Tokenizer.
/// </summary>
public override PreTokenizer? PreTokenizer => null;
/// <summary>
/// Gets the Normalizer in use by the Tokenizer.
/// </summary>
public override Normalizer? Normalizer => _normalizer;
/// <summary> /// <summary>
/// The vocabulary of the model. /// The vocabulary of the model.
/// </summary> /// </summary>
@ -152,12 +165,74 @@ namespace Microsoft.ML.Tokenizers
} }
/// <summary> /// <summary>
/// Encode a text to a list of tokens. /// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
/// </summary> /// </summary>
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <returns>The list of tokens generated from the text tokenization.</returns> /// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <remarks>The input text has to be normalized before calling this method.</remarks> /// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text) => Encode(text, AddBeginningOfSentence, AddEndOfSentence); /// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
public override IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true)
=> Encode(text, Span<char>.Empty, out normalizedString, AddBeginningOfSentence, AddEndOfSentence, considerPreTokenization, considerNormalization);
/// <summary>
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true)
=> Encode(null, text, out normalizedString, AddBeginningOfSentence, AddEndOfSentence, considerPreTokenization, considerNormalization);
/// <summary>
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
public IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
=> Encode(text, Span<char>.Empty, out normalizedString, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization);
/// <summary>
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
public IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
=> Encode(null, text, out normalizedString, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization);
private IReadOnlyList<Token> Encode(string? text, ReadOnlySpan<char> textSpan, out string? normalizedString, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization, bool considerNormalization)
{
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
normalizedString = null;
return [];
}
ReadOnlySpan<char> textToEncode = text is null ? textSpan : text.AsSpan();
if (considerNormalization && _normalizer is not null)
{
normalizedString = text is not null ? _normalizer.Normalize(text) : _normalizer.Normalize(textSpan);
textToEncode = normalizedString.AsSpan();
}
else
{
normalizedString = null;
}
return EncodeInternal(textToEncode, addBeginningOfSentence, addEndOfSentence);
}
/// <summary> /// <summary>
/// Encode a text to a list of tokens. /// Encode a text to a list of tokens.
@ -167,11 +242,11 @@ namespace Microsoft.ML.Tokenizers
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param> /// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <returns>The list of tokens generated from the text tokenization.</returns> /// <returns>The list of tokens generated from the text tokenization.</returns>
/// <remarks>The input text has to be normalized before calling this method.</remarks> /// <remarks>The input text has to be normalized before calling this method.</remarks>
public IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence) private IReadOnlyList<Token> EncodeInternal(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence)
{ {
if (text.Length == 0) if (text.Length == 0)
{ {
return Array.Empty<Token>(); return [];
} }
BpeSymbol[] symbols = ArrayPool<BpeSymbol>.Shared.Rent(text.Length); BpeSymbol[] symbols = ArrayPool<BpeSymbol>.Shared.Rent(text.Length);
@ -306,16 +381,168 @@ namespace Microsoft.ML.Tokenizers
} }
/// <summary> /// <summary>
/// Encode a text to a list of Ids and add them to the accumulatedIds list. /// Encodes input text to tokes Ids.
/// </summary> /// </summary>
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param> /// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(text, Span<char>.Empty, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out _, out _);
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(null, text, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out _, out _);
/// <summary>
/// Encodes input text to token Ids up to maximum number of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param> /// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param> /// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns> /// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <remarks>The input text has to be normalized before calling this method.</remarks> /// <returns>The list of encoded Ids.</returns>
public override int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue) public override IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(text, AddBeginningOfSentence, AddEndOfSentence, accumulatedIds, out textLength, maxTokens); => EncodeToIds(text, Span<char>.Empty, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount);
/// <summary>
/// Encodes input text to token Ids up to maximum number of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(null, text, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount);
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public IReadOnlyList<int> EncodeToIds(string text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(text, Span<char>.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _);
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _);
/// <summary>
/// Encodes input text to token Ids up to maximum number of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public IReadOnlyList<int> EncodeToIds(string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(text, Span<char>.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount);
/// <summary>
/// Encodes input text to token Ids up to maximum number of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount);
private IReadOnlyList<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
{
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
}
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
normalizedString = null;
textLength = 0;
return [];
}
return EncodeToIds(text is null ? textSpan : text.AsSpan(), addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount);
}
/// <summary>
/// Encodes input text to token Ids up to maximum number of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <returns>The list of encoded Ids.</returns>
public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization,
out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
{
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
}
if (text.IsEmpty)
{
normalizedString = null;
textLength = 0;
return [];
}
ReadOnlySpan<char> textToEncode;
if (considerNormalization && _normalizer is not null)
{
normalizedString = _normalizer.Normalize(text);
textToEncode = normalizedString.AsSpan();
}
else
{
normalizedString = null;
textToEncode = text;
}
List<int> ids = new();
EncodeToIds(textToEncode, addBeginningOfSentence, addEndOfSentence, ids, out textLength, maxTokenCount);
return ids;
}
/// <summary> /// <summary>
/// Encode a text to a list of Ids and add them to the accumulatedIds list. /// Encode a text to a list of Ids and add them to the accumulatedIds list.
@ -328,7 +555,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="maxTokens">The maximum number of tokens to encode.</param> /// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns> /// <returns>The number of tokens that the input text will be encoded to.</returns>
/// <remarks>The input text has to be normalized before calling this method.</remarks> /// <remarks>The input text has to be normalized before calling this method.</remarks>
public int EncodeToIds(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue) private int EncodeToIds(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue)
{ {
if (maxTokens <= 0) if (maxTokens <= 0)
{ {
@ -378,7 +605,8 @@ namespace Microsoft.ML.Tokenizers
{ {
if (!EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref textLength)) if (!EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref textLength))
{ {
break; ArrayPool<BpeSymbol>.Shared.Return(symbols);
return idsCount;
} }
} }
else else
@ -391,6 +619,7 @@ namespace Microsoft.ML.Tokenizers
} }
else else
{ {
ArrayPool<BpeSymbol>.Shared.Return(symbols);
return idsCount; return idsCount;
} }
} }
@ -510,21 +739,252 @@ namespace Microsoft.ML.Tokenizers
/// Get the number of tokens that the input text will be encoded to. /// Get the number of tokens that the input text will be encoded to.
/// </summary> /// </summary>
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns> /// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param> /// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param> /// <returns>The number of token Ids that the input text will be encoded to.</returns>
/// <returns>The number of tokens that the input text will be encoded to.</returns> public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
/// <remarks>The input text has to be normalized before calling this method.</remarks> => CountTokens(text, Span<char>.Empty, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out _, out _);
public override int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue) => CountTokens(text, AddBeginningOfSentence, AddEndOfSentence, out textLength, maxTokens);
/// <summary> /// <summary>
/// Get the number of tokens that the input text will be encoded to. /// Get the number of tokens that the input text will be encoded to.
/// </summary> /// </summary>
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param> /// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param> /// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
public override int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
=> CountTokens(null, text, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out _, out _);
/// <summary>
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// </returns>
public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
{
tokenCount = CountTokens(text, Span<char>.Empty, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
return textLength;
}
/// <summary>
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// </returns>
public override int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
{
tokenCount = CountTokens(null, text, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
return textLength;
}
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
public int CountTokens(string text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
=> CountTokens(text, Span<char>.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _);
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
public int CountTokens(ReadOnlySpan<char> text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
=> CountTokens(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _);
/// <summary>
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// </returns>
public int IndexOfTokenCount(string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
{
tokenCount = CountTokens(text, Span<char>.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
return textLength;
}
/// <summary>
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// </returns>
public int IndexOfTokenCount(ReadOnlySpan<char> text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
{
tokenCount = CountTokens(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
return textLength;
}
private int CountTokens(string? text, ReadOnlySpan<char> textSpan, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
=> CountTokens(text is null ? textSpan : text.AsSpan(), addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount);
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns> /// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue) => CountTokensFromEnd(text, AddBeginningOfSentence, AddEndOfSentence, out textIndex, maxTokens); public int CountTokens(ReadOnlySpan<char> text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
{
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
}
if (text.IsEmpty)
{
normalizedString = null;
textLength = 0;
return 0;
}
ReadOnlySpan<char> textToEncode;
if (considerNormalization && _normalizer is not null)
{
normalizedString = _normalizer.Normalize(text);
textToEncode = normalizedString.AsSpan();
}
else
{
normalizedString = null;
textToEncode = text;
}
return CountTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, out textLength, maxTokenCount);
}
/// <summary>
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if normalization is enabled;
/// conversely, if all tokens fit, the result will be 0.
/// </returns>
public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> LastIndexOf(text, Span<char>.Empty, maxTokenCount, considerNormalization, out normalizedString, out tokenCount);
/// <summary>
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="normalizedString"/>; conversely, if all tokens fit, the result will be 0.
/// </returns>
public override int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> LastIndexOf(null, text, maxTokenCount, considerNormalization, out normalizedString, out tokenCount);
private int LastIndexOf(string? text, ReadOnlySpan<char> textSpan, int maxTokenCount, bool considerNormalization, out string? normalizedString, out int tokenCount)
=> LastIndexOfTokenCount(text is null ? textSpan : text.AsSpan(), maxTokenCount, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out tokenCount);
/// <summary>
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <returns>
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="normalizedString"/>; conversely, if all tokens fit, the result will be 0.
/// </returns>
public int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, out string? normalizedString, out int tokenCount)
{
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
}
if (text.IsEmpty)
{
normalizedString = null;
tokenCount = 0;
return 0;
}
ReadOnlySpan<char> textToEncode;
if (considerNormalization && _normalizer is not null)
{
normalizedString = _normalizer.Normalize(text);
textToEncode = normalizedString.AsSpan();
}
else
{
normalizedString = null;
textToEncode = text;
}
tokenCount = CountTokensFromEnd(textToEncode, addBeginningOfSentence, addEndOfSentence, out int textIndex, maxTokenCount);
return textIndex;
}
/// <summary> /// <summary>
/// Get the number of tokens that the input text will be encoded to. /// Get the number of tokens that the input text will be encoded to.
@ -536,7 +996,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="maxTokens">The maximum number of tokens to encode.</param> /// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns> /// <returns>The number of tokens that the input text will be encoded to.</returns>
/// <remarks>The input text has to be normalized before calling this method.</remarks> /// <remarks>The input text has to be normalized before calling this method.</remarks>
public int CountTokens(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, out int textLength, int maxTokens = int.MaxValue) private int CountTokens(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, out int textLength, int maxTokens = int.MaxValue)
{ {
textLength = 0; textLength = 0;
if (text.IsEmpty) if (text.IsEmpty)
@ -705,7 +1165,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="maxTokens">The maximum number of tokens to encode.</param> /// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns> /// <returns>The number of tokens that the input text will be encoded to.</returns>
/// <remarks>The input text has to be normalized before calling this method.</remarks> /// <remarks>The input text has to be normalized before calling this method.</remarks>
public int CountTokensFromEnd(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, out int textIndex, int maxTokens = int.MaxValue) private int CountTokensFromEnd(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, out int textIndex, int maxTokens = int.MaxValue)
{ {
textIndex = text.Length; textIndex = text.Length;
if (text.IsEmpty) if (text.IsEmpty)
@ -944,7 +1404,7 @@ namespace Microsoft.ML.Tokenizers
else if (_vocabReverse.TryGetValue(enumerator.Current, out string? token)) else if (_vocabReverse.TryGetValue(enumerator.Current, out string? token))
{ {
// escape the dummy prefix if needed. // escape the dummy prefix if needed.
sb.Append(AddDummyPrefix && !TreatWhitespaceAsSuffix && token.Length > 0 && token[0] == LlamaNormalizer.DummyPrefix ? sb.Append(AddDummyPrefix && !TreatWhitespaceAsSuffix && token.Length > 0 && token[0] == SentencePieceNormalizer.DummyPrefix ?
token.AsSpan(1) : token.AsSpan(1) :
token.AsSpan()); token.AsSpan());
} }
@ -999,7 +1459,7 @@ namespace Microsoft.ML.Tokenizers
FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb); FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb);
} }
if (AddDummyPrefix && TreatWhitespaceAsSuffix && sb.Length > 0 && sb[sb.Length - 1] == LlamaNormalizer.DummyPrefix) if (AddDummyPrefix && TreatWhitespaceAsSuffix && sb.Length > 0 && sb[sb.Length - 1] == SentencePieceNormalizer.DummyPrefix)
{ {
sb.RemoveLastChar(); sb.RemoveLastChar();
} }
@ -1014,7 +1474,7 @@ namespace Microsoft.ML.Tokenizers
ArrayPool<char>.Shared.Return(charPoolArray); ArrayPool<char>.Shared.Return(charPoolArray);
} }
return sb.ToString(LlamaNormalizer.DummyPrefix, ' '); return sb.ToString(SentencePieceNormalizer.DummyPrefix, ' ');
static void FlushBytes(ref int bytesCount, ref byte[]? bytesPoolArray, ref char[]? charPoolArray, ref ValueStringBuilder sb) static void FlushBytes(ref int bytesCount, ref byte[]? bytesPoolArray, ref char[]? charPoolArray, ref ValueStringBuilder sb)
{ {

Просмотреть файл

@ -19,9 +19,9 @@ using System.Threading.Tasks;
namespace Microsoft.ML.Tokenizers namespace Microsoft.ML.Tokenizers
{ {
/// <summary> /// <summary>
/// Represent the rapid Byte Pair Encoding model commonly referred to as Tiktoken. /// Represent the rapid Byte Pair Encoding tokenizer.
/// </summary> /// </summary>
public sealed partial class Tiktoken : Model public sealed partial class Tiktoken : Tokenizer
{ {
private readonly Dictionary<ReadOnlyMemory<byte>, int> _encoder; private readonly Dictionary<ReadOnlyMemory<byte>, int> _encoder;
private readonly Dictionary<int, ReadOnlyMemory<byte>> _decoder; private readonly Dictionary<int, ReadOnlyMemory<byte>> _decoder;
@ -29,46 +29,56 @@ namespace Microsoft.ML.Tokenizers
private readonly Dictionary<StringSpanOrdinalKey, (int Id, string Token)> _vocab; private readonly Dictionary<StringSpanOrdinalKey, (int Id, string Token)> _vocab;
private IReadOnlyDictionary<string, int>? _vocabOriginal; private IReadOnlyDictionary<string, int>? _vocabOriginal;
private const int MaxWordLengthToCache = 15; private const int MaxWordLengthToCache = 15;
private readonly PreTokenizer? _preTokenizer;
private readonly Normalizer? _normalizer;
/// <summary> /// <summary>
/// Create a new Tiktoken tokenizer's model object. /// Create a new Tiktoken tokenizer's object.
/// </summary> /// </summary>
/// <param name="vocabFilePath">The path to the BPE vocab file.</param> /// <param name="vocabFilePath">The path to the BPE vocab file.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param> /// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="cacheSize">The size of the cache to use.</param> /// <param name="cacheSize">The size of the cache to use.</param>
/// <exception cref="ArgumentNullException">Thrown when <paramref name="vocabFilePath"/> is null or empty.</exception> /// <exception cref="ArgumentNullException">Thrown when <paramref name="vocabFilePath"/> is null or empty.</exception>
/// <exception cref="InvalidOperationException">Thrown when failed to load the BPE vocab file.</exception> /// <exception cref="InvalidOperationException">Thrown when failed to load the BPE vocab file.</exception>
public Tiktoken(string vocabFilePath, IReadOnlyDictionary<string, int>? specialTokens = null, int cacheSize = LruCache<int[]>.DefaultCacheSize) : public Tiktoken(string vocabFilePath, PreTokenizer? preTokenizer, IReadOnlyDictionary<string, int>? specialTokens = null, Normalizer? normalizer = null, int cacheSize = LruCache<int[]>.DefaultCacheSize) :
this(string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath), specialTokens, cacheSize, disposeStream: true) this(string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath), preTokenizer, specialTokens, normalizer, cacheSize, disposeStream: true)
{ {
} }
/// <summary> /// <summary>
/// Create a new Tiktoken tokenizer's model object. /// Create a new Tiktoken tokenizer's object.
/// </summary> /// </summary>
/// <param name="vocabStream">The stream to the BPE vocab file.</param> /// <param name="vocabStream">The stream to the BPE vocab file.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param> /// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="cacheSize">The size of the cache to use.</param> /// <param name="cacheSize">The size of the cache to use.</param>
/// <exception cref="ArgumentNullException">Thrown when <paramref name="vocabStream"/> is null or empty.</exception> /// <exception cref="ArgumentNullException">Thrown when <paramref name="vocabStream"/> is null or empty.</exception>
/// <exception cref="InvalidOperationException">Thrown when failed to load the BPE vocab file.</exception> /// <exception cref="InvalidOperationException">Thrown when failed to load the BPE vocab file.</exception>
public Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTokens = null, int cacheSize = LruCache<int[]>.DefaultCacheSize) : public Tiktoken(Stream vocabStream, PreTokenizer? preTokenizer, IReadOnlyDictionary<string, int>? specialTokens = null, Normalizer? normalizer = null, int cacheSize = LruCache<int[]>.DefaultCacheSize) :
this(vocabStream ?? throw new ArgumentNullException(nameof(vocabStream)), specialTokens, cacheSize, disposeStream: false) this(vocabStream ?? throw new ArgumentNullException(nameof(vocabStream)), preTokenizer, specialTokens, normalizer, cacheSize, disposeStream: false)
{ {
} }
/// <summary> /// <summary>
/// Create a new Tiktoken tokenizer's model object. /// Create a new Tiktoken tokenizer's object.
/// </summary> /// </summary>
/// <param name="encoder">The dictionary mapping token utf-8 bytes to Ids.</param> /// <param name="encoder">The dictionary mapping token utf-8 bytes to Ids.</param>
/// <param name="decoder">The dictionary mapping Ids to token utf-8 bytes.</param> /// <param name="decoder">The dictionary mapping Ids to token utf-8 bytes.</param>
/// <param name="vocab">The dictionary mapping string tokens to Ids.</param> /// <param name="vocab">The dictionary mapping string tokens to Ids.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param> /// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="cacheSize">The max size of the cache to use.</param> /// <param name="cacheSize">The max size of the cache to use.</param>
internal Tiktoken( internal Tiktoken(
Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<ReadOnlyMemory<byte>, int> encoder,
Dictionary<int, ReadOnlyMemory<byte>> decoder, Dictionary<int, ReadOnlyMemory<byte>> decoder,
Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab, Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab,
PreTokenizer? preTokenizer,
IReadOnlyDictionary<string, int>? specialTokens, IReadOnlyDictionary<string, int>? specialTokens,
Normalizer? normalizer = null,
int cacheSize = LruCache<int[]>.DefaultCacheSize) int cacheSize = LruCache<int[]>.DefaultCacheSize)
{ {
_encoder = encoder ?? throw new ArgumentNullException(nameof(encoder)); _encoder = encoder ?? throw new ArgumentNullException(nameof(encoder));
@ -78,19 +88,26 @@ namespace Microsoft.ML.Tokenizers
_encoder = encoder!; _encoder = encoder!;
_decoder = decoder!; _decoder = decoder!;
_vocab = vocab!; _vocab = vocab!;
_preTokenizer = preTokenizer;
_normalizer = normalizer;
_cache = new LruCache<(int[] Bytes, string Token)>(cacheSize); _cache = new LruCache<(int[] Bytes, string Token)>(cacheSize);
SpecialTokens = specialTokens; SpecialTokens = specialTokens;
CacheSpecialTokensEncoding(specialTokens); CacheSpecialTokensEncoding(specialTokens);
} }
private Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTokens, int cacheSize, bool disposeStream) private Tiktoken(Stream vocabStream, PreTokenizer? preTokenizer, IReadOnlyDictionary<string, int>? specialTokens, Normalizer? normalizer, int cacheSize, bool disposeStream)
{ {
try try
{ {
_cache = new LruCache<(int[] Bytes, string Token)>(cacheSize); _cache = new LruCache<(int[] Bytes, string Token)>(cacheSize);
(_encoder, _vocab, _decoder) = LoadTiktokenBpeAsync(vocabStream, useAsync: false).GetAwaiter().GetResult(); (_encoder, _vocab, _decoder) = LoadTiktokenBpeAsync(vocabStream, useAsync: false).GetAwaiter().GetResult();
_preTokenizer = preTokenizer;
_normalizer = normalizer;
SpecialTokens = specialTokens; SpecialTokens = specialTokens;
CacheSpecialTokensEncoding(specialTokens); CacheSpecialTokensEncoding(specialTokens);
} }
@ -103,6 +120,16 @@ namespace Microsoft.ML.Tokenizers
} }
} }
/// <summary>
/// Gets the PreTokenizer used by the Tokenizer.
/// </summary>
public override PreTokenizer? PreTokenizer => _preTokenizer;
/// <summary>
/// Gets the Normalizer in use by the Tokenizer.
/// </summary>
public override Normalizer? Normalizer => _normalizer;
private void CacheSpecialTokensEncoding(IReadOnlyDictionary<string, int>? specialTokens) private void CacheSpecialTokensEncoding(IReadOnlyDictionary<string, int>? specialTokens)
{ {
Debug.Assert(_cache is not null); Debug.Assert(_cache is not null);
@ -118,54 +145,6 @@ namespace Microsoft.ML.Tokenizers
} }
} }
/// <summary>
/// Create a new Tiktoken tokenizer's model object asynchronously.
/// </summary>
/// <param name="vocabStream">The stream to the BPE vocab file.</param>
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
/// <param name="cacheSize">The size of the cache to use.</param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>Tiktoken tokenizer's object.</returns>
public static async Task<Tiktoken> CreateAsync(
Stream vocabStream,
IReadOnlyDictionary<string, int>? specialTokens = null,
int cacheSize = LruCache<int[]>.DefaultCacheSize,
CancellationToken cancellationToken = default)
{
if (vocabStream is null)
{
throw new ArgumentNullException(nameof(vocabStream));
}
(Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) =
await LoadTiktokenBpeAsync(vocabStream, useAsync: true, cancellationToken).ConfigureAwait(false);
return new Tiktoken(encoder, decoder, vocab, specialTokens, cacheSize);
}
/// <summary>
/// Create a new Tiktoken tokenizer's object asynchronously.
/// </summary>
/// <param name="vocabFilePath">The BPE vocab file.</param>
/// <param name="specialTokensEncoder">The dictionary mapping special tokens to Ids.</param>
/// <param name="cacheSize">The size of the cache to use.</param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>Tiktoken tokenizer's model object.</returns>
public static async Task<Tiktoken> CreateAsync(
string vocabFilePath,
IReadOnlyDictionary<string, int>? specialTokensEncoder = null,
int cacheSize = LruCache<int[]>.DefaultCacheSize,
CancellationToken cancellationToken = default)
{
if (vocabFilePath is null)
{
throw new ArgumentNullException(nameof(vocabFilePath));
}
using Stream vocabStream = File.OpenRead(vocabFilePath);
return await CreateAsync(vocabStream, specialTokensEncoder, cacheSize, cancellationToken).ConfigureAwait(false);
}
/// <summary> /// <summary>
/// Load BPE vocab dictionary from a stream. /// Load BPE vocab dictionary from a stream.
/// </summary> /// </summary>
@ -254,37 +233,80 @@ namespace Microsoft.ML.Tokenizers
} }
} }
/// <summary>
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
public override IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(text, Span<char>.Empty, out normalizedString, considerPreTokenization, considerNormalization);
/// <summary>
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(null, text, out normalizedString, considerPreTokenization, considerNormalization);
private IReadOnlyList<Token> Encode(string? text, ReadOnlySpan<char> textSpan, out string? normalizedString, bool considerPreTokenization, bool considerNormalization)
{
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
normalizedString = null;
return [];
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
List<Token> tokens = new();
if (splits is not null)
{
foreach ((int Offset, int Length) split in splits)
{
Encode(textSpanToEncode.Slice(split.Offset, split.Length), tokens, split.Offset);
}
}
else
{
Encode(textSpanToEncode, tokens, 0);
}
return tokens;
}
/// <summary> /// <summary>
/// Encode text to a list of tokens. /// Encode text to a list of tokens.
/// </summary> /// </summary>
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <returns>The list of tokens generated from the text tokenization.</returns> /// <param name="tokens">The list of tokens to populate.</param>
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text) /// <param name="offset">The offset to start encoding from.</param>
private void Encode(ReadOnlySpan<char> text, List<Token> tokens, int offset)
{ {
Token[] tokens; Debug.Assert(!text.IsEmpty);
if (text.IsEmpty)
{
return Array.Empty<Token>();
}
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value)) if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
{ {
tokens = new Token[value.Ids.Length]; tokens.Add(new Token(value.Ids[0], value.Token, (offset, value.Token.Length)));
tokens[0] = new Token(value.Ids[0], value.Token, (0, value.Token.Length));
for (int i = 1; i < value.Ids.Length; i++) for (int i = 1; i < value.Ids.Length; i++)
{ {
// One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width. // One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width.
tokens[i] = new Token(value.Ids[i], "", (text.Length, 0)); tokens.Add(new Token(value.Ids[i], "", (offset + text.Length, 0)));
} }
return tokens; return;
} }
// cache miss // cache miss
if (_vocab.TryGetValue(text, out (int Id, string Token) mappedId)) if (_vocab.TryGetValue(text, out (int Id, string Token) mappedId))
{ {
return new Token[1] { new(mappedId.Id, mappedId.Token, (0, mappedId.Token.Length)) }; tokens.Add(new Token(mappedId.Id, mappedId.Token, (offset, mappedId.Token.Length)));
return;
} }
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length)); byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
@ -301,15 +323,98 @@ namespace Microsoft.ML.Tokenizers
_cache.Add(textAsString, (encodedIds, textAsString)); _cache.Add(textAsString, (encodedIds, textAsString));
} }
tokens = new Token[encodedIds.Length]; tokens.Add(new Token(encodedIds[0], textAsString, (offset, text.Length)));
tokens[0] = new Token(encodedIds[0], textAsString, (0, text.Length));
for (int i = 1; i < encodedIds.Length; i++) for (int i = 1; i < encodedIds.Length; i++)
{ {
// One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width. // One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width.
tokens[i] = new Token(encodedIds[i], "", (text.Length, 0)); tokens.Add(new Token(encodedIds[i], "", (offset + text.Length, 0)));
}
}
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out _, out _);
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(null, text, considerPreTokenization, considerNormalization, out _, out _);
/// <summary>
/// Encodes input text to token Ids up to maximum number of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
/// <summary>
/// Encodes input text to token Ids up to maximum number of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
=> EncodeToIds(null, text, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
private IReadOnlyList<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
{
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
} }
return tokens; if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
textLength = 0;
normalizedString = null;
return [];
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
List<int> ids = new();
if (splits is not null)
{
textLength = 0;
foreach ((int Offset, int Length) split in splits)
{
EncodeToIds(textSpanToEncode.Slice(split.Offset, split.Length), ids, out int length, maxTokenCount - ids.Count);
textLength = split.Offset + length;
if (length < split.Length || ids.Count >= maxTokenCount)
{
break;
}
}
}
else
{
EncodeToIds(textSpanToEncode, ids, out textLength);
}
return ids;
} }
/// <summary> /// <summary>
@ -318,11 +423,11 @@ namespace Microsoft.ML.Tokenizers
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <param name="accumulatedIds">The list of accumulated Ids.</param> /// <param name="accumulatedIds">The list of accumulated Ids.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param> /// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param> /// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns> /// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue) private int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokenCount = int.MaxValue)
{ {
Debug.Assert(maxTokens > 0); Debug.Assert(maxTokenCount > 0);
if (text.IsEmpty) if (text.IsEmpty)
{ {
@ -332,7 +437,7 @@ namespace Microsoft.ML.Tokenizers
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value)) if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
{ {
if (value.Ids.Length <= maxTokens) if (value.Ids.Length <= maxTokenCount)
{ {
accumulatedIds.AddRange(value.Ids); accumulatedIds.AddRange(value.Ids);
textLength = text.Length; textLength = text.Length;
@ -362,7 +467,7 @@ namespace Microsoft.ML.Tokenizers
} }
int result; int result;
if (encodedIds.Length <= maxTokens) if (encodedIds.Length <= maxTokenCount)
{ {
accumulatedIds.AddRange(encodedIds); accumulatedIds.AddRange(encodedIds);
textLength = text.Length; textLength = text.Length;
@ -378,6 +483,104 @@ namespace Microsoft.ML.Tokenizers
return result; return result;
} }
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
=> CountTokens(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out _, out _);
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
public override int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
=> CountTokens(null, text, considerPreTokenization, considerNormalization, out _, out _);
/// <summary>
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// </returns>
public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
{
tokenCount = CountTokens(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
return textLength;
}
/// <summary>
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// </returns>
public override int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
{
tokenCount = CountTokens(null, text, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
return textLength;
}
private int CountTokens(string? text, ReadOnlySpan<char> textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
{
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
}
textLength = 0;
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
normalizedString = null;
return 0;
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
int count = 0;
if (splits is not null)
{
foreach ((int Offset, int Length) split in splits)
{
count += CountTokens(textSpanToEncode.Slice(split.Offset, split.Length), out int length, maxTokenCount - count);
textLength = split.Offset + length;
if (length < split.Length || count >= maxTokenCount)
{
break;
}
}
}
else
{
count = CountTokens(textSpanToEncode, out textLength, maxTokenCount);
}
return count;
}
/// <summary> /// <summary>
/// Get the number of tokens that the input text will be encoded to. /// Get the number of tokens that the input text will be encoded to.
/// </summary> /// </summary>
@ -385,7 +588,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param> /// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param> /// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns> /// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue) private int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue)
{ {
Debug.Assert(maxTokens > 0); Debug.Assert(maxTokens > 0);
@ -439,6 +642,76 @@ namespace Microsoft.ML.Tokenizers
return result; return result;
} }
/// <summary>
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if normalization is enabled;
/// conversely, if all tokens fit, the result will be 0.
/// </returns>
public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> LastIndexOf(text, Span<char>.Empty, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
/// <summary>
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="normalizedString"/>; conversely, if all tokens fit, the result will be 0.
/// </returns>
public override int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
=> LastIndexOf(null, text, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
private int LastIndexOf(string? text, ReadOnlySpan<char> textSpan, int maxTokenCount, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int tokenCount)
{
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
}
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
normalizedString = null;
tokenCount = 0;
return 0;
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
if (splits is not null)
{
tokenCount = 0;
foreach ((int Offset, int Length) split in splits.Reverse())
{
tokenCount += CountTokensFromEnd(textSpanToEncode.Slice(split.Offset, split.Length), out int textIndex, maxTokenCount - tokenCount);
if (textIndex > 0 || tokenCount >= maxTokenCount)
{
return split.Offset + textIndex;
}
}
return 0;
}
else
{
tokenCount = CountTokensFromEnd(textSpanToEncode, out int textLength, maxTokenCount);
return textLength;
}
}
/// <summary> /// <summary>
/// Get the number of tokens that the input text will be encoded to. /// Get the number of tokens that the input text will be encoded to.
/// </summary> /// </summary>
@ -446,7 +719,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param> /// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
/// <param name="maxTokens">The maximum number of tokens to encode.</param> /// <param name="maxTokens">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns> /// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue) private int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue)
{ {
Debug.Assert(maxTokens > 0); Debug.Assert(maxTokens > 0);
@ -510,7 +783,7 @@ namespace Microsoft.ML.Tokenizers
{ {
if (token.IsEmpty) if (token.IsEmpty)
{ {
return 0; return null;
} }
if (_cache.TryGetValue(token, out (int[] Ids, string Token) value)) if (_cache.TryGetValue(token, out (int[] Ids, string Token) value))
@ -676,8 +949,8 @@ namespace Microsoft.ML.Tokenizers
[ [
// chat // chat
( "gpt-4-", ModelEncoding.Cl100kBase), // e.g., gpt-4-0314, etc., plus gpt-4-32k ( "gpt-4-", ModelEncoding.Cl100kBase), // e.g., gpt-4-0314, etc., plus gpt-4-32k
( "gpt-3.5-turbo-", ModelEncoding.Cl100kBase), // e.g, gpt-3.5-turbo-0301, -0401, etc. ( "gpt-3.5-", ModelEncoding.Cl100kBase), // e.g, gpt-3.5-turbo-0301, -0401, etc.
( "gpt-35-turbo-", ModelEncoding.Cl100kBase ) // Azure deployment name ( "gpt-35-", ModelEncoding.Cl100kBase ) // Azure deployment name
]; ];
private static readonly Dictionary<string, ModelEncoding> _modelToEncoding = private static readonly Dictionary<string, ModelEncoding> _modelToEncoding =
@ -687,6 +960,7 @@ namespace Microsoft.ML.Tokenizers
{ "gpt-4", ModelEncoding.Cl100kBase }, { "gpt-4", ModelEncoding.Cl100kBase },
{ "gpt-3.5-turbo", ModelEncoding.Cl100kBase }, { "gpt-3.5-turbo", ModelEncoding.Cl100kBase },
{ "gpt-3.5-turbo-16k", ModelEncoding.Cl100kBase }, { "gpt-3.5-turbo-16k", ModelEncoding.Cl100kBase },
{ "gpt-35", ModelEncoding.Cl100kBase }, // Azure deployment name
{ "gpt-35-turbo", ModelEncoding.Cl100kBase }, // Azure deployment name { "gpt-35-turbo", ModelEncoding.Cl100kBase }, // Azure deployment name
{ "gpt-35-turbo-16k", ModelEncoding.Cl100kBase }, // Azure deployment name { "gpt-35-turbo-16k", ModelEncoding.Cl100kBase }, // Azure deployment name
@ -817,26 +1091,13 @@ namespace Microsoft.ML.Tokenizers
private static readonly ConcurrentDictionary<string, (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase); private static readonly ConcurrentDictionary<string, (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase);
internal static Tokenizer CreateTokenizerForModel( internal static Tokenizer CreateForModel(
string modelName, ModelEncoding modelEncoding,
IReadOnlyDictionary<string, int>? extraSpecialTokens = null, string? modelName = null,
Normalizer? normalizer = null) IReadOnlyDictionary<string, int>? extraSpecialTokens = null,
Normalizer? normalizer = null)
{ {
if (string.IsNullOrEmpty(modelName)) (Dictionary<string, int> SpecialTokens, Regex Regex, string VocabFile) tiktokenConfiguration = GetTiktokenConfigurations(modelEncoding, modelName);
{
throw new ArgumentNullException(nameof(modelName));
}
return CreateTokenizerForModel(GetModelEncoding(modelName), modelName, extraSpecialTokens, normalizer);
}
internal static Tokenizer CreateTokenizerForModel(
ModelEncoding modelEncoding,
string? modelName = null,
IReadOnlyDictionary<string, int>? extraSpecialTokens = null,
Normalizer? normalizer = null)
{
(Dictionary<string, int> SpecialTokens, Regex Regex, string VocabFile) tiktokenConfiguration = Tiktoken.GetTiktokenConfigurations(modelEncoding, modelName);
if (extraSpecialTokens is not null) if (extraSpecialTokens is not null)
{ {
@ -858,10 +1119,14 @@ namespace Microsoft.ML.Tokenizers
_tiktokenCache.TryAdd(tiktokenConfiguration.VocabFile, cache); _tiktokenCache.TryAdd(tiktokenConfiguration.VocabFile, cache);
} }
return new Tokenizer( return new Tiktoken(
new Tiktoken(cache.encoder, cache.decoder, cache.vocab, tiktokenConfiguration.SpecialTokens, LruCache<int[]>.DefaultCacheSize), cache.encoder,
new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens), cache.decoder,
normalizer); cache.vocab,
new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
tiktokenConfiguration.SpecialTokens,
normalizer,
LruCache<int[]>.DefaultCacheSize);
} }
} }
} }

Просмотреть файл

@ -97,24 +97,24 @@ namespace Microsoft.ML.Tokenizers
return changes; return changes;
} }
public void MergeAll(Dictionary<Pair<int>, (int, int)> merges, float? dropout) public void MergeAll(Dictionary<Pair<int>, (int, int)> merges, float? dropout, ref PriorityQueue<Merge>? priorityQueue)
{ {
// Queue<Merge> queue = new Queue<Merge>(_symbols.Count); priorityQueue ??= new PriorityQueue<Merge>(_symbols.Count);
PriorityQueue<Merge> queue = new PriorityQueue<Merge>(_symbols.Count); priorityQueue.Clear();
Vec<Merge> skip = new Vec<Merge>(queue.Count); Vec<Merge> skip = new Vec<Merge>(priorityQueue.Count);
for (int i = 0; i < _symbols.Count - 1; i++) for (int i = 0; i < _symbols.Count - 1; i++)
{ {
if (merges.TryGetValue(Pair<int>.Create(_symbols[i].C, _symbols[i + 1].C), out (int m1, int m2) value)) if (merges.TryGetValue(Pair<int>.Create(_symbols[i].C, _symbols[i + 1].C), out (int m1, int m2) value))
{ {
queue.Enqueue(new Merge(i, value.m1, value.m2)); priorityQueue.Enqueue(new Merge(i, value.m1, value.m2));
} }
} }
while (queue.Count > 0) while (priorityQueue.Count > 0)
{ {
Merge top = queue.Dequeue(); Merge top = priorityQueue.Dequeue();
if (dropout.HasValue && (_random ??= new()).NextDouble() < dropout) if (dropout.HasValue && (_random ??= new()).NextDouble() < dropout)
{ {
skip.Push(top); skip.Push(top);
@ -124,7 +124,7 @@ namespace Microsoft.ML.Tokenizers
// Re-insert the skipped elements // Re-insert the skipped elements
for (int i = 0; i < skip.Count; i++) for (int i = 0; i < skip.Count; i++)
{ {
queue.Enqueue(skip[i]); priorityQueue.Enqueue(skip[i]);
} }
skip.Clear(); skip.Clear();
@ -166,7 +166,7 @@ namespace Microsoft.ML.Tokenizers
if (merges.TryGetValue(newPair, out value)) if (merges.TryGetValue(newPair, out value))
{ {
queue.Enqueue(new Merge(current.Prev, value.m1, value.m2)); priorityQueue.Enqueue(new Merge(current.Prev, value.m1, value.m2));
} }
} }
@ -178,7 +178,7 @@ namespace Microsoft.ML.Tokenizers
Pair<int> newPair = Pair<int>.Create(current.C, nextSymbol.C); Pair<int> newPair = Pair<int>.Create(current.C, nextSymbol.C);
if (merges.TryGetValue(newPair, out value)) if (merges.TryGetValue(newPair, out value))
{ {
queue.Enqueue(new Merge(top.Pos, value.m1, value.m2)); priorityQueue.Enqueue(new Merge(top.Pos, value.m1, value.m2));
} }
} }
} }
@ -289,19 +289,16 @@ namespace Microsoft.ML.Tokenizers
return sb.ToString(); return sb.ToString();
} }
public List<Token> ToTokens(SortedDictionary<int, string> vocabReverse) public void ToTokens(SortedDictionary<int, string> vocabReverse, List<Token> tokens, int offset)
{ {
List<Token> tokens = new(SymbolsCount);
int index = 0; int index = 0;
for (int i = 0; i < SymbolsCount; i++) for (int i = 0; i < SymbolsCount; i++)
{ {
int endIndex = index + _symbols[i].Len; int endIndex = index + _symbols[i].Len;
tokens.Add(new Token(_symbols[i].C, vocabReverse[_symbols[i].C], (index, _symbols[i].Len))); tokens.Add(new Token(_symbols[i].C, vocabReverse[_symbols[i].C], (index + offset, _symbols[i].Len)));
index += _symbols[i].Len; index += _symbols[i].Len;
} }
return tokens;
} }
} }
} }

Просмотреть файл

@ -3,6 +3,8 @@
// See the LICENSE file in the project root for more information. // See the LICENSE file in the project root for more information.
using System; using System;
using System.Buffers;
using System.Diagnostics;
namespace Microsoft.ML.Tokenizers namespace Microsoft.ML.Tokenizers
{ {
@ -22,5 +24,27 @@ namespace Microsoft.ML.Tokenizers
/// <param name="original">The original string to normalize to lowercase form.</param> /// <param name="original">The original string to normalize to lowercase form.</param>
/// <returns>The lower-cased normalized string.</returns> /// <returns>The lower-cased normalized string.</returns>
public override string Normalize(string original) => original.ToLowerInvariant(); public override string Normalize(string original) => original.ToLowerInvariant();
/// <summary>
/// Lowercase the original string.
/// </summary>
/// <param name="original">The original string to normalize to lowercase form.</param>
/// <returns>The lower-cased normalized string.</returns>
public override string Normalize(ReadOnlySpan<char> original)
{
if (original.IsEmpty)
{
return string.Empty;
}
char[] arrayPoolArray = ArrayPool<char>.Shared.Rent(original.Length);
int length = original.ToLowerInvariant(arrayPoolArray);
Debug.Assert(length == original.Length);
string result = new string(arrayPoolArray, 0, length);
ArrayPool<char>.Shared.Return(arrayPoolArray);
return result;
}
} }
} }

Просмотреть файл

@ -17,5 +17,12 @@ namespace Microsoft.ML.Tokenizers
/// <param name="original">The original string to normalize.</param> /// <param name="original">The original string to normalize.</param>
/// <returns>The normalized string.</returns> /// <returns>The normalized string.</returns>
public abstract string Normalize(string original); public abstract string Normalize(string original);
/// <summary>
/// Process the original string to modify it and obtain a normalized string.
/// </summary>
/// <param name="original">The original string to normalize.</param>
/// <returns>The normalized string.</returns>
public abstract string Normalize(ReadOnlySpan<char> original);
} }
} }

Просмотреть файл

@ -10,14 +10,14 @@ namespace Microsoft.ML.Tokenizers
/// <summary> /// <summary>
/// Normalize the string to lowercase form before processing it with the tokenizer. /// Normalize the string to lowercase form before processing it with the tokenizer.
/// </summary> /// </summary>
public sealed class LlamaNormalizer : Normalizer public sealed class SentencePieceNormalizer : Normalizer
{ {
internal const char DummyPrefix = '\u2581'; // '▁' (LOWER ONE EIGHT BLOCK) internal const char DummyPrefix = '\u2581'; // '▁' (LOWER ONE EIGHT BLOCK)
/// <summary> /// <summary>
/// Creates a LowerCaseNormalizer object. /// Creates a LowerCaseNormalizer object.
/// </summary> /// </summary>
public LlamaNormalizer(bool removeExtraWhiteSpaces, bool addDummyPrefix, bool escapeWhiteSpaces, bool treatWhitespaceAsSuffix) public SentencePieceNormalizer(bool removeExtraWhiteSpaces, bool addDummyPrefix, bool escapeWhiteSpaces, bool treatWhitespaceAsSuffix)
{ {
RemoveExtraWhiteSpaces = removeExtraWhiteSpaces; RemoveExtraWhiteSpaces = removeExtraWhiteSpaces;
AddDummyPrefix = addDummyPrefix; AddDummyPrefix = addDummyPrefix;
@ -40,7 +40,7 @@ namespace Microsoft.ML.Tokenizers
public bool TreatWhitespaceAsSuffix { get; } public bool TreatWhitespaceAsSuffix { get; }
/// <summary> /// <summary>
/// Normalize the original string according to SentencePiece normalization with Llama model. /// Normalize the original string according to SentencePiece normalization.
/// </summary> /// </summary>
/// <param name="original">The original string to normalize.</param> /// <param name="original">The original string to normalize.</param>
/// <returns>The normalized string.</returns> /// <returns>The normalized string.</returns>
@ -51,6 +51,16 @@ namespace Microsoft.ML.Tokenizers
return string.Empty; return string.Empty;
} }
return Normalize(original.AsSpan());
}
/// <summary>
/// Normalize the original string according to SentencePiece normalization.
/// </summary>
/// <param name="original">The original string to normalize.</param>
/// <returns>The normalized string.</returns>
public override string Normalize(ReadOnlySpan<char> original)
{
int startIndex = 0; int startIndex = 0;
int endIndex = original.Length - 1; int endIndex = original.Length - 1;

Просмотреть файл

@ -3,6 +3,8 @@
// See the LICENSE file in the project root for more information. // See the LICENSE file in the project root for more information.
using System; using System;
using System.Buffers;
using System.Diagnostics;
namespace Microsoft.ML.Tokenizers namespace Microsoft.ML.Tokenizers
{ {
@ -22,5 +24,27 @@ namespace Microsoft.ML.Tokenizers
/// <param name="original">The original string to normalize to uppercase form.</param> /// <param name="original">The original string to normalize to uppercase form.</param>
/// <returns>The upper-cased normalized string.</returns> /// <returns>The upper-cased normalized string.</returns>
public override string Normalize(string original) => original.ToUpperInvariant(); public override string Normalize(string original) => original.ToUpperInvariant();
/// <summary>
/// Uppercase the original string.
/// </summary>
/// <param name="original">The original string to normalize to uppercase form.</param>
/// <returns>The upper-cased normalized string.</returns>
public override string Normalize(ReadOnlySpan<char> original)
{
if (original.IsEmpty)
{
return string.Empty;
}
char[] arrayPoolArray = ArrayPool<char>.Shared.Rent(original.Length);
int length = original.ToUpperInvariant(arrayPoolArray);
Debug.Assert(length == original.Length);
string result = new string(arrayPoolArray, 0, length);
ArrayPool<char>.Shared.Return(arrayPoolArray);
return result;
}
} }
} }

Просмотреть файл

@ -3,66 +3,12 @@
// See the LICENSE file in the project root for more information. // See the LICENSE file in the project root for more information.
using System; using System;
using System.Buffers;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
namespace Microsoft.ML.Tokenizers namespace Microsoft.ML.Tokenizers
{ {
/// <summary>
/// This Split contains the underlying split token as well as its offsets
/// in the original string. These offsets are in the `original` referential.
/// It also contains any `Token` associated to the current split.
/// </summary>
public struct Split : IEquatable<Split>
{
private readonly string? _originalString;
private string? _tokenString;
/// <summary>
/// Gets the underlying split token. Each SubString is represented by a token
/// and in the end we might be carrying a lot of SubString representing various parts of the
/// original input string.
/// </summary>
public string TokenString => _tokenString ??= _originalString!.Substring(Offset.Index, Offset.Length);
/// <summary>
/// Gets the underlying split token as a span.
/// </summary>
public ReadOnlySpan<char> TokenSpan => _tokenString is string s ? s.AsSpan() : _originalString.AsSpan(Offset.Index, Offset.Length);
/// <summary>
/// Returns the offset mapping to the original string
/// </summary>
public (int Index, int Length) Offset { get; }
/// <summary>
/// create a Split object using the token and the offset
/// </summary>
/// <param name="token">The token string</param>
/// <param name="offset">The offset mapping to the original string</param>
public Split(string token, (int Index, int Length) offset)
{
_tokenString = token;
Offset = offset;
}
internal Split(string originalString, string? token, (int Index, int Length) offset)
{
_originalString = originalString;
_tokenString = token;
Offset = offset;
}
/// <summary>
/// Indicates whether the current Split object is equal to another Split object.
/// </summary>
/// <param name="other">The Split object to compare with the current object.</param>
public bool Equals(Split other) =>
(_originalString == other._originalString || TokenString == other.TokenString) &&
Offset.Index == other.Offset.Index &&
Offset.Length == other.Offset.Length;
}
/// <summary> /// <summary>
/// Base class for all pre-tokenizers classes. /// Base class for all pre-tokenizers classes.
/// The PreTokenizer is in charge of doing the pre-segmentation step. /// The PreTokenizer is in charge of doing the pre-segmentation step.
@ -70,23 +16,54 @@ namespace Microsoft.ML.Tokenizers
public abstract class PreTokenizer public abstract class PreTokenizer
{ {
/// <summary> /// <summary>
/// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string. /// Get the offsets and lengths of the tokens relative to the <paramref name="text"/>.
/// </summary> /// </summary>
/// <param name="text">The string to split into tokens.</param> /// <param name="text">The string to split into tokens.</param>
/// <returns>The list of the splits containing the tokens and the token's offsets to the original string.</returns> /// <returns>The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.</returns>
public abstract IEnumerable<Split> PreTokenize(string text); public abstract IEnumerable<(int Offset, int Length)> PreTokenize(string text);
internal static IEnumerable<Split> SplitText(string text, Regex regex) /// <summary>
/// Get the offsets and lengths of the tokens relative to the original string.
/// </summary>
/// <param name="text">The character span to split into tokens.</param>
/// <returns>The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.</returns>
public abstract IEnumerable<(int Offset, int Length)> PreTokenize(ReadOnlySpan<char> text);
internal static IEnumerable<(int Offset, int Length)> SplitText(string text, Regex regex)
{ {
(int Offset, int Length) match; (int Offset, int Length) match;
int beginning = 0; int beginning = 0;
while (TryGetMatch(regex, text, beginning, text.Length - beginning, out match)) while (TryGetMatch(regex, text, beginning, text.Length - beginning, out match))
{ {
yield return new Split(text, null, (match.Offset, match.Length)); yield return (match.Offset, match.Length);
beginning = match.Offset + match.Length; beginning = match.Offset + match.Length;
} }
} }
internal static IEnumerable<(int Offset, int Length)> SplitText(ReadOnlySpan<char> text, Regex regex)
{
#if NET7_0_OR_GREATER
char[] buffer = ArrayPool<char>.Shared.Rent(text.Length);
text.CopyTo(buffer);
return SplitText(buffer, regex, text.Length);
static IEnumerable<(int Offset, int Length)> SplitText(char[] text, Regex regex, int textLength)
{
(int Offset, int Length) match;
int beginning = 0;
while (TryGetMatch(regex, text, beginning, textLength - beginning, out match))
{
yield return (match.Offset, match.Length);
beginning = match.Offset + match.Length;
}
ArrayPool<char>.Shared.Return(text);
}
#else
return SplitText(text.ToString(), regex);
#endif // NET7_0_OR_GREATER
}
internal static bool TryGetMatch(Regex regex, string text, int beginning, int length, out (int offset, int length) match) internal static bool TryGetMatch(Regex regex, string text, int beginning, int length, out (int offset, int length) match)
{ {
#if NET7_0_OR_GREATER #if NET7_0_OR_GREATER
@ -106,5 +83,18 @@ namespace Microsoft.ML.Tokenizers
match = default; match = default;
return false; return false;
} }
#if NET7_0_OR_GREATER
internal static bool TryGetMatch(Regex regex, scoped ReadOnlySpan<char> text, int beginning, int length, out (int offset, int length) match)
{
foreach (ValueMatch m in regex.EnumerateMatches(text.Slice(beginning, length)))
{
match = (beginning + m.Index, m.Length);
return true;
}
match = default;
return false;
}
#endif // NET7_0_OR_GREATER
} }
} }

Просмотреть файл

@ -18,15 +18,30 @@ namespace Microsoft.ML.Tokenizers
public static RobertaPreTokenizer Instance { get; } = new RobertaPreTokenizer(); public static RobertaPreTokenizer Instance { get; } = new RobertaPreTokenizer();
/// <summary> /// <summary>
/// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string. /// Get the offsets and lengths of the tokens relative to the <paramref name="text"/>.
/// </summary> /// </summary>
/// <param name="text">The string to split into tokens.</param> /// <param name="text">The string to split into tokens.</param>
/// <returns>The list of the splits containing the tokens and the token's offsets to the original string.</returns> /// <returns>The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.</returns>
public override IEnumerable<Split> PreTokenize(string text) public override IEnumerable<(int Offset, int Length)> PreTokenize(string text)
{ {
if (string.IsNullOrEmpty(text)) if (string.IsNullOrEmpty(text))
{ {
return Array.Empty<Split>(); return [];
}
return SplitText(text, Tiktoken.P50kBaseRegex());
}
/// <summary>
/// Get the offsets and lengths of the tokens relative to the <paramref name="text"/>.
/// </summary>
/// <param name="text">The string to split into tokens.</param>
/// <returns>The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.</returns>
public override IEnumerable<(int Offset, int Length)> PreTokenize(ReadOnlySpan<char> text)
{
if (text.IsEmpty)
{
return [];
} }
return SplitText(text, Tiktoken.P50kBaseRegex()); return SplitText(text, Tiktoken.P50kBaseRegex());

Просмотреть файл

@ -1,35 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
namespace Microsoft.ML.Tokenizers
{
/// <summary>
/// The pre-tokenizer for SentencePiece tokenizers.
/// </summary>
internal sealed partial class SentencePiecePreTokenizer : PreTokenizer
{
/// <summary>
/// Gets a singleton instance of the Roberta pre-tokenizer..
/// </summary>
public static SentencePiecePreTokenizer Instance { get; } = new SentencePiecePreTokenizer();
/// <summary>
/// Return the whole text as one chunk.
/// </summary>
/// <param name="text">The string to split into tokens.</param>
/// <returns>The original string as one chunk.</returns>
public override IEnumerable<Split> PreTokenize(string text)
{
if (string.IsNullOrEmpty(text))
{
yield break;
}
yield return new Split(text, (0, text.Length));
}
}
}

Просмотреть файл

@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information. // See the LICENSE file in the project root for more information.
using System; using System;
using System.Buffers;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
@ -39,20 +40,20 @@ namespace Microsoft.ML.Tokenizers
} }
/// <summary> /// <summary>
/// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string. /// Get the offsets and lengths of the tokens relative to the <paramref name="text"/>.
/// </summary> /// </summary>
/// <param name="text">The string to split into tokens.</param> /// <param name="text">The string to split into tokens.</param>
/// <returns>The list of the splits containing the tokens and the token's offsets to the original string.</returns> /// <returns>The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.</returns>
public override IEnumerable<Split> PreTokenize(string text) public override IEnumerable<(int Offset, int Length)> PreTokenize(string text)
{ {
if (string.IsNullOrEmpty(text)) if (string.IsNullOrEmpty(text))
{ {
return Array.Empty<Split>(); return [];
} }
return SplitText(text, _regex, _specialTokensRegex); return SplitText(text, _regex, _specialTokensRegex);
static IEnumerable<Split> SplitText(string text, Regex regex, Regex? specialTokensRegex) static IEnumerable<(int Offset, int Length)> SplitText(string text, Regex regex, Regex? specialTokensRegex)
{ {
(int Offset, int Length) match; (int Offset, int Length) match;
int beginning = 0; int beginning = 0;
@ -69,21 +70,77 @@ namespace Microsoft.ML.Tokenizers
while (TryGetMatch(regex, text, beginning, specialMatch.Offset - beginning, out match)) while (TryGetMatch(regex, text, beginning, specialMatch.Offset - beginning, out match))
{ {
yield return new Split(text, null, (match.Offset, match.Length)); yield return (match.Offset, match.Length);
beginning = match.Offset + match.Length; beginning = match.Offset + match.Length;
} }
yield return new Split(text, null, (specialMatch.Offset, specialMatch.Length)); yield return (specialMatch.Offset, specialMatch.Length);
beginning = specialMatch.Offset + specialMatch.Length; beginning = specialMatch.Offset + specialMatch.Length;
} }
} }
while (TryGetMatch(regex, text, beginning, text.Length - beginning, out match)) while (TryGetMatch(regex, text, beginning, text.Length - beginning, out match))
{ {
yield return new Split(text, null, (match.Offset, match.Length)); yield return (match.Offset, match.Length);
beginning = match.Length + match.Offset; beginning = match.Length + match.Offset;
} }
} }
} }
/// <summary>
/// Get the offsets and lengths of the tokens relative to the <paramref name="text"/>.
/// </summary>
/// <param name="text">The string to split into tokens.</param>
/// <returns>The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.</returns>
public override IEnumerable<(int Offset, int Length)> PreTokenize(ReadOnlySpan<char> text)
{
if (text.IsEmpty)
{
return [];
}
#if NET7_0_OR_GREATER
char[] buffer = ArrayPool<char>.Shared.Rent(text.Length);
text.CopyTo(buffer);
return SplitText(buffer, _regex, _specialTokensRegex, text.Length);
static IEnumerable<(int Offset, int Length)> SplitText(char[] text, Regex regex, Regex? specialTokensRegex, int textLength)
{
(int Offset, int Length) match;
int beginning = 0;
if (specialTokensRegex is not null)
{
while (true)
{
(int Offset, int Length) specialMatch;
if (!TryGetMatch(specialTokensRegex, text.AsSpan(), beginning, textLength - beginning, out specialMatch))
{
break;
}
while (TryGetMatch(regex, text.AsSpan(), beginning, specialMatch.Offset - beginning, out match))
{
yield return (match.Offset, match.Length);
beginning = match.Offset + match.Length;
}
yield return (specialMatch.Offset, specialMatch.Length);
beginning = specialMatch.Offset + specialMatch.Length;
}
}
while (TryGetMatch(regex, text.AsSpan(), beginning, textLength - beginning, out match))
{
yield return (match.Offset, match.Length);
beginning = match.Length + match.Offset;
}
ArrayPool<char>.Shared.Return(text);
}
#else
return PreTokenize(text.ToString());
#endif // NET7_0_OR_GREATER
}
} }
} }

Просмотреть файл

@ -29,15 +29,30 @@ namespace Microsoft.ML.Tokenizers
#endif #endif
/// <summary> /// <summary>
/// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string. /// Get the offsets and lengths of the tokens relative to the <paramref name="text"/>.
/// </summary> /// </summary>
/// <param name="text">The string to split into tokens.</param> /// <param name="text">The string to split into tokens.</param>
/// <returns>The list of the splits containing the tokens and the token's offsets to the original string.</returns> /// <returns>The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.</returns>
public override IEnumerable<Split> PreTokenize(string text) public override IEnumerable<(int Offset, int Length)> PreTokenize(string text)
{ {
if (string.IsNullOrEmpty(text)) if (string.IsNullOrEmpty(text))
{ {
return Array.Empty<Split>(); return [];
}
return SplitText(text, PretokenizeRegex());
}
/// <summary>
/// Get the offsets and lengths of the tokens relative to the <paramref name="text"/>.
/// </summary>
/// <param name="text">The string to split into tokens.</param>
/// <returns>The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.</returns>
public override IEnumerable<(int Offset, int Length)> PreTokenize(ReadOnlySpan<char> text)
{
if (text.IsEmpty)
{
return [];
} }
return SplitText(text, PretokenizeRegex()); return SplitText(text, PretokenizeRegex());

Просмотреть файл

@ -12,7 +12,7 @@ namespace Microsoft.ML.Tokenizers
/// Represent the token produced from the tokenization process containing the token substring, /// Represent the token produced from the tokenization process containing the token substring,
/// the id associated to the token substring, and the offset mapping to the original string. /// the id associated to the token substring, and the offset mapping to the original string.
/// </summary> /// </summary>
public sealed class Token public readonly struct Token
{ {
/// <summary> /// <summary>
/// Gets the Id value associated to the token. /// Gets the Id value associated to the token.
@ -22,12 +22,12 @@ namespace Microsoft.ML.Tokenizers
/// <summary> /// <summary>
/// Gets the token string value. /// Gets the token string value.
/// </summary> /// </summary>
public string Value { get; set; } public string Value { get; }
/// <summary> /// <summary>
/// Gets the offset mapping to the original string. /// Gets the offset mapping to the original string.
/// </summary> /// </summary>
public (int Index, int Length) Offset { get; internal set; } public (int Index, int Length) Offset { get; }
/// <summary> /// <summary>
/// Construct a new Token object using the token value, Id, and the offset mapping to the original string. /// Construct a new Token object using the token value, Id, and the offset mapping to the original string.

Просмотреть файл

@ -8,6 +8,7 @@ using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Text;
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -15,263 +16,279 @@ using System.Threading.Tasks;
namespace Microsoft.ML.Tokenizers namespace Microsoft.ML.Tokenizers
{ {
/// <summary> /// <summary>
/// A Tokenizer works as a pipeline. It processes some raw text as input and outputs a EncodingResult object. /// serves as an abstraction for concrete tokenizers, enabling the encoding of text into tokens and IDs, as well as the decoding of IDs back into text.
/// </summary> /// </summary>
public partial class Tokenizer public abstract class Tokenizer
{ {
/// <summary> /// <summary>
/// Create a new Tokenizer object. /// Gets the PreTokenizer used by the Tokenizer.
/// </summary> /// </summary>
/// <param name="model">The Model in use by the Tokenizer.</param> public virtual PreTokenizer? PreTokenizer => null;
/// <param name="preTokenizer">The optional PreTokenizer in use by the Tokenizer. WhiteSpace PreTokenizer will be used if this parameter is null.</param>
/// <param name="normalizer">The optional Normalizer in use by the Tokenizer.</param>
public Tokenizer(Model model, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null)
{
Model = model;
PreTokenizer = preTokenizer ?? WhiteSpace.Instance;
Normalizer = normalizer;
}
/// <summary> /// <summary>
/// Gets the Model in use by the Tokenizer. /// Gets the Normalizer in use by the Tokenizer.
/// </summary> /// </summary>
public Model Model { get; } public virtual Normalizer? Normalizer => null;
/// <summary> /// <summary>
/// Gets or sets the PreTokenizer used by the Tokenizer. /// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
/// </summary>
public PreTokenizer PreTokenizer { get; }
/// <summary>
/// Gets or sets the Normalizer in use by the Tokenizer.
/// </summary>
public Normalizer? Normalizer { get; }
/// <summary>
/// Encodes input text to object has the tokens list, tokens Ids, tokens offset mapping.
/// </summary> /// </summary>
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <returns>The tokenization result includes the tokens list, tokens Ids, tokens offset mapping.</returns> /// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
public EncodingResult Encode(string text) /// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
{ /// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
if (text is null) /// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
{ public virtual IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(text.AsSpan(), out normalizedString, considerPreTokenization, considerNormalization);
throw new ArgumentNullException(nameof(text));
}
string normalized = Normalizer is null ? text : Normalizer.Normalize(text);
bool offsetsMappedToOriginal = true;
EncodingResult encoding = new(text, normalized, PreTokenizer.PreTokenize(normalized), offsetsMappedToOriginal);
foreach (Split split in encoding.Splits)
{
IReadOnlyList<Token> tokens = Model.Encode(split.TokenString.AsSpan());
foreach (Token token in tokens)
{
token.Offset = (token.Offset.Index + split.Offset.Index, token.Offset.Length);
}
encoding.AddTokens(tokens);
}
return encoding;
}
/// <summary> /// <summary>
/// Encodes input text to tokens Ids. /// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
/// </summary> /// </summary>
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
public abstract IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true);
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns> /// <returns>The list of encoded Ids.</returns>
public IReadOnlyList<int> EncodeToIds(string text) public virtual IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true) => EncodeToIds(text.AsSpan(), considerPreTokenization, considerNormalization);
{
if (text is null)
{
throw new ArgumentNullException(nameof(text));
}
string normalized = Normalizer is not null ? Normalizer.Normalize(text) : text;
List<int> idsList = new();
foreach (Split split in PreTokenizer.PreTokenize(normalized))
{
Model.EncodeToIds(split.TokenSpan, idsList, out _);
}
return idsList;
}
/// <summary> /// <summary>
/// Encodes input text to tokens Ids up to maximum number of tokens. /// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public abstract IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true);
/// <summary>
/// Encodes input text to token Ids up to maximum number of tokens.
/// </summary> /// </summary>
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param> /// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="processedText">If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text.</param> /// <param name="normalizedText">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param> /// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns> /// <returns>The list of encoded Ids.</returns>
public IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string processedText, out int textLength) public virtual IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
{ => EncodeToIds(text.AsSpan(), maxTokenCount, out normalizedText, out textLength, considerPreTokenization, considerNormalization);
processedText = text;
textLength = 0;
if (text is null) /// <summary>
{ /// Encodes input text to token Ids up to maximum number of tokens.
throw new ArgumentNullException(nameof(text)); /// </summary>
} /// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
if (maxTokenCount <= 0) /// <param name="normalizedText">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
{ /// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than 0."); /// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
} /// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
if (Normalizer is not null) public abstract IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true);
{
processedText = Normalizer.Normalize(text);
}
List<int> idsList = new();
foreach (Split split in PreTokenizer.PreTokenize(processedText))
{
Model.EncodeToIds(split.TokenSpan, idsList, out int length, maxTokenCount - idsList.Count);
if (length < split.Offset.Length || idsList.Count >= maxTokenCount)
{
break;
}
}
return idsList;
}
/// <summary> /// <summary>
/// Get the number of tokens that the input text will be encoded to. /// Get the number of tokens that the input text will be encoded to.
/// </summary> /// </summary>
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <returns>The number of tokens Ids that the input text will be encoded to.</returns> /// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <exception cref="ArgumentNullException">The input text is null.</exception> /// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <exception cref="ArgumentException">Unable to encode the text.</exception> /// <returns>The number of token Ids that the input text will be encoded to.</returns>
public int CountTokens(string text) public virtual int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
{ => CountTokens(text.AsSpan(), considerPreTokenization, considerNormalization);
if (text is null)
{
throw new ArgumentNullException(nameof(text));
}
string normalized = Normalizer is not null ? Normalizer.Normalize(text) : text; /// <summary>
/// Get the number of tokens that the input text will be encoded to.
int idsCount = 0; /// </summary>
foreach (Split split in PreTokenizer.PreTokenize(normalized)) /// <param name="text">The text to encode.</param>
{ /// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
idsCount += Model.CountTokens(split.TokenSpan, out _); /// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
} /// <returns>The number of token Ids that the input text will be encoded to.</returns>
public abstract int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true);
return idsCount;
}
/// <summary> /// <summary>
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit. /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
/// </summary> /// </summary>
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param> /// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="processedText">If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text.</param> /// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param> /// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns> /// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit. /// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, if all tokens fit, the result will be length of the <paramref name="processedText"/>. /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// </returns> /// </returns>
/// <exception cref="ArgumentNullException">The input text is null.</exception> public virtual int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
/// <exception cref="ArgumentOutOfRangeException">The maximum token count must be greater than 0.</exception> => IndexOfTokenCount(text.AsSpan(), maxTokenCount, out normalizedString, out tokenCount, considerPreTokenization, considerNormalization);
public int IndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount)
=> IndexOf(text, maxTokenCount, out processedText, out tokenCount); /// <summary>
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// </returns>
public abstract int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true);
/// <summary> /// <summary>
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit. /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
/// </summary> /// </summary>
/// <param name="text">The text to encode.</param> /// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param> /// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="processedText">If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text.</param> /// <param name="processedText">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param> /// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns> /// <returns>
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit. /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="processedText"/>; conversely, if all tokens fit, the result will be 0. /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="processedText"/>; conversely, if all tokens fit, the result will be 0.
/// </returns> /// </returns>
/// <exception cref="ArgumentNullException">The input text is null.</exception> public virtual int LastIndexOfTokenCount(string text, int maxTokenCount, out string? processedText, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
/// <exception cref="ArgumentOutOfRangeException">The maximum token count must be greater than 0.</exception> => LastIndexOfTokenCount(text.AsSpan(), maxTokenCount, out processedText, out tokenCount, considerPreTokenization, considerNormalization);
/// <remarks>
/// If the whole text can be encoded within the token limit, the returned index will be 0.
/// </remarks>
public int LastIndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount)
=> LastIndexOf(text, maxTokenCount, out processedText, out tokenCount);
private int IndexOf(string text, int maxTokenCount, out string processedText, out int tokenCount) /// <summary>
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="processedText">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="processedText"/>; conversely, if all tokens fit, the result will be 0.
/// </returns>
public abstract int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? processedText, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true);
/// <summary>
/// Map the token to encoded Id.
/// </summary>
/// <param name="token">The token to map to the Id.</param>
/// <returns>The mapped Id of the token.</returns>
public virtual int? MapTokenToId(string token)
{ {
if (text is null) if (token is null)
{ {
throw new ArgumentNullException(nameof(text)); throw new ArgumentNullException(nameof(token));
} }
if (maxTokenCount <= 0) return MapTokenToId(token.AsSpan());
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
}
processedText = Normalizer is not null ? Normalizer.Normalize(text) : text;
tokenCount = 0;
IEnumerable<Split> splits = PreTokenizer.PreTokenize(processedText);
foreach (Split split in splits)
{
tokenCount += Model.CountTokens(split.TokenSpan, out int textLength, maxTokenCount - tokenCount);
if (textLength < split.Offset.Length || tokenCount >= maxTokenCount)
{
return split.Offset.Index + textLength;
}
}
return processedText.Length;
} }
private int LastIndexOf(string text, int maxTokenCount, out string processedText, out int tokenCount) /// <summary>
{ /// Map the token to encoded Id.
if (text is null) /// </summary>
{ /// <param name="token">The token to map to the Id.</param>
throw new ArgumentNullException(nameof(text)); /// <returns>The mapped Id of the token.</returns>
} public abstract int? MapTokenToId(ReadOnlySpan<char> token);
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
}
processedText = Normalizer is not null ? Normalizer.Normalize(text) : text;
tokenCount = 0;
IEnumerable<Split> splits = PreTokenizer.PreTokenize(processedText);
foreach (Split split in splits.Reverse())
{
tokenCount += Model.CountTokensFromEnd(split.TokenSpan, out int textIndex, maxTokenCount - tokenCount);
if (textIndex > 0 || tokenCount >= maxTokenCount)
{
return split.Offset.Index + textIndex;
}
}
return 0;
}
/// <summary> /// <summary>
/// Decodes the Id to the mapped token. /// Decodes the Id to the mapped token.
/// </summary> /// </summary>
/// <param name="id">The id to map to the token.</param> /// <param name="id">The id to map to the token.</param>
/// <returns>The decoded string or null if there is no token mapped to the input id.</returns> /// <returns>The decoded string or null if there is no token mapped to the input id.</returns>
public string? Decode(int id) => Model.MapIdToToken(id); public abstract string? MapIdToToken(int id);
/// <summary> /// <summary>
/// Decode the given ids, back to a String. /// Decode the given ids, back to a String.
/// </summary> /// </summary>
/// <param name="ids">The list of ids that we want to decode.</param> /// <param name="ids">The list of ids that we want to decode.</param>
/// <returns>The decoded string.</returns> /// <returns>The decoded string.</returns>
public string? Decode(IEnumerable<int> ids) => Model.Decode(ids); public virtual string? Decode(IEnumerable<int> ids)
{
if (ids is null)
{
throw new ArgumentNullException(nameof(ids));
}
ValueStringBuilder sb = new ValueStringBuilder();
foreach (int id in ids)
{
if (MapIdToToken(id) is string s)
{
sb.Append(s);
}
}
return sb.ToString();
}
//
// Factory Methods
//
/// <summary>
/// Create a new Tiktoken tokenizer's object asynchronously.
/// </summary>
/// <param name="vocabStream">The stream to the BPE vocab file.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
/// <param name="cacheSize">The size of the cache to use.</param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>The tokenizer's object.</returns>
public static async Task<Tokenizer> CreateTiktokenAsync(
Stream vocabStream,
PreTokenizer? preTokenizer,
Normalizer? normalizer,
IReadOnlyDictionary<string, int>? specialTokens = null,
int cacheSize = LruCache<int[]>.DefaultCacheSize,
CancellationToken cancellationToken = default)
{
if (vocabStream is null)
{
throw new ArgumentNullException(nameof(vocabStream));
}
(Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) =
await Tiktoken.LoadTiktokenBpeAsync(vocabStream, useAsync: true, cancellationToken).ConfigureAwait(false);
return new Tiktoken(encoder, decoder, vocab, preTokenizer, specialTokens, normalizer, cacheSize);
}
/// <summary>
/// Create a new Tiktoken tokenizer's object asynchronously.
/// </summary>
/// <param name="vocabFilePath">The BPE vocab file.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="specialTokensEncoder">The dictionary mapping special tokens to Ids.</param>
/// <param name="cacheSize">The size of the cache to use.</param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>The tokenizer's object.</returns>
public static async Task<Tokenizer> CreateTiktokenAsync(
string vocabFilePath,
PreTokenizer? preTokenizer,
Normalizer? normalizer,
IReadOnlyDictionary<string, int>? specialTokensEncoder = null,
int cacheSize = LruCache<int[]>.DefaultCacheSize,
CancellationToken cancellationToken = default)
{
if (vocabFilePath is null)
{
throw new ArgumentNullException(nameof(vocabFilePath));
}
using Stream vocabStream = File.OpenRead(vocabFilePath);
return await CreateTiktokenAsync(vocabStream, preTokenizer, normalizer, specialTokensEncoder, cacheSize, cancellationToken).ConfigureAwait(false);
}
/// <summary> /// <summary>
/// Create a Tiktoken tokenizer based on model name and vocab file. /// Create a Tiktoken tokenizer based on model name and vocab file.
@ -304,10 +321,11 @@ namespace Microsoft.ML.Tokenizers
} }
} }
return new Tokenizer( return new Tiktoken(vocabStream,
new Tiktoken(vocabStream, tiktokenConfiguration.SpecialTokens, cacheSize), new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens), tiktokenConfiguration.SpecialTokens,
normalizer); normalizer,
cacheSize);
} }
/// <summary> /// <summary>
@ -343,10 +361,11 @@ namespace Microsoft.ML.Tokenizers
} }
} }
return new Tokenizer( return await CreateTiktokenAsync(vocabStream,
await Tiktoken.CreateAsync(vocabStream, tiktokenConfiguration.SpecialTokens, cacheSize, cancellationToken).ConfigureAwait(false), new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens), normalizer,
normalizer); tiktokenConfiguration.SpecialTokens,
cacheSize, cancellationToken).ConfigureAwait(false);
} }
/// <summary> /// <summary>
@ -357,7 +376,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="normalizer">To normalize the text before tokenization</param> /// <param name="normalizer">To normalize the text before tokenization</param>
/// <returns>The tokenizer</returns> /// <returns>The tokenizer</returns>
public static Tokenizer CreateTiktokenForModel(string modelName, IReadOnlyDictionary<string, int>? extraSpecialTokens = null, Normalizer? normalizer = null) public static Tokenizer CreateTiktokenForModel(string modelName, IReadOnlyDictionary<string, int>? extraSpecialTokens = null, Normalizer? normalizer = null)
=> Tiktoken.CreateTokenizerForModel(modelName, extraSpecialTokens, normalizer); => Tiktoken.CreateForModel(Tiktoken.GetModelEncoding(modelName), modelName, extraSpecialTokens, normalizer);
/// <summary> /// <summary>
/// Create tokenizer based on encoding name /// Create tokenizer based on encoding name
@ -395,7 +414,7 @@ namespace Microsoft.ML.Tokenizers
throw new ArgumentException($"The encoding name '{encodingName}' is not supported. The only supported encoding names are: {Tiktoken.Cl100kBaseEncodingName}, {Tiktoken.P50kBaseEncodingName}, {Tiktoken.P50kEditEncodingName}, and {Tiktoken.R50kBaseEncodingName}.", nameof(encodingName)); throw new ArgumentException($"The encoding name '{encodingName}' is not supported. The only supported encoding names are: {Tiktoken.Cl100kBaseEncodingName}, {Tiktoken.P50kBaseEncodingName}, {Tiktoken.P50kEditEncodingName}, and {Tiktoken.R50kBaseEncodingName}.", nameof(encodingName));
} }
return Tiktoken.CreateTokenizerForModel(modelEncoding, modelName: null, extraSpecialTokens, normalizer); return Tiktoken.CreateForModel(modelEncoding, modelName: null, extraSpecialTokens, normalizer);
} }
/// <summary> /// <summary>
@ -427,16 +446,70 @@ namespace Microsoft.ML.Tokenizers
throw new ArgumentException($"Normalization '{modelProto.NormalizerSpec.Name}' is not supported.", nameof(modelProto)); throw new ArgumentException($"Normalization '{modelProto.NormalizerSpec.Name}' is not supported.", nameof(modelProto));
} }
LlamaNormalizer normalizer = new( SentencePieceNormalizer normalizer = new(
modelProto.NormalizerSpec.RemoveExtraWhitespaces, modelProto.NormalizerSpec.RemoveExtraWhitespaces,
modelProto.NormalizerSpec.AddDummyPrefix, modelProto.NormalizerSpec.AddDummyPrefix,
modelProto.NormalizerSpec.EscapeWhitespaces, modelProto.NormalizerSpec.EscapeWhitespaces,
modelProto.TrainerSpec.TreatWhitespaceAsSuffix); modelProto.TrainerSpec.TreatWhitespaceAsSuffix);
return new Tokenizer( return new SentencePieceBpe(modelProto, addBeginOfSentence, addEndOfSentence);
new SentencePieceBpe(modelProto, addBeginOfSentence, addEndOfSentence), }
SentencePiecePreTokenizer.Instance,
normalizer); internal static IEnumerable<(int Offset, int Length)>? InitializeForEncoding(
string? text,
ReadOnlySpan<char> textSpan,
bool considerPreTokenization,
bool considerNormalization,
Normalizer? normalizer,
PreTokenizer? preTokenizer,
out string? normalizedString,
out ReadOnlySpan<char> textSpanToEncode)
{
normalizedString = null;
IEnumerable<(int Offset, int Length)>? splits = null;
if (text is null)
{
if (considerNormalization && (normalizer is not null))
{
normalizedString = normalizer.Normalize(textSpan.ToString());
textSpanToEncode = normalizedString.AsSpan();
if (considerPreTokenization && preTokenizer is not null)
{
splits = preTokenizer.PreTokenize(normalizedString);
}
}
else
{
textSpanToEncode = textSpan;
if (considerPreTokenization && preTokenizer is not null)
{
splits = preTokenizer.PreTokenize(textSpan);
}
}
}
else
{
if (considerNormalization && (normalizer is not null))
{
normalizedString = normalizer.Normalize(text);
textSpanToEncode = normalizedString.AsSpan();
if (considerPreTokenization && preTokenizer is not null)
{
splits = preTokenizer.PreTokenize(normalizedString);
}
}
else
{
textSpanToEncode = text.AsSpan();
if (considerPreTokenization && preTokenizer is not null)
{
splits = preTokenizer.PreTokenize(text);
}
}
}
return splits;
} }
} }
} }

Просмотреть файл

@ -71,6 +71,8 @@ namespace Microsoft.ML.Tokenizers
return s; return s;
} }
public void Clear() => _data.Clear();
public bool IsConsistent() public bool IsConsistent()
{ {
// is the heap property true for all data? // is the heap property true for all data?

Просмотреть файл

@ -26,12 +26,12 @@ namespace Microsoft.ML.TorchSharp.Extensions
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt" // "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt"
Assembly assembly = typeof(TokenizerExtensions).Assembly; Assembly assembly = typeof(TokenizerExtensions).Assembly;
EnglishRoberta model = new EnglishRoberta( _instance = new EnglishRoberta(
assembly.GetManifestResourceStream("encoder.json"), assembly.GetManifestResourceStream("encoder.json"),
assembly.GetManifestResourceStream("vocab.bpe"), assembly.GetManifestResourceStream("vocab.bpe"),
assembly.GetManifestResourceStream("dict.txt")); assembly.GetManifestResourceStream("dict.txt"),
model.AddMaskSymbol(); new RobertaPreTokenizer());
_instance = new Tokenizer(model, new RobertaPreTokenizer()); (_instance as EnglishRoberta).AddMaskSymbol();
} }
return _instance; return _instance;
@ -40,7 +40,7 @@ namespace Microsoft.ML.TorchSharp.Extensions
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static EnglishRoberta RobertaModel(this Tokenizer tokenizer) internal static EnglishRoberta RobertaModel(this Tokenizer tokenizer)
{ {
EnglishRoberta model = tokenizer.Model as EnglishRoberta; EnglishRoberta model = tokenizer as EnglishRoberta;
if (model is null) if (model is null)
{ {
throw new InvalidOperationException($"The input tokenizer is not using the EnglishRoberta model."); throw new InvalidOperationException($"The input tokenizer is not using the EnglishRoberta model.");
@ -51,8 +51,7 @@ namespace Microsoft.ML.TorchSharp.Extensions
internal static IReadOnlyList<int> EncodeToConverted(this Tokenizer tokenizer, string sentence) internal static IReadOnlyList<int> EncodeToConverted(this Tokenizer tokenizer, string sentence)
{ {
EncodingResult encoding = tokenizer.Encode(sentence); return tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(tokenizer.EncodeToIds(sentence));
return tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(encoding.Ids);
} }
} }
} }

Просмотреть файл

@ -167,16 +167,16 @@ namespace Microsoft.ML.TorchSharp.NasBert
Sentence1Getter(ref sentenceRom); Sentence1Getter(ref sentenceRom);
var sentence = sentenceRom.ToString(); var sentence = sentenceRom.ToString();
Tensor t; Tensor t;
var encoding = Tokenizer.Encode(sentence); IReadOnlyList<Token> encoding = Tokenizer.Encode(sentence, out string normalizedString);
if (target.Length != encoding.Tokens.Count) if (target.Length != encoding.Count)
{ {
var targetIndex = 0; var targetIndex = 0;
var targetEditor = VBufferEditor.Create(ref target, encoding.Tokens.Count); var targetEditor = VBufferEditor.Create(ref target, encoding.Count);
var newValues = targetEditor.Values; var newValues = targetEditor.Values;
for (var i = 0; i < encoding.Tokens.Count; i++) for (var i = 0; i < encoding.Count; i++)
{ {
if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i])) if (NerTrainer.TokenStartsWithSpace(encoding[i].Value))
{ {
newValues[i] = target.GetItemOrDefault(++targetIndex); newValues[i] = target.GetItemOrDefault(++targetIndex);
} }
@ -187,7 +187,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
} }
target = targetEditor.Commit(); target = targetEditor.Commit();
} }
t = torch.tensor((ZeroArray).Concat(Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(encoding.Ids)).ToList(), device: Device); t = torch.tensor((ZeroArray).Concat(Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(encoding.Select(t => t.Id).ToArray())).ToList(), device: Device);
if (t.NumberOfElements > 512) if (t.NumberOfElements > 512)
t = t.slice(0, 0, 512, 1); t = t.slice(0, 0, 512, 1);
@ -377,16 +377,16 @@ namespace Microsoft.ML.TorchSharp.NasBert
private void CondenseOutput(ref VBuffer<UInt32> dst, string sentence, Tokenizer tokenizer, TensorCacher outputCacher) private void CondenseOutput(ref VBuffer<UInt32> dst, string sentence, Tokenizer tokenizer, TensorCacher outputCacher)
{ {
var pre = tokenizer.PreTokenizer.PreTokenize(sentence); var pre = tokenizer.PreTokenizer.PreTokenize(sentence);
EncodingResult encoding = tokenizer.Encode(sentence); IReadOnlyList<Token> encoding = tokenizer.Encode(sentence, out string normalizedString);
var argmax = (outputCacher as BertTensorCacher).Result.argmax(-1); var argmax = (outputCacher as BertTensorCacher).Result.argmax(-1);
var prediction = argmax.ToArray<long>(); var prediction = argmax.ToArray<long>();
var targetIndex = 0; var targetIndex = 0;
// Figure out actual count of output tokens // Figure out actual count of output tokens
for (var i = 0; i < encoding.Tokens.Count; i++) for (var i = 0; i < encoding.Count; i++)
{ {
if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i])) if (NerTrainer.TokenStartsWithSpace(encoding[i].Value))
{ {
targetIndex++; targetIndex++;
} }
@ -398,9 +398,9 @@ namespace Microsoft.ML.TorchSharp.NasBert
newValues[targetIndex++] = (uint)prediction[0]; newValues[targetIndex++] = (uint)prediction[0];
for (var i = 1; i < encoding.Tokens.Count; i++) for (var i = 1; i < encoding.Count; i++)
{ {
if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i])) if (NerTrainer.TokenStartsWithSpace(encoding[i].Value))
{ {
newValues[targetIndex++] = (uint)prediction[i]; newValues[targetIndex++] = (uint)prediction[i];
} }

Просмотреть файл

@ -401,9 +401,9 @@ namespace Microsoft.ML.TorchSharp.Roberta
answerIndexGetter(ref answerIndex); answerIndexGetter(ref answerIndex);
var contextString = context.ToString(); var contextString = context.ToString();
var contextTokens = Tokenizer.Encode(contextString); var contextTokens = Tokenizer.Encode(contextString, out string normalized);
var contextToken = contextTokens.Tokens; var contextToken = contextTokens.Select(t => t.Value).ToArray();
var contextTokenId = Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(contextTokens.Ids); var contextTokenId = Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(contextTokens.Select(t => t.Id).ToArray());
var mapping = AlignAnswerPosition(contextToken, contextString); var mapping = AlignAnswerPosition(contextToken, contextString);
if (mapping == null) if (mapping == null)
@ -437,7 +437,7 @@ namespace Microsoft.ML.TorchSharp.Roberta
private Dictionary<int, int> AlignAnswerPosition(IReadOnlyList<string> tokens, string text) private Dictionary<int, int> AlignAnswerPosition(IReadOnlyList<string> tokens, string text)
{ {
EnglishRoberta robertaModel = Tokenizer.Model as EnglishRoberta; EnglishRoberta robertaModel = Tokenizer as EnglishRoberta;
Debug.Assert(robertaModel is not null); Debug.Assert(robertaModel is not null);
var mapping = new Dictionary<int, int>(); var mapping = new Dictionary<int, int>();
@ -854,9 +854,9 @@ namespace Microsoft.ML.TorchSharp.Roberta
contextGetter(ref context); contextGetter(ref context);
questionGetter(ref question); questionGetter(ref question);
var contextTokenId = _parent.Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(_parent.Tokenizer.Encode(context.ToString()).Ids); var contextTokenId = _parent.Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(_parent.Tokenizer.EncodeToIds(context.ToString()));
var questionTokenId = _parent.Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(_parent.Tokenizer.Encode(question.ToString()).Ids); var questionTokenId = _parent.Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(_parent.Tokenizer.EncodeToIds(question.ToString()));
var srcTensor = torch.tensor((new[] { 0 /* InitToken */ }).Concat(questionTokenId).Concat(new[] { 2 /* SeparatorToken */ }).Concat(contextTokenId).ToList(), device: _parent.Device); var srcTensor = torch.tensor((new[] { 0 /* InitToken */ }).Concat(questionTokenId).Concat(new[] { 2 /* SeparatorToken */ }).Concat(contextTokenId).ToList(), device: _parent.Device);

Просмотреть файл

@ -248,28 +248,28 @@ namespace Microsoft.ML.Tokenizers.Tests
try try
{ {
Bpe bpe = new Bpe(vocabFile, mergesFile, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownToken); Bpe bpe = new Bpe(vocabFile, mergesFile, unknownToken: unknownToken, continuingSubwordPrefix: continuingSubwordPrefix, endOfWordSuffix: endOfWordSuffix, fuseUnknownTokens: fuseUnknownToken);
Tokenizer tokenizer = new Tokenizer(bpe); Tokenizer tokenizer = bpe;
EncodingResult encoding = tokenizer.Encode(sentence); IReadOnlyList<Token> encoding = tokenizer.Encode(sentence, out _);
int[] encodingIds = encoding.Select(t => t.Id).ToArray();
IReadOnlyList<int> idsList = tokenizer.EncodeToIds(sentence); IReadOnlyList<int> idsList = tokenizer.EncodeToIds(sentence);
Assert.Equal(expectedTokens.Length, encoding.Tokens.Count); Assert.Equal(expectedTokens.Length, encoding.Count);
Assert.Equal(offsets.Length, encoding.Offsets.Count); Assert.Equal(offsets.Length, encoding.Count);
Assert.Equal(ids.Length, encoding.Ids.Count); Assert.Equal(ids.Length, encoding.Count);
Assert.Equal(ids.Length, idsList.Count); Assert.Equal(ids.Length, idsList.Count);
Assert.Equal(ids.Length, tokenizer.CountTokens(sentence)); Assert.Equal(ids.Length, tokenizer.CountTokens(sentence));
Assert.Equal(decodedTokens, tokenizer.Decode(encoding.Ids)); Assert.Equal(decodedTokens, tokenizer.Decode(encodingIds));
Assert.Equal(decodedTokensWithoutUnknownToken, bpe.Decode(encoding.Ids, considerSpecialTokens: false)); Assert.Equal(decodedTokensWithoutUnknownToken, bpe.Decode(encodingIds, considerSpecialTokens: false));
for (int i = 0; i < encoding.Tokens.Count; i++) for (int i = 0; i < encoding.Count; i++)
{ {
Assert.Equal(expectedTokens[i], encoding.Tokens[i]); Assert.Equal(expectedTokens[i], encoding[i].Value);
Assert.Equal(offsets[i], encoding.Offsets[i]); Assert.Equal(offsets[i], encoding[i].Offset);
Assert.Equal(ids[i], encoding.Ids[i]); Assert.Equal(ids[i], encoding[i].Id);
Assert.Equal(ids[i], idsList[i]); Assert.Equal(ids[i], idsList[i]);
Assert.Equal(encoding.Tokens[i], tokenizer.Model.MapIdToToken(encoding.Ids[i])); Assert.Equal(encoding[i].Value, tokenizer.MapIdToToken(encodingIds[i]));
Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i].AsSpan())); Assert.Equal(encodingIds[i], tokenizer.MapTokenToId(encoding[i].Value.AsSpan()));
Assert.Equal(encoding.Tokens[i], tokenizer.Decode(encoding.Ids[i]));
} }
} }
finally finally
@ -282,6 +282,41 @@ namespace Microsoft.ML.Tokenizers.Tests
} }
} }
private static Tokenizer? _gpt2Tokenizer = null;
private static Tokenizer GetGpt2Tokenizer()
{
if (_gpt2Tokenizer is null)
{
// "https://huggingface.co/openai-community/gpt2/raw/main/vocab.json";
// "https://huggingface.co/openai-community/gpt2/raw/main/merges.txt";
using Stream vocabStream = File.OpenRead(Path.Combine(@"Gpt-2", "vocab.json"));
using Stream mergesStream = File.OpenRead(Path.Combine(@"Gpt-2", "merges.txt"));
_gpt2Tokenizer = new Bpe(vocabStream, mergesStream);
}
return _gpt2Tokenizer;
}
[Fact]
public void TestGpt2Vocab()
{
Tokenizer tokenizer = GetGpt2Tokenizer();
string text = "The quick brown fox jumps over the lazy dog!";
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
IReadOnlyList<int> ids = tokenizer.EncodeToIds(text);
Assert.Equal(12, encoding.Count);
Assert.Equal(encoding.Select(t => t.Id).ToArray(), ids);
Assert.Equal(12, tokenizer.CountTokens(text));
TokenizerTests.TestTokenLimits(tokenizer);
}
public static IEnumerable<object?[]> BpeTestData public static IEnumerable<object?[]> BpeTestData
{ {
get get
@ -290,55 +325,84 @@ namespace Microsoft.ML.Tokenizers.Tests
yield return new object?[] yield return new object?[]
{ {
"the brown fox jumped over the lazy dog!", "the brown fox jumped over the lazy dog!",
new string[] {"the", "brown", "fox", "jumped", "over", "the", "lazy", "dog", "!"}, new string[] { "the", "brown", "fox", "j", "umped", "over", "the", "l", "azy", "dog", "!" },
new (int, int)[] {(0, 3), (4, 9), (10, 13), (14, 20), (21, 25), (26, 29), (30, 34), (35, 38), (38, 39)} new (int Index, int Length)[] { (0, 3), (4, 5), (10, 3), (14, 1), (15, 5), (21, 4), (26, 3), (30, 1), (31, 3), (35, 3), (38, 1) },
new int[] { 1169, 33282, 12792, 73, 27073, 2502, 1169, 75, 12582, 9703, 0 }
}; };
yield return new object?[] yield return new object?[]
{ {
"he traveled to Egypt during the summer, the weather was hot and ammunition." , "he traveled to Egypt during the summer, the weather was hot and ammunition." ,
new string[] {"he", "traveled", "to", "Egypt", "during", "the", "summer", ",", "the", "weather", "was", "hot", "and", "ammunition", "."}, new string[] { "he", "travel", "ed", "to", "Egypt", "during", "the", "sum", "mer", ",", "the", "weather", "was", "hot", "and", "am", "munition", "." },
new (int, int)[] {(0, 2), (3, 11), (12, 14), (15, 20), (21, 27), (28, 31), (32, 38), (38, 39), (40, 43), (44, 51), (52, 55), (56, 59), (60, 63), (64, 74), (74, 75)} new (int Index, int Length)[] { (0, 2), (3, 6), (9, 2), (12, 2), (15, 5), (21, 6), (28, 3), (32, 3), (35, 3), (38, 1), (40, 3), (44, 7), (52, 3), (56, 3), (60, 3), (64, 2), (66, 8), (74, 1) },
new int[] { 258, 35927, 276, 1462, 39299, 42122, 1169, 16345, 647, 11, 1169, 23563, 9776, 8940, 392, 321, 12640, 13 }
}; };
yield return new object?[] yield return new object?[]
{ {
"She played many games and she felt exhausted afterward", "She played many games and she felt exhausted afterward",
new string[] {"She", "played", "many", "games", "and", "she", "felt", "exhausted", "afterward"}, new string[] { "She", "played", "many", "games", "and", "she", "felt", "ex", "ha", "usted", "after", "ward" },
new (int, int)[] {(0, 3), (4, 10), (11, 15), (16, 21), (22, 25), (26, 29), (30, 34), (35, 44), (45, 54)} new (int Index, int Length)[] { (0, 3), (4, 6), (11, 4), (16, 5), (22, 3), (26, 3), (30, 4), (35, 2), (37, 2), (39, 5), (45, 5), (50, 4) },
new int[] { 3347, 21542, 21834, 19966, 392, 7091, 31985, 1069, 3099, 8459, 8499, 904 }
}; };
yield return new object?[] yield return new object?[]
{ {
"Hello, y'all! How are you 😁 ?", "Hello, y'all! How are you 😁 ?",
new string[] {"Hello", ",", "y", "'", "all", "!", "How", "are", "you", "[UNK]", "?"}, new string[] { "Hello", ",", "y", "'", "all", "!", "How", "are", "you", "?" },
new (int, int)[] {(0, 5), (5, 6), (7, 8), (8, 9), (9, 12), (12, 13), (14, 17), (18, 21), (22, 25), (26, 28), (29, 30)} new (int Index, int Length)[] { (0, 5), (5, 1), (7, 1), (8, 1), (9, 3), (12, 1), (14, 3), (18, 3), (22, 3), (29, 1) },
new int[] { 15496, 11, 88, 6, 439, 0, 2437, 533, 5832, 30 }
}; };
} }
} }
private const string Gpt2VocabUrl = "https://huggingface.co/openai-community/gpt2/raw/main/vocab.json"; [Theory]
private const string Gpt2MergesUrl = "https://huggingface.co/openai-community/gpt2/raw/main/merges.txt"; [MemberData(nameof(BpeTestData))]
public void TestBpeTokenizer(string text, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds)
[Fact]
public async void TestGpt2Vocab()
{ {
using HttpClient httpClient = new HttpClient(); Tokenizer tokenizer = GetGpt2Tokenizer();
using Stream vocabStream = await httpClient.GetStreamAsync(Gpt2VocabUrl);
using Stream mergesStream = await httpClient.GetStreamAsync(Gpt2MergesUrl);
Bpe bpe = new Bpe(vocabStream, mergesStream); IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
Tokenizer tokenizer = new Tokenizer(bpe); IReadOnlyList<Token> encoding1 = tokenizer.Encode(text.AsSpan(), out _);
string text = "The quick brown fox jumps over the lazy dog!"; Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray());
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
EncodingResult encoding = tokenizer.Encode(text); Assert.Equal(expectedTokens, encoding1.Select(t => t.Value).ToArray());
IReadOnlyList<int> ids = tokenizer.EncodeToIds(text); Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray());
Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray());
Assert.Equal(12, encoding.Tokens.Count); Assert.Equal(expectedIds, tokenizer.EncodeToIds(text));
Assert.Equal(12, encoding.Offsets.Count); Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan()));
Assert.Equal(12, encoding.Ids.Count); Assert.Equal(expectedIds, tokenizer.EncodeToIds(text, expectedIds.Length, out string? normalizedString, out int length));
Assert.Equal(encoding.Ids, ids); Assert.Null(normalizedString);
Assert.Equal(12, tokenizer.CountTokens(text)); Assert.Equal(text.Length, length);
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length, out normalizedString, out length));
Assert.Null(normalizedString);
Assert.Equal(text.Length, length);
TokenizerTests.TestTokenLimits(tokenizer); Assert.Equal(expectedIds.Take(expectedIds.Length - 2), tokenizer.EncodeToIds(text, expectedIds.Length - 2, out normalizedString, out length));
Assert.Null(normalizedString);
int expectedLength = expectedOffsets[expectedOffsets.Length - 3].Index + expectedOffsets[expectedOffsets.Length - 3].Length;
Assert.Equal(expectedLength, length);
Assert.Equal(expectedIds.Take(expectedIds.Length - 2), tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length - 2, out normalizedString, out length));
Assert.Null(normalizedString);
Assert.Equal(expectedLength, length);
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text));
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text.AsSpan()));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text, expectedIds.Length - 3, out normalizedString, out int tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIds.Length - 3, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text.AsSpan(), expectedIds.Length - 3, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIds.Length - 3, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text, 3, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(3, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text.AsSpan(), 3, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(3, tokenCount);
} }
private static string WriteToMergeFile((string, string)[] mergeEntries) private static string WriteToMergeFile((string, string)[] mergeEntries)
@ -360,7 +424,7 @@ namespace Microsoft.ML.Tokenizers.Tests
return fileName; return fileName;
} }
internal static Bpe CreateEmptyBpe() internal static Bpe CreateEmptyBpe(PreTokenizer? preTokenizer = null, Normalizer? normalizer = null)
{ {
using MemoryStream emptyVocabStream = new MemoryStream(); using MemoryStream emptyVocabStream = new MemoryStream();
using StreamWriter writer = new StreamWriter(emptyVocabStream); using StreamWriter writer = new StreamWriter(emptyVocabStream);
@ -368,7 +432,7 @@ namespace Microsoft.ML.Tokenizers.Tests
writer.Flush(); writer.Flush();
emptyVocabStream.Position = 0; emptyVocabStream.Position = 0;
return new Bpe(vocabStream: emptyVocabStream, mergesStream: null, unknownToken: "Ukn"); return new Bpe(vocabStream: emptyVocabStream, mergesStream: null, preTokenizer: preTokenizer ?? WhiteSpace.Instance, normalizer: normalizer, unknownToken: "Ukn");
} }
} }
} }

Просмотреть файл

@ -6,9 +6,11 @@ using System;
using System.IO; using System.IO;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Net.Http;
using Xunit; using Xunit;
using System.Diagnostics;
using System.Threading.Tasks;
namespace Microsoft.ML.Tokenizers.Tests namespace Microsoft.ML.Tokenizers.Tests
{ {
@ -80,6 +82,34 @@ namespace Microsoft.ML.Tokenizers.Tests
} }
} }
private static Tokenizer? _robertaTokenizer = null;
private async static Task<Tokenizer> GetRobertaTokenizer()
{
if (_robertaTokenizer is null)
{
string vocabFile = Utils.CreateTemporaryFile("json");
string mergeFile = Utils.CreateTemporaryFile("txt");
string translationFile = Utils.CreateTemporaryFile("txt");
try
{
await Utils.DownloadFile(_vocabUrl, vocabFile);
await Utils.DownloadFile(_mergeUrl, mergeFile);
await Utils.DownloadFile(_dictUrl, translationFile);
_robertaTokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
}
finally
{
Utils.DeleteFile(vocabFile);
Utils.DeleteFile(mergeFile);
Utils.DeleteFile(translationFile);
}
}
return _robertaTokenizer;
}
[Fact] [Fact]
public async void TokenizationTest() public async void TokenizationTest()
{ {
@ -94,26 +124,26 @@ namespace Microsoft.ML.Tokenizers.Tests
await Utils.DownloadFile(_mergeUrl, mergeFile); await Utils.DownloadFile(_mergeUrl, mergeFile);
await Utils.DownloadFile(_dictUrl, translationFile); await Utils.DownloadFile(_dictUrl, translationFile);
Tokenizer tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile), RobertaPreTokenizer.Instance); Tokenizer tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer); TestTokenizer(tokenizer);
TokenizerTests.TestTokenLimits(tokenizer); TokenizerTests.TestTokenLimits(tokenizer);
tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile, filterUnsupportedChars: false), RobertaPreTokenizer.Instance); tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance, filterUnsupportedChars: false);
TestTokenizer(tokenizer); TestTokenizer(tokenizer);
using Stream vocabStream = File.OpenRead(vocabFile); using Stream vocabStream = File.OpenRead(vocabFile);
using Stream mergeStream = File.OpenRead(mergeFile); using Stream mergeStream = File.OpenRead(mergeFile);
using Stream translationStream = File.OpenRead(translationFile); using Stream translationStream = File.OpenRead(translationFile);
tokenizer = new Tokenizer(new EnglishRoberta(vocabStream, mergeStream, translationStream), RobertaPreTokenizer.Instance); tokenizer = new EnglishRoberta(vocabStream, mergeStream, translationStream, RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer); TestTokenizer(tokenizer);
// Ensure caching works regardless of which method is called first. // Ensure caching works regardless of which method is called first.
for (CallingOrder order = CallingOrder.Encode; order <= CallingOrder.CountTokens; order++) for (CallingOrder order = CallingOrder.Encode; order <= CallingOrder.CountTokens; order++)
{ {
tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile), RobertaPreTokenizer.Instance); tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer, order); TestTokenizer(tokenizer, order);
tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile, filterUnsupportedChars: false), RobertaPreTokenizer.Instance); tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance, filterUnsupportedChars: false);
TestTokenizer(tokenizer, order); TestTokenizer(tokenizer, order);
} }
} }
@ -132,6 +162,94 @@ namespace Microsoft.ML.Tokenizers.Tests
} }
} }
public static IEnumerable<object?[]> RobertaTestData
{
get
{
// string to tokenize, produced tokens, the token offsets
yield return new object?[]
{
"the brown fox jumped over the lazy dog!",
new string[] { "the", "Ġbrown", "Ġfox", "Ġjumped", "Ġover", "Ġthe", "Ġlazy", "Ġdog", "!" },
new (int Index, int Length)[] { (0, 3), (3, 6), (9, 4), (13, 7), (20, 5), (25, 4), (29, 5), (34, 4), (38, 1) },
new int[] { 1169, 7586, 21831, 11687, 625, 262, 16931, 3290, 0 }
};
yield return new object?[]
{
"he traveled to Egypt during the summer, the weather was hot and ammunition." ,
new string[] { "he", "Ġtraveled", "Ġto", "ĠEgypt", "Ġduring", "Ġthe", "Ġsummer", ",", "Ġthe", "Ġweather", "Ġwas", "Ġhot", "Ġand", "Ġammunition", "." },
new (int Index, int Length)[] { (0, 2), (2, 9), (11, 3), (14, 6), (20, 7), (27, 4), (31, 7), (38, 1), (39, 4), (43, 8), (51, 4), (55, 4), (59, 4), (63, 11), (74, 1) },
new int[] { 258, 14113, 284, 6365, 1141, 262, 3931, 11, 262, 6193, 373, 3024, 290, 14271, 13 }
};
yield return new object?[]
{
"She played many games and she felt exhausted afterward",
new string[] { "She", "Ġplayed", "Ġmany", "Ġgames", "Ġand", "Ġshe", "Ġfelt", "Ġexhausted", "Ġafterward" },
new (int Index, int Length)[] { (0, 3), (3, 7), (10, 5), (15, 6), (21, 4), (25, 4), (29, 5), (34, 10), (44, 10) },
new int[] { 3347, 2826, 867, 1830, 290, 673, 2936, 19064, 20875 }
};
yield return new object?[]
{
"Hello, y'all! How are you 😁 ?",
new string[] { "Hello", ",", "Ġy", "'", "all", "!", "ĠHow", "Ġare", "Ġyou", "Ġ", "Ġ?" },
new (int Index, int Length)[] { (0, 5), (5, 1), (6, 2), (8, 1), (9, 3), (12, 1), (13, 4), (17, 4), (21, 4), (25, 1), (28, 2) },
new int[] { 15496, 11, 331, 6, 439, 0, 1374, 389, 345, 220, 5633 }
};
}
}
[Theory]
[MemberData(nameof(RobertaTestData))]
public async void TestTokenizerEncoding(string text, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds)
{
Tokenizer tokenizer = await GetRobertaTokenizer();
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
IReadOnlyList<Token> encoding1 = tokenizer.Encode(text.AsSpan(), out _);
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray());
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding1.Select(t => t.Value).ToArray());
Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray());
Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray());
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text));
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan()));
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text, expectedIds.Length, out string? normalizedString, out int length));
Assert.Null(normalizedString);
Assert.Equal(text.Length, length);
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length, out normalizedString, out length));
Assert.Null(normalizedString);
Assert.Equal(text.Length, length);
Assert.Equal(expectedIds.Take(expectedIds.Length - 2), tokenizer.EncodeToIds(text, expectedIds.Length - 2, out normalizedString, out length));
Assert.Null(normalizedString);
int expectedLength = expectedOffsets[expectedOffsets.Length - 3].Index + expectedOffsets[expectedOffsets.Length - 3].Length;
Assert.Equal(expectedLength, length);
Assert.Equal(expectedIds.Take(expectedIds.Length - 2), tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length - 2, out normalizedString, out length));
Assert.Null(normalizedString);
Assert.Equal(expectedLength, length);
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text));
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text.AsSpan()));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text, expectedIds.Length - 3, out normalizedString, out int tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIds.Length - 3, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text.AsSpan(), expectedIds.Length - 3, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIds.Length - 3, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text, 3, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(3, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text.AsSpan(), 3, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(3, tokenCount);
}
private enum CallingOrder private enum CallingOrder
{ {
Encode, Encode,
@ -143,66 +261,67 @@ namespace Microsoft.ML.Tokenizers.Tests
// Calling with callIdsFirst = true will test the other way around. // Calling with callIdsFirst = true will test the other way around.
private void TestTokenizer(Tokenizer tokenizer, CallingOrder callingOrder = CallingOrder.Encode) private void TestTokenizer(Tokenizer tokenizer, CallingOrder callingOrder = CallingOrder.Encode)
{ {
Assert.NotNull(tokenizer.Model); Assert.True(tokenizer is EnglishRoberta);
Assert.True(tokenizer.Model is EnglishRoberta);
Assert.True(tokenizer.PreTokenizer is RobertaPreTokenizer); Assert.True(tokenizer.PreTokenizer is RobertaPreTokenizer);
foreach (object[] p in BertaData) foreach (object[] p in BertaData)
{ {
IReadOnlyList<int> ids; IReadOnlyList<int> ids;
EncodingResult encoding; IReadOnlyList<Token> encoding;
int idsCount; int idsCount;
if (callingOrder == CallingOrder.Encode) if (callingOrder == CallingOrder.Encode)
{ {
encoding = tokenizer.Encode((string)p[0]); encoding = tokenizer.Encode((string)p[0], out _);
ids = tokenizer.EncodeToIds((string)p[0]); ids = tokenizer.EncodeToIds((string)p[0]);
idsCount = tokenizer.CountTokens((string)p[0]); idsCount = tokenizer.CountTokens((string)p[0]);
} }
else if (callingOrder == CallingOrder.EncodeToIds) else if (callingOrder == CallingOrder.EncodeToIds)
{ {
ids = tokenizer.EncodeToIds((string)p[0]); ids = tokenizer.EncodeToIds((string)p[0]);
encoding = tokenizer.Encode((string)p[0]); encoding = tokenizer.Encode((string)p[0], out _);
idsCount = tokenizer.CountTokens((string)p[0]); idsCount = tokenizer.CountTokens((string)p[0]);
} }
else // CountTokens else // CountTokens
{ {
idsCount = tokenizer.CountTokens((string)p[0]); idsCount = tokenizer.CountTokens((string)p[0]);
ids = tokenizer.EncodeToIds((string)p[0]); ids = tokenizer.EncodeToIds((string)p[0]);
encoding = tokenizer.Encode((string)p[0]); encoding = tokenizer.Encode((string)p[0], out _);
} }
Assert.Equal(p[1], encoding.Ids); int[] encodingIds = encoding.Select(t => t.Id).ToArray();
(int, int)[] offsets = encoding.Select(t => t.Offset).ToArray();
string[] tokens = encoding.Select(t => t.Value).ToArray();
Assert.Equal(p[1], encodingIds);
Assert.Equal(p[1], ids); Assert.Equal(p[1], ids);
Assert.Equal(((int[])p[1]).Length, idsCount); Assert.Equal(((int[])p[1]).Length, idsCount);
Assert.Equal(p[3], encoding.Offsets); Assert.Equal(p[3], offsets);
Assert.Equal(encoding.Ids.Count, encoding.Tokens.Count);
Assert.Equal(encoding.Ids.Count, encoding.Offsets.Count);
EnglishRoberta? robertaModel = tokenizer.Model as EnglishRoberta; EnglishRoberta? robertaModel = tokenizer as EnglishRoberta;
Assert.Equal(p[2], encoding.Tokens); Assert.Equal(p[2], tokens);
Assert.Equal(string.Concat((string[])(p[robertaModel!.FilterUnsupportedChars ? 5 : 2])), tokenizer.Decode(encoding.Ids)); Assert.Equal(string.Concat((string[])(p[robertaModel!.FilterUnsupportedChars ? 5 : 2])), tokenizer.Decode(encodingIds));
Assert.NotNull(robertaModel); Assert.NotNull(robertaModel);
Assert.Equal(encoding.Ids, robertaModel!.ConvertOccurrenceRanksToIds(robertaModel!.ConvertIdsToOccurrenceRanks(encoding.Ids))); Assert.Equal(encodingIds, robertaModel!.ConvertOccurrenceRanksToIds(robertaModel!.ConvertIdsToOccurrenceRanks(encodingIds)));
Assert.Equal(p[4], robertaModel.ConvertIdsToOccurrenceValues(encoding.Ids)); Assert.Equal(p[4], robertaModel.ConvertIdsToOccurrenceValues(encodingIds));
for (int i = 0; i < encoding.Tokens.Count; i++) for (int i = 0; i < tokens.Length; i++)
{ {
if (robertaModel.FilterUnsupportedChars) if (robertaModel.FilterUnsupportedChars)
{ {
string[]? filteredToken = p[5] as string[]; string[]? filteredToken = p[5] as string[];
Assert.Equal(filteredToken![i], tokenizer.Model.MapIdToToken(encoding.Ids[i])); Assert.Equal(filteredToken![i], tokenizer.MapIdToToken(encodingIds[i]));
} }
else else
{ {
Assert.Equal(encoding.Tokens[i], tokenizer.Model.MapIdToToken(encoding.Ids[i])); Assert.Equal(tokens[i], tokenizer.MapIdToToken(encodingIds[i]));
string[]? unfilteredToken = p[2] as string[]; string[]? unfilteredToken = p[2] as string[];
Assert.Equal(unfilteredToken![i], tokenizer.Model.MapIdToToken(encoding.Ids[i])); Assert.Equal(unfilteredToken![i], tokenizer.MapIdToToken(encodingIds[i]));
} }
Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i].AsSpan())); Assert.Equal(encodingIds[i], tokenizer.MapTokenToId(tokens[i].AsSpan()));
} }
} }
} }

Просмотреть файл

@ -17,20 +17,21 @@ namespace Microsoft.ML.Tokenizers.Tests
public class LlamaTests public class LlamaTests
{ {
private static readonly HttpClient _httpClient = new HttpClient() { Timeout = TimeSpan.FromMinutes(5) }; private static readonly HttpClient _httpClient = new HttpClient() { Timeout = TimeSpan.FromMinutes(5) };
private static Tokenizer _llamaTokenizer = CreateLlamaTokenizer().GetAwaiter().GetResult(); private static Tokenizer _llamaTokenizer = CreateLlamaTokenizer();
private static Tokenizer _llamaMistralTokenizer = CreateLMistralTokenizer().GetAwaiter().GetResult(); private static Tokenizer _llamaMistralTokenizer = CreateLMistralTokenizer();
private static async Task<Tokenizer> CreateLlamaTokenizer() private static Tokenizer CreateLlamaTokenizer()
{ {
const string modelUrl = @"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model"; // @"https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.model?download=true";
using Stream remoteStream = await _httpClient.GetStreamAsync(modelUrl); // @"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model";
using Stream remoteStream = File.OpenRead(Path.Combine(@"Llama", "tokenizer.model"));
return Tokenizer.CreateLlama(remoteStream); return Tokenizer.CreateLlama(remoteStream);
} }
private static async Task<Tokenizer> CreateLMistralTokenizer() private static Tokenizer CreateLMistralTokenizer()
{ {
const string modelUrl = @"https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/tokenizer.model?download=true"; // @"https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/tokenizer.model?download=true";
using Stream remoteStream = await _httpClient.GetStreamAsync(modelUrl); using Stream remoteStream = File.OpenRead(Path.Combine(@"Mistral", "tokenizer.model"));
return Tokenizer.CreateLlama(remoteStream); return Tokenizer.CreateLlama(remoteStream);
} }
@ -185,13 +186,13 @@ namespace Microsoft.ML.Tokenizers.Tests
[MemberData(nameof(LlamaTestData))] [MemberData(nameof(LlamaTestData))]
public void TestLlamaTokenizer(Tokenizer llamaTokenizer, string input, int[] ids, string[] tokens, (int Index, int Length)[] offsets) public void TestLlamaTokenizer(Tokenizer llamaTokenizer, string input, int[] ids, string[] tokens, (int Index, int Length)[] offsets)
{ {
SentencePieceBpe? bpe = llamaTokenizer.Model as SentencePieceBpe; SentencePieceBpe? bpe = llamaTokenizer as SentencePieceBpe;
Assert.NotNull(bpe); Assert.NotNull(bpe);
EncodingResult result = llamaTokenizer.Encode(input); IReadOnlyList<Token> result = llamaTokenizer.Encode(input, out _);
Assert.Equal(ids, result.Ids); Assert.Equal(ids, result.Select(t => t.Id).ToArray());
Assert.Equal(tokens, result.Tokens); Assert.Equal(tokens, result.Select(t => t.Value).ToArray());
Assert.Equal(offsets, result.Offsets); Assert.Equal(offsets, result.Select(t => t.Offset).ToArray());
Assert.Equal(input, llamaTokenizer.Decode(ids)); Assert.Equal(input, llamaTokenizer.Decode(ids));
Assert.Equal(ids, llamaTokenizer.EncodeToIds(input)); Assert.Equal(ids, llamaTokenizer.EncodeToIds(input));
Assert.Equal(ids.Length, llamaTokenizer.CountTokens(input)); Assert.Equal(ids.Length, llamaTokenizer.CountTokens(input));
@ -208,32 +209,29 @@ namespace Microsoft.ML.Tokenizers.Tests
bool isEmptyInput = string.IsNullOrEmpty(input); bool isEmptyInput = string.IsNullOrEmpty(input);
IReadOnlyList<Token> bpeTokens = bpe.Encode(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: false); IReadOnlyList<Token> bpeTokens = bpe.Encode(normalizedInput.AsSpan(), out _, addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false);
Assert.Equal(ids.Skip(1), bpeTokens.Select(token => token.Id)); Assert.Equal(ids.Skip(1), bpeTokens.Select(token => token.Id));
Assert.Equal(tokens.Skip(1), bpeTokens.Select(token => token.Value)); Assert.Equal(tokens.Skip(1), bpeTokens.Select(token => token.Value));
Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id))); Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id)));
List<int> encodedIds = new(); IReadOnlyList<int> encodedIds = bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false);
bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: false, accumulatedIds: encodedIds, out _);
Assert.Equal(ids.Skip(1), encodedIds); Assert.Equal(ids.Skip(1), encodedIds);
Assert.Equal(isEmptyInput ? 0 : ids.Length - 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: false, out _)); Assert.Equal(isEmptyInput ? 0 : ids.Length - 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false));
bpeTokens = bpe.Encode(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: true); bpeTokens = bpe.Encode(normalizedInput.AsSpan(), out _, addBeginningOfSentence: false, addEndOfSentence: true, considerNormalization: false);
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Skip(1).Concat(new[] { bpe.EndOfSentenceId }), bpeTokens.Select(token => token.Id)); Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Skip(1).Concat(new[] { bpe.EndOfSentenceId }), bpeTokens.Select(token => token.Id));
Assert.Equal(isEmptyInput ? Array.Empty<string>() : tokens.Skip(1).Concat(new[] { bpe.EndOfSentenceToken }), bpeTokens.Select(token => token.Value)); Assert.Equal(isEmptyInput ? Array.Empty<string>() : tokens.Skip(1).Concat(new[] { bpe.EndOfSentenceToken }), bpeTokens.Select(token => token.Value));
Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id))); Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id)));
encodedIds.Clear(); encodedIds = bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: true, considerNormalization: false);
bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: true, accumulatedIds: encodedIds, out _);
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Skip(1).Concat(new[] { bpe.EndOfSentenceId }), encodedIds); Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Skip(1).Concat(new[] { bpe.EndOfSentenceId }), encodedIds);
Assert.Equal(isEmptyInput ? 0 : ids.Length, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: true, out _)); Assert.Equal(isEmptyInput ? 0 : ids.Length, bpe.CountTokens(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: true, considerNormalization: false));
bpeTokens = bpe.Encode(normalizedInput.AsSpan(), addBeginOfSentence: true, addEndOfSentence: true); bpeTokens = bpe.Encode(normalizedInput.AsSpan(), out _, addBeginningOfSentence: true, addEndOfSentence: true, considerNormalization: false);
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Concat(new[] { bpe.EndOfSentenceId }), bpeTokens.Select(token => token.Id)); Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Concat(new[] { bpe.EndOfSentenceId }), bpeTokens.Select(token => token.Id));
Assert.Equal(isEmptyInput ? Array.Empty<string>() : tokens.Concat(new[] { bpe.EndOfSentenceToken }), bpeTokens.Select(token => token.Value)); Assert.Equal(isEmptyInput ? Array.Empty<string>() : tokens.Concat(new[] { bpe.EndOfSentenceToken }), bpeTokens.Select(token => token.Value));
Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id))); Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id)));
encodedIds.Clear(); encodedIds = bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginningOfSentence: true, addEndOfSentence: true, considerNormalization: false);
bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: true, addEndOfSentence: true, accumulatedIds: encodedIds, out _);
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Concat(new[] { bpe.EndOfSentenceId }), encodedIds); Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Concat(new[] { bpe.EndOfSentenceId }), encodedIds);
Assert.Equal(isEmptyInput ? 0 : ids.Length + 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: true, addEndOfSentence: true, out _)); Assert.Equal(isEmptyInput ? 0 : ids.Length + 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginningOfSentence: true, addEndOfSentence: true, considerNormalization: false));
} }
public static IEnumerable<object[]> LlamaTokenizersListData() public static IEnumerable<object[]> LlamaTokenizersListData()
@ -244,11 +242,17 @@ namespace Microsoft.ML.Tokenizers.Tests
[Theory] [Theory]
[MemberData(nameof(LlamaTokenizersListData))] [MemberData(nameof(LlamaTokenizersListData))]
public void TestLlamaTokenizerWithInvalidInput(Tokenizer llamaTokenizer) public void TestLlamaTokenizerWithEmptyInput(Tokenizer llamaTokenizer)
{ {
Assert.Throws<ArgumentNullException>(() => llamaTokenizer.Encode(null!)); Assert.Equal([], llamaTokenizer.Encode((string)null!, out _));
Assert.Throws<ArgumentNullException>(() => llamaTokenizer.EncodeToIds(null!)); Assert.Equal([], llamaTokenizer.Encode(Span<char>.Empty, out _));
Assert.Throws<ArgumentNullException>(() => llamaTokenizer.CountTokens(null!));
Assert.Equal([], llamaTokenizer.EncodeToIds((string)null!));
Assert.Equal([], llamaTokenizer.EncodeToIds(Span<char>.Empty));
Assert.Equal(0, llamaTokenizer.CountTokens((string)null!));
Assert.Equal(0, llamaTokenizer.CountTokens(Span<char>.Empty));
Assert.Throws<ArgumentNullException>(() => llamaTokenizer.Decode(null!)); Assert.Throws<ArgumentNullException>(() => llamaTokenizer.Decode(null!));
} }
@ -256,7 +260,7 @@ namespace Microsoft.ML.Tokenizers.Tests
[MemberData(nameof(LlamaTokenizersListData))] [MemberData(nameof(LlamaTokenizersListData))]
public void TestLlamaTokenizerProperties(Tokenizer llamaTokenizer) public void TestLlamaTokenizerProperties(Tokenizer llamaTokenizer)
{ {
SentencePieceBpe? bpe = llamaTokenizer.Model as SentencePieceBpe; SentencePieceBpe? bpe = llamaTokenizer as SentencePieceBpe;
Assert.NotNull(bpe); Assert.NotNull(bpe);
Assert.NotNull(llamaTokenizer.Normalizer); Assert.NotNull(llamaTokenizer.Normalizer);
@ -284,34 +288,242 @@ namespace Microsoft.ML.Tokenizers.Tests
} }
[Fact] [Fact]
public void TestLlamaNormalizer() public void TestSentencePieceNormalizer()
{ {
LlamaNormalizer normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: false, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false); SentencePieceNormalizer normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: false, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false);
Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!")); Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!"));
Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!".AsSpan()));
normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: false, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false); normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: false, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false);
Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!")); Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!"));
Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!".AsSpan()));
normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false); normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false);
Assert.Equal(" Hello, World!", normalizer.Normalize("Hello, World!")); Assert.Equal(" Hello, World!", normalizer.Normalize("Hello, World!"));
Assert.Equal(" Hello, World!", normalizer.Normalize("Hello, World!".AsSpan()));
normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: false); normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: false);
Assert.Equal("▁Hello,▁World!", normalizer.Normalize("Hello, World!")); Assert.Equal("▁Hello,▁World!", normalizer.Normalize("Hello, World!"));
Assert.Equal("▁Hello,▁World!", normalizer.Normalize("Hello, World!".AsSpan()));
normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: false); normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: false);
Assert.Equal("▁Hello,▁▁▁▁▁▁World!", normalizer.Normalize("Hello, World!")); Assert.Equal("▁Hello,▁▁▁▁▁▁World!", normalizer.Normalize("Hello, World!"));
Assert.Equal("▁Hello,▁▁▁▁▁▁World!", normalizer.Normalize("Hello, World!".AsSpan()));
normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true); normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true);
Assert.Equal("Hello,▁World!▁", normalizer.Normalize("Hello, World!")); Assert.Equal("Hello,▁World!▁", normalizer.Normalize("Hello, World!"));
Assert.Equal("Hello,▁World!▁", normalizer.Normalize("Hello, World!".AsSpan()));
normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: false, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true); normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: false, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true);
Assert.Equal("Hello,▁World!", normalizer.Normalize("Hello, World!")); Assert.Equal("Hello,▁World!", normalizer.Normalize("Hello, World!"));
Assert.Equal("Hello,▁World!", normalizer.Normalize("Hello, World!".AsSpan()));
normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true); normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true);
Assert.Equal("Hello,▁▁▁▁▁▁World!▁", normalizer.Normalize("Hello, World!")); Assert.Equal("Hello,▁▁▁▁▁▁World!▁", normalizer.Normalize("Hello, World!"));
Assert.Equal("Hello,▁▁▁▁▁▁World!▁", normalizer.Normalize("Hello, World!".AsSpan()));
normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: true); normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: true);
Assert.Equal("Hello, World! ", normalizer.Normalize("Hello, World!")); Assert.Equal("Hello, World! ", normalizer.Normalize("Hello, World!"));
Assert.Equal("Hello, World! ", normalizer.Normalize("Hello, World!".AsSpan()));
}
public static IEnumerable<object?[]> TokenizerTestData
{
get
{
// string to tokenize, produced tokens, the token offsets
yield return new object?[]
{
"the brown fox jumped over the lazy dog!",
"▁the▁brown▁fox▁jumped▁over▁the▁lazy▁dog!",
new string[] { "<s>", "▁the", "▁brown", "▁fo", "x", "▁jump", "ed", "▁over", "▁the", "▁lazy", "▁dog", "!" },
new (int Index, int Length)[] { (0, 0), (0, 4), (4, 6), (10, 3), (13, 1), (14, 5), (19, 2), (21, 5), (26, 4), (30, 5), (35, 4), (39, 1) },
new int[] { 1, 278, 17354, 1701, 29916, 12500, 287, 975, 278, 17366, 11203, 29991 }
};
yield return new object?[]
{
"he traveled to Egypt during the summer, the weather was hot and ammunition." ,
"▁he▁traveled▁to▁Egypt▁during▁the▁summer,▁the▁weather▁was▁hot▁and▁ammunition." ,
new string[] { "<s>", "▁he", "▁tra", "ve", "led", "▁to", "▁Egypt", "▁during", "▁the", "▁summer", ",", "▁the", "▁weather", "▁was", "▁hot", "▁and", "▁am", "mun", "ition", "." },
new (int Index, int Length)[] { (0, 0), (0, 3), (3, 4), (7, 2), (9, 3), (12, 3), (15, 6), (21, 7), (28, 4), (32, 7), (39, 1), (40, 4), (44, 8), (52, 4), (56, 4), (60, 4), (64, 3), (67, 3), (70, 5), (75, 1) },
new int[] { 1, 540, 1020, 345, 839, 304, 12892, 2645, 278, 11801, 29892, 278, 14826, 471, 7375, 322, 626, 24579, 654, 29889 }
};
yield return new object?[]
{
"She played many games and she felt exhausted afterward",
"▁She▁played▁many▁games▁and▁she▁felt▁exhausted▁afterward",
new string[] { "<s>", "▁She", "▁played", "▁many", "▁games", "▁and", "▁she", "▁felt", "▁exha", "usted", "▁after", "ward" },
new (int Index, int Length)[] { (0, 0), (0, 4), (4, 7), (11, 5), (16, 6), (22, 4), (26, 4), (30, 5), (35, 5), (40, 5), (45, 6), (51, 4) },
new int[] { 1, 2296, 5318, 1784, 8090, 322, 1183, 7091, 18782, 16656, 1156, 1328 }
};
yield return new object?[]
{
"Hello, y'all! How are you 😁 ?",
"▁Hello,▁y'all!▁How▁are▁you▁😁▁?",
new string[] { "<s>", "▁Hello", ",", "▁y", "'", "all", "!", "▁How", "▁are", "▁you", "▁", "<0xF0>", "<0x9F>", "<0x98>", "<0x81>", "▁?" },
new (int Index, int Length)[] { (0, 0), (0, 6), (6, 1), (7, 2), (9, 1), (10, 3), (13, 1), (14, 4), (18, 4), (22, 4), (26, 1), (27, 2), (27, 0), (27, 0), (27, 0), (29, 2) },
new int[] { 1, 15043, 29892, 343, 29915, 497, 29991, 1128, 526, 366, 29871, 243, 162, 155, 132, 1577 }
};
}
}
[Theory]
[MemberData(nameof(TokenizerTestData))]
public void TestTokenizerEncoding(string text, string normalizedText, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds)
{
Tokenizer tokenizer = _llamaTokenizer;
Assert.NotNull(tokenizer.Normalizer);
Assert.Null(tokenizer.PreTokenizer);
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
IReadOnlyList<Token> encoding1 = tokenizer.Encode(text.AsSpan(), out _);
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray());
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding1.Select(t => t.Value).ToArray());
Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray());
Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray());
SentencePieceBpe sentencePieceBpe = (tokenizer as SentencePieceBpe)!;
foreach (bool considerNormalization in new[] { true, false })
foreach (bool addBeginningOfSentence in new[] { true, false })
foreach (bool addEndOfSentence in new[] { true, false })
{
encoding = sentencePieceBpe.Encode(
considerNormalization ? text : normalizedText,
out _,
addBeginningOfSentence: addBeginningOfSentence,
addEndOfSentence: addEndOfSentence,
considerPreTokenization: false,
considerNormalization: considerNormalization);
encoding1 = sentencePieceBpe.Encode(
considerNormalization ? text.AsSpan() : normalizedText.AsSpan(),
out _,
addBeginningOfSentence: addBeginningOfSentence,
addEndOfSentence: addEndOfSentence,
considerPreTokenization: false,
considerNormalization: considerNormalization);
string[] expectedTokens1 = addBeginningOfSentence ? expectedTokens : expectedTokens.Skip(1).ToArray();
expectedTokens1 = addEndOfSentence ? expectedTokens1.Concat(new[] { sentencePieceBpe.EndOfSentenceToken }).ToArray() : expectedTokens1;
(int Index, int Length)[] expectedOffsets1 = addBeginningOfSentence ? expectedOffsets : expectedOffsets.Skip(1).ToArray();
expectedOffsets1 = addEndOfSentence ? expectedOffsets1.Concat(new[] { (normalizedText.Length, 0) }).ToArray() : expectedOffsets1;
int[] expectedIds1 = addBeginningOfSentence ? expectedIds : expectedIds.Skip(1).ToArray();
expectedIds1 = addEndOfSentence ? expectedIds1.Concat(new[] { sentencePieceBpe.EndOfSentenceId }).ToArray() : expectedIds1;
Assert.Equal(expectedTokens1, encoding.Select(t => t.Value).ToArray());
Assert.Equal(expectedOffsets1, encoding.Select(t => t.Offset).ToArray());
Assert.Equal(expectedIds1, encoding.Select(t => t.Id).ToArray());
}
}
[Theory]
[MemberData(nameof(TokenizerTestData))]
public void TestTokenizerEncodingToIds(string text, string normalizedText, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds)
{
Tokenizer tokenizer = _llamaTokenizer;
Assert.NotNull(expectedTokens);
Assert.NotNull(expectedOffsets);
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text));
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan()));
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text, expectedIds.Length, out string? normalizedString, out int length));
Assert.Equal(normalizedText, normalizedString);
Assert.Equal(normalizedText.Length, length);
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length, out normalizedString, out length));
Assert.Equal(normalizedText, normalizedString);
Assert.Equal(normalizedText.Length, length);
SentencePieceBpe sentencePieceBpe = (tokenizer as SentencePieceBpe)!;
foreach (bool considerNormalization in new[] { true, false })
foreach (bool addBeginningOfSentence in new[] { true, false })
foreach (bool addEndOfSentence in new[] { true, false })
{
// (string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
int[] expectedIds1 = addBeginningOfSentence ? expectedIds : expectedIds.Skip(1).ToArray();
expectedIds1 = addEndOfSentence ? expectedIds1.Concat(new[] { sentencePieceBpe.EndOfSentenceId }).ToArray() : expectedIds1;
Assert.Equal(expectedIds1, sentencePieceBpe.EncodeToIds(
considerNormalization ? text : normalizedText,
addBeginningOfSentence: addBeginningOfSentence,
addEndOfSentence: addEndOfSentence,
expectedIds1.Length,
out normalizedString,
out length,
considerNormalization: considerNormalization));
Assert.Equal(expectedIds1, sentencePieceBpe.EncodeToIds(
considerNormalization ? text.AsSpan() : normalizedText.AsSpan(),
addBeginningOfSentence: addBeginningOfSentence,
addEndOfSentence: addEndOfSentence,
expectedIds1.Length,
out normalizedString,
out length,
considerNormalization: considerNormalization));
Assert.Equal(considerNormalization ? normalizedText : null, normalizedString);
Assert.Equal(normalizedText.Length, length);
Assert.Equal(expectedIds1.Take(expectedIds1.Length - 6), sentencePieceBpe.EncodeToIds(
considerNormalization ? text : normalizedText,
addBeginningOfSentence: addBeginningOfSentence,
addEndOfSentence: addEndOfSentence,
expectedIds1.Length - 6,
out normalizedString,
out length,
considerNormalization: considerNormalization));
Assert.Equal(considerNormalization ? normalizedText : null, normalizedString);
(int Index, int Length)[] expectedOffsets1 = addBeginningOfSentence ? expectedOffsets.Take(expectedIds1.Length - 6).ToArray() : expectedOffsets.Skip(1).Take(expectedIds1.Length - 6).ToArray();
int expectedLength = expectedOffsets1[expectedOffsets1.Length - 1].Index + expectedOffsets1[expectedOffsets1.Length - 1].Length;
Assert.Equal(expectedLength, length);
Assert.Equal(expectedIds1.Take(expectedIds1.Length - 6), sentencePieceBpe.EncodeToIds(
considerNormalization ? text.AsSpan() : normalizedText.AsSpan(),
addBeginningOfSentence: addBeginningOfSentence,
addEndOfSentence: addEndOfSentence,
expectedIds1.Length - 6,
out normalizedString,
out length,
considerNormalization: considerNormalization));
Assert.Equal(expectedLength, length);
}
}
[Theory]
[MemberData(nameof(TokenizerTestData))]
public void TestTokenizerCountTokens(string text, string normalizedText, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds)
{
Tokenizer tokenizer = _llamaTokenizer;
Assert.NotNull(expectedTokens);
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text));
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text.AsSpan()));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index + expectedOffsets[expectedOffsets.Length - 7].Length, tokenizer.IndexOfTokenCount(text, expectedIds.Length - 6, out string? normalizedString, out int tokenCount));
Assert.Equal(normalizedText, normalizedString);
Assert.Equal(expectedIds.Length - 6, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index + expectedOffsets[expectedOffsets.Length - 7].Length, tokenizer.IndexOfTokenCount(text.AsSpan(), expectedIds.Length - 6, out normalizedString, out tokenCount));
Assert.Equal(normalizedText, normalizedString);
Assert.Equal(expectedIds.Length - 6, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index, tokenizer.LastIndexOfTokenCount(text, 7, out normalizedString, out tokenCount));
Assert.Equal(normalizedText, normalizedString);
Assert.Equal(7, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index, tokenizer.LastIndexOfTokenCount(text.AsSpan(), 7, out normalizedString, out tokenCount));
Assert.Equal(normalizedText, normalizedString);
Assert.Equal(7, tokenCount);
} }
} }
} }

Просмотреть файл

@ -42,6 +42,7 @@
<ItemGroup> <ItemGroup>
<PackageReference Include="Microsoft.DotNet.RemoteExecutor" Version="$(MicrosoftDotNetRemoteExecutorVersion)" /> <PackageReference Include="Microsoft.DotNet.RemoteExecutor" Version="$(MicrosoftDotNetRemoteExecutorVersion)" />
<PackageReference Include="System.Text.Json" Version="$(SystemTextJsonVersion)" /> <PackageReference Include="System.Text.Json" Version="$(SystemTextJsonVersion)" />
<PackageReference Include="Microsoft.ML.TestTokenizers" Version="$(MicrosoftMLTestTokenizersVersion)" />
</ItemGroup> </ItemGroup>
</Project> </Project>

Просмотреть файл

@ -61,10 +61,9 @@ namespace Microsoft.ML.Tokenizers.Tests
string normalizedText = normalizer.Normalize(text); string normalizedText = normalizer.Normalize(text);
Assert.Equal(normalized, normalizedText); Assert.Equal(normalized, normalizedText);
Tokenizer tokenizer = new Tokenizer(BpeTests.CreateEmptyBpe(), WhiteSpace.Instance, normalizer); Tokenizer tokenizer = BpeTests.CreateEmptyBpe(preTokenizer: null, normalizer);
EncodingResult encoding = tokenizer.Encode(text); IReadOnlyList<Token> tokens = tokenizer.Encode(text, out string? normalizedString);
Assert.Equal(text, encoding.OriginalString); Assert.Equal(normalized, normalizedString);
Assert.Equal(normalized, encoding.NormalizedString);
} }
public class RemoveQuotesNormalizer : Normalizer public class RemoveQuotesNormalizer : Normalizer
@ -77,6 +76,22 @@ namespace Microsoft.ML.Tokenizers.Tests
return original; return original;
} }
return RemoveQuotes(original.AsSpan(), index);
}
public override string Normalize(ReadOnlySpan<char> original)
{
int index = original.IndexOf('"');
if (index <= 0)
{
return original.ToString();
}
return RemoveQuotes(original, index);
}
private string RemoveQuotes(ReadOnlySpan<char> original, int index)
{
StringBuilder sb = new StringBuilder(original.Length); StringBuilder sb = new StringBuilder(original.Length);
List<int> mapping = new List<int>(); List<int> mapping = new List<int>();
@ -97,7 +112,7 @@ namespace Microsoft.ML.Tokenizers.Tests
break; break;
} }
index = original.IndexOf('"', start); index = original.Slice(start).IndexOf('"');
if (index <= 0) if (index <= 0)
{ {
for (int i = start; i < original.Length; i++) for (int i = start; i < original.Length; i++)
@ -107,6 +122,8 @@ namespace Microsoft.ML.Tokenizers.Tests
} }
break; break;
} }
index += start;
} while (true); } while (true);
return sb.ToString(); return sb.ToString();
@ -130,6 +147,16 @@ namespace Microsoft.ML.Tokenizers.Tests
return original.Normalize(_normalizationForm); return original.Normalize(_normalizationForm);
} }
public override string Normalize(ReadOnlySpan<char> original)
{
if (original.IsEmpty)
{
return string.Empty;
}
return original.ToString().Normalize(_normalizationForm);
}
} }
} }
} }

Просмотреть файл

@ -20,62 +20,63 @@ namespace Microsoft.ML.Tokenizers.Tests
{ {
WhiteSpace.Instance, WhiteSpace.Instance,
"How are you doing?", "How are you doing?",
new Split[] { new Split("How", (0, 3)), new Split("are", (4, 3)), new Split("you", (8, 3)), new Split("doing", (12, 5)), new Split("?", (17, 1)),} new (int Offset, int Length)[] { (0, 3), (4, 3), (8, 3), (12, 5), (17, 1), }
}; };
yield return new object[] yield return new object[]
{ {
WhiteSpace.Instance, WhiteSpace.Instance,
"I_am_Just_Fine!", "I_am_Just_Fine!",
new Split[] { new Split("I_am_Just_Fine", (0, 14)), new Split("!", (14, 1)) } new (int Offset, int Length)[] { (0, 14), (14, 1) }
}; };
yield return new object[] yield return new object[]
{ {
new SpacePreTokenizer(), new SpacePreTokenizer(),
"How are you doing?!", "How are you doing?!",
new Split[] { new Split("How", (0, 3)), new Split("are", (4, 3)), new Split("you", (11, 3)), new Split("doing?!", (15, 7)) } new (int Offset, int Length)[] { (0, 3), (4, 3), (11, 3), (15, 7) }
}; };
yield return new object[] yield return new object[]
{ {
new SpacePreTokenizer(), new SpacePreTokenizer(),
new string(' ', 100), new string(' ', 100),
new Split[] { } new (int Offset, int Length)[] { }
}; };
} }
} }
[Theory] [Theory]
[MemberData(nameof(PreTokenizerData))] [MemberData(nameof(PreTokenizerData))]
public void TestPreTokenizer(PreTokenizer preTokenizer, string text, Split[] splits) public void TestPreTokenizer(PreTokenizer preTokenizer, string text, (int Offset, int Length)[] splits)
{ {
Split[] splitParts = preTokenizer.PreTokenize(text).ToArray<Split>(); (int Offset, int Length)[] splitParts = preTokenizer.PreTokenize(text).ToArray<(int Offset, int Length)>();
Assert.Equal(splits, splitParts); Assert.Equal(splits, splitParts);
// Empty tokenizer which tokenize all parts as unknown tokens. // Empty tokenizer which tokenize all parts as unknown tokens.
Tokenizer tokenizer = new Tokenizer(BpeTests.CreateEmptyBpe(), preTokenizer); Tokenizer tokenizer = BpeTests.CreateEmptyBpe(normalizer: null, preTokenizer: preTokenizer);
EncodingResult encoding = tokenizer.Encode(text); IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
Assert.True(encoding.Tokens.Count >= splitParts.Length, $"Expected to have {encoding.Tokens.Count} >= {splitParts.Length}"); Assert.True(encoding.Count >= splitParts.Length, $"Expected to have {encoding.Count} >= {splitParts.Length}");
} }
[Fact] [Fact]
public void TestWhiteSpacePreTokenizer() public void TestWhiteSpacePreTokenizer()
{ {
Assert.Empty(WhiteSpace.Instance.PreTokenize(null!)); Assert.Empty(WhiteSpace.Instance.PreTokenize((string)null!));
} }
public class SpacePreTokenizer : PreTokenizer public class SpacePreTokenizer : PreTokenizer
{ {
public override IEnumerable<Split> PreTokenize(string text) public override IEnumerable<(int Offset, int Length)> PreTokenize(ReadOnlySpan<char> text)
{ {
List<Split> splits = new(); if (text.IsEmpty)
if (string.IsNullOrEmpty(text))
{ {
return splits; return [];
} }
List<(int Offset, int Length)> splits = new();
int index = 0; int index = 0;
while (true) while (true)
{ {
@ -92,7 +93,7 @@ namespace Microsoft.ML.Tokenizers.Tests
if (index < text.Length) if (index < text.Length)
{ {
splits.Add(new Split(text.Substring(index, end - index), (index, end - index))); splits.Add((index, end - index));
} }
else else
{ {
@ -104,6 +105,16 @@ namespace Microsoft.ML.Tokenizers.Tests
return splits; return splits;
} }
public override IEnumerable<(int Offset, int Length)> PreTokenize(string text)
{
if (string.IsNullOrEmpty(text))
{
return [];
}
return PreTokenize(text.AsSpan());
}
} }
} }
} }

Просмотреть файл

@ -38,8 +38,8 @@ namespace Microsoft.ML.Tokenizers.Tests
{ {
TestGPT4TokenizationEncoding(GPT4); TestGPT4TokenizationEncoding(GPT4);
Assert.True(GPT4.Model is Tiktoken); Assert.True(GPT4 is Tiktoken);
IReadOnlyDictionary<string, int>? specialTokensEncoder = (GPT4.Model as Tiktoken)!.SpecialTokens; IReadOnlyDictionary<string, int>? specialTokensEncoder = (GPT4 as Tiktoken)!.SpecialTokens;
string tokenizerDataFileName = Utils.CreateTemporaryFile("tiktoken"); string tokenizerDataFileName = Utils.CreateTemporaryFile("tiktoken");
@ -53,21 +53,21 @@ namespace Microsoft.ML.Tokenizers.Tests
try try
{ {
Tokenizer tokenizer = new Tokenizer(new Tiktoken(tokenizerDataFileName, specialTokensEncoder), GPT4.PreTokenizer); Tokenizer tokenizer = new Tiktoken(tokenizerDataFileName, GPT4.PreTokenizer, specialTokensEncoder);
TestGPT4TokenizationEncoding(tokenizer); TestGPT4TokenizationEncoding(tokenizer);
using (Stream stream = File.OpenRead(tokenizerDataFileName)) using (Stream stream = File.OpenRead(tokenizerDataFileName))
{ {
tokenizer = new Tokenizer(new Tiktoken(stream, specialTokensEncoder), GPT4.PreTokenizer); tokenizer = new Tiktoken(stream, GPT4.PreTokenizer, specialTokensEncoder);
} }
TestGPT4TokenizationEncoding(tokenizer); TestGPT4TokenizationEncoding(tokenizer);
tokenizer = new Tokenizer(await Tiktoken.CreateAsync(tokenizerDataFileName, specialTokensEncoder), GPT4.PreTokenizer); tokenizer = await Tokenizer.CreateTiktokenAsync(tokenizerDataFileName, GPT4.PreTokenizer, normalizer: null, specialTokensEncoder);
TestGPT4TokenizationEncoding(tokenizer); TestGPT4TokenizationEncoding(tokenizer);
using (Stream stream = File.OpenRead(tokenizerDataFileName)) using (Stream stream = File.OpenRead(tokenizerDataFileName))
{ {
tokenizer = new Tokenizer(await Tiktoken.CreateAsync(stream, specialTokensEncoder), GPT4.PreTokenizer); tokenizer = await Tokenizer.CreateTiktokenAsync(stream, GPT4.PreTokenizer, normalizer: null, specialTokensEncoder);
} }
TestGPT4TokenizationEncoding(tokenizer); TestGPT4TokenizationEncoding(tokenizer);
@ -109,11 +109,11 @@ namespace Microsoft.ML.Tokenizers.Tests
try try
{ {
Tiktoken tiktoken = (tokenizer.Model as Tiktoken)!; Tiktoken tiktoken = (tokenizer as Tiktoken)!;
Tokenizer externalTokenizer = new Tokenizer(new Tiktoken(tokenizerDataFileName, tiktoken.SpecialTokens), tokenizer.PreTokenizer); Tokenizer externalTokenizer = new Tiktoken(tokenizerDataFileName, tokenizer.PreTokenizer, tiktoken.SpecialTokens);
IReadOnlyDictionary<ReadOnlyMemory<byte>, int> encoder = tiktoken.Encoder; IReadOnlyDictionary<ReadOnlyMemory<byte>, int> encoder = tiktoken.Encoder;
IReadOnlyDictionary<ReadOnlyMemory<byte>, int> externalEncoder = (externalTokenizer.Model as Tiktoken)!.Encoder; IReadOnlyDictionary<ReadOnlyMemory<byte>, int> externalEncoder = (externalTokenizer as Tiktoken)!.Encoder;
Assert.Equal(externalEncoder.Count, encoder.Count); Assert.Equal(externalEncoder.Count, encoder.Count);
foreach (KeyValuePair<ReadOnlyMemory<byte>, int> kvp in encoder) foreach (KeyValuePair<ReadOnlyMemory<byte>, int> kvp in encoder)
@ -135,13 +135,17 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(new List<int>() { 9906, 4435 }, encoded); Assert.Equal(new List<int>() { 9906, 4435 }, encoded);
Assert.Equal(text, tokenizer.Decode(encoded.ToArray())!); Assert.Equal(text, tokenizer.Decode(encoded.ToArray())!);
EncodingResult result = tokenizer.Encode(text); IReadOnlyList<Token> result = tokenizer.Encode(text, out string? normalizedString);
int idsCount = tokenizer.CountTokens(text); int idsCount = tokenizer.CountTokens(text);
Assert.Equal(encoded, result.Ids);
Assert.Equal(new string[] { "Hello", " World" }, result.Tokens); int[] ids = result.Select(token => token.Id).ToArray();
Assert.Equal(new List<(int, int)> { (0, 5), (5, 6) }, result.Offsets); string[] tokens = result.Select(token => token.Value).ToArray();
(int, int)[] offsets = result.Select(token => token.Offset).ToArray();
Assert.Equal(encoded, ids);
Assert.Equal(new string[] { "Hello", " World" }, tokens);
Assert.Equal(new List<(int, int)> { (0, 5), (5, 6) }, offsets);
Assert.Equal(encoded.Count, idsCount); Assert.Equal(encoded.Count, idsCount);
Assert.Equal(encoded, result.Ids); Assert.Equal(encoded, ids);
TestGPT4Tokenizer(tokenizer); TestGPT4Tokenizer(tokenizer);
} }
@ -154,13 +158,18 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(new List<int>() { 100264, 9906, 4435, 100265 }, encoded); Assert.Equal(new List<int>() { 100264, 9906, 4435, 100265 }, encoded);
Assert.Equal(text, GPT4.Decode(encoded.ToArray())); Assert.Equal(text, GPT4.Decode(encoded.ToArray()));
EncodingResult result = GPT4.Encode(text); IReadOnlyList<Token> result = GPT4.Encode(text, out string? normalizedString);
int idsCount = GPT4.CountTokens(text); int idsCount = GPT4.CountTokens(text);
Assert.Equal(encoded, result.Ids);
Assert.Equal(new string[] { "<|im_start|>", "Hello", " World", "<|im_end|>" }, result.Tokens); int[] ids = result.Select(token => token.Id).ToArray();
Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 6), (23, 10) }, result.Offsets); string[] tokens = result.Select(token => token.Value).ToArray();
(int, int)[] offsets = result.Select(token => token.Offset).ToArray();
Assert.Equal(encoded, ids);
Assert.Equal(new string[] { "<|im_start|>", "Hello", " World", "<|im_end|>" }, tokens);
Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 6), (23, 10) }, offsets);
Assert.Equal(encoded.Count, idsCount); Assert.Equal(encoded.Count, idsCount);
Assert.Equal(encoded, result.Ids); Assert.Equal(encoded, ids);
} }
private void TestGPT4Tokenizer(Tokenizer gpt4Tokenizer) private void TestGPT4Tokenizer(Tokenizer gpt4Tokenizer)
@ -192,12 +201,16 @@ namespace Microsoft.ML.Tokenizers.Tests
string? decoded = GPT4.Decode(encoded.ToArray()); string? decoded = GPT4.Decode(encoded.ToArray());
Assert.Equal(text, decoded); Assert.Equal(text, decoded);
EncodingResult result = GPT4.Encode(text); IReadOnlyList<Token> result = GPT4.Encode(text, out string? normalizedString);
int[] ids = result.Select(token => token.Id).ToArray();
string[] tokens = result.Select(token => token.Value).ToArray();
(int, int)[] offsets = result.Select(token => token.Offset).ToArray();
int idsCount = GPT4.CountTokens(text); int idsCount = GPT4.CountTokens(text);
Assert.Equal(encoded, result.Ids); Assert.Equal(encoded, ids);
Assert.Equal(encoded.Count, idsCount); Assert.Equal(encoded.Count, idsCount);
Assert.Equal(new string[] { "<|im_start|>", "Hello", "<|im_end|>", " World" }, result.Tokens); Assert.Equal(new string[] { "<|im_start|>", "Hello", "<|im_end|>", " World" }, tokens);
Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 10), (27, 6) }, result.Offsets); Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 10), (27, 6) }, offsets);
} }
[Fact] [Fact]
@ -207,12 +220,10 @@ namespace Microsoft.ML.Tokenizers.Tests
IReadOnlyList<int> encoded = GPT4.EncodeToIds(text); IReadOnlyList<int> encoded = GPT4.EncodeToIds(text);
Assert.Empty(encoded); Assert.Empty(encoded);
EncodingResult result = GPT4.Encode(text); IReadOnlyList<Token> result = GPT4.Encode(text, out string? normalizedString);
int idsCount = GPT4.CountTokens(text); int idsCount = GPT4.CountTokens(text);
Assert.Empty(result.Ids); Assert.Empty(result);
Assert.Empty(result.Tokens); Assert.Equal(0, idsCount);
Assert.Empty(result.Offsets);
Assert.Equal(result.Ids.Count, idsCount);
} }
[Fact] [Fact]
@ -224,11 +235,11 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(new List<int>() { 100264, 9906, 2928, 99834, 4435, 100265 }, encoded); Assert.Equal(new List<int>() { 100264, 9906, 2928, 99834, 4435, 100265 }, encoded);
Assert.Equal(text, GPT4.Decode(encoded.ToArray())); Assert.Equal(text, GPT4.Decode(encoded.ToArray()));
EncodingResult result = GPT4.Encode(text); IReadOnlyList<Token> result = GPT4.Encode(text, out string? normalizedString);
Assert.Equal(encoded, result.Ids); Assert.Equal(encoded, result.Select(token => token.Id).ToArray());
Assert.Equal(encoded.Count, idsCount); Assert.Equal(encoded.Count, idsCount);
Assert.Equal(new string[] { "<|im_start|>", "Hello", " ⭐", "", " World", "<|im_end|>" }, result.Tokens); Assert.Equal(new string[] { "<|im_start|>", "Hello", " ⭐", "", " World", "<|im_end|>" }, result.Select(token => token.Value).ToArray());
Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 2), (19, 0), (19, 6), (25, 10) }, result.Offsets); Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 2), (19, 0), (19, 6), (25, 10) }, result.Select(token => token.Offset).ToArray());
} }
[Fact] [Fact]
@ -310,9 +321,12 @@ namespace Microsoft.ML.Tokenizers.Tests
[Theory] [Theory]
[InlineData("gpt-4")] [InlineData("gpt-4")]
[InlineData("gpt-4-")] [InlineData("gpt-4-")]
[InlineData("gpt-3.5-")]
[InlineData("gpt-3.5-turbo")] [InlineData("gpt-3.5-turbo")]
[InlineData("gpt-3.5-turbo-")] [InlineData("gpt-3.5-turbo-")]
[InlineData("gpt-3.5-turbo-16k")] [InlineData("gpt-3.5-turbo-16k")]
[InlineData("gpt-35")]
[InlineData("gpt-35-")]
[InlineData("gpt-35-turbo")] [InlineData("gpt-35-turbo")]
[InlineData("gpt-35-turbo-16k")] [InlineData("gpt-35-turbo-16k")]
[InlineData("gpt-35-turbo-")] [InlineData("gpt-35-turbo-")]
@ -351,7 +365,7 @@ namespace Microsoft.ML.Tokenizers.Tests
public void TestAllSupportedModelNames(string modelName) public void TestAllSupportedModelNames(string modelName)
{ {
Tokenizer tokenizer = Tokenizer.CreateTiktokenForModel(modelName); Tokenizer tokenizer = Tokenizer.CreateTiktokenForModel(modelName);
Assert.NotNull(tokenizer.Model); Assert.True(tokenizer is Tiktoken);
Assert.NotNull(tokenizer.PreTokenizer); Assert.NotNull(tokenizer.PreTokenizer);
} }
@ -363,7 +377,7 @@ namespace Microsoft.ML.Tokenizers.Tests
public void TestAllSupportedEncodingNames(string encodingName) public void TestAllSupportedEncodingNames(string encodingName)
{ {
Tokenizer tokenizer = Tokenizer.CreateTiktokenForEncoding(encodingName); Tokenizer tokenizer = Tokenizer.CreateTiktokenForEncoding(encodingName);
Assert.NotNull(tokenizer.Model); Assert.True(tokenizer is Tiktoken);
Assert.NotNull(tokenizer.PreTokenizer); Assert.NotNull(tokenizer.PreTokenizer);
string modelName = encodingName.ToLowerInvariant() switch string modelName = encodingName.ToLowerInvariant() switch
@ -377,15 +391,16 @@ namespace Microsoft.ML.Tokenizers.Tests
Tokenizer tokenizer1 = Tokenizer.CreateTiktokenForModel(modelName); Tokenizer tokenizer1 = Tokenizer.CreateTiktokenForModel(modelName);
Tiktoken? model1 = tokenizer.Model as Tiktoken; Assert.True(tokenizer is Tiktoken);
Tiktoken? model2 = tokenizer1.Model as Tiktoken; Assert.True(tokenizer1 is Tiktoken);
Assert.NotNull(model1);
Assert.NotNull(model2);
Assert.Equal(model2.Encoder, model1.Encoder); Tiktoken tiktoken = (tokenizer as Tiktoken)!;
Assert.Equal(model2.Decoder, model1.Decoder); Tiktoken tiktoken1 = (tokenizer1 as Tiktoken)!;
Assert.Equal(model2.SpecialTokens, model1.SpecialTokens);
Assert.Equal(model2.Vocab, model1.Vocab); Assert.Equal(tiktoken1.Encoder, tiktoken.Encoder);
Assert.Equal(tiktoken1.Decoder, tiktoken.Decoder);
Assert.Equal(tiktoken1.SpecialTokens, tiktoken.SpecialTokens);
Assert.Equal(tiktoken1.Vocab, tiktoken.Vocab);
} }
[Fact] [Fact]
@ -408,11 +423,99 @@ namespace Microsoft.ML.Tokenizers.Tests
RemoteExecutor.Invoke(static (name) => RemoteExecutor.Invoke(static (name) =>
{ {
Tokenizer tokenizer = Tokenizer.CreateTiktokenForModel(name); Tokenizer tokenizer = Tokenizer.CreateTiktokenForModel(name);
Assert.NotNull(tokenizer.Model); Assert.True(tokenizer is Tiktoken);
Assert.NotNull(tokenizer.PreTokenizer); Assert.NotNull(tokenizer.PreTokenizer);
}, modelName).Dispose(); }, modelName).Dispose();
} }
public static IEnumerable<object?[]> TokenizerTestData
{
get
{
// string to tokenize, produced tokens, the token offsets
yield return new object?[]
{
"the brown fox jumped over the lazy dog!",
new string[] { "the", " brown", " fox", " jumped", " over", " the", " lazy", " dog", "!" },
new (int Index, int Length)[] { (0, 3), (3, 6), (9, 4), (13, 7), (20, 5), (25, 4), (29, 5), (34, 4), (38, 1) },
new int[] { 1820, 14198, 39935, 27096, 927, 279, 16053, 5679, 0 }
};
yield return new object?[]
{
"he traveled to Egypt during the summer, the weather was hot and ammunition." ,
new string[] { "he", " traveled", " to", " Egypt", " during", " the", " summer", ",", " the", " weather", " was", " hot", " and", " ammunition", "." },
new (int Index, int Length)[] { (0, 2), (2, 9), (11, 3), (14, 6), (20, 7), (27, 4), (31, 7), (38, 1), (39, 4), (43, 8), (51, 4), (55, 4), (59, 4), (63, 11), (74, 1) },
new int[] { 383, 31796, 311, 15212, 2391, 279, 7474, 11, 279, 9282, 574, 4106, 323, 37768, 13 }
};
yield return new object?[]
{
"She played many games and she felt exhausted afterward",
new string[] { "She", " played", " many", " games", " and", " she", " felt", " exhausted", " afterward" },
new (int Index, int Length)[] { (0, 3), (3, 7), (10, 5), (15, 6), (21, 4), (25, 4), (29, 5), (34, 10), (44, 10) },
new int[] { 8100, 6476, 1690, 3953, 323, 1364, 6612, 39019, 49043 }
};
yield return new object?[]
{
"Hello, y'all! How are you 😁 ?",
new string[] { "Hello", ",", " y", "'all", "!", " How", " are", " you", " 😁", "", " ?" },
new (int Index, int Length)[] { (0, 5), (5, 1), (6, 2), (8, 4), (12, 1), (13, 4), (17, 4), (21, 4), (25, 3), (28, 0), (28, 2) },
new int[] { 9906, 11, 379, 65948, 0, 2650, 527, 499, 27623, 223, 949 }
};
}
}
[Theory]
[MemberData(nameof(TokenizerTestData))]
public void TestTokenizerEncoding(string text, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds)
{
Tokenizer tokenizer = GPT4;
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
IReadOnlyList<Token> encoding1 = tokenizer.Encode(text.AsSpan(), out _);
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray());
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding1.Select(t => t.Value).ToArray());
Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray());
Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray());
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text));
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan()));
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text, expectedIds.Length, out string? normalizedString, out int length));
Assert.Null(normalizedString);
Assert.Equal(text.Length, length);
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length, out normalizedString, out length));
Assert.Null(normalizedString);
Assert.Equal(text.Length, length);
Assert.Equal(expectedIds.Take(expectedIds.Length - 4), tokenizer.EncodeToIds(text, expectedIds.Length - 4, out normalizedString, out length));
Assert.Null(normalizedString);
int expectedLength = expectedOffsets[expectedOffsets.Length - 5].Index + expectedOffsets[expectedOffsets.Length - 5].Length;
Assert.Equal(expectedLength, length);
Assert.Equal(expectedIds.Take(expectedIds.Length - 4), tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length - 4, out normalizedString, out length));
Assert.Null(normalizedString);
Assert.Equal(expectedLength, length);
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text));
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text.AsSpan()));
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text, expectedIds.Length - 3, out normalizedString, out int tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIds.Length - 3, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text.AsSpan(), expectedIds.Length - 3, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIds.Length - 3, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text, 3, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(3, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text.AsSpan(), 3, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(3, tokenCount);
}
// Test running copy the test data files to the output folder but sometimes the file content is mutated replacing '\n' with '\r\n'. // Test running copy the test data files to the output folder but sometimes the file content is mutated replacing '\n' with '\r\n'.
// This method reads the file and removes the extra inserted '\r' characters. Having '\r' in the file content will cause the tests to fail. // This method reads the file and removes the extra inserted '\r' characters. Having '\r' in the file content will cause the tests to fail.
private string ReadAndSanitizeFile(string path) private string ReadAndSanitizeFile(string path)

Просмотреть файл

@ -26,12 +26,12 @@ namespace Microsoft.ML.Tokenizers.Tests
for (int i = 1; i <= fullIdsList.Count; i++) for (int i = 1; i <= fullIdsList.Count; i++)
{ {
int index1 = tokenizer.IndexOfTokenCount(input, maxTokenCount: i, out string processedText1, out int tokenCount1); int index1 = tokenizer.IndexOfTokenCount(input, maxTokenCount: i, out string? processedText1, out int tokenCount1);
int index2 = tokenizer.LastIndexOfTokenCount(input, maxTokenCount: i, out string processedText2, out int tokenCount2); int index2 = tokenizer.LastIndexOfTokenCount(input, maxTokenCount: i, out string? processedText2, out int tokenCount2);
IReadOnlyList<int> partialIdsList = tokenizer.EncodeToIds(input, maxTokenCount: i, out string processedText, out int textLength); IReadOnlyList<int> partialIdsList = tokenizer.EncodeToIds(input, maxTokenCount: i, out string? processedText, out int textLength);
Assert.True(textLength <= processedText.Length); Assert.True(processedText is null || textLength <= processedText.Length);
Assert.True(tokenizer.Normalizer is not null || processedText == input); Assert.True(tokenizer.Normalizer is not null || processedText is null);
Assert.Equal(fullIdsList.Take(partialIdsList.Count), partialIdsList); Assert.Equal(fullIdsList.Take(partialIdsList.Count), partialIdsList);
@ -42,14 +42,13 @@ namespace Microsoft.ML.Tokenizers.Tests
// In this case, we'll get index1 equal to zero and nothing really will need to be tested. // In this case, we'll get index1 equal to zero and nothing really will need to be tested.
if (tokenCount1 > 0 && index1 > 0) if (tokenCount1 > 0 && index1 > 0)
{ {
string prefixString = processedText1.Substring(0, index1); string prefixString = (processedText1 ?? input).Substring(0, index1);
if (tokenizer.Model is SentencePieceBpe) if (tokenizer is SentencePieceBpe)
{ {
// SentencePieceBpe model normalize the text and insert more characters. // SentencePieceBpe model normalize the text and insert more characters.
// We call the model directly to bypass the normalization step // We call the model directly to bypass the normalization step
prefixIds = new List<int>(); prefixIds = tokenizer.EncodeToIds(prefixString.AsSpan(), considerNormalization: false);
tokenizer.Model.EncodeToIds(prefixString.AsSpan(), (prefixIds as IList<int>)!, out _);
} }
else else
{ {
@ -61,14 +60,13 @@ namespace Microsoft.ML.Tokenizers.Tests
if (tokenCount2 > 0) if (tokenCount2 > 0)
{ {
string suffixString = processedText2.Substring(index2); string suffixString = (processedText2 ?? input).Substring(index2);
if (tokenizer.Model is SentencePieceBpe) if (tokenizer is SentencePieceBpe)
{ {
// SentencePieceBpe model normalize the text and insert more characters. // SentencePieceBpe model normalize the text and insert more characters.
// We call the model directly to bypass the normalization step // We call the model directly to bypass the normalization step
suffixIds = new List<int>(); suffixIds = tokenizer.EncodeToIds(suffixString.AsSpan(), considerNormalization: false);
tokenizer.Model.EncodeToIds(suffixString.AsSpan(), (suffixIds as IList<int>)!, out _);
if (i < fullIdsList.Count) if (i < fullIdsList.Count)
{ {
suffixIds = suffixIds.Skip(1).ToList(); // Skip the start of sentence token <s> suffixIds = suffixIds.Skip(1).ToList(); // Skip the start of sentence token <s>
@ -85,12 +83,13 @@ namespace Microsoft.ML.Tokenizers.Tests
if (i == fullIdsList.Count) if (i == fullIdsList.Count)
{ {
if (index1 != processedText1.Length) string s = processedText1 ?? input;
if (index1 != s.Length)
{ {
// It's possible that the remaining text on the left doesn't produce any tokens, as in the case of BPE, // It's possible that the remaining text on the left doesn't produce any tokens, as in the case of BPE,
// where the pre-tokenizer removes spaces and the left text consists entirely of spaces. // where the pre-tokenizer removes spaces and the left text consists entirely of spaces.
Assert.True(index1 < processedText1.Length); Assert.True(index1 < s.Length);
Assert.Equal(0, tokenizer.CountTokens(processedText1.Substring(index1))); Assert.Equal(0, tokenizer.CountTokens(s.Substring(index1)));
} }
if (index2 != 0) if (index2 != 0)
@ -98,7 +97,7 @@ namespace Microsoft.ML.Tokenizers.Tests
// It's possible that the remaining text on the right doesn't produce any tokens, as in the case of BPE, // It's possible that the remaining text on the right doesn't produce any tokens, as in the case of BPE,
// where the pre-tokenizer removes spaces and the left text consists entirely of spaces. // where the pre-tokenizer removes spaces and the left text consists entirely of spaces.
Assert.True(index2 > 0); Assert.True(index2 > 0);
Assert.Equal(0, tokenizer.CountTokens(processedText1.Substring(0, index2))); Assert.Equal(0, tokenizer.CountTokens(s.Substring(0, index2)));
} }
Assert.Equal(fullIdsList, prefixIds); Assert.Equal(fullIdsList, prefixIds);
@ -106,13 +105,15 @@ namespace Microsoft.ML.Tokenizers.Tests
} }
} }
Assert.Equal(0, tokenizer.IndexOfTokenCount((string)null!, maxTokenCount: 10, out _, out _));
Assert.Equal(0, tokenizer.LastIndexOfTokenCount((string)null!, maxTokenCount: 10, out _, out _));
Assert.Equal(0, tokenizer.IndexOfTokenCount(Span<char>.Empty, maxTokenCount: 10, out _, out _));
Assert.Equal(0, tokenizer.LastIndexOfTokenCount(Span<char>.Empty, maxTokenCount: 10, out _, out _));
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.IndexOfTokenCount(input, maxTokenCount: 0, out _, out _)); Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.IndexOfTokenCount(input, maxTokenCount: 0, out _, out _));
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.IndexOfTokenCount(input, maxTokenCount: -1, out _, out _)); Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.IndexOfTokenCount(input, maxTokenCount: -1, out _, out _));
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.LastIndexOfTokenCount(input, maxTokenCount: 0, out _, out _)); Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.LastIndexOfTokenCount(input, maxTokenCount: 0, out _, out _));
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.LastIndexOfTokenCount(input, maxTokenCount: -1, out _, out _)); Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.LastIndexOfTokenCount(input, maxTokenCount: -1, out _, out _));
Assert.Throws<ArgumentNullException>(() => tokenizer.IndexOfTokenCount(null!, maxTokenCount: 10, out _, out _));
Assert.Throws<ArgumentNullException>(() => tokenizer.LastIndexOfTokenCount(null!, maxTokenCount: 10, out _, out _));
} }
} }
} }