Tokenizer's APIs Update (#7128)
* Tokenizer's APIs Update * Address the feedback * Address the feedback and use the new TestTokenizers package
This commit is contained in:
Родитель
fac1e1018b
Коммит
72cfdf611a
|
@ -41,7 +41,7 @@
|
||||||
<MicrosoftDotNetInteractiveVersion>1.0.0-beta.23509.3</MicrosoftDotNetInteractiveVersion>
|
<MicrosoftDotNetInteractiveVersion>1.0.0-beta.23509.3</MicrosoftDotNetInteractiveVersion>
|
||||||
<MicrosoftMLOnnxRuntimeVersion>1.16.3</MicrosoftMLOnnxRuntimeVersion>
|
<MicrosoftMLOnnxRuntimeVersion>1.16.3</MicrosoftMLOnnxRuntimeVersion>
|
||||||
<MlNetMklDepsVersion>0.0.0.12</MlNetMklDepsVersion>
|
<MlNetMklDepsVersion>0.0.0.12</MlNetMklDepsVersion>
|
||||||
<!--
|
<!--
|
||||||
@("inteltbb.devel", "win", "2021.7.1.15305")
|
@("inteltbb.devel", "win", "2021.7.1.15305")
|
||||||
-->
|
-->
|
||||||
<OneDalPkgVersion Condition="'$(OS)' == 'Windows_NT'">2023.0.0.23189</OneDalPkgVersion>
|
<OneDalPkgVersion Condition="'$(OS)' == 'Windows_NT'">2023.0.0.23189</OneDalPkgVersion>
|
||||||
|
@ -87,6 +87,7 @@
|
||||||
<MicrosoftMLTensorFlowTestModelsVersion>0.0.13-test</MicrosoftMLTensorFlowTestModelsVersion>
|
<MicrosoftMLTensorFlowTestModelsVersion>0.0.13-test</MicrosoftMLTensorFlowTestModelsVersion>
|
||||||
<MicrosoftMLTestDatabasesVersion>0.0.6-test</MicrosoftMLTestDatabasesVersion>
|
<MicrosoftMLTestDatabasesVersion>0.0.6-test</MicrosoftMLTestDatabasesVersion>
|
||||||
<MicrosoftMLTestModelsVersion>0.0.7-test</MicrosoftMLTestModelsVersion>
|
<MicrosoftMLTestModelsVersion>0.0.7-test</MicrosoftMLTestModelsVersion>
|
||||||
|
<MicrosoftMLTestTokenizersVersion>2.0.0-beta.24218.2</MicrosoftMLTestTokenizersVersion>
|
||||||
<SystemDataSqlClientVersion>4.8.6</SystemDataSqlClientVersion>
|
<SystemDataSqlClientVersion>4.8.6</SystemDataSqlClientVersion>
|
||||||
<SystemDataSQLiteCoreVersion>1.0.118</SystemDataSQLiteCoreVersion>
|
<SystemDataSQLiteCoreVersion>1.0.118</SystemDataSQLiteCoreVersion>
|
||||||
<XunitCombinatorialVersion>1.2.7</XunitCombinatorialVersion>
|
<XunitCombinatorialVersion>1.2.7</XunitCombinatorialVersion>
|
||||||
|
|
|
@ -1,152 +0,0 @@
|
||||||
// Licensed to the .NET Foundation under one or more agreements.
|
|
||||||
// The .NET Foundation licenses this file to you under the MIT license.
|
|
||||||
// See the LICENSE file in the project root for more information.
|
|
||||||
|
|
||||||
using System;
|
|
||||||
using System.Collections.Generic;
|
|
||||||
using System.Text;
|
|
||||||
|
|
||||||
namespace Microsoft.ML.Tokenizers
|
|
||||||
{
|
|
||||||
/// <summary>
|
|
||||||
/// The Encoding represents the output of a Tokenizer.
|
|
||||||
/// </summary>
|
|
||||||
public sealed class EncodingResult
|
|
||||||
{
|
|
||||||
/// <summary>
|
|
||||||
/// Create a new object of the EncodingResult object.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="originalString">The list of tokens to merge.</param>
|
|
||||||
/// <param name="normalizedString">The list of tokens to merge.</param>
|
|
||||||
/// <param name="splits">The list of tokens to merge.</param>
|
|
||||||
/// <param name="offsetsMappedToOriginalString">Indicate whether the offsets is mapped to the original string or the normalized string.</param>
|
|
||||||
public EncodingResult(string originalString, string normalizedString, IEnumerable<Split> splits, bool offsetsMappedToOriginalString)
|
|
||||||
{
|
|
||||||
OriginalString = originalString;
|
|
||||||
NormalizedString = normalizedString;
|
|
||||||
Splits = splits;
|
|
||||||
OffsetsMappedToOriginalString = offsetsMappedToOriginalString;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Gets the original tokenized string.
|
|
||||||
/// </summary>
|
|
||||||
public string? OriginalString { get; }
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Gets the normalized form of the original string.
|
|
||||||
/// </summary>
|
|
||||||
public string? NormalizedString { get; }
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Gets the normalized form of the original string.
|
|
||||||
/// </summary>
|
|
||||||
public bool OffsetsMappedToOriginalString { get; }
|
|
||||||
|
|
||||||
internal IEnumerable<Split> Splits { get; }
|
|
||||||
private List<Token>? _tokens;
|
|
||||||
private List<string>? _tokensWords;
|
|
||||||
private List<int>? _ids;
|
|
||||||
private List<(int Index, int Length)>? _offsets;
|
|
||||||
|
|
||||||
internal void AddTokens(IReadOnlyList<Token> addedTokens)
|
|
||||||
{
|
|
||||||
if (_tokens is null)
|
|
||||||
{
|
|
||||||
_tokens = new(addedTokens);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
foreach (var token in addedTokens)
|
|
||||||
{
|
|
||||||
_tokens.Add(token);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Gets list of the tokens Ids.
|
|
||||||
/// The Ids are the main input to a Language Model. They are the token indices, the numerical representations that a LM understands.
|
|
||||||
/// </summary>
|
|
||||||
public IReadOnlyList<int> Ids
|
|
||||||
{
|
|
||||||
get
|
|
||||||
{
|
|
||||||
if (_ids is not null)
|
|
||||||
{
|
|
||||||
return _ids;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (_tokens is null)
|
|
||||||
{
|
|
||||||
return Array.Empty<int>();
|
|
||||||
}
|
|
||||||
|
|
||||||
_ids = new List<int>(_tokens.Count);
|
|
||||||
|
|
||||||
foreach (var token in _tokens)
|
|
||||||
{
|
|
||||||
_ids.Add(token.Id);
|
|
||||||
}
|
|
||||||
|
|
||||||
return _ids;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Gets the generated tokens. They are the string representation of the Ids.
|
|
||||||
/// </summary>
|
|
||||||
public IReadOnlyList<string> Tokens
|
|
||||||
{
|
|
||||||
get
|
|
||||||
{
|
|
||||||
if (_tokensWords is not null)
|
|
||||||
{
|
|
||||||
return _tokensWords;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (_tokens is null)
|
|
||||||
{
|
|
||||||
return Array.Empty<string>();
|
|
||||||
}
|
|
||||||
|
|
||||||
_tokensWords = new List<string>(_tokens.Count);
|
|
||||||
|
|
||||||
foreach (var token in _tokens)
|
|
||||||
{
|
|
||||||
_tokensWords.Add(token.Value);
|
|
||||||
}
|
|
||||||
|
|
||||||
return _tokensWords;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Gets The list of offsets. These offsets let's you slice the input string, and thus retrieve
|
|
||||||
/// the original part that led to producing the corresponding token.
|
|
||||||
/// </summary>
|
|
||||||
public IReadOnlyList<(int Index, int Length)> Offsets
|
|
||||||
{
|
|
||||||
get
|
|
||||||
{
|
|
||||||
if (_offsets is not null)
|
|
||||||
{
|
|
||||||
return _offsets;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (_tokens is null)
|
|
||||||
{
|
|
||||||
return Array.Empty<(int, int)>();
|
|
||||||
}
|
|
||||||
|
|
||||||
_offsets = new List<(int Index, int Length)>(_tokens.Count);
|
|
||||||
|
|
||||||
foreach (var token in _tokens)
|
|
||||||
{
|
|
||||||
_offsets.Add(token.Offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
return _offsets;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -17,13 +17,15 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Represent the Byte Pair Encoding model.
|
/// Represent the Byte Pair Encoding model.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public sealed class Bpe : Model
|
public sealed class Bpe : Tokenizer
|
||||||
{
|
{
|
||||||
/// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model.
|
/// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model.
|
||||||
|
|
||||||
private const int MaxWordLengthToCache = 15;
|
private const int MaxWordLengthToCache = 15;
|
||||||
private string? _unknownToken;
|
private string? _unknownToken;
|
||||||
private int? _unknownTokenId;
|
private int? _unknownTokenId;
|
||||||
|
private readonly PreTokenizer? _preTokenizer;
|
||||||
|
private readonly Normalizer? _normalizer;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Gets or Sets unknown token. The unknown token to be used when we encounter an unknown char
|
/// Gets or Sets unknown token. The unknown token to be used when we encounter an unknown char
|
||||||
|
@ -74,13 +76,15 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="vocabFile">The JSON file path containing the dictionary of string keys and their ids.</param>
|
/// <param name="vocabFile">The JSON file path containing the dictionary of string keys and their ids.</param>
|
||||||
/// <param name="mergesFile">The file path containing the tokens's pairs list.</param>
|
/// <param name="mergesFile">The file path containing the tokens's pairs list.</param>
|
||||||
|
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
|
||||||
|
/// <param name="normalizer">The normalizer to use.</param>
|
||||||
/// <param name="unknownToken"> The unknown token to be used by the model.</param>
|
/// <param name="unknownToken"> The unknown token to be used by the model.</param>
|
||||||
/// <param name="continuingSubwordPrefix">The prefix to attach to sub-word units that don’t represent a beginning of word.</param>
|
/// <param name="continuingSubwordPrefix">The prefix to attach to sub-word units that don’t represent a beginning of word.</param>
|
||||||
/// <param name="endOfWordSuffix">The suffix to attach to sub-word units that represent an end of word.</param>
|
/// <param name="endOfWordSuffix">The suffix to attach to sub-word units that represent an end of word.</param>
|
||||||
/// <param name="fuseUnknownTokens">Indicate whether allowing multiple unknown tokens get fused.</param>
|
/// <param name="fuseUnknownTokens">Indicate whether allowing multiple unknown tokens get fused.</param>
|
||||||
public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) :
|
public Bpe(string vocabFile, string? mergesFile, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) :
|
||||||
this(vocabFile is null ? throw new ArgumentNullException(nameof(vocabFile)) : File.Open(vocabFile, FileMode.Open, FileAccess.Read),
|
this(vocabFile is null ? throw new ArgumentNullException(nameof(vocabFile)) : File.Open(vocabFile, FileMode.Open, FileAccess.Read),
|
||||||
mergesFile is null ? null : File.Open(mergesFile, FileMode.Open, FileAccess.Read), unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: true)
|
mergesFile is null ? null : File.Open(mergesFile, FileMode.Open, FileAccess.Read), preTokenizer, normalizer, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: true)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,16 +93,18 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="vocabStream">The JSON stream containing the dictionary of string keys and their ids.</param>
|
/// <param name="vocabStream">The JSON stream containing the dictionary of string keys and their ids.</param>
|
||||||
/// <param name="mergesStream">The stream containing the tokens's pairs list.</param>
|
/// <param name="mergesStream">The stream containing the tokens's pairs list.</param>
|
||||||
|
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
|
||||||
|
/// <param name="normalizer">The normalizer to use.</param>
|
||||||
/// <param name="unknownToken"> The unknown token to be used by the model.</param>
|
/// <param name="unknownToken"> The unknown token to be used by the model.</param>
|
||||||
/// <param name="continuingSubwordPrefix">The prefix to attach to sub-word units that don’t represent a beginning of word.</param>
|
/// <param name="continuingSubwordPrefix">The prefix to attach to sub-word units that don’t represent a beginning of word.</param>
|
||||||
/// <param name="endOfWordSuffix">The suffix to attach to sub-word units that represent an end of word.</param>
|
/// <param name="endOfWordSuffix">The suffix to attach to sub-word units that represent an end of word.</param>
|
||||||
/// <param name="fuseUnknownTokens">Indicate whether allowing multiple unknown tokens get fused.</param>
|
/// <param name="fuseUnknownTokens">Indicate whether allowing multiple unknown tokens get fused.</param>
|
||||||
public Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) :
|
public Bpe(Stream vocabStream, Stream? mergesStream, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) :
|
||||||
this(vocabStream, mergesStream, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: false)
|
this(vocabStream, mergesStream, preTokenizer, normalizer, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: false)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, string? continuingSubwordPrefix, string? endOfWordSuffix, bool fuseUnknownTokens, bool disposeStreams)
|
private Bpe(Stream vocabStream, Stream? mergesStream, PreTokenizer? preTokenizer, Normalizer? normalizer, string? unknownToken, string? continuingSubwordPrefix, string? endOfWordSuffix, bool fuseUnknownTokens, bool disposeStreams)
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
|
@ -110,6 +116,8 @@ namespace Microsoft.ML.Tokenizers
|
||||||
FuseUnknownTokens = fuseUnknownTokens;
|
FuseUnknownTokens = fuseUnknownTokens;
|
||||||
ContinuingSubwordPrefix = continuingSubwordPrefix;
|
ContinuingSubwordPrefix = continuingSubwordPrefix;
|
||||||
EndOfWordSuffix = endOfWordSuffix;
|
EndOfWordSuffix = endOfWordSuffix;
|
||||||
|
_preTokenizer = preTokenizer ?? WhiteSpace.Instance; // Default to WhiteSpace pre-tokenizer
|
||||||
|
_normalizer = normalizer;
|
||||||
|
|
||||||
(Dictionary<StringSpanOrdinalKey, int>? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream);
|
(Dictionary<StringSpanOrdinalKey, int>? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream);
|
||||||
_vocab = vocab1 ?? new Dictionary<StringSpanOrdinalKey, int>();
|
_vocab = vocab1 ?? new Dictionary<StringSpanOrdinalKey, int>();
|
||||||
|
@ -166,47 +174,320 @@ namespace Microsoft.ML.Tokenizers
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Encode a text string to a list of tokens.
|
/// Gets the PreTokenizer used by the Tokenizer.
|
||||||
|
/// </summary>
|
||||||
|
public override PreTokenizer? PreTokenizer => _preTokenizer;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Gets the Normalizer in use by the Tokenizer.
|
||||||
|
/// </summary>
|
||||||
|
public override Normalizer? Normalizer => _normalizer;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
/// <returns>The list of tokens generated from the text tokenization.</returns>
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text)
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
|
||||||
|
public override IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(text, Span<char>.Empty, out normalizedString, considerPreTokenization, considerNormalization);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
|
||||||
|
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(null, text, out normalizedString, considerPreTokenization, considerNormalization);
|
||||||
|
|
||||||
|
private IReadOnlyList<Token> Encode(string? text, ReadOnlySpan<char> textSpan, out string? normalizedString, bool considerPreTokenization, bool considerNormalization)
|
||||||
{
|
{
|
||||||
if (text.Length == 0)
|
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
|
||||||
{
|
{
|
||||||
return EmptyTokensList;
|
normalizedString = null;
|
||||||
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
return EncodeWithCache(text);
|
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
|
||||||
|
|
||||||
|
List<Token> tokens = new();
|
||||||
|
PriorityQueue<Merge>? priorityQueue = null;
|
||||||
|
|
||||||
|
if (splits is not null)
|
||||||
|
{
|
||||||
|
foreach ((int Offset, int Length) split in splits)
|
||||||
|
{
|
||||||
|
EncodeWithCache(textSpanToEncode.Slice(split.Offset, split.Length), tokens, split.Offset, ref priorityQueue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
EncodeWithCache(textSpanToEncode, tokens, 0, ref priorityQueue);
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Encode a split text string to a list of Ids and add them to the accumulatedIds list.
|
/// Encodes input text to token Ids.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public override IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> EncodeToIds(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out _, out _);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> EncodeToIds(null, text, considerPreTokenization, considerNormalization, out _, out _);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids up to maximum number of tokens.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
public override int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsWithCache(text, accumulatedIds, maxTokens, out textLength);
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public override IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> EncodeToIds(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids up to maximum number of tokens.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> EncodeToIds(null, text, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
|
||||||
|
|
||||||
|
private IReadOnlyList<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
|
||||||
|
{
|
||||||
|
if (maxTokenCount <= 0)
|
||||||
|
{
|
||||||
|
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
|
||||||
|
{
|
||||||
|
textLength = 0;
|
||||||
|
normalizedString = null;
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
|
||||||
|
|
||||||
|
List<int> ids = new();
|
||||||
|
PriorityQueue<Merge>? priorityQueue = null;
|
||||||
|
|
||||||
|
if (splits is not null)
|
||||||
|
{
|
||||||
|
textLength = 0;
|
||||||
|
foreach ((int Offset, int Length) split in splits)
|
||||||
|
{
|
||||||
|
EncodeToIdsWithCache(textSpanToEncode.Slice(split.Offset, split.Length), ids, maxTokenCount - ids.Count, out int length, ref priorityQueue);
|
||||||
|
textLength = split.Offset + length;
|
||||||
|
|
||||||
|
if (length < split.Length || ids.Count >= maxTokenCount)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
EncodeToIdsWithCache(textSpanToEncode, ids, maxTokenCount, out textLength, ref priorityQueue);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ids;
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Get the number of tokens that the input text will be encoded to.
|
/// Get the number of tokens that the input text will be encoded to.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
|
||||||
public override int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsWithCache(text, null, maxTokens, out textLength);
|
public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> CountTokens(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out _, out _);
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Get the number of tokens that the input text will be encoded to.
|
/// Get the number of tokens that the input text will be encoded to.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
|
||||||
public override int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue) => EncodeToIdsFromEndWithCache(text, null, maxTokens, out textIndex);
|
public override int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> CountTokens(null, text, considerPreTokenization, considerNormalization, out _, out _);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
|
||||||
|
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
|
||||||
|
/// </returns>
|
||||||
|
public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
{
|
||||||
|
tokenCount = CountTokens(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
|
||||||
|
return textLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
|
||||||
|
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
|
||||||
|
/// </returns>
|
||||||
|
public override int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
{
|
||||||
|
tokenCount = CountTokens(null, text, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
|
||||||
|
return textLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int CountTokens(string? text, ReadOnlySpan<char> textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
|
||||||
|
{
|
||||||
|
if (maxTokenCount <= 0)
|
||||||
|
{
|
||||||
|
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
|
||||||
|
}
|
||||||
|
|
||||||
|
textLength = 0;
|
||||||
|
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
|
||||||
|
{
|
||||||
|
normalizedString = null;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
|
||||||
|
|
||||||
|
PriorityQueue<Merge>? priorityQueue = null;
|
||||||
|
int count = 0;
|
||||||
|
if (splits is not null)
|
||||||
|
{
|
||||||
|
foreach ((int Offset, int Length) split in splits)
|
||||||
|
{
|
||||||
|
count += EncodeToIdsWithCache(textSpanToEncode.Slice(split.Offset, split.Length), null, maxTokenCount - count, out int length, ref priorityQueue);
|
||||||
|
textLength = split.Offset + length;
|
||||||
|
|
||||||
|
if (length < split.Length || count >= maxTokenCount)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
count = EncodeToIdsWithCache(textSpanToEncode, null, maxTokenCount, out textLength, ref priorityQueue);
|
||||||
|
}
|
||||||
|
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if normalization is enabled;
|
||||||
|
/// conversely, if all tokens fit, the result will be 0.
|
||||||
|
/// </returns>
|
||||||
|
public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> LastIndexOf(text, Span<char>.Empty, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="normalizedString"/>; conversely, if all tokens fit, the result will be 0.
|
||||||
|
/// </returns>
|
||||||
|
public override int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> LastIndexOf(null, text, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
|
||||||
|
|
||||||
|
private int LastIndexOf(string? text, ReadOnlySpan<char> textSpan, int maxTokenCount, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int tokenCount)
|
||||||
|
{
|
||||||
|
if (maxTokenCount <= 0)
|
||||||
|
{
|
||||||
|
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
|
||||||
|
{
|
||||||
|
normalizedString = null;
|
||||||
|
tokenCount = 0;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
|
||||||
|
|
||||||
|
PriorityQueue<Merge>? priorityQueue = null;
|
||||||
|
|
||||||
|
if (splits is not null)
|
||||||
|
{
|
||||||
|
tokenCount = 0;
|
||||||
|
foreach ((int Offset, int Length) split in splits.Reverse())
|
||||||
|
{
|
||||||
|
tokenCount += EncodeToIdsFromEndWithCache(textSpanToEncode.Slice(split.Offset, split.Length), null, maxTokenCount - tokenCount, out int textIndex, ref priorityQueue);
|
||||||
|
if (textIndex > 0 || tokenCount >= maxTokenCount)
|
||||||
|
{
|
||||||
|
return split.Offset + textIndex;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
tokenCount = EncodeToIdsFromEndWithCache(textSpanToEncode, null, maxTokenCount, out int textLength, ref priorityQueue);
|
||||||
|
return textLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Map the token to encoded Id.
|
/// Map the token to encoded Id.
|
||||||
|
@ -383,7 +664,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
internal Word MergeWord(ReadOnlySpan<char> w)
|
internal Word MergeWord(ReadOnlySpan<char> w, ref PriorityQueue<Merge>? priorityQueue)
|
||||||
{
|
{
|
||||||
Word word = Word.WithCapacity(w.Length);
|
Word word = Word.WithCapacity(w.Length);
|
||||||
(int Id, int Len)? unk = null;
|
(int Id, int Len)? unk = null;
|
||||||
|
@ -494,23 +775,24 @@ namespace Microsoft.ML.Tokenizers
|
||||||
word.Add(unk.Value.Id, unk.Value.Len);
|
word.Add(unk.Value.Id, unk.Value.Len);
|
||||||
}
|
}
|
||||||
|
|
||||||
word.MergeAll(Merges, Dropout);
|
word.MergeAll(Merges, Dropout, ref priorityQueue);
|
||||||
return word;
|
return word;
|
||||||
}
|
}
|
||||||
|
|
||||||
internal List<Token> WordToTokens(ref Word word) => word.ToTokens(VocabReverse);
|
internal void WordToTokens(ref Word word, List<Token> tokens, int offset) => word.ToTokens(VocabReverse, tokens, offset);
|
||||||
|
|
||||||
internal List<Token> EncodeWithCache(ReadOnlySpan<char> text)
|
internal void EncodeWithCache(ReadOnlySpan<char> text, List<Token> tokens, int offset, ref PriorityQueue<Merge>? priorityQueue)
|
||||||
{
|
{
|
||||||
Word word;
|
Word word;
|
||||||
if (Cache is not null)
|
if (Cache is not null)
|
||||||
{
|
{
|
||||||
if (Cache.TryGetValue(text, out word))
|
if (Cache.TryGetValue(text, out word))
|
||||||
{
|
{
|
||||||
return WordToTokens(ref word);
|
WordToTokens(ref word, tokens, offset);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
word = MergeWord(text);
|
word = MergeWord(text, ref priorityQueue);
|
||||||
|
|
||||||
if (text.Length <= MaxWordLengthToCache)
|
if (text.Length <= MaxWordLengthToCache)
|
||||||
{
|
{
|
||||||
|
@ -519,15 +801,15 @@ namespace Microsoft.ML.Tokenizers
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
word = MergeWord(text);
|
word = MergeWord(text, ref priorityQueue);
|
||||||
}
|
}
|
||||||
|
|
||||||
return WordToTokens(ref word);
|
WordToTokens(ref word, tokens, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
internal int WordToIds(ref Word word, IList<int>? accumulatedIds, out int textLength, int fullTextLength, int maxTokens)
|
internal int WordToIds(ref Word word, IList<int>? accumulatedIds, out int textLength, int fullTextLength, int maxTokens)
|
||||||
{
|
{
|
||||||
if (word.SymbolsCount < maxTokens)
|
if (word.SymbolsCount <= maxTokens)
|
||||||
{
|
{
|
||||||
textLength = fullTextLength;
|
textLength = fullTextLength;
|
||||||
if (accumulatedIds is not null)
|
if (accumulatedIds is not null)
|
||||||
|
@ -548,7 +830,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
|
|
||||||
internal int WordToIdsFromEnd(ref Word word, IList<int>? accumulatedIds, out int textIndex, int fullTextLength, int maxTokens)
|
internal int WordToIdsFromEnd(ref Word word, IList<int>? accumulatedIds, out int textIndex, int fullTextLength, int maxTokens)
|
||||||
{
|
{
|
||||||
if (word.SymbolsCount < maxTokens)
|
if (word.SymbolsCount <= maxTokens)
|
||||||
{
|
{
|
||||||
textIndex = 0;
|
textIndex = 0;
|
||||||
if (accumulatedIds is not null)
|
if (accumulatedIds is not null)
|
||||||
|
@ -567,7 +849,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
return word.CountIdsUpToMaxFromEnd(maxTokens, fullTextLength, out textIndex);
|
return word.CountIdsUpToMaxFromEnd(maxTokens, fullTextLength, out textIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
internal int EncodeToIdsWithCache(ReadOnlySpan<char> text, IList<int>? accumulatedIds, int maxTokens, out int textLength)
|
private int EncodeToIdsWithCache(ReadOnlySpan<char> text, List<int>? accumulatedIds, int maxTokens, out int textLength, ref PriorityQueue<Merge>? priorityQueue)
|
||||||
{
|
{
|
||||||
Word word;
|
Word word;
|
||||||
|
|
||||||
|
@ -578,7 +860,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
return WordToIds(ref hit, accumulatedIds, out textLength, text.Length, maxTokens);
|
return WordToIds(ref hit, accumulatedIds, out textLength, text.Length, maxTokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
word = MergeWord(text);
|
word = MergeWord(text, ref priorityQueue);
|
||||||
|
|
||||||
if (text.Length <= MaxWordLengthToCache)
|
if (text.Length <= MaxWordLengthToCache)
|
||||||
{
|
{
|
||||||
|
@ -587,13 +869,13 @@ namespace Microsoft.ML.Tokenizers
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
word = MergeWord(text);
|
word = MergeWord(text, ref priorityQueue);
|
||||||
}
|
}
|
||||||
|
|
||||||
return WordToIds(ref word, accumulatedIds, out textLength, text.Length, maxTokens);
|
return WordToIds(ref word, accumulatedIds, out textLength, text.Length, maxTokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
internal int EncodeToIdsFromEndWithCache(ReadOnlySpan<char> text, IList<int>? accumulatedIds, int maxTokens, out int textIndex)
|
internal int EncodeToIdsFromEndWithCache(ReadOnlySpan<char> text, IList<int>? accumulatedIds, int maxTokens, out int textIndex, ref PriorityQueue<Merge>? priorityQueue)
|
||||||
{
|
{
|
||||||
Word word;
|
Word word;
|
||||||
|
|
||||||
|
@ -604,7 +886,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
return WordToIdsFromEnd(ref hit, accumulatedIds, out textIndex, text.Length, maxTokens);
|
return WordToIdsFromEnd(ref hit, accumulatedIds, out textIndex, text.Length, maxTokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
word = MergeWord(text);
|
word = MergeWord(text, ref priorityQueue);
|
||||||
|
|
||||||
if (text.Length <= MaxWordLengthToCache)
|
if (text.Length <= MaxWordLengthToCache)
|
||||||
{
|
{
|
||||||
|
@ -613,12 +895,10 @@ namespace Microsoft.ML.Tokenizers
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
word = MergeWord(text);
|
word = MergeWord(text, ref priorityQueue);
|
||||||
}
|
}
|
||||||
|
|
||||||
return WordToIdsFromEnd(ref word, accumulatedIds, out textIndex, text.Length, maxTokens);
|
return WordToIdsFromEnd(ref word, accumulatedIds, out textIndex, text.Length, maxTokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
internal static readonly List<Token> EmptyTokensList = new();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ using System.Collections.Generic;
|
||||||
using System.Diagnostics;
|
using System.Diagnostics;
|
||||||
using System.IO;
|
using System.IO;
|
||||||
using System.Linq;
|
using System.Linq;
|
||||||
|
using System.Text;
|
||||||
using System.Text.Json;
|
using System.Text.Json;
|
||||||
|
|
||||||
namespace Microsoft.ML.Tokenizers
|
namespace Microsoft.ML.Tokenizers
|
||||||
|
@ -15,7 +16,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Represent the Byte Pair Encoding model.
|
/// Represent the Byte Pair Encoding model.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public sealed class EnglishRoberta : Model
|
public sealed class EnglishRoberta : Tokenizer
|
||||||
{
|
{
|
||||||
private readonly HighestOccurrenceMapping _vocabIdToHighestOccurrence;
|
private readonly HighestOccurrenceMapping _vocabIdToHighestOccurrence;
|
||||||
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;
|
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;
|
||||||
|
@ -26,6 +27,8 @@ namespace Microsoft.ML.Tokenizers
|
||||||
private readonly IReadOnlyDictionary<char, char> _unicodeToByte;
|
private readonly IReadOnlyDictionary<char, char> _unicodeToByte;
|
||||||
private readonly string[] _charToString;
|
private readonly string[] _charToString;
|
||||||
private readonly StringSpanOrdinalKeyCache<List<Token>> _cache;
|
private readonly StringSpanOrdinalKeyCache<List<Token>> _cache;
|
||||||
|
private readonly PreTokenizer? _preTokenizer;
|
||||||
|
private readonly Normalizer? _normalizer;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Indicate if want to filter the unsupported characters during the decoding.
|
/// Indicate if want to filter the unsupported characters during the decoding.
|
||||||
|
@ -38,29 +41,51 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// <param name="vocabularyPath">The JSON file path containing the dictionary of string keys and their ids.</param>
|
/// <param name="vocabularyPath">The JSON file path containing the dictionary of string keys and their ids.</param>
|
||||||
/// <param name="mergePath">The file path containing the tokens's pairs list.</param>
|
/// <param name="mergePath">The file path containing the tokens's pairs list.</param>
|
||||||
/// <param name="highestOccurrenceMappingPath">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
|
/// <param name="highestOccurrenceMappingPath">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
|
||||||
|
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
|
||||||
|
/// <param name="normalizer">The normalizer to use.</param>
|
||||||
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
|
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
|
||||||
public EnglishRoberta(string vocabularyPath, string mergePath, string highestOccurrenceMappingPath, bool filterUnsupportedChars = true)
|
public EnglishRoberta(string vocabularyPath, string mergePath, string highestOccurrenceMappingPath, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, bool filterUnsupportedChars = true) :
|
||||||
|
this(vocabularyPath is null ? throw new ArgumentNullException(nameof(vocabularyPath)) : File.OpenRead(vocabularyPath),
|
||||||
|
mergePath is null ? throw new ArgumentNullException(nameof(mergePath)) : File.OpenRead(mergePath),
|
||||||
|
highestOccurrenceMappingPath is null ? throw new ArgumentNullException(nameof(highestOccurrenceMappingPath)) : File.OpenRead(highestOccurrenceMappingPath),
|
||||||
|
preTokenizer, normalizer, filterUnsupportedChars, true)
|
||||||
{
|
{
|
||||||
if (vocabularyPath is null)
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Construct tokenizer's model object to use with the English Robert model.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="vocabularyStream">The stream of a JSON file containing the dictionary of string keys and their ids.</param>
|
||||||
|
/// <param name="mergeStream">The stream of a file containing the tokens's pairs list.</param>
|
||||||
|
/// <param name="highestOccurrenceMappingStream">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
|
||||||
|
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
|
||||||
|
/// <param name="normalizer">The normalizer to use.</param>
|
||||||
|
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
|
||||||
|
public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, bool filterUnsupportedChars = true) :
|
||||||
|
this(vocabularyStream, mergeStream, highestOccurrenceMappingStream, preTokenizer, normalizer, filterUnsupportedChars, false)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream, PreTokenizer? preTokenizer, Normalizer? normalizer, bool filterUnsupportedChars, bool disposeStream)
|
||||||
|
{
|
||||||
|
if (vocabularyStream is null)
|
||||||
{
|
{
|
||||||
throw new ArgumentNullException(nameof(vocabularyPath));
|
throw new ArgumentNullException(nameof(vocabularyStream));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mergePath is null)
|
if (mergeStream is null)
|
||||||
{
|
{
|
||||||
throw new ArgumentNullException(nameof(mergePath));
|
throw new ArgumentNullException(nameof(mergeStream));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (highestOccurrenceMappingPath is null)
|
if (highestOccurrenceMappingStream is null)
|
||||||
{
|
{
|
||||||
throw new ArgumentNullException(nameof(highestOccurrenceMappingPath));
|
throw new ArgumentNullException(nameof(highestOccurrenceMappingStream));
|
||||||
}
|
}
|
||||||
|
|
||||||
FilterUnsupportedChars = filterUnsupportedChars;
|
FilterUnsupportedChars = filterUnsupportedChars;
|
||||||
|
_preTokenizer = preTokenizer;
|
||||||
using Stream vocabularyStream = File.OpenRead(vocabularyPath);
|
_normalizer = normalizer;
|
||||||
using Stream mergeStream = File.OpenRead(mergePath);
|
|
||||||
using Stream highestOccurrenceMappingStream = File.OpenRead(highestOccurrenceMappingPath);
|
|
||||||
|
|
||||||
// vocabularyPath like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
|
// vocabularyPath like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
|
||||||
// merge file like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
|
// merge file like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
|
||||||
|
@ -79,48 +104,24 @@ namespace Microsoft.ML.Tokenizers
|
||||||
|
|
||||||
_unicodeToByte = _byteToUnicode.Reverse();
|
_unicodeToByte = _byteToUnicode.Reverse();
|
||||||
_cache = new StringSpanOrdinalKeyCache<List<Token>>();
|
_cache = new StringSpanOrdinalKeyCache<List<Token>>();
|
||||||
|
|
||||||
|
if (disposeStream)
|
||||||
|
{
|
||||||
|
vocabularyStream.Dispose();
|
||||||
|
mergeStream.Dispose();
|
||||||
|
highestOccurrenceMappingStream.Dispose();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Construct tokenizer's model object to use with the English Robert model.
|
/// Gets the PreTokenizer used by the Tokenizer.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="vocabularyStream">The stream of a JSON file containing the dictionary of string keys and their ids.</param>
|
public override PreTokenizer? PreTokenizer => _preTokenizer;
|
||||||
/// <param name="mergeStream">The stream of a file containing the tokens's pairs list.</param>
|
|
||||||
/// <param name="highestOccurrenceMappingStream">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
|
|
||||||
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
|
|
||||||
public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream, bool filterUnsupportedChars = true)
|
|
||||||
{
|
|
||||||
if (vocabularyStream is null)
|
|
||||||
{
|
|
||||||
throw new ArgumentNullException(nameof(vocabularyStream));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mergeStream is null)
|
/// <summary>
|
||||||
{
|
/// Gets the Normalizer in use by the Tokenizer.
|
||||||
throw new ArgumentNullException(nameof(mergeStream));
|
/// </summary>
|
||||||
}
|
public override Normalizer? Normalizer => _normalizer;
|
||||||
|
|
||||||
if (highestOccurrenceMappingStream is null)
|
|
||||||
{
|
|
||||||
throw new ArgumentNullException(nameof(highestOccurrenceMappingStream));
|
|
||||||
}
|
|
||||||
|
|
||||||
FilterUnsupportedChars = filterUnsupportedChars;
|
|
||||||
|
|
||||||
_vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingStream);
|
|
||||||
_vocab = GetVocabulary(vocabularyStream);
|
|
||||||
_vocabReverse = _vocab.ReverseSorted();
|
|
||||||
_mergeRanks = GetMergeRanks(mergeStream);
|
|
||||||
int maxCharValue = GetByteToUnicode(out _byteToUnicode);
|
|
||||||
_charToString = new string[maxCharValue];
|
|
||||||
for (char c = (char)0; c < (char)maxCharValue; c++)
|
|
||||||
{
|
|
||||||
_charToString[c] = c.ToString();
|
|
||||||
}
|
|
||||||
|
|
||||||
_unicodeToByte = _byteToUnicode.Reverse();
|
|
||||||
_cache = new StringSpanOrdinalKeyCache<List<Token>>();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Gets the dictionary mapping tokens to Ids.
|
/// Gets the dictionary mapping tokens to Ids.
|
||||||
|
@ -167,16 +168,64 @@ namespace Microsoft.ML.Tokenizers
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
|
||||||
|
public override IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(text, Span<char>.Empty, out normalizedString, considerPreTokenization, considerNormalization);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
|
||||||
|
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(null, text, out normalizedString, considerPreTokenization, considerNormalization);
|
||||||
|
|
||||||
|
private IReadOnlyList<Token> Encode(string? text, ReadOnlySpan<char> textSpan, out string? normalizedString, bool considerPreTokenization, bool considerNormalization)
|
||||||
|
{
|
||||||
|
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
|
||||||
|
{
|
||||||
|
normalizedString = null;
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
|
||||||
|
|
||||||
|
if (splits is not null)
|
||||||
|
{
|
||||||
|
List<Token> tokens = new();
|
||||||
|
foreach ((int Offset, int Length) split in splits)
|
||||||
|
{
|
||||||
|
foreach (Token t in EncodeInternal(textSpanToEncode.Slice(split.Offset, split.Length)))
|
||||||
|
{
|
||||||
|
tokens.Add(new Token(t.Id, t.Value, (split.Offset + t.Offset.Index, t.Offset.Length)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return EncodeInternal(textSpanToEncode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Encode a text string to a list of tokens.
|
/// Encode a text string to a list of tokens.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
/// <returns>The list of tokens generated from the text tokenization.</returns>
|
/// <returns>The list of tokens generated from the text tokenization.</returns>
|
||||||
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text)
|
private IReadOnlyList<Token> EncodeInternal(ReadOnlySpan<char> text)
|
||||||
{
|
{
|
||||||
if (text.IsEmpty)
|
if (text.IsEmpty)
|
||||||
{
|
{
|
||||||
return Bpe.EmptyTokensList;
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
char[] token = ArrayPool<char>.Shared.Rent(text.Length);
|
char[] token = ArrayPool<char>.Shared.Rent(text.Length);
|
||||||
|
@ -197,7 +246,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
{
|
{
|
||||||
ArrayPool<char>.Shared.Return(token);
|
ArrayPool<char>.Shared.Return(token);
|
||||||
ArrayPool<int>.Shared.Return(indexMapping);
|
ArrayPool<int>.Shared.Return(indexMapping);
|
||||||
return Array.Empty<Token>();
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (_cache.TryGetValue(text, out List<Token>? hit))
|
if (_cache.TryGetValue(text, out List<Token>? hit))
|
||||||
|
@ -215,32 +264,257 @@ namespace Microsoft.ML.Tokenizers
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Encode a split text string to a list of Ids and add them to the accumulatedIds list.
|
/// Encodes input text to token Ids.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public override IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> EncodeToIds(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out _, out _);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> EncodeToIds(null, text, considerPreTokenization, considerNormalization, out _, out _);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids up to maximum number of tokens.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
public override int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsInternal(text, accumulatedIds, out textLength, maxTokens);
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public override IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> EncodeToIds(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids up to maximum number of tokens.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> EncodeToIds(null, text, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
|
||||||
|
|
||||||
|
private IReadOnlyList<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
|
||||||
|
{
|
||||||
|
if (maxTokenCount <= 0)
|
||||||
|
{
|
||||||
|
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
|
||||||
|
{
|
||||||
|
textLength = 0;
|
||||||
|
normalizedString = null;
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
|
||||||
|
|
||||||
|
List<int> ids = new();
|
||||||
|
if (splits is not null)
|
||||||
|
{
|
||||||
|
textLength = 0;
|
||||||
|
foreach ((int Offset, int Length) split in splits)
|
||||||
|
{
|
||||||
|
EncodeToIdsInternal(textSpanToEncode.Slice(split.Offset, split.Length), ids, out int length, maxTokenCount - ids.Count);
|
||||||
|
textLength = split.Offset + length;
|
||||||
|
|
||||||
|
if (length < split.Length || ids.Count >= maxTokenCount)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
EncodeToIdsInternal(textSpanToEncode, ids, out textLength, maxTokenCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ids;
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Get the number of tokens that the input text will be encoded to.
|
/// Get the number of tokens that the input text will be encoded to.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
|
||||||
public override int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsInternal(text, null, out textLength, maxTokens);
|
public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> CountTokens(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out _, out _);
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Get the number of tokens that the input text will be encoded to.
|
/// Get the number of tokens that the input text will be encoded to.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
|
||||||
public override int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue) => EncodeToIdsFromEndInternal(text, null, out textIndex, maxTokens);
|
public override int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> CountTokens(null, text, considerPreTokenization, considerNormalization, out _, out _);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
|
||||||
|
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
|
||||||
|
/// </returns>
|
||||||
|
public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
{
|
||||||
|
tokenCount = CountTokens(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
|
||||||
|
return textLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
|
||||||
|
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
|
||||||
|
/// </returns>
|
||||||
|
public override int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
{
|
||||||
|
tokenCount = CountTokens(null, text, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
|
||||||
|
return textLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int CountTokens(string? text, ReadOnlySpan<char> textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
|
||||||
|
{
|
||||||
|
if (maxTokenCount <= 0)
|
||||||
|
{
|
||||||
|
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
|
||||||
|
}
|
||||||
|
|
||||||
|
textLength = 0;
|
||||||
|
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
|
||||||
|
{
|
||||||
|
normalizedString = null;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
|
||||||
|
|
||||||
|
int count = 0;
|
||||||
|
if (splits is not null)
|
||||||
|
{
|
||||||
|
foreach ((int Offset, int Length) split in splits)
|
||||||
|
{
|
||||||
|
count += EncodeToIdsInternal(textSpanToEncode.Slice(split.Offset, split.Length), null, out int length, maxTokenCount - count);
|
||||||
|
textLength = split.Offset + length;
|
||||||
|
|
||||||
|
if (length < split.Length || count >= maxTokenCount)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
count += EncodeToIdsInternal(textSpanToEncode, null, out textLength, maxTokenCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if normalization is enabled;
|
||||||
|
/// conversely, if all tokens fit, the result will be 0.
|
||||||
|
/// </returns>
|
||||||
|
public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> LastIndexOf(text, Span<char>.Empty, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="normalizedString"/>; conversely, if all tokens fit, the result will be 0.
|
||||||
|
/// </returns>
|
||||||
|
public override int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> LastIndexOf(null, text, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
|
||||||
|
|
||||||
|
private int LastIndexOf(string? text, ReadOnlySpan<char> textSpan, int maxTokenCount, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int tokenCount)
|
||||||
|
{
|
||||||
|
if (maxTokenCount <= 0)
|
||||||
|
{
|
||||||
|
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
|
||||||
|
{
|
||||||
|
normalizedString = null;
|
||||||
|
tokenCount = 0;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
|
||||||
|
|
||||||
|
if (splits is not null)
|
||||||
|
{
|
||||||
|
tokenCount = 0;
|
||||||
|
foreach ((int Offset, int Length) split in splits.Reverse())
|
||||||
|
{
|
||||||
|
tokenCount += EncodeToIdsFromEndInternal(textSpanToEncode.Slice(split.Offset, split.Length), null, out int textIndex, maxTokenCount - tokenCount);
|
||||||
|
if (textIndex > 0 || tokenCount >= maxTokenCount)
|
||||||
|
{
|
||||||
|
return split.Offset + textIndex;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
tokenCount = EncodeToIdsFromEndInternal(textSpanToEncode, null, out int textLength, maxTokenCount);
|
||||||
|
return textLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
private int EncodeToIdsResult(List<Token> tokens, IList<int>? accumulatedIds, int maxTokens, int fullTextLength, out int textLength)
|
private int EncodeToIdsResult(List<Token> tokens, IList<int>? accumulatedIds, int maxTokens, int fullTextLength, out int textLength)
|
||||||
{
|
{
|
||||||
|
@ -347,7 +621,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
{
|
{
|
||||||
ArrayPool<char>.Shared.Return(token);
|
ArrayPool<char>.Shared.Return(token);
|
||||||
ArrayPool<int>.Shared.Return(indexMapping);
|
ArrayPool<int>.Shared.Return(indexMapping);
|
||||||
textLength = 0;
|
textLength = text.Length;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -390,7 +664,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
{
|
{
|
||||||
ArrayPool<char>.Shared.Return(token);
|
ArrayPool<char>.Shared.Return(token);
|
||||||
ArrayPool<int>.Shared.Return(indexMapping);
|
ArrayPool<int>.Shared.Return(indexMapping);
|
||||||
textIndex = text.Length;
|
textIndex = 0;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -410,7 +684,32 @@ namespace Microsoft.ML.Tokenizers
|
||||||
public override int? MapTokenToId(ReadOnlySpan<char> token) => _vocab.TryGetValue(token, out int value) ? value : null;
|
public override int? MapTokenToId(ReadOnlySpan<char> token) => _vocab.TryGetValue(token, out int value) ? value : null;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Convert a list of tokens Ids to highest occurrence rankings.
|
/// Decode the given ids, back to a String.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="ids">The list of ids that we want to decode.</param>
|
||||||
|
/// <returns>The decoded string.</returns>
|
||||||
|
public override string? Decode(IEnumerable<int> ids)
|
||||||
|
{
|
||||||
|
if (ids is null)
|
||||||
|
{
|
||||||
|
throw new ArgumentNullException(nameof(ids));
|
||||||
|
}
|
||||||
|
|
||||||
|
ValueStringBuilder sb = new ValueStringBuilder();
|
||||||
|
|
||||||
|
foreach (int id in ids)
|
||||||
|
{
|
||||||
|
if (MapIdToToken(id) is string s)
|
||||||
|
{
|
||||||
|
sb.Append(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Convert a list of token Ids to highest occurrence rankings.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="ids">The Ids list to map to the high occurrence rank.</param>
|
/// <param name="ids">The Ids list to map to the high occurrence rank.</param>
|
||||||
/// <returns>The list of ranks mapped from the list of Ids.</returns>
|
/// <returns>The list of ranks mapped from the list of Ids.</returns>
|
||||||
|
@ -432,7 +731,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Convert a list of tokens Ids to highest occurrence values.
|
/// Convert a list of token Ids to highest occurrence values.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="ids">The Ids list to map to the high occurrence values.</param>
|
/// <param name="ids">The Ids list to map to the high occurrence values.</param>
|
||||||
/// <returns>The list of occurrence values mapped from the list of Ids.</returns>
|
/// <returns>The list of occurrence values mapped from the list of Ids.</returns>
|
||||||
|
@ -454,7 +753,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Convert a list of highest occurrence rankings to tokens Ids list .
|
/// Convert a list of highest occurrence rankings to token Ids list .
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="ranks">The high occurrence ranks list to map to the Ids list.</param>
|
/// <param name="ranks">The high occurrence ranks list to map to the Ids list.</param>
|
||||||
/// <returns>The list of Ids mapped from the list of ranks.</returns>
|
/// <returns>The list of Ids mapped from the list of ranks.</returns>
|
||||||
|
@ -638,7 +937,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
{
|
{
|
||||||
if (token.Length == 0)
|
if (token.Length == 0)
|
||||||
{
|
{
|
||||||
return Bpe.EmptyTokensList;
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (token.Length == 1)
|
if (token.Length == 1)
|
||||||
|
|
|
@ -1,180 +0,0 @@
|
||||||
// Licensed to the .NET Foundation under one or more agreements.
|
|
||||||
// The .NET Foundation licenses this file to you under the MIT license.
|
|
||||||
// See the LICENSE file in the project root for more information.
|
|
||||||
|
|
||||||
using System;
|
|
||||||
using System.Collections.Generic;
|
|
||||||
using System.Text;
|
|
||||||
|
|
||||||
namespace Microsoft.ML.Tokenizers
|
|
||||||
{
|
|
||||||
/// <summary>
|
|
||||||
/// Represents a model used during Tokenization (like BPE or Word Piece or Unigram).
|
|
||||||
/// </summary>
|
|
||||||
public abstract class Model
|
|
||||||
{
|
|
||||||
/// <summary>
|
|
||||||
/// Encode a text to a list of tokens.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="text">The text to encode.</param>
|
|
||||||
/// <returns>The list of tokens generated from the text tokenization.</returns>
|
|
||||||
public abstract IReadOnlyList<Token> Encode(ReadOnlySpan<char> text);
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Encode a text to a list of Ids and add them to the accumulatedIds list.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="text">The text to encode. </param>
|
|
||||||
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
|
|
||||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
|
||||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
|
||||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
|
||||||
/// <remarks>
|
|
||||||
/// This method does the default implementation that uses the Encode method to get the token's Ids.
|
|
||||||
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
|
|
||||||
/// </remarks>
|
|
||||||
public virtual int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue)
|
|
||||||
{
|
|
||||||
if (accumulatedIds is null)
|
|
||||||
{
|
|
||||||
throw new ArgumentNullException(nameof(accumulatedIds));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default implementation is not optimized for memory allocation. It is recommended to override this method for the sake of the performance.
|
|
||||||
textLength = 0;
|
|
||||||
var tokens = Encode(text);
|
|
||||||
|
|
||||||
int count = Math.Min(tokens.Count, maxTokens);
|
|
||||||
|
|
||||||
for (int i = 0; i < count; i++)
|
|
||||||
{
|
|
||||||
textLength += tokens[i].Offset.Length;
|
|
||||||
accumulatedIds.Add(tokens[i].Id);
|
|
||||||
}
|
|
||||||
|
|
||||||
return count;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Get the number of tokens that the input text will be encoded to.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="text">The text to encode.</param>
|
|
||||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
|
||||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
|
||||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
|
||||||
/// <remarks>
|
|
||||||
/// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids.
|
|
||||||
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
|
|
||||||
/// </remarks>
|
|
||||||
public virtual int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue)
|
|
||||||
{
|
|
||||||
if (maxTokens <= 0)
|
|
||||||
{
|
|
||||||
throw new ArgumentOutOfRangeException(nameof(maxTokens), "The maximum number of tokens must be greater than 0.");
|
|
||||||
}
|
|
||||||
|
|
||||||
var ids = new List<int>();
|
|
||||||
|
|
||||||
if (maxTokens == int.MaxValue)
|
|
||||||
{
|
|
||||||
EncodeToIds(text, ids, out _);
|
|
||||||
textLength = text.Length;
|
|
||||||
return ids.Count;
|
|
||||||
}
|
|
||||||
|
|
||||||
IReadOnlyList<Token> tokens = Encode(text);
|
|
||||||
textLength = 0;
|
|
||||||
int count = Math.Min(tokens.Count, maxTokens);
|
|
||||||
for (int i = 0; i < count; i++)
|
|
||||||
{
|
|
||||||
textLength += tokens[i].Offset.Length;
|
|
||||||
}
|
|
||||||
|
|
||||||
return count;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Get the number of tokens that the input text will be encoded to.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="text">The text to encode.</param>
|
|
||||||
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
|
|
||||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
|
||||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
|
||||||
/// <remarks>
|
|
||||||
/// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids.
|
|
||||||
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
|
|
||||||
/// </remarks>
|
|
||||||
public virtual int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue)
|
|
||||||
{
|
|
||||||
if (maxTokens <= 0)
|
|
||||||
{
|
|
||||||
throw new ArgumentOutOfRangeException(nameof(maxTokens), "The maximum number of tokens must be greater than 0.");
|
|
||||||
}
|
|
||||||
|
|
||||||
var ids = new List<int>();
|
|
||||||
|
|
||||||
if (maxTokens == int.MaxValue)
|
|
||||||
{
|
|
||||||
EncodeToIds(text, ids, out _);
|
|
||||||
textIndex = 0;
|
|
||||||
return ids.Count;
|
|
||||||
}
|
|
||||||
|
|
||||||
IReadOnlyList<Token> tokens = Encode(text);
|
|
||||||
textIndex = text.Length;
|
|
||||||
int count = Math.Min(tokens.Count, maxTokens);
|
|
||||||
|
|
||||||
int tokensCount = tokens.Count;
|
|
||||||
int end = tokensCount - count;
|
|
||||||
for (int i = tokensCount - 1; i >= end; i--)
|
|
||||||
{
|
|
||||||
textIndex -= tokens[i].Offset.Length;
|
|
||||||
}
|
|
||||||
|
|
||||||
return count;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Map the token to encoded id with the option to skip the special tokens.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="token">The token to map to Id</param>
|
|
||||||
/// <returns>The mapped Id of the token.</returns>
|
|
||||||
public abstract int? MapTokenToId(ReadOnlySpan<char> token);
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Map the encoded Id to the token.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="id">The Id to map to the token.</param>
|
|
||||||
/// <returns>The mapped token of the Id.</returns>
|
|
||||||
public abstract string? MapIdToToken(int id);
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Decode the given ids, back to a String.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="ids">The list of ids that we want to decode.</param>
|
|
||||||
/// <returns>The decoded string.</returns>
|
|
||||||
/// <remarks>
|
|
||||||
/// This method does the default implementation that uses the MapIdToToken method to get the token.
|
|
||||||
/// Tokenizer models may opt to override this method to ensure accurate results if the default implementation
|
|
||||||
/// provided here proves insufficient for the model's specific scenario.
|
|
||||||
/// </remarks>
|
|
||||||
public virtual string? Decode(IEnumerable<int> ids)
|
|
||||||
{
|
|
||||||
if (ids is null)
|
|
||||||
{
|
|
||||||
throw new ArgumentNullException(nameof(ids));
|
|
||||||
}
|
|
||||||
|
|
||||||
ValueStringBuilder sb = new ValueStringBuilder();
|
|
||||||
|
|
||||||
foreach (int id in ids)
|
|
||||||
{
|
|
||||||
if (MapIdToToken(id) is string s)
|
|
||||||
{
|
|
||||||
sb.Append(s);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return sb.ToString();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -19,7 +19,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// SentencePieceBpe is a tokenizer that splits the input into tokens using the SentencePiece Bpe model.
|
/// SentencePieceBpe is a tokenizer that splits the input into tokens using the SentencePiece Bpe model.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public sealed class SentencePieceBpe : Model
|
public sealed class SentencePieceBpe : Tokenizer
|
||||||
{
|
{
|
||||||
private const int UninitializedId = -2; // indicate if the symbol contains uninitialized id.
|
private const int UninitializedId = -2; // indicate if the symbol contains uninitialized id.
|
||||||
|
|
||||||
|
@ -29,9 +29,10 @@ namespace Microsoft.ML.Tokenizers
|
||||||
private readonly int _maxByteId;
|
private readonly int _maxByteId;
|
||||||
private readonly int _byteCodeToIdOffset; // offset of mapping byte code to the to the Ids.
|
private readonly int _byteCodeToIdOffset; // offset of mapping byte code to the to the Ids.
|
||||||
private readonly int _oneByteUtf8EncodingMaxId; // the maximum value of the one byte UTF-8 character.
|
private readonly int _oneByteUtf8EncodingMaxId; // the maximum value of the one byte UTF-8 character.
|
||||||
|
private readonly Normalizer? _normalizer;
|
||||||
|
|
||||||
internal SentencePieceBpe(ModelProto modelProto, bool addBos, bool addEos) :
|
internal SentencePieceBpe(ModelProto modelProto, bool addBos, bool addEos) :
|
||||||
this(modelProto is null ? throw new ArgumentNullException(nameof(modelProto)) : modelProto)
|
this(modelProto is null ? throw new ArgumentNullException(nameof(modelProto)) : modelProto)
|
||||||
{
|
{
|
||||||
AddBeginningOfSentence = addBos;
|
AddBeginningOfSentence = addBos;
|
||||||
AddEndOfSentence = addEos;
|
AddEndOfSentence = addEos;
|
||||||
|
@ -65,6 +66,8 @@ namespace Microsoft.ML.Tokenizers
|
||||||
EscapeWhiteSpaces = modelProto.NormalizerSpec.EscapeWhitespaces;
|
EscapeWhiteSpaces = modelProto.NormalizerSpec.EscapeWhitespaces;
|
||||||
TreatWhitespaceAsSuffix = modelProto.TrainerSpec.TreatWhitespaceAsSuffix;
|
TreatWhitespaceAsSuffix = modelProto.TrainerSpec.TreatWhitespaceAsSuffix;
|
||||||
ByteFallback = modelProto.TrainerSpec.ByteFallback;
|
ByteFallback = modelProto.TrainerSpec.ByteFallback;
|
||||||
|
|
||||||
|
_normalizer = new SentencePieceNormalizer(modelProto.NormalizerSpec.RemoveExtraWhitespaces, AddDummyPrefix, EscapeWhiteSpaces, modelProto.TrainerSpec.TreatWhitespaceAsSuffix);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
@ -127,6 +130,16 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public int UnknownId { get; }
|
public int UnknownId { get; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Gets the PreTokenizer used by the Tokenizer.
|
||||||
|
/// </summary>
|
||||||
|
public override PreTokenizer? PreTokenizer => null;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Gets the Normalizer in use by the Tokenizer.
|
||||||
|
/// </summary>
|
||||||
|
public override Normalizer? Normalizer => _normalizer;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// The vocabulary of the model.
|
/// The vocabulary of the model.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
@ -152,12 +165,74 @@ namespace Microsoft.ML.Tokenizers
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Encode a text to a list of tokens.
|
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
/// <returns>The list of tokens generated from the text tokenization.</returns>
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
/// <remarks>The input text has to be normalized before calling this method.</remarks>
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text) => Encode(text, AddBeginningOfSentence, AddEndOfSentence);
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
|
||||||
|
public override IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> Encode(text, Span<char>.Empty, out normalizedString, AddBeginningOfSentence, AddEndOfSentence, considerPreTokenization, considerNormalization);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
|
||||||
|
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> Encode(null, text, out normalizedString, AddBeginningOfSentence, AddEndOfSentence, considerPreTokenization, considerNormalization);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
|
||||||
|
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
|
||||||
|
public IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> Encode(text, Span<char>.Empty, out normalizedString, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
|
||||||
|
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
|
||||||
|
public IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> Encode(null, text, out normalizedString, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization);
|
||||||
|
|
||||||
|
private IReadOnlyList<Token> Encode(string? text, ReadOnlySpan<char> textSpan, out string? normalizedString, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization, bool considerNormalization)
|
||||||
|
{
|
||||||
|
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
|
||||||
|
{
|
||||||
|
normalizedString = null;
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
ReadOnlySpan<char> textToEncode = text is null ? textSpan : text.AsSpan();
|
||||||
|
if (considerNormalization && _normalizer is not null)
|
||||||
|
{
|
||||||
|
normalizedString = text is not null ? _normalizer.Normalize(text) : _normalizer.Normalize(textSpan);
|
||||||
|
textToEncode = normalizedString.AsSpan();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
normalizedString = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return EncodeInternal(textToEncode, addBeginningOfSentence, addEndOfSentence);
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Encode a text to a list of tokens.
|
/// Encode a text to a list of tokens.
|
||||||
|
@ -167,11 +242,11 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||||
/// <returns>The list of tokens generated from the text tokenization.</returns>
|
/// <returns>The list of tokens generated from the text tokenization.</returns>
|
||||||
/// <remarks>The input text has to be normalized before calling this method.</remarks>
|
/// <remarks>The input text has to be normalized before calling this method.</remarks>
|
||||||
public IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence)
|
private IReadOnlyList<Token> EncodeInternal(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence)
|
||||||
{
|
{
|
||||||
if (text.Length == 0)
|
if (text.Length == 0)
|
||||||
{
|
{
|
||||||
return Array.Empty<Token>();
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
BpeSymbol[] symbols = ArrayPool<BpeSymbol>.Shared.Rent(text.Length);
|
BpeSymbol[] symbols = ArrayPool<BpeSymbol>.Shared.Rent(text.Length);
|
||||||
|
@ -306,16 +381,168 @@ namespace Microsoft.ML.Tokenizers
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Encode a text to a list of Ids and add them to the accumulatedIds list.
|
/// Encodes input text to tokes Ids.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public override IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> EncodeToIds(text, Span<char>.Empty, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out _, out _);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> EncodeToIds(null, text, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out _, out _);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids up to maximum number of tokens.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
/// <remarks>The input text has to be normalized before calling this method.</remarks>
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
public override int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue)
|
public override IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
=> EncodeToIds(text, AddBeginningOfSentence, AddEndOfSentence, accumulatedIds, out textLength, maxTokens);
|
=> EncodeToIds(text, Span<char>.Empty, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids up to maximum number of tokens.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> EncodeToIds(null, text, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
|
||||||
|
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public IReadOnlyList<int> EncodeToIds(string text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> EncodeToIds(text, Span<char>.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
|
||||||
|
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> EncodeToIds(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids up to maximum number of tokens.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
|
||||||
|
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public IReadOnlyList<int> EncodeToIds(string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> EncodeToIds(text, Span<char>.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids up to maximum number of tokens.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
|
||||||
|
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> EncodeToIds(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount);
|
||||||
|
|
||||||
|
|
||||||
|
private IReadOnlyList<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
|
||||||
|
{
|
||||||
|
if (maxTokenCount <= 0)
|
||||||
|
{
|
||||||
|
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
|
||||||
|
{
|
||||||
|
normalizedString = null;
|
||||||
|
textLength = 0;
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
return EncodeToIds(text is null ? textSpan : text.AsSpan(), addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids up to maximum number of tokens.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
|
||||||
|
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization,
|
||||||
|
out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
|
||||||
|
{
|
||||||
|
if (maxTokenCount <= 0)
|
||||||
|
{
|
||||||
|
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (text.IsEmpty)
|
||||||
|
{
|
||||||
|
normalizedString = null;
|
||||||
|
textLength = 0;
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
ReadOnlySpan<char> textToEncode;
|
||||||
|
|
||||||
|
if (considerNormalization && _normalizer is not null)
|
||||||
|
{
|
||||||
|
normalizedString = _normalizer.Normalize(text);
|
||||||
|
textToEncode = normalizedString.AsSpan();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
normalizedString = null;
|
||||||
|
textToEncode = text;
|
||||||
|
}
|
||||||
|
|
||||||
|
List<int> ids = new();
|
||||||
|
|
||||||
|
EncodeToIds(textToEncode, addBeginningOfSentence, addEndOfSentence, ids, out textLength, maxTokenCount);
|
||||||
|
|
||||||
|
return ids;
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Encode a text to a list of Ids and add them to the accumulatedIds list.
|
/// Encode a text to a list of Ids and add them to the accumulatedIds list.
|
||||||
|
@ -328,7 +555,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||||
/// <remarks>The input text has to be normalized before calling this method.</remarks>
|
/// <remarks>The input text has to be normalized before calling this method.</remarks>
|
||||||
public int EncodeToIds(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue)
|
private int EncodeToIds(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue)
|
||||||
{
|
{
|
||||||
if (maxTokens <= 0)
|
if (maxTokens <= 0)
|
||||||
{
|
{
|
||||||
|
@ -378,7 +605,8 @@ namespace Microsoft.ML.Tokenizers
|
||||||
{
|
{
|
||||||
if (!EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref textLength))
|
if (!EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref textLength))
|
||||||
{
|
{
|
||||||
break;
|
ArrayPool<BpeSymbol>.Shared.Return(symbols);
|
||||||
|
return idsCount;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
@ -391,6 +619,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
ArrayPool<BpeSymbol>.Shared.Return(symbols);
|
||||||
return idsCount;
|
return idsCount;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -510,21 +739,252 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// Get the number of tokens that the input text will be encoded to.
|
/// Get the number of tokens that the input text will be encoded to.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
|
||||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
/// <remarks>The input text has to be normalized before calling this method.</remarks>
|
=> CountTokens(text, Span<char>.Empty, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out _, out _);
|
||||||
public override int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue) => CountTokens(text, AddBeginningOfSentence, AddEndOfSentence, out textLength, maxTokens);
|
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Get the number of tokens that the input text will be encoded to.
|
/// Get the number of tokens that the input text will be encoded to.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
|
||||||
|
public override int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> CountTokens(null, text, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out _, out _);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
|
||||||
|
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
|
||||||
|
/// </returns>
|
||||||
|
public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
{
|
||||||
|
tokenCount = CountTokens(text, Span<char>.Empty, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
|
||||||
|
return textLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
|
||||||
|
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
|
||||||
|
/// </returns>
|
||||||
|
public override int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
{
|
||||||
|
tokenCount = CountTokens(null, text, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
|
||||||
|
return textLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Get the number of tokens that the input text will be encoded to.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
|
||||||
|
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
|
||||||
|
public int CountTokens(string text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> CountTokens(text, Span<char>.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Get the number of tokens that the input text will be encoded to.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
|
||||||
|
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
|
||||||
|
public int CountTokens(ReadOnlySpan<char> text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> CountTokens(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
|
||||||
|
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
|
||||||
|
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
|
||||||
|
/// </returns>
|
||||||
|
public int IndexOfTokenCount(string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
{
|
||||||
|
tokenCount = CountTokens(text, Span<char>.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
|
||||||
|
return textLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
|
||||||
|
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
|
||||||
|
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
|
||||||
|
/// </returns>
|
||||||
|
public int IndexOfTokenCount(ReadOnlySpan<char> text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
{
|
||||||
|
tokenCount = CountTokens(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
|
||||||
|
return textLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int CountTokens(string? text, ReadOnlySpan<char> textSpan, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
|
||||||
|
=> CountTokens(text is null ? textSpan : text.AsSpan(), addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Get the number of tokens that the input text will be encoded to.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
|
||||||
|
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
|
||||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||||
public override int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue) => CountTokensFromEnd(text, AddBeginningOfSentence, AddEndOfSentence, out textIndex, maxTokens);
|
public int CountTokens(ReadOnlySpan<char> text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
|
||||||
|
{
|
||||||
|
if (maxTokenCount <= 0)
|
||||||
|
{
|
||||||
|
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (text.IsEmpty)
|
||||||
|
{
|
||||||
|
normalizedString = null;
|
||||||
|
textLength = 0;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
ReadOnlySpan<char> textToEncode;
|
||||||
|
if (considerNormalization && _normalizer is not null)
|
||||||
|
{
|
||||||
|
normalizedString = _normalizer.Normalize(text);
|
||||||
|
textToEncode = normalizedString.AsSpan();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
normalizedString = null;
|
||||||
|
textToEncode = text;
|
||||||
|
}
|
||||||
|
|
||||||
|
return CountTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, out textLength, maxTokenCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if normalization is enabled;
|
||||||
|
/// conversely, if all tokens fit, the result will be 0.
|
||||||
|
/// </returns>
|
||||||
|
public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> LastIndexOf(text, Span<char>.Empty, maxTokenCount, considerNormalization, out normalizedString, out tokenCount);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="normalizedString"/>; conversely, if all tokens fit, the result will be 0.
|
||||||
|
/// </returns>
|
||||||
|
public override int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> LastIndexOf(null, text, maxTokenCount, considerNormalization, out normalizedString, out tokenCount);
|
||||||
|
|
||||||
|
private int LastIndexOf(string? text, ReadOnlySpan<char> textSpan, int maxTokenCount, bool considerNormalization, out string? normalizedString, out int tokenCount)
|
||||||
|
=> LastIndexOfTokenCount(text is null ? textSpan : text.AsSpan(), maxTokenCount, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out tokenCount);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
|
||||||
|
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="normalizedString"/>; conversely, if all tokens fit, the result will be 0.
|
||||||
|
/// </returns>
|
||||||
|
public int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, out string? normalizedString, out int tokenCount)
|
||||||
|
{
|
||||||
|
if (maxTokenCount <= 0)
|
||||||
|
{
|
||||||
|
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (text.IsEmpty)
|
||||||
|
{
|
||||||
|
normalizedString = null;
|
||||||
|
tokenCount = 0;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
ReadOnlySpan<char> textToEncode;
|
||||||
|
if (considerNormalization && _normalizer is not null)
|
||||||
|
{
|
||||||
|
normalizedString = _normalizer.Normalize(text);
|
||||||
|
textToEncode = normalizedString.AsSpan();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
normalizedString = null;
|
||||||
|
textToEncode = text;
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenCount = CountTokensFromEnd(textToEncode, addBeginningOfSentence, addEndOfSentence, out int textIndex, maxTokenCount);
|
||||||
|
return textIndex;
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Get the number of tokens that the input text will be encoded to.
|
/// Get the number of tokens that the input text will be encoded to.
|
||||||
|
@ -536,7 +996,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||||
/// <remarks>The input text has to be normalized before calling this method.</remarks>
|
/// <remarks>The input text has to be normalized before calling this method.</remarks>
|
||||||
public int CountTokens(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, out int textLength, int maxTokens = int.MaxValue)
|
private int CountTokens(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, out int textLength, int maxTokens = int.MaxValue)
|
||||||
{
|
{
|
||||||
textLength = 0;
|
textLength = 0;
|
||||||
if (text.IsEmpty)
|
if (text.IsEmpty)
|
||||||
|
@ -705,7 +1165,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||||
/// <remarks>The input text has to be normalized before calling this method.</remarks>
|
/// <remarks>The input text has to be normalized before calling this method.</remarks>
|
||||||
public int CountTokensFromEnd(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, out int textIndex, int maxTokens = int.MaxValue)
|
private int CountTokensFromEnd(ReadOnlySpan<char> text, bool addBeginOfSentence, bool addEndOfSentence, out int textIndex, int maxTokens = int.MaxValue)
|
||||||
{
|
{
|
||||||
textIndex = text.Length;
|
textIndex = text.Length;
|
||||||
if (text.IsEmpty)
|
if (text.IsEmpty)
|
||||||
|
@ -944,7 +1404,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
else if (_vocabReverse.TryGetValue(enumerator.Current, out string? token))
|
else if (_vocabReverse.TryGetValue(enumerator.Current, out string? token))
|
||||||
{
|
{
|
||||||
// escape the dummy prefix if needed.
|
// escape the dummy prefix if needed.
|
||||||
sb.Append(AddDummyPrefix && !TreatWhitespaceAsSuffix && token.Length > 0 && token[0] == LlamaNormalizer.DummyPrefix ?
|
sb.Append(AddDummyPrefix && !TreatWhitespaceAsSuffix && token.Length > 0 && token[0] == SentencePieceNormalizer.DummyPrefix ?
|
||||||
token.AsSpan(1) :
|
token.AsSpan(1) :
|
||||||
token.AsSpan());
|
token.AsSpan());
|
||||||
}
|
}
|
||||||
|
@ -999,7 +1459,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb);
|
FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (AddDummyPrefix && TreatWhitespaceAsSuffix && sb.Length > 0 && sb[sb.Length - 1] == LlamaNormalizer.DummyPrefix)
|
if (AddDummyPrefix && TreatWhitespaceAsSuffix && sb.Length > 0 && sb[sb.Length - 1] == SentencePieceNormalizer.DummyPrefix)
|
||||||
{
|
{
|
||||||
sb.RemoveLastChar();
|
sb.RemoveLastChar();
|
||||||
}
|
}
|
||||||
|
@ -1014,7 +1474,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
ArrayPool<char>.Shared.Return(charPoolArray);
|
ArrayPool<char>.Shared.Return(charPoolArray);
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.ToString(LlamaNormalizer.DummyPrefix, ' ');
|
return sb.ToString(SentencePieceNormalizer.DummyPrefix, ' ');
|
||||||
|
|
||||||
static void FlushBytes(ref int bytesCount, ref byte[]? bytesPoolArray, ref char[]? charPoolArray, ref ValueStringBuilder sb)
|
static void FlushBytes(ref int bytesCount, ref byte[]? bytesPoolArray, ref char[]? charPoolArray, ref ValueStringBuilder sb)
|
||||||
{
|
{
|
||||||
|
|
|
@ -19,9 +19,9 @@ using System.Threading.Tasks;
|
||||||
namespace Microsoft.ML.Tokenizers
|
namespace Microsoft.ML.Tokenizers
|
||||||
{
|
{
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Represent the rapid Byte Pair Encoding model commonly referred to as Tiktoken.
|
/// Represent the rapid Byte Pair Encoding tokenizer.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public sealed partial class Tiktoken : Model
|
public sealed partial class Tiktoken : Tokenizer
|
||||||
{
|
{
|
||||||
private readonly Dictionary<ReadOnlyMemory<byte>, int> _encoder;
|
private readonly Dictionary<ReadOnlyMemory<byte>, int> _encoder;
|
||||||
private readonly Dictionary<int, ReadOnlyMemory<byte>> _decoder;
|
private readonly Dictionary<int, ReadOnlyMemory<byte>> _decoder;
|
||||||
|
@ -29,46 +29,56 @@ namespace Microsoft.ML.Tokenizers
|
||||||
private readonly Dictionary<StringSpanOrdinalKey, (int Id, string Token)> _vocab;
|
private readonly Dictionary<StringSpanOrdinalKey, (int Id, string Token)> _vocab;
|
||||||
private IReadOnlyDictionary<string, int>? _vocabOriginal;
|
private IReadOnlyDictionary<string, int>? _vocabOriginal;
|
||||||
private const int MaxWordLengthToCache = 15;
|
private const int MaxWordLengthToCache = 15;
|
||||||
|
private readonly PreTokenizer? _preTokenizer;
|
||||||
|
private readonly Normalizer? _normalizer;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Create a new Tiktoken tokenizer's model object.
|
/// Create a new Tiktoken tokenizer's object.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="vocabFilePath">The path to the BPE vocab file.</param>
|
/// <param name="vocabFilePath">The path to the BPE vocab file.</param>
|
||||||
|
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
|
||||||
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
|
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
|
||||||
|
/// <param name="normalizer">The normalizer to use.</param>
|
||||||
/// <param name="cacheSize">The size of the cache to use.</param>
|
/// <param name="cacheSize">The size of the cache to use.</param>
|
||||||
/// <exception cref="ArgumentNullException">Thrown when <paramref name="vocabFilePath"/> is null or empty.</exception>
|
/// <exception cref="ArgumentNullException">Thrown when <paramref name="vocabFilePath"/> is null or empty.</exception>
|
||||||
/// <exception cref="InvalidOperationException">Thrown when failed to load the BPE vocab file.</exception>
|
/// <exception cref="InvalidOperationException">Thrown when failed to load the BPE vocab file.</exception>
|
||||||
public Tiktoken(string vocabFilePath, IReadOnlyDictionary<string, int>? specialTokens = null, int cacheSize = LruCache<int[]>.DefaultCacheSize) :
|
public Tiktoken(string vocabFilePath, PreTokenizer? preTokenizer, IReadOnlyDictionary<string, int>? specialTokens = null, Normalizer? normalizer = null, int cacheSize = LruCache<int[]>.DefaultCacheSize) :
|
||||||
this(string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath), specialTokens, cacheSize, disposeStream: true)
|
this(string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath), preTokenizer, specialTokens, normalizer, cacheSize, disposeStream: true)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Create a new Tiktoken tokenizer's model object.
|
/// Create a new Tiktoken tokenizer's object.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="vocabStream">The stream to the BPE vocab file.</param>
|
/// <param name="vocabStream">The stream to the BPE vocab file.</param>
|
||||||
|
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
|
||||||
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
|
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
|
||||||
|
/// <param name="normalizer">The normalizer to use.</param>
|
||||||
/// <param name="cacheSize">The size of the cache to use.</param>
|
/// <param name="cacheSize">The size of the cache to use.</param>
|
||||||
/// <exception cref="ArgumentNullException">Thrown when <paramref name="vocabStream"/> is null or empty.</exception>
|
/// <exception cref="ArgumentNullException">Thrown when <paramref name="vocabStream"/> is null or empty.</exception>
|
||||||
/// <exception cref="InvalidOperationException">Thrown when failed to load the BPE vocab file.</exception>
|
/// <exception cref="InvalidOperationException">Thrown when failed to load the BPE vocab file.</exception>
|
||||||
public Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTokens = null, int cacheSize = LruCache<int[]>.DefaultCacheSize) :
|
public Tiktoken(Stream vocabStream, PreTokenizer? preTokenizer, IReadOnlyDictionary<string, int>? specialTokens = null, Normalizer? normalizer = null, int cacheSize = LruCache<int[]>.DefaultCacheSize) :
|
||||||
this(vocabStream ?? throw new ArgumentNullException(nameof(vocabStream)), specialTokens, cacheSize, disposeStream: false)
|
this(vocabStream ?? throw new ArgumentNullException(nameof(vocabStream)), preTokenizer, specialTokens, normalizer, cacheSize, disposeStream: false)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Create a new Tiktoken tokenizer's model object.
|
/// Create a new Tiktoken tokenizer's object.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="encoder">The dictionary mapping token utf-8 bytes to Ids.</param>
|
/// <param name="encoder">The dictionary mapping token utf-8 bytes to Ids.</param>
|
||||||
/// <param name="decoder">The dictionary mapping Ids to token utf-8 bytes.</param>
|
/// <param name="decoder">The dictionary mapping Ids to token utf-8 bytes.</param>
|
||||||
/// <param name="vocab">The dictionary mapping string tokens to Ids.</param>
|
/// <param name="vocab">The dictionary mapping string tokens to Ids.</param>
|
||||||
|
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
|
||||||
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
|
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
|
||||||
|
/// <param name="normalizer">The normalizer to use.</param>
|
||||||
/// <param name="cacheSize">The max size of the cache to use.</param>
|
/// <param name="cacheSize">The max size of the cache to use.</param>
|
||||||
internal Tiktoken(
|
internal Tiktoken(
|
||||||
Dictionary<ReadOnlyMemory<byte>, int> encoder,
|
Dictionary<ReadOnlyMemory<byte>, int> encoder,
|
||||||
Dictionary<int, ReadOnlyMemory<byte>> decoder,
|
Dictionary<int, ReadOnlyMemory<byte>> decoder,
|
||||||
Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab,
|
Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab,
|
||||||
|
PreTokenizer? preTokenizer,
|
||||||
IReadOnlyDictionary<string, int>? specialTokens,
|
IReadOnlyDictionary<string, int>? specialTokens,
|
||||||
|
Normalizer? normalizer = null,
|
||||||
int cacheSize = LruCache<int[]>.DefaultCacheSize)
|
int cacheSize = LruCache<int[]>.DefaultCacheSize)
|
||||||
{
|
{
|
||||||
_encoder = encoder ?? throw new ArgumentNullException(nameof(encoder));
|
_encoder = encoder ?? throw new ArgumentNullException(nameof(encoder));
|
||||||
|
@ -78,19 +88,26 @@ namespace Microsoft.ML.Tokenizers
|
||||||
_encoder = encoder!;
|
_encoder = encoder!;
|
||||||
_decoder = decoder!;
|
_decoder = decoder!;
|
||||||
_vocab = vocab!;
|
_vocab = vocab!;
|
||||||
|
|
||||||
|
_preTokenizer = preTokenizer;
|
||||||
|
_normalizer = normalizer;
|
||||||
|
|
||||||
_cache = new LruCache<(int[] Bytes, string Token)>(cacheSize);
|
_cache = new LruCache<(int[] Bytes, string Token)>(cacheSize);
|
||||||
|
|
||||||
SpecialTokens = specialTokens;
|
SpecialTokens = specialTokens;
|
||||||
CacheSpecialTokensEncoding(specialTokens);
|
CacheSpecialTokensEncoding(specialTokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
private Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTokens, int cacheSize, bool disposeStream)
|
private Tiktoken(Stream vocabStream, PreTokenizer? preTokenizer, IReadOnlyDictionary<string, int>? specialTokens, Normalizer? normalizer, int cacheSize, bool disposeStream)
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
_cache = new LruCache<(int[] Bytes, string Token)>(cacheSize);
|
_cache = new LruCache<(int[] Bytes, string Token)>(cacheSize);
|
||||||
(_encoder, _vocab, _decoder) = LoadTiktokenBpeAsync(vocabStream, useAsync: false).GetAwaiter().GetResult();
|
(_encoder, _vocab, _decoder) = LoadTiktokenBpeAsync(vocabStream, useAsync: false).GetAwaiter().GetResult();
|
||||||
|
|
||||||
|
_preTokenizer = preTokenizer;
|
||||||
|
_normalizer = normalizer;
|
||||||
|
|
||||||
SpecialTokens = specialTokens;
|
SpecialTokens = specialTokens;
|
||||||
CacheSpecialTokensEncoding(specialTokens);
|
CacheSpecialTokensEncoding(specialTokens);
|
||||||
}
|
}
|
||||||
|
@ -103,6 +120,16 @@ namespace Microsoft.ML.Tokenizers
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Gets the PreTokenizer used by the Tokenizer.
|
||||||
|
/// </summary>
|
||||||
|
public override PreTokenizer? PreTokenizer => _preTokenizer;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Gets the Normalizer in use by the Tokenizer.
|
||||||
|
/// </summary>
|
||||||
|
public override Normalizer? Normalizer => _normalizer;
|
||||||
|
|
||||||
private void CacheSpecialTokensEncoding(IReadOnlyDictionary<string, int>? specialTokens)
|
private void CacheSpecialTokensEncoding(IReadOnlyDictionary<string, int>? specialTokens)
|
||||||
{
|
{
|
||||||
Debug.Assert(_cache is not null);
|
Debug.Assert(_cache is not null);
|
||||||
|
@ -118,54 +145,6 @@ namespace Microsoft.ML.Tokenizers
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Create a new Tiktoken tokenizer's model object asynchronously.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="vocabStream">The stream to the BPE vocab file.</param>
|
|
||||||
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
|
|
||||||
/// <param name="cacheSize">The size of the cache to use.</param>
|
|
||||||
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
|
|
||||||
/// <returns>Tiktoken tokenizer's object.</returns>
|
|
||||||
public static async Task<Tiktoken> CreateAsync(
|
|
||||||
Stream vocabStream,
|
|
||||||
IReadOnlyDictionary<string, int>? specialTokens = null,
|
|
||||||
int cacheSize = LruCache<int[]>.DefaultCacheSize,
|
|
||||||
CancellationToken cancellationToken = default)
|
|
||||||
{
|
|
||||||
if (vocabStream is null)
|
|
||||||
{
|
|
||||||
throw new ArgumentNullException(nameof(vocabStream));
|
|
||||||
}
|
|
||||||
|
|
||||||
(Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) =
|
|
||||||
await LoadTiktokenBpeAsync(vocabStream, useAsync: true, cancellationToken).ConfigureAwait(false);
|
|
||||||
|
|
||||||
return new Tiktoken(encoder, decoder, vocab, specialTokens, cacheSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Create a new Tiktoken tokenizer's object asynchronously.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="vocabFilePath">The BPE vocab file.</param>
|
|
||||||
/// <param name="specialTokensEncoder">The dictionary mapping special tokens to Ids.</param>
|
|
||||||
/// <param name="cacheSize">The size of the cache to use.</param>
|
|
||||||
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
|
|
||||||
/// <returns>Tiktoken tokenizer's model object.</returns>
|
|
||||||
public static async Task<Tiktoken> CreateAsync(
|
|
||||||
string vocabFilePath,
|
|
||||||
IReadOnlyDictionary<string, int>? specialTokensEncoder = null,
|
|
||||||
int cacheSize = LruCache<int[]>.DefaultCacheSize,
|
|
||||||
CancellationToken cancellationToken = default)
|
|
||||||
{
|
|
||||||
if (vocabFilePath is null)
|
|
||||||
{
|
|
||||||
throw new ArgumentNullException(nameof(vocabFilePath));
|
|
||||||
}
|
|
||||||
|
|
||||||
using Stream vocabStream = File.OpenRead(vocabFilePath);
|
|
||||||
return await CreateAsync(vocabStream, specialTokensEncoder, cacheSize, cancellationToken).ConfigureAwait(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Load BPE vocab dictionary from a stream.
|
/// Load BPE vocab dictionary from a stream.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
@ -254,37 +233,80 @@ namespace Microsoft.ML.Tokenizers
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
|
||||||
|
public override IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(text, Span<char>.Empty, out normalizedString, considerPreTokenization, considerNormalization);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
|
||||||
|
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(null, text, out normalizedString, considerPreTokenization, considerNormalization);
|
||||||
|
|
||||||
|
private IReadOnlyList<Token> Encode(string? text, ReadOnlySpan<char> textSpan, out string? normalizedString, bool considerPreTokenization, bool considerNormalization)
|
||||||
|
{
|
||||||
|
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
|
||||||
|
{
|
||||||
|
normalizedString = null;
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
|
||||||
|
|
||||||
|
List<Token> tokens = new();
|
||||||
|
|
||||||
|
if (splits is not null)
|
||||||
|
{
|
||||||
|
foreach ((int Offset, int Length) split in splits)
|
||||||
|
{
|
||||||
|
Encode(textSpanToEncode.Slice(split.Offset, split.Length), tokens, split.Offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
Encode(textSpanToEncode, tokens, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Encode text to a list of tokens.
|
/// Encode text to a list of tokens.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
/// <returns>The list of tokens generated from the text tokenization.</returns>
|
/// <param name="tokens">The list of tokens to populate.</param>
|
||||||
public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text)
|
/// <param name="offset">The offset to start encoding from.</param>
|
||||||
|
private void Encode(ReadOnlySpan<char> text, List<Token> tokens, int offset)
|
||||||
{
|
{
|
||||||
Token[] tokens;
|
Debug.Assert(!text.IsEmpty);
|
||||||
|
|
||||||
if (text.IsEmpty)
|
|
||||||
{
|
|
||||||
return Array.Empty<Token>();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
|
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
|
||||||
{
|
{
|
||||||
tokens = new Token[value.Ids.Length];
|
tokens.Add(new Token(value.Ids[0], value.Token, (offset, value.Token.Length)));
|
||||||
tokens[0] = new Token(value.Ids[0], value.Token, (0, value.Token.Length));
|
|
||||||
for (int i = 1; i < value.Ids.Length; i++)
|
for (int i = 1; i < value.Ids.Length; i++)
|
||||||
{
|
{
|
||||||
// One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width.
|
// One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width.
|
||||||
tokens[i] = new Token(value.Ids[i], "", (text.Length, 0));
|
tokens.Add(new Token(value.Ids[i], "", (offset + text.Length, 0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
return tokens;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// cache miss
|
// cache miss
|
||||||
if (_vocab.TryGetValue(text, out (int Id, string Token) mappedId))
|
if (_vocab.TryGetValue(text, out (int Id, string Token) mappedId))
|
||||||
{
|
{
|
||||||
return new Token[1] { new(mappedId.Id, mappedId.Token, (0, mappedId.Token.Length)) };
|
tokens.Add(new Token(mappedId.Id, mappedId.Token, (offset, mappedId.Token.Length)));
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
|
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
|
||||||
|
@ -301,15 +323,98 @@ namespace Microsoft.ML.Tokenizers
|
||||||
_cache.Add(textAsString, (encodedIds, textAsString));
|
_cache.Add(textAsString, (encodedIds, textAsString));
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens = new Token[encodedIds.Length];
|
tokens.Add(new Token(encodedIds[0], textAsString, (offset, text.Length)));
|
||||||
tokens[0] = new Token(encodedIds[0], textAsString, (0, text.Length));
|
|
||||||
for (int i = 1; i < encodedIds.Length; i++)
|
for (int i = 1; i < encodedIds.Length; i++)
|
||||||
{
|
{
|
||||||
// One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width.
|
// One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width.
|
||||||
tokens[i] = new Token(encodedIds[i], "", (text.Length, 0));
|
tokens.Add(new Token(encodedIds[i], "", (offset + text.Length, 0)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public override IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> EncodeToIds(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out _, out _);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> EncodeToIds(null, text, considerPreTokenization, considerNormalization, out _, out _);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids up to maximum number of tokens.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public override IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> EncodeToIds(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids up to maximum number of tokens.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public override IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> EncodeToIds(null, text, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
|
||||||
|
|
||||||
|
private IReadOnlyList<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
|
||||||
|
{
|
||||||
|
if (maxTokenCount <= 0)
|
||||||
|
{
|
||||||
|
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
|
||||||
}
|
}
|
||||||
|
|
||||||
return tokens;
|
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
|
||||||
|
{
|
||||||
|
textLength = 0;
|
||||||
|
normalizedString = null;
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
|
||||||
|
|
||||||
|
List<int> ids = new();
|
||||||
|
|
||||||
|
if (splits is not null)
|
||||||
|
{
|
||||||
|
textLength = 0;
|
||||||
|
foreach ((int Offset, int Length) split in splits)
|
||||||
|
{
|
||||||
|
EncodeToIds(textSpanToEncode.Slice(split.Offset, split.Length), ids, out int length, maxTokenCount - ids.Count);
|
||||||
|
textLength = split.Offset + length;
|
||||||
|
|
||||||
|
if (length < split.Length || ids.Count >= maxTokenCount)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
EncodeToIds(textSpanToEncode, ids, out textLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ids;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
@ -318,11 +423,11 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
/// <param name="accumulatedIds">The list of accumulated Ids.</param>
|
/// <param name="accumulatedIds">The list of accumulated Ids.</param>
|
||||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
|
||||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||||
public override int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokens = int.MaxValue)
|
private int EncodeToIds(ReadOnlySpan<char> text, IList<int> accumulatedIds, out int textLength, int maxTokenCount = int.MaxValue)
|
||||||
{
|
{
|
||||||
Debug.Assert(maxTokens > 0);
|
Debug.Assert(maxTokenCount > 0);
|
||||||
|
|
||||||
if (text.IsEmpty)
|
if (text.IsEmpty)
|
||||||
{
|
{
|
||||||
|
@ -332,7 +437,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
|
|
||||||
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
|
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
|
||||||
{
|
{
|
||||||
if (value.Ids.Length <= maxTokens)
|
if (value.Ids.Length <= maxTokenCount)
|
||||||
{
|
{
|
||||||
accumulatedIds.AddRange(value.Ids);
|
accumulatedIds.AddRange(value.Ids);
|
||||||
textLength = text.Length;
|
textLength = text.Length;
|
||||||
|
@ -362,7 +467,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
}
|
}
|
||||||
|
|
||||||
int result;
|
int result;
|
||||||
if (encodedIds.Length <= maxTokens)
|
if (encodedIds.Length <= maxTokenCount)
|
||||||
{
|
{
|
||||||
accumulatedIds.AddRange(encodedIds);
|
accumulatedIds.AddRange(encodedIds);
|
||||||
textLength = text.Length;
|
textLength = text.Length;
|
||||||
|
@ -378,6 +483,104 @@ namespace Microsoft.ML.Tokenizers
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Get the number of tokens that the input text will be encoded to.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
|
||||||
|
public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> CountTokens(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out _, out _);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Get the number of tokens that the input text will be encoded to.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
|
||||||
|
public override int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> CountTokens(null, text, considerPreTokenization, considerNormalization, out _, out _);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
|
||||||
|
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
|
||||||
|
/// </returns>
|
||||||
|
public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
{
|
||||||
|
tokenCount = CountTokens(text, Span<char>.Empty, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
|
||||||
|
return textLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
|
||||||
|
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
|
||||||
|
/// </returns>
|
||||||
|
public override int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
{
|
||||||
|
tokenCount = CountTokens(null, text, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
|
||||||
|
return textLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int CountTokens(string? text, ReadOnlySpan<char> textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
|
||||||
|
{
|
||||||
|
if (maxTokenCount <= 0)
|
||||||
|
{
|
||||||
|
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
|
||||||
|
}
|
||||||
|
|
||||||
|
textLength = 0;
|
||||||
|
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
|
||||||
|
{
|
||||||
|
normalizedString = null;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
|
||||||
|
|
||||||
|
int count = 0;
|
||||||
|
if (splits is not null)
|
||||||
|
{
|
||||||
|
foreach ((int Offset, int Length) split in splits)
|
||||||
|
{
|
||||||
|
count += CountTokens(textSpanToEncode.Slice(split.Offset, split.Length), out int length, maxTokenCount - count);
|
||||||
|
textLength = split.Offset + length;
|
||||||
|
|
||||||
|
if (length < split.Length || count >= maxTokenCount)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
count = CountTokens(textSpanToEncode, out textLength, maxTokenCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Get the number of tokens that the input text will be encoded to.
|
/// Get the number of tokens that the input text will be encoded to.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
@ -385,7 +588,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||||
public override int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue)
|
private int CountTokens(ReadOnlySpan<char> text, out int textLength, int maxTokens = int.MaxValue)
|
||||||
{
|
{
|
||||||
Debug.Assert(maxTokens > 0);
|
Debug.Assert(maxTokens > 0);
|
||||||
|
|
||||||
|
@ -439,6 +642,76 @@ namespace Microsoft.ML.Tokenizers
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if normalization is enabled;
|
||||||
|
/// conversely, if all tokens fit, the result will be 0.
|
||||||
|
/// </returns>
|
||||||
|
public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> LastIndexOf(text, Span<char>.Empty, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="normalizedString"/>; conversely, if all tokens fit, the result will be 0.
|
||||||
|
/// </returns>
|
||||||
|
public override int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
=> LastIndexOf(null, text, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
|
||||||
|
|
||||||
|
private int LastIndexOf(string? text, ReadOnlySpan<char> textSpan, int maxTokenCount, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int tokenCount)
|
||||||
|
{
|
||||||
|
if (maxTokenCount <= 0)
|
||||||
|
{
|
||||||
|
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
|
||||||
|
{
|
||||||
|
normalizedString = null;
|
||||||
|
tokenCount = 0;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan<char> textSpanToEncode);
|
||||||
|
|
||||||
|
if (splits is not null)
|
||||||
|
{
|
||||||
|
tokenCount = 0;
|
||||||
|
foreach ((int Offset, int Length) split in splits.Reverse())
|
||||||
|
{
|
||||||
|
tokenCount += CountTokensFromEnd(textSpanToEncode.Slice(split.Offset, split.Length), out int textIndex, maxTokenCount - tokenCount);
|
||||||
|
if (textIndex > 0 || tokenCount >= maxTokenCount)
|
||||||
|
{
|
||||||
|
return split.Offset + textIndex;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
tokenCount = CountTokensFromEnd(textSpanToEncode, out int textLength, maxTokenCount);
|
||||||
|
return textLength;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Get the number of tokens that the input text will be encoded to.
|
/// Get the number of tokens that the input text will be encoded to.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
@ -446,7 +719,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
|
/// <param name="textIndex">Starting from this index to the end of the text will encompasses the maximum encoded tokens.</param>
|
||||||
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
/// <param name="maxTokens">The maximum number of tokens to encode.</param>
|
||||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||||
public override int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue)
|
private int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex, int maxTokens = int.MaxValue)
|
||||||
{
|
{
|
||||||
Debug.Assert(maxTokens > 0);
|
Debug.Assert(maxTokens > 0);
|
||||||
|
|
||||||
|
@ -510,7 +783,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
{
|
{
|
||||||
if (token.IsEmpty)
|
if (token.IsEmpty)
|
||||||
{
|
{
|
||||||
return 0;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (_cache.TryGetValue(token, out (int[] Ids, string Token) value))
|
if (_cache.TryGetValue(token, out (int[] Ids, string Token) value))
|
||||||
|
@ -676,8 +949,8 @@ namespace Microsoft.ML.Tokenizers
|
||||||
[
|
[
|
||||||
// chat
|
// chat
|
||||||
( "gpt-4-", ModelEncoding.Cl100kBase), // e.g., gpt-4-0314, etc., plus gpt-4-32k
|
( "gpt-4-", ModelEncoding.Cl100kBase), // e.g., gpt-4-0314, etc., plus gpt-4-32k
|
||||||
( "gpt-3.5-turbo-", ModelEncoding.Cl100kBase), // e.g, gpt-3.5-turbo-0301, -0401, etc.
|
( "gpt-3.5-", ModelEncoding.Cl100kBase), // e.g, gpt-3.5-turbo-0301, -0401, etc.
|
||||||
( "gpt-35-turbo-", ModelEncoding.Cl100kBase ) // Azure deployment name
|
( "gpt-35-", ModelEncoding.Cl100kBase ) // Azure deployment name
|
||||||
];
|
];
|
||||||
|
|
||||||
private static readonly Dictionary<string, ModelEncoding> _modelToEncoding =
|
private static readonly Dictionary<string, ModelEncoding> _modelToEncoding =
|
||||||
|
@ -687,6 +960,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
{ "gpt-4", ModelEncoding.Cl100kBase },
|
{ "gpt-4", ModelEncoding.Cl100kBase },
|
||||||
{ "gpt-3.5-turbo", ModelEncoding.Cl100kBase },
|
{ "gpt-3.5-turbo", ModelEncoding.Cl100kBase },
|
||||||
{ "gpt-3.5-turbo-16k", ModelEncoding.Cl100kBase },
|
{ "gpt-3.5-turbo-16k", ModelEncoding.Cl100kBase },
|
||||||
|
{ "gpt-35", ModelEncoding.Cl100kBase }, // Azure deployment name
|
||||||
{ "gpt-35-turbo", ModelEncoding.Cl100kBase }, // Azure deployment name
|
{ "gpt-35-turbo", ModelEncoding.Cl100kBase }, // Azure deployment name
|
||||||
{ "gpt-35-turbo-16k", ModelEncoding.Cl100kBase }, // Azure deployment name
|
{ "gpt-35-turbo-16k", ModelEncoding.Cl100kBase }, // Azure deployment name
|
||||||
|
|
||||||
|
@ -817,26 +1091,13 @@ namespace Microsoft.ML.Tokenizers
|
||||||
|
|
||||||
private static readonly ConcurrentDictionary<string, (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase);
|
private static readonly ConcurrentDictionary<string, (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase);
|
||||||
|
|
||||||
internal static Tokenizer CreateTokenizerForModel(
|
internal static Tokenizer CreateForModel(
|
||||||
string modelName,
|
ModelEncoding modelEncoding,
|
||||||
IReadOnlyDictionary<string, int>? extraSpecialTokens = null,
|
string? modelName = null,
|
||||||
Normalizer? normalizer = null)
|
IReadOnlyDictionary<string, int>? extraSpecialTokens = null,
|
||||||
|
Normalizer? normalizer = null)
|
||||||
{
|
{
|
||||||
if (string.IsNullOrEmpty(modelName))
|
(Dictionary<string, int> SpecialTokens, Regex Regex, string VocabFile) tiktokenConfiguration = GetTiktokenConfigurations(modelEncoding, modelName);
|
||||||
{
|
|
||||||
throw new ArgumentNullException(nameof(modelName));
|
|
||||||
}
|
|
||||||
|
|
||||||
return CreateTokenizerForModel(GetModelEncoding(modelName), modelName, extraSpecialTokens, normalizer);
|
|
||||||
}
|
|
||||||
|
|
||||||
internal static Tokenizer CreateTokenizerForModel(
|
|
||||||
ModelEncoding modelEncoding,
|
|
||||||
string? modelName = null,
|
|
||||||
IReadOnlyDictionary<string, int>? extraSpecialTokens = null,
|
|
||||||
Normalizer? normalizer = null)
|
|
||||||
{
|
|
||||||
(Dictionary<string, int> SpecialTokens, Regex Regex, string VocabFile) tiktokenConfiguration = Tiktoken.GetTiktokenConfigurations(modelEncoding, modelName);
|
|
||||||
|
|
||||||
if (extraSpecialTokens is not null)
|
if (extraSpecialTokens is not null)
|
||||||
{
|
{
|
||||||
|
@ -858,10 +1119,14 @@ namespace Microsoft.ML.Tokenizers
|
||||||
_tiktokenCache.TryAdd(tiktokenConfiguration.VocabFile, cache);
|
_tiktokenCache.TryAdd(tiktokenConfiguration.VocabFile, cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Tokenizer(
|
return new Tiktoken(
|
||||||
new Tiktoken(cache.encoder, cache.decoder, cache.vocab, tiktokenConfiguration.SpecialTokens, LruCache<int[]>.DefaultCacheSize),
|
cache.encoder,
|
||||||
new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
|
cache.decoder,
|
||||||
normalizer);
|
cache.vocab,
|
||||||
|
new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
|
||||||
|
tiktokenConfiguration.SpecialTokens,
|
||||||
|
normalizer,
|
||||||
|
LruCache<int[]>.DefaultCacheSize);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -97,24 +97,24 @@ namespace Microsoft.ML.Tokenizers
|
||||||
return changes;
|
return changes;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void MergeAll(Dictionary<Pair<int>, (int, int)> merges, float? dropout)
|
public void MergeAll(Dictionary<Pair<int>, (int, int)> merges, float? dropout, ref PriorityQueue<Merge>? priorityQueue)
|
||||||
{
|
{
|
||||||
// Queue<Merge> queue = new Queue<Merge>(_symbols.Count);
|
priorityQueue ??= new PriorityQueue<Merge>(_symbols.Count);
|
||||||
PriorityQueue<Merge> queue = new PriorityQueue<Merge>(_symbols.Count);
|
priorityQueue.Clear();
|
||||||
|
|
||||||
Vec<Merge> skip = new Vec<Merge>(queue.Count);
|
Vec<Merge> skip = new Vec<Merge>(priorityQueue.Count);
|
||||||
|
|
||||||
for (int i = 0; i < _symbols.Count - 1; i++)
|
for (int i = 0; i < _symbols.Count - 1; i++)
|
||||||
{
|
{
|
||||||
if (merges.TryGetValue(Pair<int>.Create(_symbols[i].C, _symbols[i + 1].C), out (int m1, int m2) value))
|
if (merges.TryGetValue(Pair<int>.Create(_symbols[i].C, _symbols[i + 1].C), out (int m1, int m2) value))
|
||||||
{
|
{
|
||||||
queue.Enqueue(new Merge(i, value.m1, value.m2));
|
priorityQueue.Enqueue(new Merge(i, value.m1, value.m2));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
while (queue.Count > 0)
|
while (priorityQueue.Count > 0)
|
||||||
{
|
{
|
||||||
Merge top = queue.Dequeue();
|
Merge top = priorityQueue.Dequeue();
|
||||||
if (dropout.HasValue && (_random ??= new()).NextDouble() < dropout)
|
if (dropout.HasValue && (_random ??= new()).NextDouble() < dropout)
|
||||||
{
|
{
|
||||||
skip.Push(top);
|
skip.Push(top);
|
||||||
|
@ -124,7 +124,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
// Re-insert the skipped elements
|
// Re-insert the skipped elements
|
||||||
for (int i = 0; i < skip.Count; i++)
|
for (int i = 0; i < skip.Count; i++)
|
||||||
{
|
{
|
||||||
queue.Enqueue(skip[i]);
|
priorityQueue.Enqueue(skip[i]);
|
||||||
}
|
}
|
||||||
skip.Clear();
|
skip.Clear();
|
||||||
|
|
||||||
|
@ -166,7 +166,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
|
|
||||||
if (merges.TryGetValue(newPair, out value))
|
if (merges.TryGetValue(newPair, out value))
|
||||||
{
|
{
|
||||||
queue.Enqueue(new Merge(current.Prev, value.m1, value.m2));
|
priorityQueue.Enqueue(new Merge(current.Prev, value.m1, value.m2));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -178,7 +178,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
Pair<int> newPair = Pair<int>.Create(current.C, nextSymbol.C);
|
Pair<int> newPair = Pair<int>.Create(current.C, nextSymbol.C);
|
||||||
if (merges.TryGetValue(newPair, out value))
|
if (merges.TryGetValue(newPair, out value))
|
||||||
{
|
{
|
||||||
queue.Enqueue(new Merge(top.Pos, value.m1, value.m2));
|
priorityQueue.Enqueue(new Merge(top.Pos, value.m1, value.m2));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -289,19 +289,16 @@ namespace Microsoft.ML.Tokenizers
|
||||||
return sb.ToString();
|
return sb.ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<Token> ToTokens(SortedDictionary<int, string> vocabReverse)
|
public void ToTokens(SortedDictionary<int, string> vocabReverse, List<Token> tokens, int offset)
|
||||||
{
|
{
|
||||||
List<Token> tokens = new(SymbolsCount);
|
|
||||||
int index = 0;
|
int index = 0;
|
||||||
|
|
||||||
for (int i = 0; i < SymbolsCount; i++)
|
for (int i = 0; i < SymbolsCount; i++)
|
||||||
{
|
{
|
||||||
int endIndex = index + _symbols[i].Len;
|
int endIndex = index + _symbols[i].Len;
|
||||||
tokens.Add(new Token(_symbols[i].C, vocabReverse[_symbols[i].C], (index, _symbols[i].Len)));
|
tokens.Add(new Token(_symbols[i].C, vocabReverse[_symbols[i].C], (index + offset, _symbols[i].Len)));
|
||||||
index += _symbols[i].Len;
|
index += _symbols[i].Len;
|
||||||
}
|
}
|
||||||
|
|
||||||
return tokens;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,8 @@
|
||||||
// See the LICENSE file in the project root for more information.
|
// See the LICENSE file in the project root for more information.
|
||||||
|
|
||||||
using System;
|
using System;
|
||||||
|
using System.Buffers;
|
||||||
|
using System.Diagnostics;
|
||||||
|
|
||||||
namespace Microsoft.ML.Tokenizers
|
namespace Microsoft.ML.Tokenizers
|
||||||
{
|
{
|
||||||
|
@ -22,5 +24,27 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// <param name="original">The original string to normalize to lowercase form.</param>
|
/// <param name="original">The original string to normalize to lowercase form.</param>
|
||||||
/// <returns>The lower-cased normalized string.</returns>
|
/// <returns>The lower-cased normalized string.</returns>
|
||||||
public override string Normalize(string original) => original.ToLowerInvariant();
|
public override string Normalize(string original) => original.ToLowerInvariant();
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Lowercase the original string.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="original">The original string to normalize to lowercase form.</param>
|
||||||
|
/// <returns>The lower-cased normalized string.</returns>
|
||||||
|
public override string Normalize(ReadOnlySpan<char> original)
|
||||||
|
{
|
||||||
|
if (original.IsEmpty)
|
||||||
|
{
|
||||||
|
return string.Empty;
|
||||||
|
}
|
||||||
|
|
||||||
|
char[] arrayPoolArray = ArrayPool<char>.Shared.Rent(original.Length);
|
||||||
|
|
||||||
|
int length = original.ToLowerInvariant(arrayPoolArray);
|
||||||
|
Debug.Assert(length == original.Length);
|
||||||
|
|
||||||
|
string result = new string(arrayPoolArray, 0, length);
|
||||||
|
ArrayPool<char>.Shared.Return(arrayPoolArray);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,5 +17,12 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// <param name="original">The original string to normalize.</param>
|
/// <param name="original">The original string to normalize.</param>
|
||||||
/// <returns>The normalized string.</returns>
|
/// <returns>The normalized string.</returns>
|
||||||
public abstract string Normalize(string original);
|
public abstract string Normalize(string original);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Process the original string to modify it and obtain a normalized string.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="original">The original string to normalize.</param>
|
||||||
|
/// <returns>The normalized string.</returns>
|
||||||
|
public abstract string Normalize(ReadOnlySpan<char> original);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,14 +10,14 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Normalize the string to lowercase form before processing it with the tokenizer.
|
/// Normalize the string to lowercase form before processing it with the tokenizer.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public sealed class LlamaNormalizer : Normalizer
|
public sealed class SentencePieceNormalizer : Normalizer
|
||||||
{
|
{
|
||||||
internal const char DummyPrefix = '\u2581'; // '▁' (LOWER ONE EIGHT BLOCK)
|
internal const char DummyPrefix = '\u2581'; // '▁' (LOWER ONE EIGHT BLOCK)
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Creates a LowerCaseNormalizer object.
|
/// Creates a LowerCaseNormalizer object.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public LlamaNormalizer(bool removeExtraWhiteSpaces, bool addDummyPrefix, bool escapeWhiteSpaces, bool treatWhitespaceAsSuffix)
|
public SentencePieceNormalizer(bool removeExtraWhiteSpaces, bool addDummyPrefix, bool escapeWhiteSpaces, bool treatWhitespaceAsSuffix)
|
||||||
{
|
{
|
||||||
RemoveExtraWhiteSpaces = removeExtraWhiteSpaces;
|
RemoveExtraWhiteSpaces = removeExtraWhiteSpaces;
|
||||||
AddDummyPrefix = addDummyPrefix;
|
AddDummyPrefix = addDummyPrefix;
|
||||||
|
@ -40,7 +40,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
public bool TreatWhitespaceAsSuffix { get; }
|
public bool TreatWhitespaceAsSuffix { get; }
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Normalize the original string according to SentencePiece normalization with Llama model.
|
/// Normalize the original string according to SentencePiece normalization.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="original">The original string to normalize.</param>
|
/// <param name="original">The original string to normalize.</param>
|
||||||
/// <returns>The normalized string.</returns>
|
/// <returns>The normalized string.</returns>
|
||||||
|
@ -51,6 +51,16 @@ namespace Microsoft.ML.Tokenizers
|
||||||
return string.Empty;
|
return string.Empty;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return Normalize(original.AsSpan());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Normalize the original string according to SentencePiece normalization.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="original">The original string to normalize.</param>
|
||||||
|
/// <returns>The normalized string.</returns>
|
||||||
|
public override string Normalize(ReadOnlySpan<char> original)
|
||||||
|
{
|
||||||
int startIndex = 0;
|
int startIndex = 0;
|
||||||
int endIndex = original.Length - 1;
|
int endIndex = original.Length - 1;
|
||||||
|
|
|
@ -3,6 +3,8 @@
|
||||||
// See the LICENSE file in the project root for more information.
|
// See the LICENSE file in the project root for more information.
|
||||||
|
|
||||||
using System;
|
using System;
|
||||||
|
using System.Buffers;
|
||||||
|
using System.Diagnostics;
|
||||||
|
|
||||||
namespace Microsoft.ML.Tokenizers
|
namespace Microsoft.ML.Tokenizers
|
||||||
{
|
{
|
||||||
|
@ -22,5 +24,27 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// <param name="original">The original string to normalize to uppercase form.</param>
|
/// <param name="original">The original string to normalize to uppercase form.</param>
|
||||||
/// <returns>The upper-cased normalized string.</returns>
|
/// <returns>The upper-cased normalized string.</returns>
|
||||||
public override string Normalize(string original) => original.ToUpperInvariant();
|
public override string Normalize(string original) => original.ToUpperInvariant();
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Uppercase the original string.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="original">The original string to normalize to uppercase form.</param>
|
||||||
|
/// <returns>The upper-cased normalized string.</returns>
|
||||||
|
public override string Normalize(ReadOnlySpan<char> original)
|
||||||
|
{
|
||||||
|
if (original.IsEmpty)
|
||||||
|
{
|
||||||
|
return string.Empty;
|
||||||
|
}
|
||||||
|
|
||||||
|
char[] arrayPoolArray = ArrayPool<char>.Shared.Rent(original.Length);
|
||||||
|
|
||||||
|
int length = original.ToUpperInvariant(arrayPoolArray);
|
||||||
|
Debug.Assert(length == original.Length);
|
||||||
|
|
||||||
|
string result = new string(arrayPoolArray, 0, length);
|
||||||
|
ArrayPool<char>.Shared.Return(arrayPoolArray);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,66 +3,12 @@
|
||||||
// See the LICENSE file in the project root for more information.
|
// See the LICENSE file in the project root for more information.
|
||||||
|
|
||||||
using System;
|
using System;
|
||||||
|
using System.Buffers;
|
||||||
using System.Collections.Generic;
|
using System.Collections.Generic;
|
||||||
using System.Text.RegularExpressions;
|
using System.Text.RegularExpressions;
|
||||||
|
|
||||||
namespace Microsoft.ML.Tokenizers
|
namespace Microsoft.ML.Tokenizers
|
||||||
{
|
{
|
||||||
/// <summary>
|
|
||||||
/// This Split contains the underlying split token as well as its offsets
|
|
||||||
/// in the original string. These offsets are in the `original` referential.
|
|
||||||
/// It also contains any `Token` associated to the current split.
|
|
||||||
/// </summary>
|
|
||||||
public struct Split : IEquatable<Split>
|
|
||||||
{
|
|
||||||
private readonly string? _originalString;
|
|
||||||
private string? _tokenString;
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Gets the underlying split token. Each SubString is represented by a token
|
|
||||||
/// and in the end we might be carrying a lot of SubString representing various parts of the
|
|
||||||
/// original input string.
|
|
||||||
/// </summary>
|
|
||||||
public string TokenString => _tokenString ??= _originalString!.Substring(Offset.Index, Offset.Length);
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Gets the underlying split token as a span.
|
|
||||||
/// </summary>
|
|
||||||
public ReadOnlySpan<char> TokenSpan => _tokenString is string s ? s.AsSpan() : _originalString.AsSpan(Offset.Index, Offset.Length);
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Returns the offset mapping to the original string
|
|
||||||
/// </summary>
|
|
||||||
public (int Index, int Length) Offset { get; }
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// create a Split object using the token and the offset
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="token">The token string</param>
|
|
||||||
/// <param name="offset">The offset mapping to the original string</param>
|
|
||||||
public Split(string token, (int Index, int Length) offset)
|
|
||||||
{
|
|
||||||
_tokenString = token;
|
|
||||||
Offset = offset;
|
|
||||||
}
|
|
||||||
|
|
||||||
internal Split(string originalString, string? token, (int Index, int Length) offset)
|
|
||||||
{
|
|
||||||
_originalString = originalString;
|
|
||||||
_tokenString = token;
|
|
||||||
Offset = offset;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Indicates whether the current Split object is equal to another Split object.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="other">The Split object to compare with the current object.</param>
|
|
||||||
public bool Equals(Split other) =>
|
|
||||||
(_originalString == other._originalString || TokenString == other.TokenString) &&
|
|
||||||
Offset.Index == other.Offset.Index &&
|
|
||||||
Offset.Length == other.Offset.Length;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Base class for all pre-tokenizers classes.
|
/// Base class for all pre-tokenizers classes.
|
||||||
/// The PreTokenizer is in charge of doing the pre-segmentation step.
|
/// The PreTokenizer is in charge of doing the pre-segmentation step.
|
||||||
|
@ -70,23 +16,54 @@ namespace Microsoft.ML.Tokenizers
|
||||||
public abstract class PreTokenizer
|
public abstract class PreTokenizer
|
||||||
{
|
{
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string.
|
/// Get the offsets and lengths of the tokens relative to the <paramref name="text"/>.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The string to split into tokens.</param>
|
/// <param name="text">The string to split into tokens.</param>
|
||||||
/// <returns>The list of the splits containing the tokens and the token's offsets to the original string.</returns>
|
/// <returns>The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.</returns>
|
||||||
public abstract IEnumerable<Split> PreTokenize(string text);
|
public abstract IEnumerable<(int Offset, int Length)> PreTokenize(string text);
|
||||||
|
|
||||||
internal static IEnumerable<Split> SplitText(string text, Regex regex)
|
/// <summary>
|
||||||
|
/// Get the offsets and lengths of the tokens relative to the original string.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The character span to split into tokens.</param>
|
||||||
|
/// <returns>The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.</returns>
|
||||||
|
public abstract IEnumerable<(int Offset, int Length)> PreTokenize(ReadOnlySpan<char> text);
|
||||||
|
|
||||||
|
internal static IEnumerable<(int Offset, int Length)> SplitText(string text, Regex regex)
|
||||||
{
|
{
|
||||||
(int Offset, int Length) match;
|
(int Offset, int Length) match;
|
||||||
int beginning = 0;
|
int beginning = 0;
|
||||||
while (TryGetMatch(regex, text, beginning, text.Length - beginning, out match))
|
while (TryGetMatch(regex, text, beginning, text.Length - beginning, out match))
|
||||||
{
|
{
|
||||||
yield return new Split(text, null, (match.Offset, match.Length));
|
yield return (match.Offset, match.Length);
|
||||||
beginning = match.Offset + match.Length;
|
beginning = match.Offset + match.Length;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal static IEnumerable<(int Offset, int Length)> SplitText(ReadOnlySpan<char> text, Regex regex)
|
||||||
|
{
|
||||||
|
#if NET7_0_OR_GREATER
|
||||||
|
char[] buffer = ArrayPool<char>.Shared.Rent(text.Length);
|
||||||
|
text.CopyTo(buffer);
|
||||||
|
return SplitText(buffer, regex, text.Length);
|
||||||
|
|
||||||
|
static IEnumerable<(int Offset, int Length)> SplitText(char[] text, Regex regex, int textLength)
|
||||||
|
{
|
||||||
|
(int Offset, int Length) match;
|
||||||
|
int beginning = 0;
|
||||||
|
while (TryGetMatch(regex, text, beginning, textLength - beginning, out match))
|
||||||
|
{
|
||||||
|
yield return (match.Offset, match.Length);
|
||||||
|
beginning = match.Offset + match.Length;
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayPool<char>.Shared.Return(text);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
return SplitText(text.ToString(), regex);
|
||||||
|
#endif // NET7_0_OR_GREATER
|
||||||
|
}
|
||||||
|
|
||||||
internal static bool TryGetMatch(Regex regex, string text, int beginning, int length, out (int offset, int length) match)
|
internal static bool TryGetMatch(Regex regex, string text, int beginning, int length, out (int offset, int length) match)
|
||||||
{
|
{
|
||||||
#if NET7_0_OR_GREATER
|
#if NET7_0_OR_GREATER
|
||||||
|
@ -106,5 +83,18 @@ namespace Microsoft.ML.Tokenizers
|
||||||
match = default;
|
match = default;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if NET7_0_OR_GREATER
|
||||||
|
internal static bool TryGetMatch(Regex regex, scoped ReadOnlySpan<char> text, int beginning, int length, out (int offset, int length) match)
|
||||||
|
{
|
||||||
|
foreach (ValueMatch m in regex.EnumerateMatches(text.Slice(beginning, length)))
|
||||||
|
{
|
||||||
|
match = (beginning + m.Index, m.Length);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
match = default;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#endif // NET7_0_OR_GREATER
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,15 +18,30 @@ namespace Microsoft.ML.Tokenizers
|
||||||
public static RobertaPreTokenizer Instance { get; } = new RobertaPreTokenizer();
|
public static RobertaPreTokenizer Instance { get; } = new RobertaPreTokenizer();
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string.
|
/// Get the offsets and lengths of the tokens relative to the <paramref name="text"/>.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The string to split into tokens.</param>
|
/// <param name="text">The string to split into tokens.</param>
|
||||||
/// <returns>The list of the splits containing the tokens and the token's offsets to the original string.</returns>
|
/// <returns>The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.</returns>
|
||||||
public override IEnumerable<Split> PreTokenize(string text)
|
public override IEnumerable<(int Offset, int Length)> PreTokenize(string text)
|
||||||
{
|
{
|
||||||
if (string.IsNullOrEmpty(text))
|
if (string.IsNullOrEmpty(text))
|
||||||
{
|
{
|
||||||
return Array.Empty<Split>();
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
return SplitText(text, Tiktoken.P50kBaseRegex());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Get the offsets and lengths of the tokens relative to the <paramref name="text"/>.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The string to split into tokens.</param>
|
||||||
|
/// <returns>The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.</returns>
|
||||||
|
public override IEnumerable<(int Offset, int Length)> PreTokenize(ReadOnlySpan<char> text)
|
||||||
|
{
|
||||||
|
if (text.IsEmpty)
|
||||||
|
{
|
||||||
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
return SplitText(text, Tiktoken.P50kBaseRegex());
|
return SplitText(text, Tiktoken.P50kBaseRegex());
|
||||||
|
|
|
@ -1,35 +0,0 @@
|
||||||
// Licensed to the .NET Foundation under one or more agreements.
|
|
||||||
// The .NET Foundation licenses this file to you under the MIT license.
|
|
||||||
// See the LICENSE file in the project root for more information.
|
|
||||||
|
|
||||||
using System;
|
|
||||||
using System.Collections.Generic;
|
|
||||||
|
|
||||||
namespace Microsoft.ML.Tokenizers
|
|
||||||
{
|
|
||||||
/// <summary>
|
|
||||||
/// The pre-tokenizer for SentencePiece tokenizers.
|
|
||||||
/// </summary>
|
|
||||||
internal sealed partial class SentencePiecePreTokenizer : PreTokenizer
|
|
||||||
{
|
|
||||||
/// <summary>
|
|
||||||
/// Gets a singleton instance of the Roberta pre-tokenizer..
|
|
||||||
/// </summary>
|
|
||||||
public static SentencePiecePreTokenizer Instance { get; } = new SentencePiecePreTokenizer();
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Return the whole text as one chunk.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="text">The string to split into tokens.</param>
|
|
||||||
/// <returns>The original string as one chunk.</returns>
|
|
||||||
public override IEnumerable<Split> PreTokenize(string text)
|
|
||||||
{
|
|
||||||
if (string.IsNullOrEmpty(text))
|
|
||||||
{
|
|
||||||
yield break;
|
|
||||||
}
|
|
||||||
|
|
||||||
yield return new Split(text, (0, text.Length));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -3,6 +3,7 @@
|
||||||
// See the LICENSE file in the project root for more information.
|
// See the LICENSE file in the project root for more information.
|
||||||
|
|
||||||
using System;
|
using System;
|
||||||
|
using System.Buffers;
|
||||||
using System.Collections.Generic;
|
using System.Collections.Generic;
|
||||||
using System.Linq;
|
using System.Linq;
|
||||||
using System.Text.RegularExpressions;
|
using System.Text.RegularExpressions;
|
||||||
|
@ -39,20 +40,20 @@ namespace Microsoft.ML.Tokenizers
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string.
|
/// Get the offsets and lengths of the tokens relative to the <paramref name="text"/>.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The string to split into tokens.</param>
|
/// <param name="text">The string to split into tokens.</param>
|
||||||
/// <returns>The list of the splits containing the tokens and the token's offsets to the original string.</returns>
|
/// <returns>The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.</returns>
|
||||||
public override IEnumerable<Split> PreTokenize(string text)
|
public override IEnumerable<(int Offset, int Length)> PreTokenize(string text)
|
||||||
{
|
{
|
||||||
if (string.IsNullOrEmpty(text))
|
if (string.IsNullOrEmpty(text))
|
||||||
{
|
{
|
||||||
return Array.Empty<Split>();
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
return SplitText(text, _regex, _specialTokensRegex);
|
return SplitText(text, _regex, _specialTokensRegex);
|
||||||
|
|
||||||
static IEnumerable<Split> SplitText(string text, Regex regex, Regex? specialTokensRegex)
|
static IEnumerable<(int Offset, int Length)> SplitText(string text, Regex regex, Regex? specialTokensRegex)
|
||||||
{
|
{
|
||||||
(int Offset, int Length) match;
|
(int Offset, int Length) match;
|
||||||
int beginning = 0;
|
int beginning = 0;
|
||||||
|
@ -69,21 +70,77 @@ namespace Microsoft.ML.Tokenizers
|
||||||
|
|
||||||
while (TryGetMatch(regex, text, beginning, specialMatch.Offset - beginning, out match))
|
while (TryGetMatch(regex, text, beginning, specialMatch.Offset - beginning, out match))
|
||||||
{
|
{
|
||||||
yield return new Split(text, null, (match.Offset, match.Length));
|
yield return (match.Offset, match.Length);
|
||||||
beginning = match.Offset + match.Length;
|
beginning = match.Offset + match.Length;
|
||||||
}
|
}
|
||||||
|
|
||||||
yield return new Split(text, null, (specialMatch.Offset, specialMatch.Length));
|
yield return (specialMatch.Offset, specialMatch.Length);
|
||||||
beginning = specialMatch.Offset + specialMatch.Length;
|
beginning = specialMatch.Offset + specialMatch.Length;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
while (TryGetMatch(regex, text, beginning, text.Length - beginning, out match))
|
while (TryGetMatch(regex, text, beginning, text.Length - beginning, out match))
|
||||||
{
|
{
|
||||||
yield return new Split(text, null, (match.Offset, match.Length));
|
yield return (match.Offset, match.Length);
|
||||||
beginning = match.Length + match.Offset;
|
beginning = match.Length + match.Offset;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Get the offsets and lengths of the tokens relative to the <paramref name="text"/>.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The string to split into tokens.</param>
|
||||||
|
/// <returns>The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.</returns>
|
||||||
|
public override IEnumerable<(int Offset, int Length)> PreTokenize(ReadOnlySpan<char> text)
|
||||||
|
{
|
||||||
|
if (text.IsEmpty)
|
||||||
|
{
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
#if NET7_0_OR_GREATER
|
||||||
|
char[] buffer = ArrayPool<char>.Shared.Rent(text.Length);
|
||||||
|
text.CopyTo(buffer);
|
||||||
|
return SplitText(buffer, _regex, _specialTokensRegex, text.Length);
|
||||||
|
|
||||||
|
static IEnumerable<(int Offset, int Length)> SplitText(char[] text, Regex regex, Regex? specialTokensRegex, int textLength)
|
||||||
|
{
|
||||||
|
(int Offset, int Length) match;
|
||||||
|
int beginning = 0;
|
||||||
|
|
||||||
|
if (specialTokensRegex is not null)
|
||||||
|
{
|
||||||
|
while (true)
|
||||||
|
{
|
||||||
|
(int Offset, int Length) specialMatch;
|
||||||
|
if (!TryGetMatch(specialTokensRegex, text.AsSpan(), beginning, textLength - beginning, out specialMatch))
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (TryGetMatch(regex, text.AsSpan(), beginning, specialMatch.Offset - beginning, out match))
|
||||||
|
{
|
||||||
|
yield return (match.Offset, match.Length);
|
||||||
|
beginning = match.Offset + match.Length;
|
||||||
|
}
|
||||||
|
|
||||||
|
yield return (specialMatch.Offset, specialMatch.Length);
|
||||||
|
beginning = specialMatch.Offset + specialMatch.Length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while (TryGetMatch(regex, text.AsSpan(), beginning, textLength - beginning, out match))
|
||||||
|
{
|
||||||
|
yield return (match.Offset, match.Length);
|
||||||
|
beginning = match.Length + match.Offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayPool<char>.Shared.Return(text);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
return PreTokenize(text.ToString());
|
||||||
|
#endif // NET7_0_OR_GREATER
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,15 +29,30 @@ namespace Microsoft.ML.Tokenizers
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string.
|
/// Get the offsets and lengths of the tokens relative to the <paramref name="text"/>.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The string to split into tokens.</param>
|
/// <param name="text">The string to split into tokens.</param>
|
||||||
/// <returns>The list of the splits containing the tokens and the token's offsets to the original string.</returns>
|
/// <returns>The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.</returns>
|
||||||
public override IEnumerable<Split> PreTokenize(string text)
|
public override IEnumerable<(int Offset, int Length)> PreTokenize(string text)
|
||||||
{
|
{
|
||||||
if (string.IsNullOrEmpty(text))
|
if (string.IsNullOrEmpty(text))
|
||||||
{
|
{
|
||||||
return Array.Empty<Split>();
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
return SplitText(text, PretokenizeRegex());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Get the offsets and lengths of the tokens relative to the <paramref name="text"/>.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The string to split into tokens.</param>
|
||||||
|
/// <returns>The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.</returns>
|
||||||
|
public override IEnumerable<(int Offset, int Length)> PreTokenize(ReadOnlySpan<char> text)
|
||||||
|
{
|
||||||
|
if (text.IsEmpty)
|
||||||
|
{
|
||||||
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
return SplitText(text, PretokenizeRegex());
|
return SplitText(text, PretokenizeRegex());
|
||||||
|
|
|
@ -12,7 +12,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// Represent the token produced from the tokenization process containing the token substring,
|
/// Represent the token produced from the tokenization process containing the token substring,
|
||||||
/// the id associated to the token substring, and the offset mapping to the original string.
|
/// the id associated to the token substring, and the offset mapping to the original string.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public sealed class Token
|
public readonly struct Token
|
||||||
{
|
{
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Gets the Id value associated to the token.
|
/// Gets the Id value associated to the token.
|
||||||
|
@ -22,12 +22,12 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Gets the token string value.
|
/// Gets the token string value.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public string Value { get; set; }
|
public string Value { get; }
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Gets the offset mapping to the original string.
|
/// Gets the offset mapping to the original string.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public (int Index, int Length) Offset { get; internal set; }
|
public (int Index, int Length) Offset { get; }
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Construct a new Token object using the token value, Id, and the offset mapping to the original string.
|
/// Construct a new Token object using the token value, Id, and the offset mapping to the original string.
|
||||||
|
|
|
@ -8,6 +8,7 @@ using System.Collections.Generic;
|
||||||
using System.Diagnostics;
|
using System.Diagnostics;
|
||||||
using System.IO;
|
using System.IO;
|
||||||
using System.Linq;
|
using System.Linq;
|
||||||
|
using System.Text;
|
||||||
using System.Text.RegularExpressions;
|
using System.Text.RegularExpressions;
|
||||||
using System.Threading;
|
using System.Threading;
|
||||||
using System.Threading.Tasks;
|
using System.Threading.Tasks;
|
||||||
|
@ -15,263 +16,279 @@ using System.Threading.Tasks;
|
||||||
namespace Microsoft.ML.Tokenizers
|
namespace Microsoft.ML.Tokenizers
|
||||||
{
|
{
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// A Tokenizer works as a pipeline. It processes some raw text as input and outputs a EncodingResult object.
|
/// serves as an abstraction for concrete tokenizers, enabling the encoding of text into tokens and IDs, as well as the decoding of IDs back into text.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public partial class Tokenizer
|
public abstract class Tokenizer
|
||||||
{
|
{
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Create a new Tokenizer object.
|
/// Gets the PreTokenizer used by the Tokenizer.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="model">The Model in use by the Tokenizer.</param>
|
public virtual PreTokenizer? PreTokenizer => null;
|
||||||
/// <param name="preTokenizer">The optional PreTokenizer in use by the Tokenizer. WhiteSpace PreTokenizer will be used if this parameter is null.</param>
|
|
||||||
/// <param name="normalizer">The optional Normalizer in use by the Tokenizer.</param>
|
|
||||||
public Tokenizer(Model model, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null)
|
|
||||||
{
|
|
||||||
Model = model;
|
|
||||||
PreTokenizer = preTokenizer ?? WhiteSpace.Instance;
|
|
||||||
Normalizer = normalizer;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Gets the Model in use by the Tokenizer.
|
/// Gets the Normalizer in use by the Tokenizer.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public Model Model { get; }
|
public virtual Normalizer? Normalizer => null;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Gets or sets the PreTokenizer used by the Tokenizer.
|
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
|
||||||
/// </summary>
|
|
||||||
public PreTokenizer PreTokenizer { get; }
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Gets or sets the Normalizer in use by the Tokenizer.
|
|
||||||
/// </summary>
|
|
||||||
public Normalizer? Normalizer { get; }
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Encodes input text to object has the tokens list, tokens Ids, tokens offset mapping.
|
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
/// <returns>The tokenization result includes the tokens list, tokens Ids, tokens offset mapping.</returns>
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
public EncodingResult Encode(string text)
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
{
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
if (text is null)
|
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
|
||||||
{
|
public virtual IReadOnlyList<Token> Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(text.AsSpan(), out normalizedString, considerPreTokenization, considerNormalization);
|
||||||
throw new ArgumentNullException(nameof(text));
|
|
||||||
}
|
|
||||||
|
|
||||||
string normalized = Normalizer is null ? text : Normalizer.Normalize(text);
|
|
||||||
bool offsetsMappedToOriginal = true;
|
|
||||||
|
|
||||||
EncodingResult encoding = new(text, normalized, PreTokenizer.PreTokenize(normalized), offsetsMappedToOriginal);
|
|
||||||
|
|
||||||
foreach (Split split in encoding.Splits)
|
|
||||||
{
|
|
||||||
IReadOnlyList<Token> tokens = Model.Encode(split.TokenString.AsSpan());
|
|
||||||
foreach (Token token in tokens)
|
|
||||||
{
|
|
||||||
token.Offset = (token.Offset.Index + split.Offset.Index, token.Offset.Length);
|
|
||||||
}
|
|
||||||
|
|
||||||
encoding.AddTokens(tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
return encoding;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Encodes input text to tokens Ids.
|
/// Encodes input text a list of <see cref="Token" />s with string value of the token, id, and offset.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The tokenization result includes a list of <see cref="Token" />s with string value of the token, id, and offset.</returns>
|
||||||
|
public abstract IReadOnlyList<Token> Encode(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
/// <returns>The list of encoded Ids.</returns>
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
public IReadOnlyList<int> EncodeToIds(string text)
|
public virtual IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true) => EncodeToIds(text.AsSpan(), considerPreTokenization, considerNormalization);
|
||||||
{
|
|
||||||
if (text is null)
|
|
||||||
{
|
|
||||||
throw new ArgumentNullException(nameof(text));
|
|
||||||
}
|
|
||||||
|
|
||||||
string normalized = Normalizer is not null ? Normalizer.Normalize(text) : text;
|
|
||||||
List<int> idsList = new();
|
|
||||||
|
|
||||||
foreach (Split split in PreTokenizer.PreTokenize(normalized))
|
|
||||||
{
|
|
||||||
Model.EncodeToIds(split.TokenSpan, idsList, out _);
|
|
||||||
}
|
|
||||||
|
|
||||||
return idsList;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Encodes input text to tokens Ids up to maximum number of tokens.
|
/// Encodes input text to token Ids.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
|
public abstract IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Encodes input text to token Ids up to maximum number of tokens.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
|
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
|
||||||
/// <param name="processedText">If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text.</param>
|
/// <param name="normalizedText">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
/// <returns>The list of encoded Ids.</returns>
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
public IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string processedText, out int textLength)
|
public virtual IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
{
|
=> EncodeToIds(text.AsSpan(), maxTokenCount, out normalizedText, out textLength, considerPreTokenization, considerNormalization);
|
||||||
processedText = text;
|
|
||||||
textLength = 0;
|
|
||||||
|
|
||||||
if (text is null)
|
/// <summary>
|
||||||
{
|
/// Encodes input text to token Ids up to maximum number of tokens.
|
||||||
throw new ArgumentNullException(nameof(text));
|
/// </summary>
|
||||||
}
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
|
||||||
if (maxTokenCount <= 0)
|
/// <param name="normalizedText">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
{
|
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||||
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than 0.");
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
}
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>The list of encoded Ids.</returns>
|
||||||
if (Normalizer is not null)
|
public abstract IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true);
|
||||||
{
|
|
||||||
processedText = Normalizer.Normalize(text);
|
|
||||||
}
|
|
||||||
|
|
||||||
List<int> idsList = new();
|
|
||||||
|
|
||||||
foreach (Split split in PreTokenizer.PreTokenize(processedText))
|
|
||||||
{
|
|
||||||
Model.EncodeToIds(split.TokenSpan, idsList, out int length, maxTokenCount - idsList.Count);
|
|
||||||
if (length < split.Offset.Length || idsList.Count >= maxTokenCount)
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return idsList;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Get the number of tokens that the input text will be encoded to.
|
/// Get the number of tokens that the input text will be encoded to.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
/// <returns>The number of tokens Ids that the input text will be encoded to.</returns>
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
/// <exception cref="ArgumentNullException">The input text is null.</exception>
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
/// <exception cref="ArgumentException">Unable to encode the text.</exception>
|
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
|
||||||
public int CountTokens(string text)
|
public virtual int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
{
|
=> CountTokens(text.AsSpan(), considerPreTokenization, considerNormalization);
|
||||||
if (text is null)
|
|
||||||
{
|
|
||||||
throw new ArgumentNullException(nameof(text));
|
|
||||||
}
|
|
||||||
|
|
||||||
string normalized = Normalizer is not null ? Normalizer.Normalize(text) : text;
|
/// <summary>
|
||||||
|
/// Get the number of tokens that the input text will be encoded to.
|
||||||
int idsCount = 0;
|
/// </summary>
|
||||||
foreach (Split split in PreTokenizer.PreTokenize(normalized))
|
/// <param name="text">The text to encode.</param>
|
||||||
{
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
idsCount += Model.CountTokens(split.TokenSpan, out _);
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
}
|
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
|
||||||
|
public abstract int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true);
|
||||||
return idsCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
|
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
/// <param name="processedText">If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text.</param>
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
/// <returns>
|
/// <returns>
|
||||||
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, if all tokens fit, the result will be length of the <paramref name="processedText"/>.
|
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
|
||||||
|
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
|
||||||
/// </returns>
|
/// </returns>
|
||||||
/// <exception cref="ArgumentNullException">The input text is null.</exception>
|
public virtual int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
/// <exception cref="ArgumentOutOfRangeException">The maximum token count must be greater than 0.</exception>
|
=> IndexOfTokenCount(text.AsSpan(), maxTokenCount, out normalizedString, out tokenCount, considerPreTokenization, considerNormalization);
|
||||||
public int IndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount)
|
|
||||||
=> IndexOf(text, maxTokenCount, out processedText, out tokenCount);
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
|
||||||
|
/// if all tokens fit, the result will be length of the text or the <paramref name="normalizedString"/> if the normalization is enabled.
|
||||||
|
/// </returns>
|
||||||
|
public abstract int IndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true);
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
|
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to encode.</param>
|
/// <param name="text">The text to encode.</param>
|
||||||
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
/// <param name="processedText">If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text.</param>
|
/// <param name="processedText">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
/// <returns>
|
/// <returns>
|
||||||
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="processedText"/>; conversely, if all tokens fit, the result will be 0.
|
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="processedText"/>; conversely, if all tokens fit, the result will be 0.
|
||||||
/// </returns>
|
/// </returns>
|
||||||
/// <exception cref="ArgumentNullException">The input text is null.</exception>
|
public virtual int LastIndexOfTokenCount(string text, int maxTokenCount, out string? processedText, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
/// <exception cref="ArgumentOutOfRangeException">The maximum token count must be greater than 0.</exception>
|
=> LastIndexOfTokenCount(text.AsSpan(), maxTokenCount, out processedText, out tokenCount, considerPreTokenization, considerNormalization);
|
||||||
/// <remarks>
|
|
||||||
/// If the whole text can be encoded within the token limit, the returned index will be 0.
|
|
||||||
/// </remarks>
|
|
||||||
public int LastIndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount)
|
|
||||||
=> LastIndexOf(text, maxTokenCount, out processedText, out tokenCount);
|
|
||||||
|
|
||||||
private int IndexOf(string text, int maxTokenCount, out string processedText, out int tokenCount)
|
/// <summary>
|
||||||
|
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">The text to encode.</param>
|
||||||
|
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
|
||||||
|
/// <param name="processedText">If the tokenizer's normalization is enabled or <paramRef name="considerNormalization" /> is false, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to null.</param>
|
||||||
|
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||||
|
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||||
|
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||||
|
/// <returns>
|
||||||
|
/// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||||
|
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the <paramref name="processedText"/>; conversely, if all tokens fit, the result will be 0.
|
||||||
|
/// </returns>
|
||||||
|
public abstract int LastIndexOfTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? processedText, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Map the token to encoded Id.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="token">The token to map to the Id.</param>
|
||||||
|
/// <returns>The mapped Id of the token.</returns>
|
||||||
|
public virtual int? MapTokenToId(string token)
|
||||||
{
|
{
|
||||||
if (text is null)
|
if (token is null)
|
||||||
{
|
{
|
||||||
throw new ArgumentNullException(nameof(text));
|
throw new ArgumentNullException(nameof(token));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (maxTokenCount <= 0)
|
return MapTokenToId(token.AsSpan());
|
||||||
{
|
|
||||||
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
|
|
||||||
}
|
|
||||||
|
|
||||||
processedText = Normalizer is not null ? Normalizer.Normalize(text) : text;
|
|
||||||
tokenCount = 0;
|
|
||||||
|
|
||||||
IEnumerable<Split> splits = PreTokenizer.PreTokenize(processedText);
|
|
||||||
foreach (Split split in splits)
|
|
||||||
{
|
|
||||||
tokenCount += Model.CountTokens(split.TokenSpan, out int textLength, maxTokenCount - tokenCount);
|
|
||||||
if (textLength < split.Offset.Length || tokenCount >= maxTokenCount)
|
|
||||||
{
|
|
||||||
return split.Offset.Index + textLength;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return processedText.Length;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private int LastIndexOf(string text, int maxTokenCount, out string processedText, out int tokenCount)
|
/// <summary>
|
||||||
{
|
/// Map the token to encoded Id.
|
||||||
if (text is null)
|
/// </summary>
|
||||||
{
|
/// <param name="token">The token to map to the Id.</param>
|
||||||
throw new ArgumentNullException(nameof(text));
|
/// <returns>The mapped Id of the token.</returns>
|
||||||
}
|
public abstract int? MapTokenToId(ReadOnlySpan<char> token);
|
||||||
|
|
||||||
if (maxTokenCount <= 0)
|
|
||||||
{
|
|
||||||
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
|
|
||||||
}
|
|
||||||
|
|
||||||
processedText = Normalizer is not null ? Normalizer.Normalize(text) : text;
|
|
||||||
tokenCount = 0;
|
|
||||||
|
|
||||||
IEnumerable<Split> splits = PreTokenizer.PreTokenize(processedText);
|
|
||||||
foreach (Split split in splits.Reverse())
|
|
||||||
{
|
|
||||||
tokenCount += Model.CountTokensFromEnd(split.TokenSpan, out int textIndex, maxTokenCount - tokenCount);
|
|
||||||
if (textIndex > 0 || tokenCount >= maxTokenCount)
|
|
||||||
{
|
|
||||||
return split.Offset.Index + textIndex;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Decodes the Id to the mapped token.
|
/// Decodes the Id to the mapped token.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="id">The id to map to the token.</param>
|
/// <param name="id">The id to map to the token.</param>
|
||||||
/// <returns>The decoded string or null if there is no token mapped to the input id.</returns>
|
/// <returns>The decoded string or null if there is no token mapped to the input id.</returns>
|
||||||
public string? Decode(int id) => Model.MapIdToToken(id);
|
public abstract string? MapIdToToken(int id);
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Decode the given ids, back to a String.
|
/// Decode the given ids, back to a String.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="ids">The list of ids that we want to decode.</param>
|
/// <param name="ids">The list of ids that we want to decode.</param>
|
||||||
/// <returns>The decoded string.</returns>
|
/// <returns>The decoded string.</returns>
|
||||||
public string? Decode(IEnumerable<int> ids) => Model.Decode(ids);
|
public virtual string? Decode(IEnumerable<int> ids)
|
||||||
|
{
|
||||||
|
if (ids is null)
|
||||||
|
{
|
||||||
|
throw new ArgumentNullException(nameof(ids));
|
||||||
|
}
|
||||||
|
|
||||||
|
ValueStringBuilder sb = new ValueStringBuilder();
|
||||||
|
|
||||||
|
foreach (int id in ids)
|
||||||
|
{
|
||||||
|
if (MapIdToToken(id) is string s)
|
||||||
|
{
|
||||||
|
sb.Append(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Factory Methods
|
||||||
|
//
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Create a new Tiktoken tokenizer's object asynchronously.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="vocabStream">The stream to the BPE vocab file.</param>
|
||||||
|
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
|
||||||
|
/// <param name="normalizer">The normalizer to use.</param>
|
||||||
|
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
|
||||||
|
/// <param name="cacheSize">The size of the cache to use.</param>
|
||||||
|
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
|
||||||
|
/// <returns>The tokenizer's object.</returns>
|
||||||
|
public static async Task<Tokenizer> CreateTiktokenAsync(
|
||||||
|
Stream vocabStream,
|
||||||
|
PreTokenizer? preTokenizer,
|
||||||
|
Normalizer? normalizer,
|
||||||
|
IReadOnlyDictionary<string, int>? specialTokens = null,
|
||||||
|
int cacheSize = LruCache<int[]>.DefaultCacheSize,
|
||||||
|
CancellationToken cancellationToken = default)
|
||||||
|
{
|
||||||
|
if (vocabStream is null)
|
||||||
|
{
|
||||||
|
throw new ArgumentNullException(nameof(vocabStream));
|
||||||
|
}
|
||||||
|
|
||||||
|
(Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) =
|
||||||
|
await Tiktoken.LoadTiktokenBpeAsync(vocabStream, useAsync: true, cancellationToken).ConfigureAwait(false);
|
||||||
|
|
||||||
|
return new Tiktoken(encoder, decoder, vocab, preTokenizer, specialTokens, normalizer, cacheSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Create a new Tiktoken tokenizer's object asynchronously.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="vocabFilePath">The BPE vocab file.</param>
|
||||||
|
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
|
||||||
|
/// <param name="normalizer">The normalizer to use.</param>
|
||||||
|
/// <param name="specialTokensEncoder">The dictionary mapping special tokens to Ids.</param>
|
||||||
|
/// <param name="cacheSize">The size of the cache to use.</param>
|
||||||
|
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
|
||||||
|
/// <returns>The tokenizer's object.</returns>
|
||||||
|
public static async Task<Tokenizer> CreateTiktokenAsync(
|
||||||
|
string vocabFilePath,
|
||||||
|
PreTokenizer? preTokenizer,
|
||||||
|
Normalizer? normalizer,
|
||||||
|
IReadOnlyDictionary<string, int>? specialTokensEncoder = null,
|
||||||
|
int cacheSize = LruCache<int[]>.DefaultCacheSize,
|
||||||
|
CancellationToken cancellationToken = default)
|
||||||
|
{
|
||||||
|
if (vocabFilePath is null)
|
||||||
|
{
|
||||||
|
throw new ArgumentNullException(nameof(vocabFilePath));
|
||||||
|
}
|
||||||
|
|
||||||
|
using Stream vocabStream = File.OpenRead(vocabFilePath);
|
||||||
|
return await CreateTiktokenAsync(vocabStream, preTokenizer, normalizer, specialTokensEncoder, cacheSize, cancellationToken).ConfigureAwait(false);
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Create a Tiktoken tokenizer based on model name and vocab file.
|
/// Create a Tiktoken tokenizer based on model name and vocab file.
|
||||||
|
@ -304,10 +321,11 @@ namespace Microsoft.ML.Tokenizers
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Tokenizer(
|
return new Tiktoken(vocabStream,
|
||||||
new Tiktoken(vocabStream, tiktokenConfiguration.SpecialTokens, cacheSize),
|
new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
|
||||||
new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
|
tiktokenConfiguration.SpecialTokens,
|
||||||
normalizer);
|
normalizer,
|
||||||
|
cacheSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
@ -343,10 +361,11 @@ namespace Microsoft.ML.Tokenizers
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Tokenizer(
|
return await CreateTiktokenAsync(vocabStream,
|
||||||
await Tiktoken.CreateAsync(vocabStream, tiktokenConfiguration.SpecialTokens, cacheSize, cancellationToken).ConfigureAwait(false),
|
new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
|
||||||
new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
|
normalizer,
|
||||||
normalizer);
|
tiktokenConfiguration.SpecialTokens,
|
||||||
|
cacheSize, cancellationToken).ConfigureAwait(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
@ -357,7 +376,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
/// <param name="normalizer">To normalize the text before tokenization</param>
|
/// <param name="normalizer">To normalize the text before tokenization</param>
|
||||||
/// <returns>The tokenizer</returns>
|
/// <returns>The tokenizer</returns>
|
||||||
public static Tokenizer CreateTiktokenForModel(string modelName, IReadOnlyDictionary<string, int>? extraSpecialTokens = null, Normalizer? normalizer = null)
|
public static Tokenizer CreateTiktokenForModel(string modelName, IReadOnlyDictionary<string, int>? extraSpecialTokens = null, Normalizer? normalizer = null)
|
||||||
=> Tiktoken.CreateTokenizerForModel(modelName, extraSpecialTokens, normalizer);
|
=> Tiktoken.CreateForModel(Tiktoken.GetModelEncoding(modelName), modelName, extraSpecialTokens, normalizer);
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Create tokenizer based on encoding name
|
/// Create tokenizer based on encoding name
|
||||||
|
@ -395,7 +414,7 @@ namespace Microsoft.ML.Tokenizers
|
||||||
throw new ArgumentException($"The encoding name '{encodingName}' is not supported. The only supported encoding names are: {Tiktoken.Cl100kBaseEncodingName}, {Tiktoken.P50kBaseEncodingName}, {Tiktoken.P50kEditEncodingName}, and {Tiktoken.R50kBaseEncodingName}.", nameof(encodingName));
|
throw new ArgumentException($"The encoding name '{encodingName}' is not supported. The only supported encoding names are: {Tiktoken.Cl100kBaseEncodingName}, {Tiktoken.P50kBaseEncodingName}, {Tiktoken.P50kEditEncodingName}, and {Tiktoken.R50kBaseEncodingName}.", nameof(encodingName));
|
||||||
}
|
}
|
||||||
|
|
||||||
return Tiktoken.CreateTokenizerForModel(modelEncoding, modelName: null, extraSpecialTokens, normalizer);
|
return Tiktoken.CreateForModel(modelEncoding, modelName: null, extraSpecialTokens, normalizer);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
@ -427,16 +446,70 @@ namespace Microsoft.ML.Tokenizers
|
||||||
throw new ArgumentException($"Normalization '{modelProto.NormalizerSpec.Name}' is not supported.", nameof(modelProto));
|
throw new ArgumentException($"Normalization '{modelProto.NormalizerSpec.Name}' is not supported.", nameof(modelProto));
|
||||||
}
|
}
|
||||||
|
|
||||||
LlamaNormalizer normalizer = new(
|
SentencePieceNormalizer normalizer = new(
|
||||||
modelProto.NormalizerSpec.RemoveExtraWhitespaces,
|
modelProto.NormalizerSpec.RemoveExtraWhitespaces,
|
||||||
modelProto.NormalizerSpec.AddDummyPrefix,
|
modelProto.NormalizerSpec.AddDummyPrefix,
|
||||||
modelProto.NormalizerSpec.EscapeWhitespaces,
|
modelProto.NormalizerSpec.EscapeWhitespaces,
|
||||||
modelProto.TrainerSpec.TreatWhitespaceAsSuffix);
|
modelProto.TrainerSpec.TreatWhitespaceAsSuffix);
|
||||||
|
|
||||||
return new Tokenizer(
|
return new SentencePieceBpe(modelProto, addBeginOfSentence, addEndOfSentence);
|
||||||
new SentencePieceBpe(modelProto, addBeginOfSentence, addEndOfSentence),
|
}
|
||||||
SentencePiecePreTokenizer.Instance,
|
|
||||||
normalizer);
|
internal static IEnumerable<(int Offset, int Length)>? InitializeForEncoding(
|
||||||
|
string? text,
|
||||||
|
ReadOnlySpan<char> textSpan,
|
||||||
|
bool considerPreTokenization,
|
||||||
|
bool considerNormalization,
|
||||||
|
Normalizer? normalizer,
|
||||||
|
PreTokenizer? preTokenizer,
|
||||||
|
out string? normalizedString,
|
||||||
|
out ReadOnlySpan<char> textSpanToEncode)
|
||||||
|
{
|
||||||
|
normalizedString = null;
|
||||||
|
IEnumerable<(int Offset, int Length)>? splits = null;
|
||||||
|
|
||||||
|
if (text is null)
|
||||||
|
{
|
||||||
|
if (considerNormalization && (normalizer is not null))
|
||||||
|
{
|
||||||
|
normalizedString = normalizer.Normalize(textSpan.ToString());
|
||||||
|
textSpanToEncode = normalizedString.AsSpan();
|
||||||
|
if (considerPreTokenization && preTokenizer is not null)
|
||||||
|
{
|
||||||
|
splits = preTokenizer.PreTokenize(normalizedString);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
textSpanToEncode = textSpan;
|
||||||
|
if (considerPreTokenization && preTokenizer is not null)
|
||||||
|
{
|
||||||
|
splits = preTokenizer.PreTokenize(textSpan);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if (considerNormalization && (normalizer is not null))
|
||||||
|
{
|
||||||
|
normalizedString = normalizer.Normalize(text);
|
||||||
|
textSpanToEncode = normalizedString.AsSpan();
|
||||||
|
if (considerPreTokenization && preTokenizer is not null)
|
||||||
|
{
|
||||||
|
splits = preTokenizer.PreTokenize(normalizedString);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
textSpanToEncode = text.AsSpan();
|
||||||
|
if (considerPreTokenization && preTokenizer is not null)
|
||||||
|
{
|
||||||
|
splits = preTokenizer.PreTokenize(text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return splits;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,6 +71,8 @@ namespace Microsoft.ML.Tokenizers
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void Clear() => _data.Clear();
|
||||||
|
|
||||||
public bool IsConsistent()
|
public bool IsConsistent()
|
||||||
{
|
{
|
||||||
// is the heap property true for all data?
|
// is the heap property true for all data?
|
||||||
|
|
|
@ -26,12 +26,12 @@ namespace Microsoft.ML.TorchSharp.Extensions
|
||||||
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt"
|
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt"
|
||||||
Assembly assembly = typeof(TokenizerExtensions).Assembly;
|
Assembly assembly = typeof(TokenizerExtensions).Assembly;
|
||||||
|
|
||||||
EnglishRoberta model = new EnglishRoberta(
|
_instance = new EnglishRoberta(
|
||||||
assembly.GetManifestResourceStream("encoder.json"),
|
assembly.GetManifestResourceStream("encoder.json"),
|
||||||
assembly.GetManifestResourceStream("vocab.bpe"),
|
assembly.GetManifestResourceStream("vocab.bpe"),
|
||||||
assembly.GetManifestResourceStream("dict.txt"));
|
assembly.GetManifestResourceStream("dict.txt"),
|
||||||
model.AddMaskSymbol();
|
new RobertaPreTokenizer());
|
||||||
_instance = new Tokenizer(model, new RobertaPreTokenizer());
|
(_instance as EnglishRoberta).AddMaskSymbol();
|
||||||
}
|
}
|
||||||
|
|
||||||
return _instance;
|
return _instance;
|
||||||
|
@ -40,7 +40,7 @@ namespace Microsoft.ML.TorchSharp.Extensions
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
internal static EnglishRoberta RobertaModel(this Tokenizer tokenizer)
|
internal static EnglishRoberta RobertaModel(this Tokenizer tokenizer)
|
||||||
{
|
{
|
||||||
EnglishRoberta model = tokenizer.Model as EnglishRoberta;
|
EnglishRoberta model = tokenizer as EnglishRoberta;
|
||||||
if (model is null)
|
if (model is null)
|
||||||
{
|
{
|
||||||
throw new InvalidOperationException($"The input tokenizer is not using the EnglishRoberta model.");
|
throw new InvalidOperationException($"The input tokenizer is not using the EnglishRoberta model.");
|
||||||
|
@ -51,8 +51,7 @@ namespace Microsoft.ML.TorchSharp.Extensions
|
||||||
|
|
||||||
internal static IReadOnlyList<int> EncodeToConverted(this Tokenizer tokenizer, string sentence)
|
internal static IReadOnlyList<int> EncodeToConverted(this Tokenizer tokenizer, string sentence)
|
||||||
{
|
{
|
||||||
EncodingResult encoding = tokenizer.Encode(sentence);
|
return tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(tokenizer.EncodeToIds(sentence));
|
||||||
return tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(encoding.Ids);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -167,16 +167,16 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
||||||
Sentence1Getter(ref sentenceRom);
|
Sentence1Getter(ref sentenceRom);
|
||||||
var sentence = sentenceRom.ToString();
|
var sentence = sentenceRom.ToString();
|
||||||
Tensor t;
|
Tensor t;
|
||||||
var encoding = Tokenizer.Encode(sentence);
|
IReadOnlyList<Token> encoding = Tokenizer.Encode(sentence, out string normalizedString);
|
||||||
|
|
||||||
if (target.Length != encoding.Tokens.Count)
|
if (target.Length != encoding.Count)
|
||||||
{
|
{
|
||||||
var targetIndex = 0;
|
var targetIndex = 0;
|
||||||
var targetEditor = VBufferEditor.Create(ref target, encoding.Tokens.Count);
|
var targetEditor = VBufferEditor.Create(ref target, encoding.Count);
|
||||||
var newValues = targetEditor.Values;
|
var newValues = targetEditor.Values;
|
||||||
for (var i = 0; i < encoding.Tokens.Count; i++)
|
for (var i = 0; i < encoding.Count; i++)
|
||||||
{
|
{
|
||||||
if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i]))
|
if (NerTrainer.TokenStartsWithSpace(encoding[i].Value))
|
||||||
{
|
{
|
||||||
newValues[i] = target.GetItemOrDefault(++targetIndex);
|
newValues[i] = target.GetItemOrDefault(++targetIndex);
|
||||||
}
|
}
|
||||||
|
@ -187,7 +187,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
||||||
}
|
}
|
||||||
target = targetEditor.Commit();
|
target = targetEditor.Commit();
|
||||||
}
|
}
|
||||||
t = torch.tensor((ZeroArray).Concat(Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(encoding.Ids)).ToList(), device: Device);
|
t = torch.tensor((ZeroArray).Concat(Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(encoding.Select(t => t.Id).ToArray())).ToList(), device: Device);
|
||||||
|
|
||||||
if (t.NumberOfElements > 512)
|
if (t.NumberOfElements > 512)
|
||||||
t = t.slice(0, 0, 512, 1);
|
t = t.slice(0, 0, 512, 1);
|
||||||
|
@ -377,16 +377,16 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
||||||
private void CondenseOutput(ref VBuffer<UInt32> dst, string sentence, Tokenizer tokenizer, TensorCacher outputCacher)
|
private void CondenseOutput(ref VBuffer<UInt32> dst, string sentence, Tokenizer tokenizer, TensorCacher outputCacher)
|
||||||
{
|
{
|
||||||
var pre = tokenizer.PreTokenizer.PreTokenize(sentence);
|
var pre = tokenizer.PreTokenizer.PreTokenize(sentence);
|
||||||
EncodingResult encoding = tokenizer.Encode(sentence);
|
IReadOnlyList<Token> encoding = tokenizer.Encode(sentence, out string normalizedString);
|
||||||
|
|
||||||
var argmax = (outputCacher as BertTensorCacher).Result.argmax(-1);
|
var argmax = (outputCacher as BertTensorCacher).Result.argmax(-1);
|
||||||
var prediction = argmax.ToArray<long>();
|
var prediction = argmax.ToArray<long>();
|
||||||
|
|
||||||
var targetIndex = 0;
|
var targetIndex = 0;
|
||||||
// Figure out actual count of output tokens
|
// Figure out actual count of output tokens
|
||||||
for (var i = 0; i < encoding.Tokens.Count; i++)
|
for (var i = 0; i < encoding.Count; i++)
|
||||||
{
|
{
|
||||||
if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i]))
|
if (NerTrainer.TokenStartsWithSpace(encoding[i].Value))
|
||||||
{
|
{
|
||||||
targetIndex++;
|
targetIndex++;
|
||||||
}
|
}
|
||||||
|
@ -398,9 +398,9 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
||||||
|
|
||||||
newValues[targetIndex++] = (uint)prediction[0];
|
newValues[targetIndex++] = (uint)prediction[0];
|
||||||
|
|
||||||
for (var i = 1; i < encoding.Tokens.Count; i++)
|
for (var i = 1; i < encoding.Count; i++)
|
||||||
{
|
{
|
||||||
if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i]))
|
if (NerTrainer.TokenStartsWithSpace(encoding[i].Value))
|
||||||
{
|
{
|
||||||
newValues[targetIndex++] = (uint)prediction[i];
|
newValues[targetIndex++] = (uint)prediction[i];
|
||||||
}
|
}
|
||||||
|
|
|
@ -401,9 +401,9 @@ namespace Microsoft.ML.TorchSharp.Roberta
|
||||||
answerIndexGetter(ref answerIndex);
|
answerIndexGetter(ref answerIndex);
|
||||||
|
|
||||||
var contextString = context.ToString();
|
var contextString = context.ToString();
|
||||||
var contextTokens = Tokenizer.Encode(contextString);
|
var contextTokens = Tokenizer.Encode(contextString, out string normalized);
|
||||||
var contextToken = contextTokens.Tokens;
|
var contextToken = contextTokens.Select(t => t.Value).ToArray();
|
||||||
var contextTokenId = Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(contextTokens.Ids);
|
var contextTokenId = Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(contextTokens.Select(t => t.Id).ToArray());
|
||||||
|
|
||||||
var mapping = AlignAnswerPosition(contextToken, contextString);
|
var mapping = AlignAnswerPosition(contextToken, contextString);
|
||||||
if (mapping == null)
|
if (mapping == null)
|
||||||
|
@ -437,7 +437,7 @@ namespace Microsoft.ML.TorchSharp.Roberta
|
||||||
|
|
||||||
private Dictionary<int, int> AlignAnswerPosition(IReadOnlyList<string> tokens, string text)
|
private Dictionary<int, int> AlignAnswerPosition(IReadOnlyList<string> tokens, string text)
|
||||||
{
|
{
|
||||||
EnglishRoberta robertaModel = Tokenizer.Model as EnglishRoberta;
|
EnglishRoberta robertaModel = Tokenizer as EnglishRoberta;
|
||||||
Debug.Assert(robertaModel is not null);
|
Debug.Assert(robertaModel is not null);
|
||||||
|
|
||||||
var mapping = new Dictionary<int, int>();
|
var mapping = new Dictionary<int, int>();
|
||||||
|
@ -854,9 +854,9 @@ namespace Microsoft.ML.TorchSharp.Roberta
|
||||||
contextGetter(ref context);
|
contextGetter(ref context);
|
||||||
questionGetter(ref question);
|
questionGetter(ref question);
|
||||||
|
|
||||||
var contextTokenId = _parent.Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(_parent.Tokenizer.Encode(context.ToString()).Ids);
|
var contextTokenId = _parent.Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(_parent.Tokenizer.EncodeToIds(context.ToString()));
|
||||||
|
|
||||||
var questionTokenId = _parent.Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(_parent.Tokenizer.Encode(question.ToString()).Ids);
|
var questionTokenId = _parent.Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(_parent.Tokenizer.EncodeToIds(question.ToString()));
|
||||||
|
|
||||||
var srcTensor = torch.tensor((new[] { 0 /* InitToken */ }).Concat(questionTokenId).Concat(new[] { 2 /* SeparatorToken */ }).Concat(contextTokenId).ToList(), device: _parent.Device);
|
var srcTensor = torch.tensor((new[] { 0 /* InitToken */ }).Concat(questionTokenId).Concat(new[] { 2 /* SeparatorToken */ }).Concat(contextTokenId).ToList(), device: _parent.Device);
|
||||||
|
|
||||||
|
|
|
@ -248,28 +248,28 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
Bpe bpe = new Bpe(vocabFile, mergesFile, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownToken);
|
Bpe bpe = new Bpe(vocabFile, mergesFile, unknownToken: unknownToken, continuingSubwordPrefix: continuingSubwordPrefix, endOfWordSuffix: endOfWordSuffix, fuseUnknownTokens: fuseUnknownToken);
|
||||||
Tokenizer tokenizer = new Tokenizer(bpe);
|
Tokenizer tokenizer = bpe;
|
||||||
EncodingResult encoding = tokenizer.Encode(sentence);
|
IReadOnlyList<Token> encoding = tokenizer.Encode(sentence, out _);
|
||||||
|
int[] encodingIds = encoding.Select(t => t.Id).ToArray();
|
||||||
IReadOnlyList<int> idsList = tokenizer.EncodeToIds(sentence);
|
IReadOnlyList<int> idsList = tokenizer.EncodeToIds(sentence);
|
||||||
|
|
||||||
Assert.Equal(expectedTokens.Length, encoding.Tokens.Count);
|
Assert.Equal(expectedTokens.Length, encoding.Count);
|
||||||
Assert.Equal(offsets.Length, encoding.Offsets.Count);
|
Assert.Equal(offsets.Length, encoding.Count);
|
||||||
Assert.Equal(ids.Length, encoding.Ids.Count);
|
Assert.Equal(ids.Length, encoding.Count);
|
||||||
Assert.Equal(ids.Length, idsList.Count);
|
Assert.Equal(ids.Length, idsList.Count);
|
||||||
Assert.Equal(ids.Length, tokenizer.CountTokens(sentence));
|
Assert.Equal(ids.Length, tokenizer.CountTokens(sentence));
|
||||||
Assert.Equal(decodedTokens, tokenizer.Decode(encoding.Ids));
|
Assert.Equal(decodedTokens, tokenizer.Decode(encodingIds));
|
||||||
Assert.Equal(decodedTokensWithoutUnknownToken, bpe.Decode(encoding.Ids, considerSpecialTokens: false));
|
Assert.Equal(decodedTokensWithoutUnknownToken, bpe.Decode(encodingIds, considerSpecialTokens: false));
|
||||||
|
|
||||||
for (int i = 0; i < encoding.Tokens.Count; i++)
|
for (int i = 0; i < encoding.Count; i++)
|
||||||
{
|
{
|
||||||
Assert.Equal(expectedTokens[i], encoding.Tokens[i]);
|
Assert.Equal(expectedTokens[i], encoding[i].Value);
|
||||||
Assert.Equal(offsets[i], encoding.Offsets[i]);
|
Assert.Equal(offsets[i], encoding[i].Offset);
|
||||||
Assert.Equal(ids[i], encoding.Ids[i]);
|
Assert.Equal(ids[i], encoding[i].Id);
|
||||||
Assert.Equal(ids[i], idsList[i]);
|
Assert.Equal(ids[i], idsList[i]);
|
||||||
Assert.Equal(encoding.Tokens[i], tokenizer.Model.MapIdToToken(encoding.Ids[i]));
|
Assert.Equal(encoding[i].Value, tokenizer.MapIdToToken(encodingIds[i]));
|
||||||
Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i].AsSpan()));
|
Assert.Equal(encodingIds[i], tokenizer.MapTokenToId(encoding[i].Value.AsSpan()));
|
||||||
Assert.Equal(encoding.Tokens[i], tokenizer.Decode(encoding.Ids[i]));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
finally
|
finally
|
||||||
|
@ -282,6 +282,41 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static Tokenizer? _gpt2Tokenizer = null;
|
||||||
|
|
||||||
|
private static Tokenizer GetGpt2Tokenizer()
|
||||||
|
{
|
||||||
|
if (_gpt2Tokenizer is null)
|
||||||
|
{
|
||||||
|
// "https://huggingface.co/openai-community/gpt2/raw/main/vocab.json";
|
||||||
|
// "https://huggingface.co/openai-community/gpt2/raw/main/merges.txt";
|
||||||
|
using Stream vocabStream = File.OpenRead(Path.Combine(@"Gpt-2", "vocab.json"));
|
||||||
|
using Stream mergesStream = File.OpenRead(Path.Combine(@"Gpt-2", "merges.txt"));
|
||||||
|
|
||||||
|
_gpt2Tokenizer = new Bpe(vocabStream, mergesStream);
|
||||||
|
}
|
||||||
|
|
||||||
|
return _gpt2Tokenizer;
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void TestGpt2Vocab()
|
||||||
|
{
|
||||||
|
Tokenizer tokenizer = GetGpt2Tokenizer();
|
||||||
|
|
||||||
|
string text = "The quick brown fox jumps over the lazy dog!";
|
||||||
|
|
||||||
|
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
|
||||||
|
IReadOnlyList<int> ids = tokenizer.EncodeToIds(text);
|
||||||
|
|
||||||
|
Assert.Equal(12, encoding.Count);
|
||||||
|
Assert.Equal(encoding.Select(t => t.Id).ToArray(), ids);
|
||||||
|
Assert.Equal(12, tokenizer.CountTokens(text));
|
||||||
|
|
||||||
|
TokenizerTests.TestTokenLimits(tokenizer);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
public static IEnumerable<object?[]> BpeTestData
|
public static IEnumerable<object?[]> BpeTestData
|
||||||
{
|
{
|
||||||
get
|
get
|
||||||
|
@ -290,55 +325,84 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
yield return new object?[]
|
yield return new object?[]
|
||||||
{
|
{
|
||||||
"the brown fox jumped over the lazy dog!",
|
"the brown fox jumped over the lazy dog!",
|
||||||
new string[] {"the", "brown", "fox", "jumped", "over", "the", "lazy", "dog", "!"},
|
new string[] { "the", "brown", "fox", "j", "umped", "over", "the", "l", "azy", "dog", "!" },
|
||||||
new (int, int)[] {(0, 3), (4, 9), (10, 13), (14, 20), (21, 25), (26, 29), (30, 34), (35, 38), (38, 39)}
|
new (int Index, int Length)[] { (0, 3), (4, 5), (10, 3), (14, 1), (15, 5), (21, 4), (26, 3), (30, 1), (31, 3), (35, 3), (38, 1) },
|
||||||
|
new int[] { 1169, 33282, 12792, 73, 27073, 2502, 1169, 75, 12582, 9703, 0 }
|
||||||
};
|
};
|
||||||
yield return new object?[]
|
yield return new object?[]
|
||||||
{
|
{
|
||||||
"he traveled to Egypt during the summer, the weather was hot and ammunition." ,
|
"he traveled to Egypt during the summer, the weather was hot and ammunition." ,
|
||||||
new string[] {"he", "traveled", "to", "Egypt", "during", "the", "summer", ",", "the", "weather", "was", "hot", "and", "ammunition", "."},
|
new string[] { "he", "travel", "ed", "to", "Egypt", "during", "the", "sum", "mer", ",", "the", "weather", "was", "hot", "and", "am", "munition", "." },
|
||||||
new (int, int)[] {(0, 2), (3, 11), (12, 14), (15, 20), (21, 27), (28, 31), (32, 38), (38, 39), (40, 43), (44, 51), (52, 55), (56, 59), (60, 63), (64, 74), (74, 75)}
|
new (int Index, int Length)[] { (0, 2), (3, 6), (9, 2), (12, 2), (15, 5), (21, 6), (28, 3), (32, 3), (35, 3), (38, 1), (40, 3), (44, 7), (52, 3), (56, 3), (60, 3), (64, 2), (66, 8), (74, 1) },
|
||||||
|
new int[] { 258, 35927, 276, 1462, 39299, 42122, 1169, 16345, 647, 11, 1169, 23563, 9776, 8940, 392, 321, 12640, 13 }
|
||||||
};
|
};
|
||||||
yield return new object?[]
|
yield return new object?[]
|
||||||
{
|
{
|
||||||
"She played many games and she felt exhausted afterward",
|
"She played many games and she felt exhausted afterward",
|
||||||
new string[] {"She", "played", "many", "games", "and", "she", "felt", "exhausted", "afterward"},
|
new string[] { "She", "played", "many", "games", "and", "she", "felt", "ex", "ha", "usted", "after", "ward" },
|
||||||
new (int, int)[] {(0, 3), (4, 10), (11, 15), (16, 21), (22, 25), (26, 29), (30, 34), (35, 44), (45, 54)}
|
new (int Index, int Length)[] { (0, 3), (4, 6), (11, 4), (16, 5), (22, 3), (26, 3), (30, 4), (35, 2), (37, 2), (39, 5), (45, 5), (50, 4) },
|
||||||
|
new int[] { 3347, 21542, 21834, 19966, 392, 7091, 31985, 1069, 3099, 8459, 8499, 904 }
|
||||||
};
|
};
|
||||||
yield return new object?[]
|
yield return new object?[]
|
||||||
{
|
{
|
||||||
"Hello, y'all! How are you 😁 ?",
|
"Hello, y'all! How are you 😁 ?",
|
||||||
new string[] {"Hello", ",", "y", "'", "all", "!", "How", "are", "you", "[UNK]", "?"},
|
new string[] { "Hello", ",", "y", "'", "all", "!", "How", "are", "you", "?" },
|
||||||
new (int, int)[] {(0, 5), (5, 6), (7, 8), (8, 9), (9, 12), (12, 13), (14, 17), (18, 21), (22, 25), (26, 28), (29, 30)}
|
new (int Index, int Length)[] { (0, 5), (5, 1), (7, 1), (8, 1), (9, 3), (12, 1), (14, 3), (18, 3), (22, 3), (29, 1) },
|
||||||
|
new int[] { 15496, 11, 88, 6, 439, 0, 2437, 533, 5832, 30 }
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private const string Gpt2VocabUrl = "https://huggingface.co/openai-community/gpt2/raw/main/vocab.json";
|
[Theory]
|
||||||
private const string Gpt2MergesUrl = "https://huggingface.co/openai-community/gpt2/raw/main/merges.txt";
|
[MemberData(nameof(BpeTestData))]
|
||||||
|
public void TestBpeTokenizer(string text, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds)
|
||||||
[Fact]
|
|
||||||
public async void TestGpt2Vocab()
|
|
||||||
{
|
{
|
||||||
using HttpClient httpClient = new HttpClient();
|
Tokenizer tokenizer = GetGpt2Tokenizer();
|
||||||
using Stream vocabStream = await httpClient.GetStreamAsync(Gpt2VocabUrl);
|
|
||||||
using Stream mergesStream = await httpClient.GetStreamAsync(Gpt2MergesUrl);
|
|
||||||
|
|
||||||
Bpe bpe = new Bpe(vocabStream, mergesStream);
|
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
|
||||||
Tokenizer tokenizer = new Tokenizer(bpe);
|
IReadOnlyList<Token> encoding1 = tokenizer.Encode(text.AsSpan(), out _);
|
||||||
|
|
||||||
string text = "The quick brown fox jumps over the lazy dog!";
|
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
|
||||||
|
Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray());
|
||||||
|
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
|
||||||
|
|
||||||
EncodingResult encoding = tokenizer.Encode(text);
|
Assert.Equal(expectedTokens, encoding1.Select(t => t.Value).ToArray());
|
||||||
IReadOnlyList<int> ids = tokenizer.EncodeToIds(text);
|
Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray());
|
||||||
|
Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray());
|
||||||
|
|
||||||
Assert.Equal(12, encoding.Tokens.Count);
|
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text));
|
||||||
Assert.Equal(12, encoding.Offsets.Count);
|
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan()));
|
||||||
Assert.Equal(12, encoding.Ids.Count);
|
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text, expectedIds.Length, out string? normalizedString, out int length));
|
||||||
Assert.Equal(encoding.Ids, ids);
|
Assert.Null(normalizedString);
|
||||||
Assert.Equal(12, tokenizer.CountTokens(text));
|
Assert.Equal(text.Length, length);
|
||||||
|
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length, out normalizedString, out length));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(text.Length, length);
|
||||||
|
|
||||||
TokenizerTests.TestTokenLimits(tokenizer);
|
Assert.Equal(expectedIds.Take(expectedIds.Length - 2), tokenizer.EncodeToIds(text, expectedIds.Length - 2, out normalizedString, out length));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
int expectedLength = expectedOffsets[expectedOffsets.Length - 3].Index + expectedOffsets[expectedOffsets.Length - 3].Length;
|
||||||
|
Assert.Equal(expectedLength, length);
|
||||||
|
Assert.Equal(expectedIds.Take(expectedIds.Length - 2), tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length - 2, out normalizedString, out length));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(expectedLength, length);
|
||||||
|
|
||||||
|
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text));
|
||||||
|
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text.AsSpan()));
|
||||||
|
|
||||||
|
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text, expectedIds.Length - 3, out normalizedString, out int tokenCount));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(expectedIds.Length - 3, tokenCount);
|
||||||
|
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text.AsSpan(), expectedIds.Length - 3, out normalizedString, out tokenCount));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(expectedIds.Length - 3, tokenCount);
|
||||||
|
|
||||||
|
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text, 3, out normalizedString, out tokenCount));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(3, tokenCount);
|
||||||
|
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text.AsSpan(), 3, out normalizedString, out tokenCount));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(3, tokenCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static string WriteToMergeFile((string, string)[] mergeEntries)
|
private static string WriteToMergeFile((string, string)[] mergeEntries)
|
||||||
|
@ -360,7 +424,7 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
return fileName;
|
return fileName;
|
||||||
}
|
}
|
||||||
|
|
||||||
internal static Bpe CreateEmptyBpe()
|
internal static Bpe CreateEmptyBpe(PreTokenizer? preTokenizer = null, Normalizer? normalizer = null)
|
||||||
{
|
{
|
||||||
using MemoryStream emptyVocabStream = new MemoryStream();
|
using MemoryStream emptyVocabStream = new MemoryStream();
|
||||||
using StreamWriter writer = new StreamWriter(emptyVocabStream);
|
using StreamWriter writer = new StreamWriter(emptyVocabStream);
|
||||||
|
@ -368,7 +432,7 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
writer.Flush();
|
writer.Flush();
|
||||||
emptyVocabStream.Position = 0;
|
emptyVocabStream.Position = 0;
|
||||||
|
|
||||||
return new Bpe(vocabStream: emptyVocabStream, mergesStream: null, unknownToken: "Ukn");
|
return new Bpe(vocabStream: emptyVocabStream, mergesStream: null, preTokenizer: preTokenizer ?? WhiteSpace.Instance, normalizer: normalizer, unknownToken: "Ukn");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,9 +6,11 @@ using System;
|
||||||
using System.IO;
|
using System.IO;
|
||||||
using System.Collections.Generic;
|
using System.Collections.Generic;
|
||||||
using System.Linq;
|
using System.Linq;
|
||||||
using System.Text;
|
using System.Net.Http;
|
||||||
|
|
||||||
using Xunit;
|
using Xunit;
|
||||||
|
using System.Diagnostics;
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
|
||||||
namespace Microsoft.ML.Tokenizers.Tests
|
namespace Microsoft.ML.Tokenizers.Tests
|
||||||
{
|
{
|
||||||
|
@ -80,6 +82,34 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static Tokenizer? _robertaTokenizer = null;
|
||||||
|
private async static Task<Tokenizer> GetRobertaTokenizer()
|
||||||
|
{
|
||||||
|
if (_robertaTokenizer is null)
|
||||||
|
{
|
||||||
|
string vocabFile = Utils.CreateTemporaryFile("json");
|
||||||
|
string mergeFile = Utils.CreateTemporaryFile("txt");
|
||||||
|
string translationFile = Utils.CreateTemporaryFile("txt");
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
await Utils.DownloadFile(_vocabUrl, vocabFile);
|
||||||
|
await Utils.DownloadFile(_mergeUrl, mergeFile);
|
||||||
|
await Utils.DownloadFile(_dictUrl, translationFile);
|
||||||
|
|
||||||
|
_robertaTokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
Utils.DeleteFile(vocabFile);
|
||||||
|
Utils.DeleteFile(mergeFile);
|
||||||
|
Utils.DeleteFile(translationFile);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return _robertaTokenizer;
|
||||||
|
}
|
||||||
|
|
||||||
[Fact]
|
[Fact]
|
||||||
public async void TokenizationTest()
|
public async void TokenizationTest()
|
||||||
{
|
{
|
||||||
|
@ -94,26 +124,26 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
await Utils.DownloadFile(_mergeUrl, mergeFile);
|
await Utils.DownloadFile(_mergeUrl, mergeFile);
|
||||||
await Utils.DownloadFile(_dictUrl, translationFile);
|
await Utils.DownloadFile(_dictUrl, translationFile);
|
||||||
|
|
||||||
Tokenizer tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile), RobertaPreTokenizer.Instance);
|
Tokenizer tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
|
||||||
TestTokenizer(tokenizer);
|
TestTokenizer(tokenizer);
|
||||||
TokenizerTests.TestTokenLimits(tokenizer);
|
TokenizerTests.TestTokenLimits(tokenizer);
|
||||||
|
|
||||||
tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile, filterUnsupportedChars: false), RobertaPreTokenizer.Instance);
|
tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance, filterUnsupportedChars: false);
|
||||||
TestTokenizer(tokenizer);
|
TestTokenizer(tokenizer);
|
||||||
|
|
||||||
using Stream vocabStream = File.OpenRead(vocabFile);
|
using Stream vocabStream = File.OpenRead(vocabFile);
|
||||||
using Stream mergeStream = File.OpenRead(mergeFile);
|
using Stream mergeStream = File.OpenRead(mergeFile);
|
||||||
using Stream translationStream = File.OpenRead(translationFile);
|
using Stream translationStream = File.OpenRead(translationFile);
|
||||||
tokenizer = new Tokenizer(new EnglishRoberta(vocabStream, mergeStream, translationStream), RobertaPreTokenizer.Instance);
|
tokenizer = new EnglishRoberta(vocabStream, mergeStream, translationStream, RobertaPreTokenizer.Instance);
|
||||||
TestTokenizer(tokenizer);
|
TestTokenizer(tokenizer);
|
||||||
|
|
||||||
// Ensure caching works regardless of which method is called first.
|
// Ensure caching works regardless of which method is called first.
|
||||||
for (CallingOrder order = CallingOrder.Encode; order <= CallingOrder.CountTokens; order++)
|
for (CallingOrder order = CallingOrder.Encode; order <= CallingOrder.CountTokens; order++)
|
||||||
{
|
{
|
||||||
tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile), RobertaPreTokenizer.Instance);
|
tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
|
||||||
TestTokenizer(tokenizer, order);
|
TestTokenizer(tokenizer, order);
|
||||||
|
|
||||||
tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile, filterUnsupportedChars: false), RobertaPreTokenizer.Instance);
|
tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance, filterUnsupportedChars: false);
|
||||||
TestTokenizer(tokenizer, order);
|
TestTokenizer(tokenizer, order);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -132,6 +162,94 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static IEnumerable<object?[]> RobertaTestData
|
||||||
|
{
|
||||||
|
get
|
||||||
|
{
|
||||||
|
// string to tokenize, produced tokens, the token offsets
|
||||||
|
yield return new object?[]
|
||||||
|
{
|
||||||
|
"the brown fox jumped over the lazy dog!",
|
||||||
|
new string[] { "the", "Ġbrown", "Ġfox", "Ġjumped", "Ġover", "Ġthe", "Ġlazy", "Ġdog", "!" },
|
||||||
|
new (int Index, int Length)[] { (0, 3), (3, 6), (9, 4), (13, 7), (20, 5), (25, 4), (29, 5), (34, 4), (38, 1) },
|
||||||
|
new int[] { 1169, 7586, 21831, 11687, 625, 262, 16931, 3290, 0 }
|
||||||
|
};
|
||||||
|
yield return new object?[]
|
||||||
|
{
|
||||||
|
"he traveled to Egypt during the summer, the weather was hot and ammunition." ,
|
||||||
|
new string[] { "he", "Ġtraveled", "Ġto", "ĠEgypt", "Ġduring", "Ġthe", "Ġsummer", ",", "Ġthe", "Ġweather", "Ġwas", "Ġhot", "Ġand", "Ġammunition", "." },
|
||||||
|
new (int Index, int Length)[] { (0, 2), (2, 9), (11, 3), (14, 6), (20, 7), (27, 4), (31, 7), (38, 1), (39, 4), (43, 8), (51, 4), (55, 4), (59, 4), (63, 11), (74, 1) },
|
||||||
|
new int[] { 258, 14113, 284, 6365, 1141, 262, 3931, 11, 262, 6193, 373, 3024, 290, 14271, 13 }
|
||||||
|
};
|
||||||
|
yield return new object?[]
|
||||||
|
{
|
||||||
|
"She played many games and she felt exhausted afterward",
|
||||||
|
new string[] { "She", "Ġplayed", "Ġmany", "Ġgames", "Ġand", "Ġshe", "Ġfelt", "Ġexhausted", "Ġafterward" },
|
||||||
|
new (int Index, int Length)[] { (0, 3), (3, 7), (10, 5), (15, 6), (21, 4), (25, 4), (29, 5), (34, 10), (44, 10) },
|
||||||
|
new int[] { 3347, 2826, 867, 1830, 290, 673, 2936, 19064, 20875 }
|
||||||
|
};
|
||||||
|
yield return new object?[]
|
||||||
|
{
|
||||||
|
"Hello, y'all! How are you 😁 ?",
|
||||||
|
new string[] { "Hello", ",", "Ġy", "'", "all", "!", "ĠHow", "Ġare", "Ġyou", "Ġ", "Ġ?" },
|
||||||
|
new (int Index, int Length)[] { (0, 5), (5, 1), (6, 2), (8, 1), (9, 3), (12, 1), (13, 4), (17, 4), (21, 4), (25, 1), (28, 2) },
|
||||||
|
new int[] { 15496, 11, 331, 6, 439, 0, 1374, 389, 345, 220, 5633 }
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Theory]
|
||||||
|
[MemberData(nameof(RobertaTestData))]
|
||||||
|
public async void TestTokenizerEncoding(string text, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds)
|
||||||
|
{
|
||||||
|
Tokenizer tokenizer = await GetRobertaTokenizer();
|
||||||
|
|
||||||
|
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
|
||||||
|
IReadOnlyList<Token> encoding1 = tokenizer.Encode(text.AsSpan(), out _);
|
||||||
|
|
||||||
|
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
|
||||||
|
Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray());
|
||||||
|
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
|
||||||
|
|
||||||
|
Assert.Equal(expectedTokens, encoding1.Select(t => t.Value).ToArray());
|
||||||
|
Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray());
|
||||||
|
Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray());
|
||||||
|
|
||||||
|
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text));
|
||||||
|
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan()));
|
||||||
|
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text, expectedIds.Length, out string? normalizedString, out int length));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(text.Length, length);
|
||||||
|
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length, out normalizedString, out length));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(text.Length, length);
|
||||||
|
|
||||||
|
Assert.Equal(expectedIds.Take(expectedIds.Length - 2), tokenizer.EncodeToIds(text, expectedIds.Length - 2, out normalizedString, out length));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
int expectedLength = expectedOffsets[expectedOffsets.Length - 3].Index + expectedOffsets[expectedOffsets.Length - 3].Length;
|
||||||
|
Assert.Equal(expectedLength, length);
|
||||||
|
Assert.Equal(expectedIds.Take(expectedIds.Length - 2), tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length - 2, out normalizedString, out length));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(expectedLength, length);
|
||||||
|
|
||||||
|
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text));
|
||||||
|
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text.AsSpan()));
|
||||||
|
|
||||||
|
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text, expectedIds.Length - 3, out normalizedString, out int tokenCount));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(expectedIds.Length - 3, tokenCount);
|
||||||
|
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text.AsSpan(), expectedIds.Length - 3, out normalizedString, out tokenCount));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(expectedIds.Length - 3, tokenCount);
|
||||||
|
|
||||||
|
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text, 3, out normalizedString, out tokenCount));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(3, tokenCount);
|
||||||
|
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text.AsSpan(), 3, out normalizedString, out tokenCount));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(3, tokenCount);
|
||||||
|
}
|
||||||
|
|
||||||
private enum CallingOrder
|
private enum CallingOrder
|
||||||
{
|
{
|
||||||
Encode,
|
Encode,
|
||||||
|
@ -143,66 +261,67 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
// Calling with callIdsFirst = true will test the other way around.
|
// Calling with callIdsFirst = true will test the other way around.
|
||||||
private void TestTokenizer(Tokenizer tokenizer, CallingOrder callingOrder = CallingOrder.Encode)
|
private void TestTokenizer(Tokenizer tokenizer, CallingOrder callingOrder = CallingOrder.Encode)
|
||||||
{
|
{
|
||||||
Assert.NotNull(tokenizer.Model);
|
Assert.True(tokenizer is EnglishRoberta);
|
||||||
Assert.True(tokenizer.Model is EnglishRoberta);
|
|
||||||
Assert.True(tokenizer.PreTokenizer is RobertaPreTokenizer);
|
Assert.True(tokenizer.PreTokenizer is RobertaPreTokenizer);
|
||||||
|
|
||||||
foreach (object[] p in BertaData)
|
foreach (object[] p in BertaData)
|
||||||
{
|
{
|
||||||
IReadOnlyList<int> ids;
|
IReadOnlyList<int> ids;
|
||||||
EncodingResult encoding;
|
IReadOnlyList<Token> encoding;
|
||||||
int idsCount;
|
int idsCount;
|
||||||
|
|
||||||
if (callingOrder == CallingOrder.Encode)
|
if (callingOrder == CallingOrder.Encode)
|
||||||
{
|
{
|
||||||
encoding = tokenizer.Encode((string)p[0]);
|
encoding = tokenizer.Encode((string)p[0], out _);
|
||||||
ids = tokenizer.EncodeToIds((string)p[0]);
|
ids = tokenizer.EncodeToIds((string)p[0]);
|
||||||
idsCount = tokenizer.CountTokens((string)p[0]);
|
idsCount = tokenizer.CountTokens((string)p[0]);
|
||||||
}
|
}
|
||||||
else if (callingOrder == CallingOrder.EncodeToIds)
|
else if (callingOrder == CallingOrder.EncodeToIds)
|
||||||
{
|
{
|
||||||
ids = tokenizer.EncodeToIds((string)p[0]);
|
ids = tokenizer.EncodeToIds((string)p[0]);
|
||||||
encoding = tokenizer.Encode((string)p[0]);
|
encoding = tokenizer.Encode((string)p[0], out _);
|
||||||
idsCount = tokenizer.CountTokens((string)p[0]);
|
idsCount = tokenizer.CountTokens((string)p[0]);
|
||||||
}
|
}
|
||||||
else // CountTokens
|
else // CountTokens
|
||||||
{
|
{
|
||||||
idsCount = tokenizer.CountTokens((string)p[0]);
|
idsCount = tokenizer.CountTokens((string)p[0]);
|
||||||
ids = tokenizer.EncodeToIds((string)p[0]);
|
ids = tokenizer.EncodeToIds((string)p[0]);
|
||||||
encoding = tokenizer.Encode((string)p[0]);
|
encoding = tokenizer.Encode((string)p[0], out _);
|
||||||
}
|
}
|
||||||
|
|
||||||
Assert.Equal(p[1], encoding.Ids);
|
int[] encodingIds = encoding.Select(t => t.Id).ToArray();
|
||||||
|
(int, int)[] offsets = encoding.Select(t => t.Offset).ToArray();
|
||||||
|
string[] tokens = encoding.Select(t => t.Value).ToArray();
|
||||||
|
|
||||||
|
Assert.Equal(p[1], encodingIds);
|
||||||
Assert.Equal(p[1], ids);
|
Assert.Equal(p[1], ids);
|
||||||
Assert.Equal(((int[])p[1]).Length, idsCount);
|
Assert.Equal(((int[])p[1]).Length, idsCount);
|
||||||
Assert.Equal(p[3], encoding.Offsets);
|
Assert.Equal(p[3], offsets);
|
||||||
Assert.Equal(encoding.Ids.Count, encoding.Tokens.Count);
|
|
||||||
Assert.Equal(encoding.Ids.Count, encoding.Offsets.Count);
|
|
||||||
|
|
||||||
EnglishRoberta? robertaModel = tokenizer.Model as EnglishRoberta;
|
EnglishRoberta? robertaModel = tokenizer as EnglishRoberta;
|
||||||
Assert.Equal(p[2], encoding.Tokens);
|
Assert.Equal(p[2], tokens);
|
||||||
|
|
||||||
Assert.Equal(string.Concat((string[])(p[robertaModel!.FilterUnsupportedChars ? 5 : 2])), tokenizer.Decode(encoding.Ids));
|
Assert.Equal(string.Concat((string[])(p[robertaModel!.FilterUnsupportedChars ? 5 : 2])), tokenizer.Decode(encodingIds));
|
||||||
|
|
||||||
Assert.NotNull(robertaModel);
|
Assert.NotNull(robertaModel);
|
||||||
Assert.Equal(encoding.Ids, robertaModel!.ConvertOccurrenceRanksToIds(robertaModel!.ConvertIdsToOccurrenceRanks(encoding.Ids)));
|
Assert.Equal(encodingIds, robertaModel!.ConvertOccurrenceRanksToIds(robertaModel!.ConvertIdsToOccurrenceRanks(encodingIds)));
|
||||||
Assert.Equal(p[4], robertaModel.ConvertIdsToOccurrenceValues(encoding.Ids));
|
Assert.Equal(p[4], robertaModel.ConvertIdsToOccurrenceValues(encodingIds));
|
||||||
|
|
||||||
for (int i = 0; i < encoding.Tokens.Count; i++)
|
for (int i = 0; i < tokens.Length; i++)
|
||||||
{
|
{
|
||||||
if (robertaModel.FilterUnsupportedChars)
|
if (robertaModel.FilterUnsupportedChars)
|
||||||
{
|
{
|
||||||
string[]? filteredToken = p[5] as string[];
|
string[]? filteredToken = p[5] as string[];
|
||||||
Assert.Equal(filteredToken![i], tokenizer.Model.MapIdToToken(encoding.Ids[i]));
|
Assert.Equal(filteredToken![i], tokenizer.MapIdToToken(encodingIds[i]));
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
Assert.Equal(encoding.Tokens[i], tokenizer.Model.MapIdToToken(encoding.Ids[i]));
|
Assert.Equal(tokens[i], tokenizer.MapIdToToken(encodingIds[i]));
|
||||||
string[]? unfilteredToken = p[2] as string[];
|
string[]? unfilteredToken = p[2] as string[];
|
||||||
Assert.Equal(unfilteredToken![i], tokenizer.Model.MapIdToToken(encoding.Ids[i]));
|
Assert.Equal(unfilteredToken![i], tokenizer.MapIdToToken(encodingIds[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i].AsSpan()));
|
Assert.Equal(encodingIds[i], tokenizer.MapTokenToId(tokens[i].AsSpan()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,20 +17,21 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
public class LlamaTests
|
public class LlamaTests
|
||||||
{
|
{
|
||||||
private static readonly HttpClient _httpClient = new HttpClient() { Timeout = TimeSpan.FromMinutes(5) };
|
private static readonly HttpClient _httpClient = new HttpClient() { Timeout = TimeSpan.FromMinutes(5) };
|
||||||
private static Tokenizer _llamaTokenizer = CreateLlamaTokenizer().GetAwaiter().GetResult();
|
private static Tokenizer _llamaTokenizer = CreateLlamaTokenizer();
|
||||||
private static Tokenizer _llamaMistralTokenizer = CreateLMistralTokenizer().GetAwaiter().GetResult();
|
private static Tokenizer _llamaMistralTokenizer = CreateLMistralTokenizer();
|
||||||
|
|
||||||
private static async Task<Tokenizer> CreateLlamaTokenizer()
|
private static Tokenizer CreateLlamaTokenizer()
|
||||||
{
|
{
|
||||||
const string modelUrl = @"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model";
|
// @"https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.model?download=true";
|
||||||
using Stream remoteStream = await _httpClient.GetStreamAsync(modelUrl);
|
// @"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model";
|
||||||
|
using Stream remoteStream = File.OpenRead(Path.Combine(@"Llama", "tokenizer.model"));
|
||||||
return Tokenizer.CreateLlama(remoteStream);
|
return Tokenizer.CreateLlama(remoteStream);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static async Task<Tokenizer> CreateLMistralTokenizer()
|
private static Tokenizer CreateLMistralTokenizer()
|
||||||
{
|
{
|
||||||
const string modelUrl = @"https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/tokenizer.model?download=true";
|
// @"https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/tokenizer.model?download=true";
|
||||||
using Stream remoteStream = await _httpClient.GetStreamAsync(modelUrl);
|
using Stream remoteStream = File.OpenRead(Path.Combine(@"Mistral", "tokenizer.model"));
|
||||||
return Tokenizer.CreateLlama(remoteStream);
|
return Tokenizer.CreateLlama(remoteStream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -185,13 +186,13 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
[MemberData(nameof(LlamaTestData))]
|
[MemberData(nameof(LlamaTestData))]
|
||||||
public void TestLlamaTokenizer(Tokenizer llamaTokenizer, string input, int[] ids, string[] tokens, (int Index, int Length)[] offsets)
|
public void TestLlamaTokenizer(Tokenizer llamaTokenizer, string input, int[] ids, string[] tokens, (int Index, int Length)[] offsets)
|
||||||
{
|
{
|
||||||
SentencePieceBpe? bpe = llamaTokenizer.Model as SentencePieceBpe;
|
SentencePieceBpe? bpe = llamaTokenizer as SentencePieceBpe;
|
||||||
Assert.NotNull(bpe);
|
Assert.NotNull(bpe);
|
||||||
|
|
||||||
EncodingResult result = llamaTokenizer.Encode(input);
|
IReadOnlyList<Token> result = llamaTokenizer.Encode(input, out _);
|
||||||
Assert.Equal(ids, result.Ids);
|
Assert.Equal(ids, result.Select(t => t.Id).ToArray());
|
||||||
Assert.Equal(tokens, result.Tokens);
|
Assert.Equal(tokens, result.Select(t => t.Value).ToArray());
|
||||||
Assert.Equal(offsets, result.Offsets);
|
Assert.Equal(offsets, result.Select(t => t.Offset).ToArray());
|
||||||
Assert.Equal(input, llamaTokenizer.Decode(ids));
|
Assert.Equal(input, llamaTokenizer.Decode(ids));
|
||||||
Assert.Equal(ids, llamaTokenizer.EncodeToIds(input));
|
Assert.Equal(ids, llamaTokenizer.EncodeToIds(input));
|
||||||
Assert.Equal(ids.Length, llamaTokenizer.CountTokens(input));
|
Assert.Equal(ids.Length, llamaTokenizer.CountTokens(input));
|
||||||
|
@ -208,32 +209,29 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
|
|
||||||
bool isEmptyInput = string.IsNullOrEmpty(input);
|
bool isEmptyInput = string.IsNullOrEmpty(input);
|
||||||
|
|
||||||
IReadOnlyList<Token> bpeTokens = bpe.Encode(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: false);
|
IReadOnlyList<Token> bpeTokens = bpe.Encode(normalizedInput.AsSpan(), out _, addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false);
|
||||||
Assert.Equal(ids.Skip(1), bpeTokens.Select(token => token.Id));
|
Assert.Equal(ids.Skip(1), bpeTokens.Select(token => token.Id));
|
||||||
Assert.Equal(tokens.Skip(1), bpeTokens.Select(token => token.Value));
|
Assert.Equal(tokens.Skip(1), bpeTokens.Select(token => token.Value));
|
||||||
Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id)));
|
Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id)));
|
||||||
List<int> encodedIds = new();
|
IReadOnlyList<int> encodedIds = bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false);
|
||||||
bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: false, accumulatedIds: encodedIds, out _);
|
|
||||||
Assert.Equal(ids.Skip(1), encodedIds);
|
Assert.Equal(ids.Skip(1), encodedIds);
|
||||||
Assert.Equal(isEmptyInput ? 0 : ids.Length - 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: false, out _));
|
Assert.Equal(isEmptyInput ? 0 : ids.Length - 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false));
|
||||||
|
|
||||||
bpeTokens = bpe.Encode(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: true);
|
bpeTokens = bpe.Encode(normalizedInput.AsSpan(), out _, addBeginningOfSentence: false, addEndOfSentence: true, considerNormalization: false);
|
||||||
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Skip(1).Concat(new[] { bpe.EndOfSentenceId }), bpeTokens.Select(token => token.Id));
|
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Skip(1).Concat(new[] { bpe.EndOfSentenceId }), bpeTokens.Select(token => token.Id));
|
||||||
Assert.Equal(isEmptyInput ? Array.Empty<string>() : tokens.Skip(1).Concat(new[] { bpe.EndOfSentenceToken }), bpeTokens.Select(token => token.Value));
|
Assert.Equal(isEmptyInput ? Array.Empty<string>() : tokens.Skip(1).Concat(new[] { bpe.EndOfSentenceToken }), bpeTokens.Select(token => token.Value));
|
||||||
Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id)));
|
Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id)));
|
||||||
encodedIds.Clear();
|
encodedIds = bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: true, considerNormalization: false);
|
||||||
bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: true, accumulatedIds: encodedIds, out _);
|
|
||||||
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Skip(1).Concat(new[] { bpe.EndOfSentenceId }), encodedIds);
|
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Skip(1).Concat(new[] { bpe.EndOfSentenceId }), encodedIds);
|
||||||
Assert.Equal(isEmptyInput ? 0 : ids.Length, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: true, out _));
|
Assert.Equal(isEmptyInput ? 0 : ids.Length, bpe.CountTokens(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: true, considerNormalization: false));
|
||||||
|
|
||||||
bpeTokens = bpe.Encode(normalizedInput.AsSpan(), addBeginOfSentence: true, addEndOfSentence: true);
|
bpeTokens = bpe.Encode(normalizedInput.AsSpan(), out _, addBeginningOfSentence: true, addEndOfSentence: true, considerNormalization: false);
|
||||||
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Concat(new[] { bpe.EndOfSentenceId }), bpeTokens.Select(token => token.Id));
|
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Concat(new[] { bpe.EndOfSentenceId }), bpeTokens.Select(token => token.Id));
|
||||||
Assert.Equal(isEmptyInput ? Array.Empty<string>() : tokens.Concat(new[] { bpe.EndOfSentenceToken }), bpeTokens.Select(token => token.Value));
|
Assert.Equal(isEmptyInput ? Array.Empty<string>() : tokens.Concat(new[] { bpe.EndOfSentenceToken }), bpeTokens.Select(token => token.Value));
|
||||||
Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id)));
|
Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id)));
|
||||||
encodedIds.Clear();
|
encodedIds = bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginningOfSentence: true, addEndOfSentence: true, considerNormalization: false);
|
||||||
bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: true, addEndOfSentence: true, accumulatedIds: encodedIds, out _);
|
|
||||||
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Concat(new[] { bpe.EndOfSentenceId }), encodedIds);
|
Assert.Equal(isEmptyInput ? Array.Empty<int>() : ids.Concat(new[] { bpe.EndOfSentenceId }), encodedIds);
|
||||||
Assert.Equal(isEmptyInput ? 0 : ids.Length + 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: true, addEndOfSentence: true, out _));
|
Assert.Equal(isEmptyInput ? 0 : ids.Length + 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginningOfSentence: true, addEndOfSentence: true, considerNormalization: false));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static IEnumerable<object[]> LlamaTokenizersListData()
|
public static IEnumerable<object[]> LlamaTokenizersListData()
|
||||||
|
@ -244,11 +242,17 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
|
|
||||||
[Theory]
|
[Theory]
|
||||||
[MemberData(nameof(LlamaTokenizersListData))]
|
[MemberData(nameof(LlamaTokenizersListData))]
|
||||||
public void TestLlamaTokenizerWithInvalidInput(Tokenizer llamaTokenizer)
|
public void TestLlamaTokenizerWithEmptyInput(Tokenizer llamaTokenizer)
|
||||||
{
|
{
|
||||||
Assert.Throws<ArgumentNullException>(() => llamaTokenizer.Encode(null!));
|
Assert.Equal([], llamaTokenizer.Encode((string)null!, out _));
|
||||||
Assert.Throws<ArgumentNullException>(() => llamaTokenizer.EncodeToIds(null!));
|
Assert.Equal([], llamaTokenizer.Encode(Span<char>.Empty, out _));
|
||||||
Assert.Throws<ArgumentNullException>(() => llamaTokenizer.CountTokens(null!));
|
|
||||||
|
Assert.Equal([], llamaTokenizer.EncodeToIds((string)null!));
|
||||||
|
Assert.Equal([], llamaTokenizer.EncodeToIds(Span<char>.Empty));
|
||||||
|
|
||||||
|
Assert.Equal(0, llamaTokenizer.CountTokens((string)null!));
|
||||||
|
Assert.Equal(0, llamaTokenizer.CountTokens(Span<char>.Empty));
|
||||||
|
|
||||||
Assert.Throws<ArgumentNullException>(() => llamaTokenizer.Decode(null!));
|
Assert.Throws<ArgumentNullException>(() => llamaTokenizer.Decode(null!));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -256,7 +260,7 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
[MemberData(nameof(LlamaTokenizersListData))]
|
[MemberData(nameof(LlamaTokenizersListData))]
|
||||||
public void TestLlamaTokenizerProperties(Tokenizer llamaTokenizer)
|
public void TestLlamaTokenizerProperties(Tokenizer llamaTokenizer)
|
||||||
{
|
{
|
||||||
SentencePieceBpe? bpe = llamaTokenizer.Model as SentencePieceBpe;
|
SentencePieceBpe? bpe = llamaTokenizer as SentencePieceBpe;
|
||||||
Assert.NotNull(bpe);
|
Assert.NotNull(bpe);
|
||||||
Assert.NotNull(llamaTokenizer.Normalizer);
|
Assert.NotNull(llamaTokenizer.Normalizer);
|
||||||
|
|
||||||
|
@ -284,34 +288,242 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
}
|
}
|
||||||
|
|
||||||
[Fact]
|
[Fact]
|
||||||
public void TestLlamaNormalizer()
|
public void TestSentencePieceNormalizer()
|
||||||
{
|
{
|
||||||
LlamaNormalizer normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: false, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false);
|
SentencePieceNormalizer normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: false, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false);
|
||||||
Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!"));
|
Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!"));
|
||||||
|
Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!".AsSpan()));
|
||||||
|
|
||||||
normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: false, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false);
|
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: false, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false);
|
||||||
Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!"));
|
Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!"));
|
||||||
|
Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!".AsSpan()));
|
||||||
|
|
||||||
normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false);
|
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false);
|
||||||
Assert.Equal(" Hello, World!", normalizer.Normalize("Hello, World!"));
|
Assert.Equal(" Hello, World!", normalizer.Normalize("Hello, World!"));
|
||||||
|
Assert.Equal(" Hello, World!", normalizer.Normalize("Hello, World!".AsSpan()));
|
||||||
|
|
||||||
normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: false);
|
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: false);
|
||||||
Assert.Equal("▁Hello,▁World!", normalizer.Normalize("Hello, World!"));
|
Assert.Equal("▁Hello,▁World!", normalizer.Normalize("Hello, World!"));
|
||||||
|
Assert.Equal("▁Hello,▁World!", normalizer.Normalize("Hello, World!".AsSpan()));
|
||||||
|
|
||||||
normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: false);
|
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: false);
|
||||||
Assert.Equal("▁Hello,▁▁▁▁▁▁World!", normalizer.Normalize("Hello, World!"));
|
Assert.Equal("▁Hello,▁▁▁▁▁▁World!", normalizer.Normalize("Hello, World!"));
|
||||||
|
Assert.Equal("▁Hello,▁▁▁▁▁▁World!", normalizer.Normalize("Hello, World!".AsSpan()));
|
||||||
|
|
||||||
normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true);
|
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true);
|
||||||
Assert.Equal("Hello,▁World!▁", normalizer.Normalize("Hello, World!"));
|
Assert.Equal("Hello,▁World!▁", normalizer.Normalize("Hello, World!"));
|
||||||
|
Assert.Equal("Hello,▁World!▁", normalizer.Normalize("Hello, World!".AsSpan()));
|
||||||
|
|
||||||
normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: false, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true);
|
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: false, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true);
|
||||||
Assert.Equal("Hello,▁World!", normalizer.Normalize("Hello, World!"));
|
Assert.Equal("Hello,▁World!", normalizer.Normalize("Hello, World!"));
|
||||||
|
Assert.Equal("Hello,▁World!", normalizer.Normalize("Hello, World!".AsSpan()));
|
||||||
|
|
||||||
normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true);
|
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true);
|
||||||
Assert.Equal("Hello,▁▁▁▁▁▁World!▁", normalizer.Normalize("Hello, World!"));
|
Assert.Equal("Hello,▁▁▁▁▁▁World!▁", normalizer.Normalize("Hello, World!"));
|
||||||
|
Assert.Equal("Hello,▁▁▁▁▁▁World!▁", normalizer.Normalize("Hello, World!".AsSpan()));
|
||||||
|
|
||||||
normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: true);
|
normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: true);
|
||||||
Assert.Equal("Hello, World! ", normalizer.Normalize("Hello, World!"));
|
Assert.Equal("Hello, World! ", normalizer.Normalize("Hello, World!"));
|
||||||
|
Assert.Equal("Hello, World! ", normalizer.Normalize("Hello, World!".AsSpan()));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static IEnumerable<object?[]> TokenizerTestData
|
||||||
|
{
|
||||||
|
get
|
||||||
|
{
|
||||||
|
// string to tokenize, produced tokens, the token offsets
|
||||||
|
yield return new object?[]
|
||||||
|
{
|
||||||
|
"the brown fox jumped over the lazy dog!",
|
||||||
|
"▁the▁brown▁fox▁jumped▁over▁the▁lazy▁dog!",
|
||||||
|
new string[] { "<s>", "▁the", "▁brown", "▁fo", "x", "▁jump", "ed", "▁over", "▁the", "▁lazy", "▁dog", "!" },
|
||||||
|
new (int Index, int Length)[] { (0, 0), (0, 4), (4, 6), (10, 3), (13, 1), (14, 5), (19, 2), (21, 5), (26, 4), (30, 5), (35, 4), (39, 1) },
|
||||||
|
new int[] { 1, 278, 17354, 1701, 29916, 12500, 287, 975, 278, 17366, 11203, 29991 }
|
||||||
|
};
|
||||||
|
yield return new object?[]
|
||||||
|
{
|
||||||
|
"he traveled to Egypt during the summer, the weather was hot and ammunition." ,
|
||||||
|
"▁he▁traveled▁to▁Egypt▁during▁the▁summer,▁the▁weather▁was▁hot▁and▁ammunition." ,
|
||||||
|
new string[] { "<s>", "▁he", "▁tra", "ve", "led", "▁to", "▁Egypt", "▁during", "▁the", "▁summer", ",", "▁the", "▁weather", "▁was", "▁hot", "▁and", "▁am", "mun", "ition", "." },
|
||||||
|
new (int Index, int Length)[] { (0, 0), (0, 3), (3, 4), (7, 2), (9, 3), (12, 3), (15, 6), (21, 7), (28, 4), (32, 7), (39, 1), (40, 4), (44, 8), (52, 4), (56, 4), (60, 4), (64, 3), (67, 3), (70, 5), (75, 1) },
|
||||||
|
new int[] { 1, 540, 1020, 345, 839, 304, 12892, 2645, 278, 11801, 29892, 278, 14826, 471, 7375, 322, 626, 24579, 654, 29889 }
|
||||||
|
};
|
||||||
|
yield return new object?[]
|
||||||
|
{
|
||||||
|
"She played many games and she felt exhausted afterward",
|
||||||
|
"▁She▁played▁many▁games▁and▁she▁felt▁exhausted▁afterward",
|
||||||
|
new string[] { "<s>", "▁She", "▁played", "▁many", "▁games", "▁and", "▁she", "▁felt", "▁exha", "usted", "▁after", "ward" },
|
||||||
|
new (int Index, int Length)[] { (0, 0), (0, 4), (4, 7), (11, 5), (16, 6), (22, 4), (26, 4), (30, 5), (35, 5), (40, 5), (45, 6), (51, 4) },
|
||||||
|
new int[] { 1, 2296, 5318, 1784, 8090, 322, 1183, 7091, 18782, 16656, 1156, 1328 }
|
||||||
|
};
|
||||||
|
yield return new object?[]
|
||||||
|
{
|
||||||
|
"Hello, y'all! How are you 😁 ?",
|
||||||
|
"▁Hello,▁y'all!▁How▁are▁you▁😁▁?",
|
||||||
|
new string[] { "<s>", "▁Hello", ",", "▁y", "'", "all", "!", "▁How", "▁are", "▁you", "▁", "<0xF0>", "<0x9F>", "<0x98>", "<0x81>", "▁?" },
|
||||||
|
new (int Index, int Length)[] { (0, 0), (0, 6), (6, 1), (7, 2), (9, 1), (10, 3), (13, 1), (14, 4), (18, 4), (22, 4), (26, 1), (27, 2), (27, 0), (27, 0), (27, 0), (29, 2) },
|
||||||
|
new int[] { 1, 15043, 29892, 343, 29915, 497, 29991, 1128, 526, 366, 29871, 243, 162, 155, 132, 1577 }
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Theory]
|
||||||
|
[MemberData(nameof(TokenizerTestData))]
|
||||||
|
public void TestTokenizerEncoding(string text, string normalizedText, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds)
|
||||||
|
{
|
||||||
|
Tokenizer tokenizer = _llamaTokenizer;
|
||||||
|
|
||||||
|
Assert.NotNull(tokenizer.Normalizer);
|
||||||
|
Assert.Null(tokenizer.PreTokenizer);
|
||||||
|
|
||||||
|
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
|
||||||
|
IReadOnlyList<Token> encoding1 = tokenizer.Encode(text.AsSpan(), out _);
|
||||||
|
|
||||||
|
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
|
||||||
|
Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray());
|
||||||
|
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
|
||||||
|
|
||||||
|
Assert.Equal(expectedTokens, encoding1.Select(t => t.Value).ToArray());
|
||||||
|
Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray());
|
||||||
|
Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray());
|
||||||
|
|
||||||
|
SentencePieceBpe sentencePieceBpe = (tokenizer as SentencePieceBpe)!;
|
||||||
|
foreach (bool considerNormalization in new[] { true, false })
|
||||||
|
foreach (bool addBeginningOfSentence in new[] { true, false })
|
||||||
|
foreach (bool addEndOfSentence in new[] { true, false })
|
||||||
|
{
|
||||||
|
encoding = sentencePieceBpe.Encode(
|
||||||
|
considerNormalization ? text : normalizedText,
|
||||||
|
out _,
|
||||||
|
addBeginningOfSentence: addBeginningOfSentence,
|
||||||
|
addEndOfSentence: addEndOfSentence,
|
||||||
|
considerPreTokenization: false,
|
||||||
|
considerNormalization: considerNormalization);
|
||||||
|
|
||||||
|
encoding1 = sentencePieceBpe.Encode(
|
||||||
|
considerNormalization ? text.AsSpan() : normalizedText.AsSpan(),
|
||||||
|
out _,
|
||||||
|
addBeginningOfSentence: addBeginningOfSentence,
|
||||||
|
addEndOfSentence: addEndOfSentence,
|
||||||
|
considerPreTokenization: false,
|
||||||
|
considerNormalization: considerNormalization);
|
||||||
|
|
||||||
|
string[] expectedTokens1 = addBeginningOfSentence ? expectedTokens : expectedTokens.Skip(1).ToArray();
|
||||||
|
expectedTokens1 = addEndOfSentence ? expectedTokens1.Concat(new[] { sentencePieceBpe.EndOfSentenceToken }).ToArray() : expectedTokens1;
|
||||||
|
|
||||||
|
(int Index, int Length)[] expectedOffsets1 = addBeginningOfSentence ? expectedOffsets : expectedOffsets.Skip(1).ToArray();
|
||||||
|
expectedOffsets1 = addEndOfSentence ? expectedOffsets1.Concat(new[] { (normalizedText.Length, 0) }).ToArray() : expectedOffsets1;
|
||||||
|
|
||||||
|
int[] expectedIds1 = addBeginningOfSentence ? expectedIds : expectedIds.Skip(1).ToArray();
|
||||||
|
expectedIds1 = addEndOfSentence ? expectedIds1.Concat(new[] { sentencePieceBpe.EndOfSentenceId }).ToArray() : expectedIds1;
|
||||||
|
|
||||||
|
Assert.Equal(expectedTokens1, encoding.Select(t => t.Value).ToArray());
|
||||||
|
Assert.Equal(expectedOffsets1, encoding.Select(t => t.Offset).ToArray());
|
||||||
|
Assert.Equal(expectedIds1, encoding.Select(t => t.Id).ToArray());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Theory]
|
||||||
|
[MemberData(nameof(TokenizerTestData))]
|
||||||
|
public void TestTokenizerEncodingToIds(string text, string normalizedText, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds)
|
||||||
|
{
|
||||||
|
Tokenizer tokenizer = _llamaTokenizer;
|
||||||
|
|
||||||
|
Assert.NotNull(expectedTokens);
|
||||||
|
Assert.NotNull(expectedOffsets);
|
||||||
|
|
||||||
|
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text));
|
||||||
|
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan()));
|
||||||
|
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text, expectedIds.Length, out string? normalizedString, out int length));
|
||||||
|
Assert.Equal(normalizedText, normalizedString);
|
||||||
|
Assert.Equal(normalizedText.Length, length);
|
||||||
|
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length, out normalizedString, out length));
|
||||||
|
Assert.Equal(normalizedText, normalizedString);
|
||||||
|
Assert.Equal(normalizedText.Length, length);
|
||||||
|
|
||||||
|
SentencePieceBpe sentencePieceBpe = (tokenizer as SentencePieceBpe)!;
|
||||||
|
foreach (bool considerNormalization in new[] { true, false })
|
||||||
|
foreach (bool addBeginningOfSentence in new[] { true, false })
|
||||||
|
foreach (bool addEndOfSentence in new[] { true, false })
|
||||||
|
{
|
||||||
|
// (string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||||
|
|
||||||
|
int[] expectedIds1 = addBeginningOfSentence ? expectedIds : expectedIds.Skip(1).ToArray();
|
||||||
|
expectedIds1 = addEndOfSentence ? expectedIds1.Concat(new[] { sentencePieceBpe.EndOfSentenceId }).ToArray() : expectedIds1;
|
||||||
|
|
||||||
|
Assert.Equal(expectedIds1, sentencePieceBpe.EncodeToIds(
|
||||||
|
considerNormalization ? text : normalizedText,
|
||||||
|
addBeginningOfSentence: addBeginningOfSentence,
|
||||||
|
addEndOfSentence: addEndOfSentence,
|
||||||
|
expectedIds1.Length,
|
||||||
|
out normalizedString,
|
||||||
|
out length,
|
||||||
|
considerNormalization: considerNormalization));
|
||||||
|
|
||||||
|
Assert.Equal(expectedIds1, sentencePieceBpe.EncodeToIds(
|
||||||
|
considerNormalization ? text.AsSpan() : normalizedText.AsSpan(),
|
||||||
|
addBeginningOfSentence: addBeginningOfSentence,
|
||||||
|
addEndOfSentence: addEndOfSentence,
|
||||||
|
expectedIds1.Length,
|
||||||
|
out normalizedString,
|
||||||
|
out length,
|
||||||
|
considerNormalization: considerNormalization));
|
||||||
|
|
||||||
|
Assert.Equal(considerNormalization ? normalizedText : null, normalizedString);
|
||||||
|
Assert.Equal(normalizedText.Length, length);
|
||||||
|
|
||||||
|
Assert.Equal(expectedIds1.Take(expectedIds1.Length - 6), sentencePieceBpe.EncodeToIds(
|
||||||
|
considerNormalization ? text : normalizedText,
|
||||||
|
addBeginningOfSentence: addBeginningOfSentence,
|
||||||
|
addEndOfSentence: addEndOfSentence,
|
||||||
|
expectedIds1.Length - 6,
|
||||||
|
out normalizedString,
|
||||||
|
out length,
|
||||||
|
considerNormalization: considerNormalization));
|
||||||
|
Assert.Equal(considerNormalization ? normalizedText : null, normalizedString);
|
||||||
|
|
||||||
|
(int Index, int Length)[] expectedOffsets1 = addBeginningOfSentence ? expectedOffsets.Take(expectedIds1.Length - 6).ToArray() : expectedOffsets.Skip(1).Take(expectedIds1.Length - 6).ToArray();
|
||||||
|
|
||||||
|
int expectedLength = expectedOffsets1[expectedOffsets1.Length - 1].Index + expectedOffsets1[expectedOffsets1.Length - 1].Length;
|
||||||
|
Assert.Equal(expectedLength, length);
|
||||||
|
|
||||||
|
Assert.Equal(expectedIds1.Take(expectedIds1.Length - 6), sentencePieceBpe.EncodeToIds(
|
||||||
|
considerNormalization ? text.AsSpan() : normalizedText.AsSpan(),
|
||||||
|
addBeginningOfSentence: addBeginningOfSentence,
|
||||||
|
addEndOfSentence: addEndOfSentence,
|
||||||
|
expectedIds1.Length - 6,
|
||||||
|
out normalizedString,
|
||||||
|
out length,
|
||||||
|
considerNormalization: considerNormalization));
|
||||||
|
Assert.Equal(expectedLength, length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
[Theory]
|
||||||
|
[MemberData(nameof(TokenizerTestData))]
|
||||||
|
public void TestTokenizerCountTokens(string text, string normalizedText, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds)
|
||||||
|
{
|
||||||
|
Tokenizer tokenizer = _llamaTokenizer;
|
||||||
|
|
||||||
|
Assert.NotNull(expectedTokens);
|
||||||
|
|
||||||
|
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text));
|
||||||
|
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text.AsSpan()));
|
||||||
|
|
||||||
|
Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index + expectedOffsets[expectedOffsets.Length - 7].Length, tokenizer.IndexOfTokenCount(text, expectedIds.Length - 6, out string? normalizedString, out int tokenCount));
|
||||||
|
Assert.Equal(normalizedText, normalizedString);
|
||||||
|
Assert.Equal(expectedIds.Length - 6, tokenCount);
|
||||||
|
Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index + expectedOffsets[expectedOffsets.Length - 7].Length, tokenizer.IndexOfTokenCount(text.AsSpan(), expectedIds.Length - 6, out normalizedString, out tokenCount));
|
||||||
|
Assert.Equal(normalizedText, normalizedString);
|
||||||
|
Assert.Equal(expectedIds.Length - 6, tokenCount);
|
||||||
|
|
||||||
|
Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index, tokenizer.LastIndexOfTokenCount(text, 7, out normalizedString, out tokenCount));
|
||||||
|
Assert.Equal(normalizedText, normalizedString);
|
||||||
|
Assert.Equal(7, tokenCount);
|
||||||
|
Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index, tokenizer.LastIndexOfTokenCount(text.AsSpan(), 7, out normalizedString, out tokenCount));
|
||||||
|
Assert.Equal(normalizedText, normalizedString);
|
||||||
|
Assert.Equal(7, tokenCount);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,6 +42,7 @@
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<PackageReference Include="Microsoft.DotNet.RemoteExecutor" Version="$(MicrosoftDotNetRemoteExecutorVersion)" />
|
<PackageReference Include="Microsoft.DotNet.RemoteExecutor" Version="$(MicrosoftDotNetRemoteExecutorVersion)" />
|
||||||
<PackageReference Include="System.Text.Json" Version="$(SystemTextJsonVersion)" />
|
<PackageReference Include="System.Text.Json" Version="$(SystemTextJsonVersion)" />
|
||||||
|
<PackageReference Include="Microsoft.ML.TestTokenizers" Version="$(MicrosoftMLTestTokenizersVersion)" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
|
|
||||||
</Project>
|
</Project>
|
|
@ -61,10 +61,9 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
string normalizedText = normalizer.Normalize(text);
|
string normalizedText = normalizer.Normalize(text);
|
||||||
Assert.Equal(normalized, normalizedText);
|
Assert.Equal(normalized, normalizedText);
|
||||||
|
|
||||||
Tokenizer tokenizer = new Tokenizer(BpeTests.CreateEmptyBpe(), WhiteSpace.Instance, normalizer);
|
Tokenizer tokenizer = BpeTests.CreateEmptyBpe(preTokenizer: null, normalizer);
|
||||||
EncodingResult encoding = tokenizer.Encode(text);
|
IReadOnlyList<Token> tokens = tokenizer.Encode(text, out string? normalizedString);
|
||||||
Assert.Equal(text, encoding.OriginalString);
|
Assert.Equal(normalized, normalizedString);
|
||||||
Assert.Equal(normalized, encoding.NormalizedString);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public class RemoveQuotesNormalizer : Normalizer
|
public class RemoveQuotesNormalizer : Normalizer
|
||||||
|
@ -77,6 +76,22 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
return original;
|
return original;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return RemoveQuotes(original.AsSpan(), index);
|
||||||
|
}
|
||||||
|
|
||||||
|
public override string Normalize(ReadOnlySpan<char> original)
|
||||||
|
{
|
||||||
|
int index = original.IndexOf('"');
|
||||||
|
if (index <= 0)
|
||||||
|
{
|
||||||
|
return original.ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
return RemoveQuotes(original, index);
|
||||||
|
}
|
||||||
|
|
||||||
|
private string RemoveQuotes(ReadOnlySpan<char> original, int index)
|
||||||
|
{
|
||||||
StringBuilder sb = new StringBuilder(original.Length);
|
StringBuilder sb = new StringBuilder(original.Length);
|
||||||
List<int> mapping = new List<int>();
|
List<int> mapping = new List<int>();
|
||||||
|
|
||||||
|
@ -97,7 +112,7 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
index = original.IndexOf('"', start);
|
index = original.Slice(start).IndexOf('"');
|
||||||
if (index <= 0)
|
if (index <= 0)
|
||||||
{
|
{
|
||||||
for (int i = start; i < original.Length; i++)
|
for (int i = start; i < original.Length; i++)
|
||||||
|
@ -107,6 +122,8 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
index += start;
|
||||||
} while (true);
|
} while (true);
|
||||||
|
|
||||||
return sb.ToString();
|
return sb.ToString();
|
||||||
|
@ -130,6 +147,16 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
|
|
||||||
return original.Normalize(_normalizationForm);
|
return original.Normalize(_normalizationForm);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public override string Normalize(ReadOnlySpan<char> original)
|
||||||
|
{
|
||||||
|
if (original.IsEmpty)
|
||||||
|
{
|
||||||
|
return string.Empty;
|
||||||
|
}
|
||||||
|
|
||||||
|
return original.ToString().Normalize(_normalizationForm);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,62 +20,63 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
{
|
{
|
||||||
WhiteSpace.Instance,
|
WhiteSpace.Instance,
|
||||||
"How are you doing?",
|
"How are you doing?",
|
||||||
new Split[] { new Split("How", (0, 3)), new Split("are", (4, 3)), new Split("you", (8, 3)), new Split("doing", (12, 5)), new Split("?", (17, 1)),}
|
new (int Offset, int Length)[] { (0, 3), (4, 3), (8, 3), (12, 5), (17, 1), }
|
||||||
};
|
};
|
||||||
|
|
||||||
yield return new object[]
|
yield return new object[]
|
||||||
{
|
{
|
||||||
WhiteSpace.Instance,
|
WhiteSpace.Instance,
|
||||||
"I_am_Just_Fine!",
|
"I_am_Just_Fine!",
|
||||||
new Split[] { new Split("I_am_Just_Fine", (0, 14)), new Split("!", (14, 1)) }
|
new (int Offset, int Length)[] { (0, 14), (14, 1) }
|
||||||
};
|
};
|
||||||
|
|
||||||
yield return new object[]
|
yield return new object[]
|
||||||
{
|
{
|
||||||
new SpacePreTokenizer(),
|
new SpacePreTokenizer(),
|
||||||
"How are you doing?!",
|
"How are you doing?!",
|
||||||
new Split[] { new Split("How", (0, 3)), new Split("are", (4, 3)), new Split("you", (11, 3)), new Split("doing?!", (15, 7)) }
|
new (int Offset, int Length)[] { (0, 3), (4, 3), (11, 3), (15, 7) }
|
||||||
};
|
};
|
||||||
|
|
||||||
yield return new object[]
|
yield return new object[]
|
||||||
{
|
{
|
||||||
new SpacePreTokenizer(),
|
new SpacePreTokenizer(),
|
||||||
new string(' ', 100),
|
new string(' ', 100),
|
||||||
new Split[] { }
|
new (int Offset, int Length)[] { }
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
[Theory]
|
[Theory]
|
||||||
[MemberData(nameof(PreTokenizerData))]
|
[MemberData(nameof(PreTokenizerData))]
|
||||||
public void TestPreTokenizer(PreTokenizer preTokenizer, string text, Split[] splits)
|
public void TestPreTokenizer(PreTokenizer preTokenizer, string text, (int Offset, int Length)[] splits)
|
||||||
{
|
{
|
||||||
Split[] splitParts = preTokenizer.PreTokenize(text).ToArray<Split>();
|
(int Offset, int Length)[] splitParts = preTokenizer.PreTokenize(text).ToArray<(int Offset, int Length)>();
|
||||||
Assert.Equal(splits, splitParts);
|
Assert.Equal(splits, splitParts);
|
||||||
|
|
||||||
// Empty tokenizer which tokenize all parts as unknown tokens.
|
// Empty tokenizer which tokenize all parts as unknown tokens.
|
||||||
Tokenizer tokenizer = new Tokenizer(BpeTests.CreateEmptyBpe(), preTokenizer);
|
Tokenizer tokenizer = BpeTests.CreateEmptyBpe(normalizer: null, preTokenizer: preTokenizer);
|
||||||
|
|
||||||
EncodingResult encoding = tokenizer.Encode(text);
|
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
|
||||||
Assert.True(encoding.Tokens.Count >= splitParts.Length, $"Expected to have {encoding.Tokens.Count} >= {splitParts.Length}");
|
Assert.True(encoding.Count >= splitParts.Length, $"Expected to have {encoding.Count} >= {splitParts.Length}");
|
||||||
}
|
}
|
||||||
|
|
||||||
[Fact]
|
[Fact]
|
||||||
public void TestWhiteSpacePreTokenizer()
|
public void TestWhiteSpacePreTokenizer()
|
||||||
{
|
{
|
||||||
Assert.Empty(WhiteSpace.Instance.PreTokenize(null!));
|
Assert.Empty(WhiteSpace.Instance.PreTokenize((string)null!));
|
||||||
}
|
}
|
||||||
|
|
||||||
public class SpacePreTokenizer : PreTokenizer
|
public class SpacePreTokenizer : PreTokenizer
|
||||||
{
|
{
|
||||||
public override IEnumerable<Split> PreTokenize(string text)
|
public override IEnumerable<(int Offset, int Length)> PreTokenize(ReadOnlySpan<char> text)
|
||||||
{
|
{
|
||||||
List<Split> splits = new();
|
if (text.IsEmpty)
|
||||||
if (string.IsNullOrEmpty(text))
|
|
||||||
{
|
{
|
||||||
return splits;
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
List<(int Offset, int Length)> splits = new();
|
||||||
|
|
||||||
int index = 0;
|
int index = 0;
|
||||||
while (true)
|
while (true)
|
||||||
{
|
{
|
||||||
|
@ -92,7 +93,7 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
|
|
||||||
if (index < text.Length)
|
if (index < text.Length)
|
||||||
{
|
{
|
||||||
splits.Add(new Split(text.Substring(index, end - index), (index, end - index)));
|
splits.Add((index, end - index));
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
@ -104,6 +105,16 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
|
|
||||||
return splits;
|
return splits;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public override IEnumerable<(int Offset, int Length)> PreTokenize(string text)
|
||||||
|
{
|
||||||
|
if (string.IsNullOrEmpty(text))
|
||||||
|
{
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
return PreTokenize(text.AsSpan());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,8 +38,8 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
{
|
{
|
||||||
TestGPT4TokenizationEncoding(GPT4);
|
TestGPT4TokenizationEncoding(GPT4);
|
||||||
|
|
||||||
Assert.True(GPT4.Model is Tiktoken);
|
Assert.True(GPT4 is Tiktoken);
|
||||||
IReadOnlyDictionary<string, int>? specialTokensEncoder = (GPT4.Model as Tiktoken)!.SpecialTokens;
|
IReadOnlyDictionary<string, int>? specialTokensEncoder = (GPT4 as Tiktoken)!.SpecialTokens;
|
||||||
|
|
||||||
string tokenizerDataFileName = Utils.CreateTemporaryFile("tiktoken");
|
string tokenizerDataFileName = Utils.CreateTemporaryFile("tiktoken");
|
||||||
|
|
||||||
|
@ -53,21 +53,21 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
Tokenizer tokenizer = new Tokenizer(new Tiktoken(tokenizerDataFileName, specialTokensEncoder), GPT4.PreTokenizer);
|
Tokenizer tokenizer = new Tiktoken(tokenizerDataFileName, GPT4.PreTokenizer, specialTokensEncoder);
|
||||||
TestGPT4TokenizationEncoding(tokenizer);
|
TestGPT4TokenizationEncoding(tokenizer);
|
||||||
|
|
||||||
using (Stream stream = File.OpenRead(tokenizerDataFileName))
|
using (Stream stream = File.OpenRead(tokenizerDataFileName))
|
||||||
{
|
{
|
||||||
tokenizer = new Tokenizer(new Tiktoken(stream, specialTokensEncoder), GPT4.PreTokenizer);
|
tokenizer = new Tiktoken(stream, GPT4.PreTokenizer, specialTokensEncoder);
|
||||||
}
|
}
|
||||||
TestGPT4TokenizationEncoding(tokenizer);
|
TestGPT4TokenizationEncoding(tokenizer);
|
||||||
|
|
||||||
tokenizer = new Tokenizer(await Tiktoken.CreateAsync(tokenizerDataFileName, specialTokensEncoder), GPT4.PreTokenizer);
|
tokenizer = await Tokenizer.CreateTiktokenAsync(tokenizerDataFileName, GPT4.PreTokenizer, normalizer: null, specialTokensEncoder);
|
||||||
TestGPT4TokenizationEncoding(tokenizer);
|
TestGPT4TokenizationEncoding(tokenizer);
|
||||||
|
|
||||||
using (Stream stream = File.OpenRead(tokenizerDataFileName))
|
using (Stream stream = File.OpenRead(tokenizerDataFileName))
|
||||||
{
|
{
|
||||||
tokenizer = new Tokenizer(await Tiktoken.CreateAsync(stream, specialTokensEncoder), GPT4.PreTokenizer);
|
tokenizer = await Tokenizer.CreateTiktokenAsync(stream, GPT4.PreTokenizer, normalizer: null, specialTokensEncoder);
|
||||||
}
|
}
|
||||||
TestGPT4TokenizationEncoding(tokenizer);
|
TestGPT4TokenizationEncoding(tokenizer);
|
||||||
|
|
||||||
|
@ -109,11 +109,11 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
Tiktoken tiktoken = (tokenizer.Model as Tiktoken)!;
|
Tiktoken tiktoken = (tokenizer as Tiktoken)!;
|
||||||
Tokenizer externalTokenizer = new Tokenizer(new Tiktoken(tokenizerDataFileName, tiktoken.SpecialTokens), tokenizer.PreTokenizer);
|
Tokenizer externalTokenizer = new Tiktoken(tokenizerDataFileName, tokenizer.PreTokenizer, tiktoken.SpecialTokens);
|
||||||
|
|
||||||
IReadOnlyDictionary<ReadOnlyMemory<byte>, int> encoder = tiktoken.Encoder;
|
IReadOnlyDictionary<ReadOnlyMemory<byte>, int> encoder = tiktoken.Encoder;
|
||||||
IReadOnlyDictionary<ReadOnlyMemory<byte>, int> externalEncoder = (externalTokenizer.Model as Tiktoken)!.Encoder;
|
IReadOnlyDictionary<ReadOnlyMemory<byte>, int> externalEncoder = (externalTokenizer as Tiktoken)!.Encoder;
|
||||||
|
|
||||||
Assert.Equal(externalEncoder.Count, encoder.Count);
|
Assert.Equal(externalEncoder.Count, encoder.Count);
|
||||||
foreach (KeyValuePair<ReadOnlyMemory<byte>, int> kvp in encoder)
|
foreach (KeyValuePair<ReadOnlyMemory<byte>, int> kvp in encoder)
|
||||||
|
@ -135,13 +135,17 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
Assert.Equal(new List<int>() { 9906, 4435 }, encoded);
|
Assert.Equal(new List<int>() { 9906, 4435 }, encoded);
|
||||||
Assert.Equal(text, tokenizer.Decode(encoded.ToArray())!);
|
Assert.Equal(text, tokenizer.Decode(encoded.ToArray())!);
|
||||||
|
|
||||||
EncodingResult result = tokenizer.Encode(text);
|
IReadOnlyList<Token> result = tokenizer.Encode(text, out string? normalizedString);
|
||||||
int idsCount = tokenizer.CountTokens(text);
|
int idsCount = tokenizer.CountTokens(text);
|
||||||
Assert.Equal(encoded, result.Ids);
|
|
||||||
Assert.Equal(new string[] { "Hello", " World" }, result.Tokens);
|
int[] ids = result.Select(token => token.Id).ToArray();
|
||||||
Assert.Equal(new List<(int, int)> { (0, 5), (5, 6) }, result.Offsets);
|
string[] tokens = result.Select(token => token.Value).ToArray();
|
||||||
|
(int, int)[] offsets = result.Select(token => token.Offset).ToArray();
|
||||||
|
Assert.Equal(encoded, ids);
|
||||||
|
Assert.Equal(new string[] { "Hello", " World" }, tokens);
|
||||||
|
Assert.Equal(new List<(int, int)> { (0, 5), (5, 6) }, offsets);
|
||||||
Assert.Equal(encoded.Count, idsCount);
|
Assert.Equal(encoded.Count, idsCount);
|
||||||
Assert.Equal(encoded, result.Ids);
|
Assert.Equal(encoded, ids);
|
||||||
|
|
||||||
TestGPT4Tokenizer(tokenizer);
|
TestGPT4Tokenizer(tokenizer);
|
||||||
}
|
}
|
||||||
|
@ -154,13 +158,18 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
Assert.Equal(new List<int>() { 100264, 9906, 4435, 100265 }, encoded);
|
Assert.Equal(new List<int>() { 100264, 9906, 4435, 100265 }, encoded);
|
||||||
Assert.Equal(text, GPT4.Decode(encoded.ToArray()));
|
Assert.Equal(text, GPT4.Decode(encoded.ToArray()));
|
||||||
|
|
||||||
EncodingResult result = GPT4.Encode(text);
|
IReadOnlyList<Token> result = GPT4.Encode(text, out string? normalizedString);
|
||||||
int idsCount = GPT4.CountTokens(text);
|
int idsCount = GPT4.CountTokens(text);
|
||||||
Assert.Equal(encoded, result.Ids);
|
|
||||||
Assert.Equal(new string[] { "<|im_start|>", "Hello", " World", "<|im_end|>" }, result.Tokens);
|
int[] ids = result.Select(token => token.Id).ToArray();
|
||||||
Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 6), (23, 10) }, result.Offsets);
|
string[] tokens = result.Select(token => token.Value).ToArray();
|
||||||
|
(int, int)[] offsets = result.Select(token => token.Offset).ToArray();
|
||||||
|
|
||||||
|
Assert.Equal(encoded, ids);
|
||||||
|
Assert.Equal(new string[] { "<|im_start|>", "Hello", " World", "<|im_end|>" }, tokens);
|
||||||
|
Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 6), (23, 10) }, offsets);
|
||||||
Assert.Equal(encoded.Count, idsCount);
|
Assert.Equal(encoded.Count, idsCount);
|
||||||
Assert.Equal(encoded, result.Ids);
|
Assert.Equal(encoded, ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void TestGPT4Tokenizer(Tokenizer gpt4Tokenizer)
|
private void TestGPT4Tokenizer(Tokenizer gpt4Tokenizer)
|
||||||
|
@ -192,12 +201,16 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
string? decoded = GPT4.Decode(encoded.ToArray());
|
string? decoded = GPT4.Decode(encoded.ToArray());
|
||||||
Assert.Equal(text, decoded);
|
Assert.Equal(text, decoded);
|
||||||
|
|
||||||
EncodingResult result = GPT4.Encode(text);
|
IReadOnlyList<Token> result = GPT4.Encode(text, out string? normalizedString);
|
||||||
|
int[] ids = result.Select(token => token.Id).ToArray();
|
||||||
|
string[] tokens = result.Select(token => token.Value).ToArray();
|
||||||
|
(int, int)[] offsets = result.Select(token => token.Offset).ToArray();
|
||||||
|
|
||||||
int idsCount = GPT4.CountTokens(text);
|
int idsCount = GPT4.CountTokens(text);
|
||||||
Assert.Equal(encoded, result.Ids);
|
Assert.Equal(encoded, ids);
|
||||||
Assert.Equal(encoded.Count, idsCount);
|
Assert.Equal(encoded.Count, idsCount);
|
||||||
Assert.Equal(new string[] { "<|im_start|>", "Hello", "<|im_end|>", " World" }, result.Tokens);
|
Assert.Equal(new string[] { "<|im_start|>", "Hello", "<|im_end|>", " World" }, tokens);
|
||||||
Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 10), (27, 6) }, result.Offsets);
|
Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 10), (27, 6) }, offsets);
|
||||||
}
|
}
|
||||||
|
|
||||||
[Fact]
|
[Fact]
|
||||||
|
@ -207,12 +220,10 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
IReadOnlyList<int> encoded = GPT4.EncodeToIds(text);
|
IReadOnlyList<int> encoded = GPT4.EncodeToIds(text);
|
||||||
Assert.Empty(encoded);
|
Assert.Empty(encoded);
|
||||||
|
|
||||||
EncodingResult result = GPT4.Encode(text);
|
IReadOnlyList<Token> result = GPT4.Encode(text, out string? normalizedString);
|
||||||
int idsCount = GPT4.CountTokens(text);
|
int idsCount = GPT4.CountTokens(text);
|
||||||
Assert.Empty(result.Ids);
|
Assert.Empty(result);
|
||||||
Assert.Empty(result.Tokens);
|
Assert.Equal(0, idsCount);
|
||||||
Assert.Empty(result.Offsets);
|
|
||||||
Assert.Equal(result.Ids.Count, idsCount);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
[Fact]
|
[Fact]
|
||||||
|
@ -224,11 +235,11 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
Assert.Equal(new List<int>() { 100264, 9906, 2928, 99834, 4435, 100265 }, encoded);
|
Assert.Equal(new List<int>() { 100264, 9906, 2928, 99834, 4435, 100265 }, encoded);
|
||||||
Assert.Equal(text, GPT4.Decode(encoded.ToArray()));
|
Assert.Equal(text, GPT4.Decode(encoded.ToArray()));
|
||||||
|
|
||||||
EncodingResult result = GPT4.Encode(text);
|
IReadOnlyList<Token> result = GPT4.Encode(text, out string? normalizedString);
|
||||||
Assert.Equal(encoded, result.Ids);
|
Assert.Equal(encoded, result.Select(token => token.Id).ToArray());
|
||||||
Assert.Equal(encoded.Count, idsCount);
|
Assert.Equal(encoded.Count, idsCount);
|
||||||
Assert.Equal(new string[] { "<|im_start|>", "Hello", " ⭐", "", " World", "<|im_end|>" }, result.Tokens);
|
Assert.Equal(new string[] { "<|im_start|>", "Hello", " ⭐", "", " World", "<|im_end|>" }, result.Select(token => token.Value).ToArray());
|
||||||
Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 2), (19, 0), (19, 6), (25, 10) }, result.Offsets);
|
Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 2), (19, 0), (19, 6), (25, 10) }, result.Select(token => token.Offset).ToArray());
|
||||||
}
|
}
|
||||||
|
|
||||||
[Fact]
|
[Fact]
|
||||||
|
@ -310,9 +321,12 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
[Theory]
|
[Theory]
|
||||||
[InlineData("gpt-4")]
|
[InlineData("gpt-4")]
|
||||||
[InlineData("gpt-4-")]
|
[InlineData("gpt-4-")]
|
||||||
|
[InlineData("gpt-3.5-")]
|
||||||
[InlineData("gpt-3.5-turbo")]
|
[InlineData("gpt-3.5-turbo")]
|
||||||
[InlineData("gpt-3.5-turbo-")]
|
[InlineData("gpt-3.5-turbo-")]
|
||||||
[InlineData("gpt-3.5-turbo-16k")]
|
[InlineData("gpt-3.5-turbo-16k")]
|
||||||
|
[InlineData("gpt-35")]
|
||||||
|
[InlineData("gpt-35-")]
|
||||||
[InlineData("gpt-35-turbo")]
|
[InlineData("gpt-35-turbo")]
|
||||||
[InlineData("gpt-35-turbo-16k")]
|
[InlineData("gpt-35-turbo-16k")]
|
||||||
[InlineData("gpt-35-turbo-")]
|
[InlineData("gpt-35-turbo-")]
|
||||||
|
@ -351,7 +365,7 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
public void TestAllSupportedModelNames(string modelName)
|
public void TestAllSupportedModelNames(string modelName)
|
||||||
{
|
{
|
||||||
Tokenizer tokenizer = Tokenizer.CreateTiktokenForModel(modelName);
|
Tokenizer tokenizer = Tokenizer.CreateTiktokenForModel(modelName);
|
||||||
Assert.NotNull(tokenizer.Model);
|
Assert.True(tokenizer is Tiktoken);
|
||||||
Assert.NotNull(tokenizer.PreTokenizer);
|
Assert.NotNull(tokenizer.PreTokenizer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -363,7 +377,7 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
public void TestAllSupportedEncodingNames(string encodingName)
|
public void TestAllSupportedEncodingNames(string encodingName)
|
||||||
{
|
{
|
||||||
Tokenizer tokenizer = Tokenizer.CreateTiktokenForEncoding(encodingName);
|
Tokenizer tokenizer = Tokenizer.CreateTiktokenForEncoding(encodingName);
|
||||||
Assert.NotNull(tokenizer.Model);
|
Assert.True(tokenizer is Tiktoken);
|
||||||
Assert.NotNull(tokenizer.PreTokenizer);
|
Assert.NotNull(tokenizer.PreTokenizer);
|
||||||
|
|
||||||
string modelName = encodingName.ToLowerInvariant() switch
|
string modelName = encodingName.ToLowerInvariant() switch
|
||||||
|
@ -377,15 +391,16 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
|
|
||||||
Tokenizer tokenizer1 = Tokenizer.CreateTiktokenForModel(modelName);
|
Tokenizer tokenizer1 = Tokenizer.CreateTiktokenForModel(modelName);
|
||||||
|
|
||||||
Tiktoken? model1 = tokenizer.Model as Tiktoken;
|
Assert.True(tokenizer is Tiktoken);
|
||||||
Tiktoken? model2 = tokenizer1.Model as Tiktoken;
|
Assert.True(tokenizer1 is Tiktoken);
|
||||||
Assert.NotNull(model1);
|
|
||||||
Assert.NotNull(model2);
|
|
||||||
|
|
||||||
Assert.Equal(model2.Encoder, model1.Encoder);
|
Tiktoken tiktoken = (tokenizer as Tiktoken)!;
|
||||||
Assert.Equal(model2.Decoder, model1.Decoder);
|
Tiktoken tiktoken1 = (tokenizer1 as Tiktoken)!;
|
||||||
Assert.Equal(model2.SpecialTokens, model1.SpecialTokens);
|
|
||||||
Assert.Equal(model2.Vocab, model1.Vocab);
|
Assert.Equal(tiktoken1.Encoder, tiktoken.Encoder);
|
||||||
|
Assert.Equal(tiktoken1.Decoder, tiktoken.Decoder);
|
||||||
|
Assert.Equal(tiktoken1.SpecialTokens, tiktoken.SpecialTokens);
|
||||||
|
Assert.Equal(tiktoken1.Vocab, tiktoken.Vocab);
|
||||||
}
|
}
|
||||||
|
|
||||||
[Fact]
|
[Fact]
|
||||||
|
@ -408,11 +423,99 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
RemoteExecutor.Invoke(static (name) =>
|
RemoteExecutor.Invoke(static (name) =>
|
||||||
{
|
{
|
||||||
Tokenizer tokenizer = Tokenizer.CreateTiktokenForModel(name);
|
Tokenizer tokenizer = Tokenizer.CreateTiktokenForModel(name);
|
||||||
Assert.NotNull(tokenizer.Model);
|
Assert.True(tokenizer is Tiktoken);
|
||||||
Assert.NotNull(tokenizer.PreTokenizer);
|
Assert.NotNull(tokenizer.PreTokenizer);
|
||||||
}, modelName).Dispose();
|
}, modelName).Dispose();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static IEnumerable<object?[]> TokenizerTestData
|
||||||
|
{
|
||||||
|
get
|
||||||
|
{
|
||||||
|
// string to tokenize, produced tokens, the token offsets
|
||||||
|
yield return new object?[]
|
||||||
|
{
|
||||||
|
"the brown fox jumped over the lazy dog!",
|
||||||
|
new string[] { "the", " brown", " fox", " jumped", " over", " the", " lazy", " dog", "!" },
|
||||||
|
new (int Index, int Length)[] { (0, 3), (3, 6), (9, 4), (13, 7), (20, 5), (25, 4), (29, 5), (34, 4), (38, 1) },
|
||||||
|
new int[] { 1820, 14198, 39935, 27096, 927, 279, 16053, 5679, 0 }
|
||||||
|
};
|
||||||
|
yield return new object?[]
|
||||||
|
{
|
||||||
|
"he traveled to Egypt during the summer, the weather was hot and ammunition." ,
|
||||||
|
new string[] { "he", " traveled", " to", " Egypt", " during", " the", " summer", ",", " the", " weather", " was", " hot", " and", " ammunition", "." },
|
||||||
|
new (int Index, int Length)[] { (0, 2), (2, 9), (11, 3), (14, 6), (20, 7), (27, 4), (31, 7), (38, 1), (39, 4), (43, 8), (51, 4), (55, 4), (59, 4), (63, 11), (74, 1) },
|
||||||
|
new int[] { 383, 31796, 311, 15212, 2391, 279, 7474, 11, 279, 9282, 574, 4106, 323, 37768, 13 }
|
||||||
|
};
|
||||||
|
yield return new object?[]
|
||||||
|
{
|
||||||
|
"She played many games and she felt exhausted afterward",
|
||||||
|
new string[] { "She", " played", " many", " games", " and", " she", " felt", " exhausted", " afterward" },
|
||||||
|
new (int Index, int Length)[] { (0, 3), (3, 7), (10, 5), (15, 6), (21, 4), (25, 4), (29, 5), (34, 10), (44, 10) },
|
||||||
|
new int[] { 8100, 6476, 1690, 3953, 323, 1364, 6612, 39019, 49043 }
|
||||||
|
};
|
||||||
|
yield return new object?[]
|
||||||
|
{
|
||||||
|
"Hello, y'all! How are you 😁 ?",
|
||||||
|
new string[] { "Hello", ",", " y", "'all", "!", " How", " are", " you", " 😁", "", " ?" },
|
||||||
|
new (int Index, int Length)[] { (0, 5), (5, 1), (6, 2), (8, 4), (12, 1), (13, 4), (17, 4), (21, 4), (25, 3), (28, 0), (28, 2) },
|
||||||
|
new int[] { 9906, 11, 379, 65948, 0, 2650, 527, 499, 27623, 223, 949 }
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Theory]
|
||||||
|
[MemberData(nameof(TokenizerTestData))]
|
||||||
|
public void TestTokenizerEncoding(string text, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds)
|
||||||
|
{
|
||||||
|
Tokenizer tokenizer = GPT4;
|
||||||
|
|
||||||
|
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
|
||||||
|
IReadOnlyList<Token> encoding1 = tokenizer.Encode(text.AsSpan(), out _);
|
||||||
|
|
||||||
|
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
|
||||||
|
Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray());
|
||||||
|
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
|
||||||
|
|
||||||
|
Assert.Equal(expectedTokens, encoding1.Select(t => t.Value).ToArray());
|
||||||
|
Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray());
|
||||||
|
Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray());
|
||||||
|
|
||||||
|
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text));
|
||||||
|
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan()));
|
||||||
|
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text, expectedIds.Length, out string? normalizedString, out int length));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(text.Length, length);
|
||||||
|
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length, out normalizedString, out length));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(text.Length, length);
|
||||||
|
|
||||||
|
Assert.Equal(expectedIds.Take(expectedIds.Length - 4), tokenizer.EncodeToIds(text, expectedIds.Length - 4, out normalizedString, out length));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
int expectedLength = expectedOffsets[expectedOffsets.Length - 5].Index + expectedOffsets[expectedOffsets.Length - 5].Length;
|
||||||
|
Assert.Equal(expectedLength, length);
|
||||||
|
Assert.Equal(expectedIds.Take(expectedIds.Length - 4), tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length - 4, out normalizedString, out length));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(expectedLength, length);
|
||||||
|
|
||||||
|
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text));
|
||||||
|
Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text.AsSpan()));
|
||||||
|
|
||||||
|
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text, expectedIds.Length - 3, out normalizedString, out int tokenCount));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(expectedIds.Length - 3, tokenCount);
|
||||||
|
Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text.AsSpan(), expectedIds.Length - 3, out normalizedString, out tokenCount));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(expectedIds.Length - 3, tokenCount);
|
||||||
|
|
||||||
|
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text, 3, out normalizedString, out tokenCount));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(3, tokenCount);
|
||||||
|
Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text.AsSpan(), 3, out normalizedString, out tokenCount));
|
||||||
|
Assert.Null(normalizedString);
|
||||||
|
Assert.Equal(3, tokenCount);
|
||||||
|
}
|
||||||
|
|
||||||
// Test running copy the test data files to the output folder but sometimes the file content is mutated replacing '\n' with '\r\n'.
|
// Test running copy the test data files to the output folder but sometimes the file content is mutated replacing '\n' with '\r\n'.
|
||||||
// This method reads the file and removes the extra inserted '\r' characters. Having '\r' in the file content will cause the tests to fail.
|
// This method reads the file and removes the extra inserted '\r' characters. Having '\r' in the file content will cause the tests to fail.
|
||||||
private string ReadAndSanitizeFile(string path)
|
private string ReadAndSanitizeFile(string path)
|
||||||
|
|
|
@ -26,12 +26,12 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
|
|
||||||
for (int i = 1; i <= fullIdsList.Count; i++)
|
for (int i = 1; i <= fullIdsList.Count; i++)
|
||||||
{
|
{
|
||||||
int index1 = tokenizer.IndexOfTokenCount(input, maxTokenCount: i, out string processedText1, out int tokenCount1);
|
int index1 = tokenizer.IndexOfTokenCount(input, maxTokenCount: i, out string? processedText1, out int tokenCount1);
|
||||||
int index2 = tokenizer.LastIndexOfTokenCount(input, maxTokenCount: i, out string processedText2, out int tokenCount2);
|
int index2 = tokenizer.LastIndexOfTokenCount(input, maxTokenCount: i, out string? processedText2, out int tokenCount2);
|
||||||
IReadOnlyList<int> partialIdsList = tokenizer.EncodeToIds(input, maxTokenCount: i, out string processedText, out int textLength);
|
IReadOnlyList<int> partialIdsList = tokenizer.EncodeToIds(input, maxTokenCount: i, out string? processedText, out int textLength);
|
||||||
|
|
||||||
Assert.True(textLength <= processedText.Length);
|
Assert.True(processedText is null || textLength <= processedText.Length);
|
||||||
Assert.True(tokenizer.Normalizer is not null || processedText == input);
|
Assert.True(tokenizer.Normalizer is not null || processedText is null);
|
||||||
|
|
||||||
Assert.Equal(fullIdsList.Take(partialIdsList.Count), partialIdsList);
|
Assert.Equal(fullIdsList.Take(partialIdsList.Count), partialIdsList);
|
||||||
|
|
||||||
|
@ -42,14 +42,13 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
// In this case, we'll get index1 equal to zero and nothing really will need to be tested.
|
// In this case, we'll get index1 equal to zero and nothing really will need to be tested.
|
||||||
if (tokenCount1 > 0 && index1 > 0)
|
if (tokenCount1 > 0 && index1 > 0)
|
||||||
{
|
{
|
||||||
string prefixString = processedText1.Substring(0, index1);
|
string prefixString = (processedText1 ?? input).Substring(0, index1);
|
||||||
|
|
||||||
if (tokenizer.Model is SentencePieceBpe)
|
if (tokenizer is SentencePieceBpe)
|
||||||
{
|
{
|
||||||
// SentencePieceBpe model normalize the text and insert more characters.
|
// SentencePieceBpe model normalize the text and insert more characters.
|
||||||
// We call the model directly to bypass the normalization step
|
// We call the model directly to bypass the normalization step
|
||||||
prefixIds = new List<int>();
|
prefixIds = tokenizer.EncodeToIds(prefixString.AsSpan(), considerNormalization: false);
|
||||||
tokenizer.Model.EncodeToIds(prefixString.AsSpan(), (prefixIds as IList<int>)!, out _);
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
@ -61,14 +60,13 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
|
|
||||||
if (tokenCount2 > 0)
|
if (tokenCount2 > 0)
|
||||||
{
|
{
|
||||||
string suffixString = processedText2.Substring(index2);
|
string suffixString = (processedText2 ?? input).Substring(index2);
|
||||||
|
|
||||||
if (tokenizer.Model is SentencePieceBpe)
|
if (tokenizer is SentencePieceBpe)
|
||||||
{
|
{
|
||||||
// SentencePieceBpe model normalize the text and insert more characters.
|
// SentencePieceBpe model normalize the text and insert more characters.
|
||||||
// We call the model directly to bypass the normalization step
|
// We call the model directly to bypass the normalization step
|
||||||
suffixIds = new List<int>();
|
suffixIds = tokenizer.EncodeToIds(suffixString.AsSpan(), considerNormalization: false);
|
||||||
tokenizer.Model.EncodeToIds(suffixString.AsSpan(), (suffixIds as IList<int>)!, out _);
|
|
||||||
if (i < fullIdsList.Count)
|
if (i < fullIdsList.Count)
|
||||||
{
|
{
|
||||||
suffixIds = suffixIds.Skip(1).ToList(); // Skip the start of sentence token <s>
|
suffixIds = suffixIds.Skip(1).ToList(); // Skip the start of sentence token <s>
|
||||||
|
@ -85,12 +83,13 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
|
|
||||||
if (i == fullIdsList.Count)
|
if (i == fullIdsList.Count)
|
||||||
{
|
{
|
||||||
if (index1 != processedText1.Length)
|
string s = processedText1 ?? input;
|
||||||
|
if (index1 != s.Length)
|
||||||
{
|
{
|
||||||
// It's possible that the remaining text on the left doesn't produce any tokens, as in the case of BPE,
|
// It's possible that the remaining text on the left doesn't produce any tokens, as in the case of BPE,
|
||||||
// where the pre-tokenizer removes spaces and the left text consists entirely of spaces.
|
// where the pre-tokenizer removes spaces and the left text consists entirely of spaces.
|
||||||
Assert.True(index1 < processedText1.Length);
|
Assert.True(index1 < s.Length);
|
||||||
Assert.Equal(0, tokenizer.CountTokens(processedText1.Substring(index1)));
|
Assert.Equal(0, tokenizer.CountTokens(s.Substring(index1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (index2 != 0)
|
if (index2 != 0)
|
||||||
|
@ -98,7 +97,7 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
// It's possible that the remaining text on the right doesn't produce any tokens, as in the case of BPE,
|
// It's possible that the remaining text on the right doesn't produce any tokens, as in the case of BPE,
|
||||||
// where the pre-tokenizer removes spaces and the left text consists entirely of spaces.
|
// where the pre-tokenizer removes spaces and the left text consists entirely of spaces.
|
||||||
Assert.True(index2 > 0);
|
Assert.True(index2 > 0);
|
||||||
Assert.Equal(0, tokenizer.CountTokens(processedText1.Substring(0, index2)));
|
Assert.Equal(0, tokenizer.CountTokens(s.Substring(0, index2)));
|
||||||
}
|
}
|
||||||
|
|
||||||
Assert.Equal(fullIdsList, prefixIds);
|
Assert.Equal(fullIdsList, prefixIds);
|
||||||
|
@ -106,13 +105,15 @@ namespace Microsoft.ML.Tokenizers.Tests
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Assert.Equal(0, tokenizer.IndexOfTokenCount((string)null!, maxTokenCount: 10, out _, out _));
|
||||||
|
Assert.Equal(0, tokenizer.LastIndexOfTokenCount((string)null!, maxTokenCount: 10, out _, out _));
|
||||||
|
Assert.Equal(0, tokenizer.IndexOfTokenCount(Span<char>.Empty, maxTokenCount: 10, out _, out _));
|
||||||
|
Assert.Equal(0, tokenizer.LastIndexOfTokenCount(Span<char>.Empty, maxTokenCount: 10, out _, out _));
|
||||||
|
|
||||||
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.IndexOfTokenCount(input, maxTokenCount: 0, out _, out _));
|
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.IndexOfTokenCount(input, maxTokenCount: 0, out _, out _));
|
||||||
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.IndexOfTokenCount(input, maxTokenCount: -1, out _, out _));
|
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.IndexOfTokenCount(input, maxTokenCount: -1, out _, out _));
|
||||||
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.LastIndexOfTokenCount(input, maxTokenCount: 0, out _, out _));
|
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.LastIndexOfTokenCount(input, maxTokenCount: 0, out _, out _));
|
||||||
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.LastIndexOfTokenCount(input, maxTokenCount: -1, out _, out _));
|
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.LastIndexOfTokenCount(input, maxTokenCount: -1, out _, out _));
|
||||||
|
|
||||||
Assert.Throws<ArgumentNullException>(() => tokenizer.IndexOfTokenCount(null!, maxTokenCount: 10, out _, out _));
|
|
||||||
Assert.Throws<ArgumentNullException>(() => tokenizer.LastIndexOfTokenCount(null!, maxTokenCount: 10, out _, out _));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
Загрузка…
Ссылка в новой задаче