* Introducing CodeGen Tokenizer

* Mark a method as private. It was not intended to be public

* Init vocab atomically.

* Prevent returning tokens that are only partially mapped to a code point.

* Ensure Tiktoken precise token's count with IndexOf & LastIndexOf. Ensure accurate offsets too.

* Address the feedback
This commit is contained in:
Tarek Mahmoud Sayed 2024-05-02 11:58:56 -07:00 коммит произвёл GitHub
Родитель 72cfdf611a
Коммит e9097ce6d6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
18 изменённых файлов: 3551 добавлений и 322 удалений

Просмотреть файл

@ -133,6 +133,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.!
License notice for CodeGen Tokenizer
--------------------------------------------
https://github.com/huggingface/transformers/blob/8c12690cecbb97e187861e386f7a0ac790e4236c/src/transformers/models/codegen/tokenization_codegen.py#L2
Copyright 2022 The Salesforce authors, The Open AI Team Authors and The HuggingFace Inc. team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
License notice for BitUtility
------------------------------------------

Просмотреть файл

@ -36,6 +36,7 @@
<ApacheArrowVersion>11.0.0</ApacheArrowVersion>
<GoogleProtobufVersion>3.19.6</GoogleProtobufVersion>
<LightGBMVersion>3.3.5</LightGBMVersion>
<MicrosoftBclHashCodeVersion>1.1.1</MicrosoftBclHashCodeVersion>
<MicrosoftCodeAnalysisAnalyzersVersion>3.3.0</MicrosoftCodeAnalysisAnalyzersVersion>
<MicrosoftCodeAnalysisCSharpVersion>3.9.0</MicrosoftCodeAnalysisCSharpVersion>
<MicrosoftDotNetInteractiveVersion>1.0.0-beta.23509.3</MicrosoftDotNetInteractiveVersion>
@ -87,7 +88,7 @@
<MicrosoftMLTensorFlowTestModelsVersion>0.0.13-test</MicrosoftMLTensorFlowTestModelsVersion>
<MicrosoftMLTestDatabasesVersion>0.0.6-test</MicrosoftMLTestDatabasesVersion>
<MicrosoftMLTestModelsVersion>0.0.7-test</MicrosoftMLTestModelsVersion>
<MicrosoftMLTestTokenizersVersion>2.0.0-beta.24218.2</MicrosoftMLTestTokenizersVersion>
<MicrosoftMLTestTokenizersVersion>2.0.0-beta.24219.1</MicrosoftMLTestTokenizersVersion>
<SystemDataSqlClientVersion>4.8.6</SystemDataSqlClientVersion>
<SystemDataSQLiteCoreVersion>1.0.118</SystemDataSQLiteCoreVersion>
<XunitCombinatorialVersion>1.2.7</XunitCombinatorialVersion>

Просмотреть файл

@ -21,6 +21,10 @@
<PackageReference Include="System.Text.Json" Version="$(SystemTextJsonVersion)" />
</ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
<PackageReference Include="Microsoft.Bcl.HashCode" Version="$(MicrosoftBclHashCodeVersion)" />
</ItemGroup>
<UsingTask TaskName="CompressFile"
TaskFactory="RoslynCodeTaskFactory"
AssemblyFile="$(MSBuildToolsPath)\Microsoft.Build.Tasks.Core.dll" >

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -10,6 +10,7 @@ using System.IO;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Threading;
namespace Microsoft.ML.Tokenizers
{
@ -23,9 +24,6 @@ namespace Microsoft.ML.Tokenizers
private Dictionary<string, int>? _vocabOriginal;
private readonly SortedDictionary<int, StringSpanOrdinalKey> _vocabReverse;
private readonly Cache<(string, string), int> _mergeRanks;
private readonly IReadOnlyDictionary<char, char> _byteToUnicode;
private readonly IReadOnlyDictionary<char, char> _unicodeToByte;
private readonly string[] _charToString;
private readonly StringSpanOrdinalKeyCache<List<Token>> _cache;
private readonly PreTokenizer? _preTokenizer;
private readonly Normalizer? _normalizer;
@ -95,14 +93,6 @@ namespace Microsoft.ML.Tokenizers
_vocab = GetVocabulary(vocabularyStream);
_vocabReverse = _vocab.ReverseSorted();
_mergeRanks = GetMergeRanks(mergeStream);
int maxCharValue = GetByteToUnicode(out _byteToUnicode);
_charToString = new string[maxCharValue];
for (char c = (char)0; c < (char)maxCharValue; c++)
{
_charToString[c] = c.ToString();
}
_unicodeToByte = _byteToUnicode.Reverse();
_cache = new StringSpanOrdinalKeyCache<List<Token>>();
if (disposeStream)
@ -113,6 +103,80 @@ namespace Microsoft.ML.Tokenizers
}
}
private static Dictionary<StringSpanOrdinalKey, int> GetVocabulary(Stream vocabularyStream)
{
Dictionary<StringSpanOrdinalKey, int>? vocab;
try
{
JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } };
vocab = JsonSerializer.Deserialize<Dictionary<StringSpanOrdinalKey, int>>(vocabularyStream, options) as Dictionary<StringSpanOrdinalKey, int>;
}
catch (Exception e)
{
throw new ArgumentException($"Problems met when parsing JSON vocabulary object.{Environment.NewLine}Error message: {e.Message}");
}
if (vocab is null)
{
throw new ArgumentException($"Failed to read the vocabulary file.");
}
return vocab;
}
private static Cache<(string, string), int> GetMergeRanks(Stream mergeStream)
{
var mergeRanks = new Cache<(string, string), int>(60_000);
try
{
using StreamReader reader = new StreamReader(mergeStream);
// We ignore the first and last line in the file
if (reader.Peek() >= 0)
{
string ignored = reader.ReadLine()!;
}
int rank = 1;
while (reader.Peek() >= 0)
{
string line = reader.ReadLine()!;
int index = line.IndexOf(' ');
if (index < 1 || index == line.Length - 1 || line.IndexOf(' ', index + 1) != -1)
{
throw new FormatException($"Invalid format of merge file at line: \"{line}\"");
}
mergeRanks.Set((line.Substring(0, index), line.Substring(index + 1)), rank++);
}
}
catch (Exception e)
{
// Report any issues encountered while consuming a data file as IOExceptions.
throw new IOException($"Cannot read the file Merge file.{Environment.NewLine}Error message: {e.Message}", e);
}
return mergeRanks;
}
private Dictionary<string, int> GetVocab()
{
Dictionary<string, int>? publicVocab = Volatile.Read(ref _vocabOriginal);
if (publicVocab is null)
{
var vocab = new Dictionary<string, int>();
foreach (var item in _vocab)
{
vocab.Add(item.Key.ToString(), item.Value);
}
Interlocked.CompareExchange(ref _vocabOriginal, vocab, null);
publicVocab = _vocabOriginal;
}
return publicVocab;
}
/// <summary>
/// Gets the PreTokenizer used by the Tokenizer.
/// </summary>
@ -126,7 +190,7 @@ namespace Microsoft.ML.Tokenizers
/// <summary>
/// Gets the dictionary mapping tokens to Ids.
/// </summary>
public IReadOnlyDictionary<string, int> Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
public IReadOnlyDictionary<string, int> Vocab => GetVocab();
//
// Public Model interfaces implementation
@ -147,9 +211,10 @@ namespace Microsoft.ML.Tokenizers
char[] buffer = ArrayPool<char>.Shared.Rent(v.Length);
int i = 0;
IReadOnlyDictionary<char, char> unicodeToByte = ByteToUnicodeEncoding.Instance.UnicodeToByte;
for (int j = 0; j < v.Length; j++)
{
if (_unicodeToByte.TryGetValue(v[j], out var c))
if (unicodeToByte.TryGetValue(v[j], out var c))
{
buffer[i++] = c;
}
@ -232,9 +297,11 @@ namespace Microsoft.ML.Tokenizers
int[] indexMapping = ArrayPool<int>.Shared.Rent(text.Length);
int newTokenIndex = 0;
IReadOnlyDictionary<char, char> byteToUnicode = ByteToUnicodeEncoding.Instance.ByteToUnicode;
for (int i = 0; i < text.Length; i++)
{
if (_byteToUnicode.TryGetValue(text[i], out var value))
if (byteToUnicode.TryGetValue(text[i], out var value))
{
token[newTokenIndex] = value;
indexMapping[newTokenIndex] = i;
@ -607,9 +674,11 @@ namespace Microsoft.ML.Tokenizers
int[] indexMapping = ArrayPool<int>.Shared.Rent(text.Length);
int newTokenIndex = 0;
IReadOnlyDictionary<char, char> byteToUnicode = ByteToUnicodeEncoding.Instance.ByteToUnicode;
for (int i = 0; i < text.Length; i++)
{
if (_byteToUnicode.TryGetValue(text[i], out var value))
if (byteToUnicode.TryGetValue(text[i], out var value))
{
token[newTokenIndex] = value;
indexMapping[newTokenIndex] = i;
@ -650,9 +719,11 @@ namespace Microsoft.ML.Tokenizers
int[] indexMapping = ArrayPool<int>.Shared.Rent(text.Length);
int newTokenIndex = 0;
IReadOnlyDictionary<char, char> byteToUnicode = ByteToUnicodeEncoding.Instance.ByteToUnicode;
for (int i = 0; i < text.Length; i++)
{
if (_byteToUnicode.TryGetValue(text[i], out var value))
if (byteToUnicode.TryGetValue(text[i], out var value))
{
token[newTokenIndex] = value;
indexMapping[newTokenIndex] = i;
@ -829,107 +900,6 @@ namespace Microsoft.ML.Tokenizers
private static HighestOccurrenceMapping GetHighestOccurrenceMapping(Stream highestOccurrenceMappingStream) =>
HighestOccurrenceMapping.Load(highestOccurrenceMappingStream);
private Dictionary<StringSpanOrdinalKey, int> GetVocabulary(Stream vocabularyStream)
{
Dictionary<StringSpanOrdinalKey, int>? vocab;
try
{
JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } };
vocab = JsonSerializer.Deserialize<Dictionary<StringSpanOrdinalKey, int>>(vocabularyStream, options) as Dictionary<StringSpanOrdinalKey, int>;
}
catch (Exception e)
{
throw new ArgumentException($"Problems met when parsing JSON vocabulary object.{Environment.NewLine}Error message: {e.Message}");
}
if (vocab is null)
{
throw new ArgumentException($"Failed to read the vocabulary file.");
}
if (_vocabIdToHighestOccurrence.BosWord is not null)
{
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.BosWord)] = -_vocabIdToHighestOccurrence.BosIndex;
}
if (_vocabIdToHighestOccurrence.EosWord is not null)
{
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.EosWord)] = -_vocabIdToHighestOccurrence.EosIndex;
}
if (_vocabIdToHighestOccurrence.UnkWord is not null)
{
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.UnkWord)] = -_vocabIdToHighestOccurrence.UnkIndex;
}
if (_vocabIdToHighestOccurrence.PadWord is not null)
{
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.PadWord)] = -_vocabIdToHighestOccurrence.PadIndex;
}
return vocab;
}
private Cache<(string, string), int> GetMergeRanks(Stream mergeStream)
{
var mergeRanks = new Cache<(string, string), int>(60_000);
try
{
using StreamReader reader = new StreamReader(mergeStream);
// We ignore the first and last line in the file
if (reader.Peek() >= 0)
{
string ignored = reader.ReadLine()!;
}
int rank = 1;
while (reader.Peek() >= 0)
{
string line = reader.ReadLine()!;
int index = line.IndexOf(' ');
if (index < 1 || index == line.Length - 1 || line.IndexOf(' ', index + 1) != -1)
{
throw new Exception($"Invalid format of merge file: \"{line}\"");
}
mergeRanks.Set((line.Substring(0, index), line.Substring(index + 1)), rank++);
}
}
catch (Exception e)
{
throw new IOException($"Cannot read the file Merge file.{Environment.NewLine}Error message: {e.Message}", e);
}
return mergeRanks;
}
/// <summary>
/// Returns list of utf-8 bytes and a corresponding list of unicode chars.
/// This mapping is to make unseen characters (such as control characters) displayable.
/// </summary>
private static int GetByteToUnicode(out IReadOnlyDictionary<char, char> byteToUnicode)
{
var byteToUnicodeMapping = Enumerable.Range('!', '~' - '!' + 1)
.Concat(Enumerable.Range('¡', '¬' - '¡' + 1))
.Concat(Enumerable.Range('®', 'ÿ' - '®' + 1))
.ToDictionary(b => (char)b, b => (char)b);
const int numChars = 256;
var n = 0;
foreach (var b in Enumerable.Range(0, numChars))
{
if (!byteToUnicodeMapping.ContainsKey((char)b))
{
byteToUnicodeMapping.Add((char)b, (char)(numChars + n));
++n;
}
}
byteToUnicode = byteToUnicodeMapping;
return numChars + n;
}
/// <summary>
/// Encode a token into BPE-ed sub-tokens. E.g., "playing" into ["play", "ing"].
/// </summary>
@ -940,17 +910,20 @@ namespace Microsoft.ML.Tokenizers
return [];
}
string[] charToString = ByteToUnicodeEncoding.Instance.CharToString;
if (token.Length == 1)
{
string tokenValue = _charToString[token[0]];
Debug.Assert(token[0] < charToString.Length);
string tokenValue = charToString[token[0]];
return new List<Token> { new Token(_vocab[new StringSpanOrdinalKey(tokenValue)], tokenValue, (indexMapping[0], 1)) };
}
List<string> word = new(token.Length);
foreach (char c in token)
{
Debug.Assert(c < _charToString.Length);
word.Add(_charToString[c]);
Debug.Assert(c < charToString.Length);
word.Add(charToString[c]);
}
HashSet<(string, string)> pairs = new();
@ -1065,10 +1038,7 @@ namespace Microsoft.ML.Tokenizers
/// </summary>
/// <param name="ch">The character to check.</param>
/// <returns>True if the character is supported, otherwise false.</returns>
public bool IsSupportedChar(char ch)
{
return _byteToUnicode.ContainsKey(ch);
}
public bool IsSupportedChar(char ch) => ByteToUnicodeEncoding.Instance.ByteToUnicode.ContainsKey(ch);
}
/// <summary>

Просмотреть файл

@ -25,7 +25,7 @@ namespace Microsoft.ML.Tokenizers
{
private readonly Dictionary<ReadOnlyMemory<byte>, int> _encoder;
private readonly Dictionary<int, ReadOnlyMemory<byte>> _decoder;
private readonly LruCache<(int[] Bytes, string Token)> _cache;
private readonly LruCache<(int Id, int TokenIndex, int TokenLength)[]> _cache;
private readonly Dictionary<StringSpanOrdinalKey, (int Id, string Token)> _vocab;
private IReadOnlyDictionary<string, int>? _vocabOriginal;
private const int MaxWordLengthToCache = 15;
@ -92,7 +92,7 @@ namespace Microsoft.ML.Tokenizers
_preTokenizer = preTokenizer;
_normalizer = normalizer;
_cache = new LruCache<(int[] Bytes, string Token)>(cacheSize);
_cache = new LruCache<(int Id, int TokenIndex, int TokenLength)[]>(cacheSize);
SpecialTokens = specialTokens;
CacheSpecialTokensEncoding(specialTokens);
@ -102,7 +102,7 @@ namespace Microsoft.ML.Tokenizers
{
try
{
_cache = new LruCache<(int[] Bytes, string Token)>(cacheSize);
_cache = new LruCache<(int Id, int TokenIndex, int TokenLength)[]>(cacheSize);
(_encoder, _vocab, _decoder) = LoadTiktokenBpeAsync(vocabStream, useAsync: false).GetAwaiter().GetResult();
_preTokenizer = preTokenizer;
@ -140,7 +140,7 @@ namespace Microsoft.ML.Tokenizers
foreach (KeyValuePair<string, int> specialToken in specialTokens)
{
_decoder![specialToken.Value] = Encoding.UTF8.GetBytes(specialToken.Key);
_cache!.Add(specialToken.Key, (new[] { specialToken.Value }, specialToken.Key));
_cache!.Add(specialToken.Key, new[] { (Id: specialToken.Value, TokenIndex0: 0, TokenLength: specialToken.Key.Length) });
}
}
}
@ -290,13 +290,14 @@ namespace Microsoft.ML.Tokenizers
{
Debug.Assert(!text.IsEmpty);
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
if (_cache.TryGetValue(text, out (int Id, int TokenIndex, int TokenLength)[] value))
{
tokens.Add(new Token(value.Ids[0], value.Token, (offset, value.Token.Length)));
for (int i = 1; i < value.Ids.Length; i++)
for (int i = 0; i < value.Length; i++)
{
// One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width.
tokens.Add(new Token(value.Ids[i], "", (offset + text.Length, 0)));
tokens.Add(new Token(
value[i].Id,
value[i].TokenLength == 0 ? string.Empty : text.Slice(value[i].TokenIndex, value[i].TokenLength).ToString(),
(value[i].TokenIndex + offset, value[i].TokenLength)));
}
return;
@ -309,25 +310,35 @@ namespace Microsoft.ML.Tokenizers
return;
}
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
int encodedLength = Helpers.GetUtf8Bytes(text, arrayPoolArray);
int utf8Length = Encoding.UTF8.GetMaxByteCount(text.Length);
byte[] arrayPoolArray = arrayPoolArray = ArrayPool<byte>.Shared.Rent(utf8Length);
int[]? indexMappingArray = null;
Span<int> indexMappingSpan = utf8Length + 1 <= 128 ? stackalloc int[128] : (indexMappingArray = ArrayPool<int>.Shared.Rent(utf8Length + 1));
int encodedLength = Helpers.EncodeToUtf8(text, arrayPoolArray, indexMappingSpan);
Debug.Assert(encodedLength < indexMappingSpan.Length);
indexMappingSpan[encodedLength] = text.Length;
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
(int Id, int TokenIndex, int TokenLength)[] encodedTokens = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder, indexMappingSpan.Slice(0, encodedLength + 1));
ArrayPool<byte>.Shared.Return(arrayPoolArray);
if (indexMappingArray is not null)
{
ArrayPool<int>.Shared.Return(indexMappingArray);
}
Debug.Assert(encodedIds.Length > 0);
Debug.Assert(encodedTokens.Length > 0);
string textAsString = text.ToString();
if (text.Length <= MaxWordLengthToCache)
{
_cache.Add(textAsString, (encodedIds, textAsString));
_cache.Add(textAsString, encodedTokens);
}
tokens.Add(new Token(encodedIds[0], textAsString, (offset, text.Length)));
for (int i = 1; i < encodedIds.Length; i++)
for (int i = 0; i < encodedTokens.Length; i++)
{
// One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width.
tokens.Add(new Token(encodedIds[i], "", (offset + text.Length, 0)));
tokens.Add(new Token(
encodedTokens[i].Id,
encodedTokens[i].TokenLength == 0 ? string.Empty : text.Slice(encodedTokens[i].TokenIndex, encodedTokens[i].TokenLength).ToString(),
(encodedTokens[i].TokenIndex + offset, encodedTokens[i].TokenLength)));
}
}
@ -435,17 +446,9 @@ namespace Microsoft.ML.Tokenizers
return 0;
}
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
if (_cache.TryGetValue(text, out (int Id, int TokenIndex, int TokenLength)[] value))
{
if (value.Ids.Length <= maxTokenCount)
{
accumulatedIds.AddRange(value.Ids);
textLength = text.Length;
return value.Ids.Length;
}
textLength = 0;
return 0;
return EncodeToIdsResult(value, accumulatedIds, maxTokenCount, text.Length, out textLength);
}
if (_vocab.TryGetValue(text, out (int Id, string Token) mappedId))
@ -455,32 +458,85 @@ namespace Microsoft.ML.Tokenizers
return 1;
}
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
int encodedLength = Helpers.GetUtf8Bytes(text, arrayPoolArray);
int utf8Length = Encoding.UTF8.GetMaxByteCount(text.Length);
byte[] arrayPoolArray = arrayPoolArray = ArrayPool<byte>.Shared.Rent(utf8Length);
int[]? indexMappingArray = null;
Span<int> indexMappingSpan = utf8Length + 1 <= 128 ? stackalloc int[128] : (indexMappingArray = ArrayPool<int>.Shared.Rent(utf8Length + 1));
int encodedLength = Helpers.EncodeToUtf8(text, arrayPoolArray, indexMappingSpan);
Debug.Assert(encodedLength < indexMappingSpan.Length);
indexMappingSpan[encodedLength] = text.Length;
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
(int Id, int TokenIndex, int TokenLength)[] encodedTokens = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder, indexMappingSpan.Slice(0, encodedLength + 1));
ArrayPool<byte>.Shared.Return(arrayPoolArray);
if (indexMappingArray is not null)
{
ArrayPool<int>.Shared.Return(indexMappingArray);
}
if (text.Length <= MaxWordLengthToCache)
{
string textAsString = text.ToString();
_cache.Add(textAsString, (encodedIds, textAsString));
_cache.Add(textAsString, encodedTokens);
}
int result;
if (encodedIds.Length <= maxTokenCount)
return EncodeToIdsResult(encodedTokens, accumulatedIds, maxTokenCount, text.Length, out textLength);
}
private int EncodeToIdsResult((int Id, int TokenIndex, int TokenLength)[] tokens, IList<int>? accumulatedIds, int maxTokens, int fullTextLength, out int textLength)
{
textLength = 0;
if (tokens.Length <= maxTokens)
{
accumulatedIds.AddRange(encodedIds);
textLength = text.Length;
result = encodedIds.Length;
}
else
{
textLength = 0;
result = 0;
if (accumulatedIds is not null)
{
foreach (var t in tokens)
{
accumulatedIds.Add(t.Id);
}
}
textLength = fullTextLength;
return tokens.Length;
}
ArrayPool<byte>.Shared.Return(arrayPoolArray);
return result;
int tokenCount;
for (tokenCount = 0; tokenCount < maxTokens; tokenCount++)
{
int overlapIndex = tokens[tokenCount].TokenIndex + tokens[tokenCount].TokenLength;
// maxTokens is less than tokens.Count, so it is safe to index maxTokens.
if (tokens[tokenCount + 1].TokenIndex < overlapIndex)
{
// Ensure we'll not break the text in the middle of a code-point
int j = tokenCount + 2;
while (j < tokens.Length && tokens[j].TokenIndex < overlapIndex)
{
j++;
}
if (j <= maxTokens)
{
// append encountered tokens to the accumulatedIds
for (int k = tokenCount; k < j; k++)
{
accumulatedIds?.Add(tokens[k].Id);
}
tokenCount = j - 1;
textLength = tokens[tokenCount].TokenIndex + tokens[tokenCount].TokenLength;
}
else
{
break;
}
}
else
{
accumulatedIds?.Add(tokens[tokenCount].Id);
textLength = tokens[tokenCount].TokenIndex + tokens[tokenCount].TokenLength;
}
}
return tokenCount;
}
/// <summary>
@ -598,16 +654,9 @@ namespace Microsoft.ML.Tokenizers
return 0;
}
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
if (_cache.TryGetValue(text, out (int Id, int TokenIndex, int TokenLength)[] value))
{
if (value.Ids.Length <= maxTokens)
{
textLength = text.Length;
return value.Ids.Length;
}
textLength = 0;
return 0;
return EncodeToIdsResult(value, accumulatedIds: null, maxTokens, text.Length, out textLength);
}
if (_vocab.TryGetValue(text, out _))
@ -616,30 +665,28 @@ namespace Microsoft.ML.Tokenizers
return 1;
}
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
int encodedLength = Helpers.GetUtf8Bytes(text, arrayPoolArray);
int utf8Length = Encoding.UTF8.GetMaxByteCount(text.Length);
byte[] arrayPoolArray = arrayPoolArray = ArrayPool<byte>.Shared.Rent(utf8Length);
int[]? indexMappingArray = null;
Span<int> indexMappingSpan = utf8Length + 1 <= 128 ? stackalloc int[128] : (indexMappingArray = ArrayPool<int>.Shared.Rent(utf8Length + 1));
int encodedLength = Helpers.EncodeToUtf8(text, arrayPoolArray, indexMappingSpan);
Debug.Assert(encodedLength < indexMappingSpan.Length);
indexMappingSpan[encodedLength] = text.Length;
(int Id, int TokenIndex, int TokenLength)[] encodedTokens = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder, indexMappingSpan.Slice(0, encodedLength + 1));
ArrayPool<byte>.Shared.Return(arrayPoolArray);
if (indexMappingArray is not null)
{
ArrayPool<int>.Shared.Return(indexMappingArray);
}
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
if (text.Length <= MaxWordLengthToCache)
{
string textAsString = text.ToString();
_cache.Add(textAsString, (encodedIds, textAsString));
_cache.Add(textAsString, encodedTokens);
}
int result;
if (encodedIds.Length <= maxTokens)
{
textLength = text.Length;
result = encodedIds.Length;
}
else
{
textLength = 0;
result = 0;
}
ArrayPool<byte>.Shared.Return(arrayPoolArray);
return result;
return EncodeToIdsResult(encodedTokens, accumulatedIds: null, maxTokens, text.Length, out textLength);
}
/// <summary>
@ -729,16 +776,9 @@ namespace Microsoft.ML.Tokenizers
return 0;
}
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
if (_cache.TryGetValue(text, out (int Id, int TokenIndex, int TokenLength)[] value))
{
if (value.Ids.Length <= maxTokens)
{
textIndex = 0;
return value.Ids.Length;
}
textIndex = text.Length;
return 0;
return EncodeToIdsFromEndResult(value, accumulatedIds: null, maxTokens, text.Length, out textIndex);
}
if (_vocab.TryGetValue(text, out _))
@ -747,31 +787,63 @@ namespace Microsoft.ML.Tokenizers
return 1;
}
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
int encodedLength = Helpers.GetUtf8Bytes(text, arrayPoolArray);
int utf8Length = Encoding.UTF8.GetMaxByteCount(text.Length);
byte[] arrayPoolArray = arrayPoolArray = ArrayPool<byte>.Shared.Rent(utf8Length);
int[]? indexMappingArray = null;
Span<int> indexMappingSpan = utf8Length + 1 <= 128 ? stackalloc int[128] : (indexMappingArray = ArrayPool<int>.Shared.Rent(utf8Length + 1));
int encodedLength = Helpers.EncodeToUtf8(text, arrayPoolArray, indexMappingSpan);
Debug.Assert(encodedLength < indexMappingSpan.Length);
indexMappingSpan[encodedLength] = text.Length;
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
(int Id, int TokenIndex, int TokenLength)[] encodedTokens = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder, indexMappingSpan.Slice(0, encodedLength + 1));
ArrayPool<byte>.Shared.Return(arrayPoolArray);
if (indexMappingArray is not null)
{
ArrayPool<int>.Shared.Return(indexMappingArray);
}
if (text.Length <= MaxWordLengthToCache)
{
string textAsString = text.ToString();
_cache.Add(textAsString, (encodedIds, textAsString));
_cache.Add(textAsString, encodedTokens);
}
int result;
if (encodedIds.Length <= maxTokens)
return EncodeToIdsFromEndResult(encodedTokens, accumulatedIds: null, maxTokens, text.Length, out textIndex);
}
private int EncodeToIdsFromEndResult((int Id, int TokenIndex, int TokenLength)[] tokens, IList<int>? accumulatedIds, int maxTokens, int fullTextLength, out int textIndex)
{
textIndex = fullTextLength;
if (tokens.Length <= maxTokens)
{
if (accumulatedIds is not null)
{
foreach (var t in tokens)
{
accumulatedIds.Add(t.Id);
}
}
textIndex = 0;
result = encodedIds.Length;
}
else
{
textIndex = text.Length;
result = 0;
return tokens.Length;
}
ArrayPool<byte>.Shared.Return(arrayPoolArray);
return result;
int index = tokens.Length - maxTokens;
// avoid breaking the text in the middle of a code-point
while (index < tokens.Length && tokens[index].TokenIndex < tokens[index - 1].TokenIndex + tokens[index - 1].TokenLength)
{
index++;
}
for (int i = index; i < tokens.Length; i++)
{
accumulatedIds?.Add(tokens[i].Id);
}
textIndex = index >= tokens.Length ? fullTextLength : tokens[index].TokenIndex;
return tokens.Length - index;
}
/// <summary>
@ -786,11 +858,11 @@ namespace Microsoft.ML.Tokenizers
return null;
}
if (_cache.TryGetValue(token, out (int[] Ids, string Token) value))
if (_cache.TryGetValue(token, out (int Id, int TokenIndex, int TokenLength)[] value))
{
if (value.Ids.Length == 1)
if (value.Length == 1)
{
return value.Ids[0];
return value[0].Id;
}
return null;
@ -801,30 +873,34 @@ namespace Microsoft.ML.Tokenizers
return id.Id;
}
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(token.Length));
try
int utf8Length = Encoding.UTF8.GetMaxByteCount(token.Length);
byte[] arrayPoolArray = arrayPoolArray = ArrayPool<byte>.Shared.Rent(utf8Length);
int[]? indexMappingArray = null;
Span<int> indexMappingSpan = utf8Length + 1 <= 128 ? stackalloc int[128] : (indexMappingArray = ArrayPool<int>.Shared.Rent(utf8Length + 1));
int encodedLength = Helpers.EncodeToUtf8(token, arrayPoolArray, indexMappingSpan);
Debug.Assert(encodedLength < indexMappingSpan.Length);
indexMappingSpan[encodedLength] = token.Length;
(int Id, int TokenIndex, int TokenLength)[] encodedTokens = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder, indexMappingSpan.Slice(0, encodedLength + 1));
ArrayPool<byte>.Shared.Return(arrayPoolArray);
if (indexMappingArray is not null)
{
int encodedLength = Helpers.GetUtf8Bytes(token, arrayPoolArray);
int[] idsToCache = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
if (token.Length <= MaxWordLengthToCache)
{
string tokenAsString = token.ToString();
_cache.Add(tokenAsString, (idsToCache, tokenAsString));
}
if (idsToCache.Length == 1)
{
return idsToCache[0];
}
return null;
ArrayPool<int>.Shared.Return(indexMappingArray);
}
finally
if (token.Length <= MaxWordLengthToCache)
{
ArrayPool<byte>.Shared.Return(arrayPoolArray);
string tokenAsString = token.ToString();
_cache.Add(tokenAsString, encodedTokens);
}
if (encodedTokens.Length == 1)
{
return encodedTokens[0].Id;
}
return null;
}
/// <summary>

Просмотреть файл

@ -455,6 +455,74 @@ namespace Microsoft.ML.Tokenizers
return new SentencePieceBpe(modelProto, addBeginOfSentence, addEndOfSentence);
}
/// <summary>
/// Create a CodeGen tokenizer from the given vocab and merges streams.
/// </summary>
/// <param name="vocabStream">The stream containing the vocab file.</param>
/// <param name="mergesStream">The stream containing the merges file.</param>
/// <param name="addPrefixSpace">Indicate whether to add a space before the token.</param>
/// <param name="addBeginOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <returns>The CodeGen tokenizer object.</returns>
/// <remarks>
/// The tokenizer will be created according to the configuration specified in https://huggingface.co/Salesforce/codegen-350M-mono/raw/main/tokenizer.json.
/// It is important to provide the similar vocab and merges files to the ones used in the training of the model.
/// The vocab and merges files can be downloaded from the following links:
/// https://huggingface.co/Salesforce/codegen-350M-mono/resolve/main/vocab.json?download=true
/// https://huggingface.co/Salesforce/codegen-350M-mono/resolve/main/merges.txt?download=true
/// </remarks>
public static Tokenizer CreateCodeGen(
Stream vocabStream,
Stream mergesStream,
bool addPrefixSpace = false,
bool addBeginOfSentence = false,
bool addEndOfSentence = false)
{
if (vocabStream is null)
{
throw new ArgumentNullException(nameof(vocabStream));
}
if (mergesStream is null)
{
throw new ArgumentNullException(nameof(mergesStream));
}
return new CodeGen(
vocabStream,
mergesStream,
new TiktokenPreTokenizer(Tiktoken.P50kBaseRegex(), CodeGen.CodeGenAddedTokens),
normalizer: null,
CodeGen.CodeGenAddedTokens,
addPrefixSpace: addPrefixSpace,
addBeginningOfSentence: addBeginOfSentence,
addEndOfSentence: addEndOfSentence);
}
/// <summary>
/// Create a CodeGen Phi2 tokenizer from the given vocab and merges streams.
/// </summary>
/// <param name="vocabStream">The stream containing the vocab file.</param>
/// <param name="mergesStream">The stream containing the merges file.</param>
/// <param name="addPrefixSpace">Indicate whether to add a space before the token.</param>
/// <param name="addBeginOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <returns>The CodeGen tokenizer object.</returns>
/// <remarks>
/// The tokenizer will be created according to the configuration specified in https://huggingface.co/microsoft/phi-2/raw/main/tokenizer.json.
/// It is important to provide the similar vocab and merges files to the ones used in the training of the model.
/// The vocab and merges files can be downloaded from the following links:
/// https://huggingface.co/microsoft/phi-2/resolve/main/vocab.json?download=true
/// https://huggingface.co/microsoft/phi-2/resolve/main/merges.txt?download=true
/// </remarks>
public static Tokenizer CreatePhi2(
Stream vocabStream,
Stream mergesStream,
bool addPrefixSpace = false,
bool addBeginOfSentence = false,
bool addEndOfSentence = false)
=> CreateCodeGen(vocabStream, mergesStream, addPrefixSpace, addBeginOfSentence, addEndOfSentence);
internal static IEnumerable<(int Offset, int Length)>? InitializeForEncoding(
string? text,
ReadOnlySpan<char> textSpan,

Просмотреть файл

@ -13,11 +13,11 @@ namespace Microsoft.ML.Tokenizers
/// </summary>
internal static class BytePairEncoder
{
public static int[] BytePairEncode(ReadOnlyMemory<byte> mergingBytes, IReadOnlyDictionary<ReadOnlyMemory<byte>, int> ranks)
public static (int Id, int TokenIndex, int TokenLength)[] BytePairEncode(ReadOnlyMemory<byte> mergingBytes, IReadOnlyDictionary<ReadOnlyMemory<byte>, int> ranks, ReadOnlySpan<int> indexMappingSpan)
{
if (mergingBytes.Length == 1)
{
return [ranks[mergingBytes]];
return [(ranks[mergingBytes], 0, 1)];
}
(int Index, int Rank)[]? arrayPoolArray = null;
@ -84,10 +84,28 @@ namespace Microsoft.ML.Tokenizers
}
}
var result = new int[byteIndicesAndRanks.Length - 1];
var result = new (int Id, int TokenIndex, int TokenLength)[byteIndicesAndRanks.Length - 1];
for (int i = 0; i < result.Length; i++)
{
result[i] = ranks[mergingBytes.SliceStartEnd(byteIndicesAndRanks[i].Index, byteIndicesAndRanks[i + 1].Index)];
int startIndex = byteIndicesAndRanks[i].Index;
int endIndex = byteIndicesAndRanks[i + 1].Index;
int mappedStartIndex = indexMappingSpan[startIndex];
int mappedEndIndex = indexMappingSpan[endIndex];
int finalEndIndex = endIndex;
if (finalEndIndex > 0 && indexMappingSpan[finalEndIndex - 1] == mappedEndIndex)
{
// The partial character/element should be included in the current token.
finalEndIndex++;
while (finalEndIndex < indexMappingSpan.Length && indexMappingSpan[finalEndIndex] == mappedEndIndex)
{
finalEndIndex++;
}
}
result[i] = (ranks[mergingBytes.SliceStartEnd(startIndex, endIndex)], mappedStartIndex, indexMappingSpan[finalEndIndex] - mappedStartIndex);
}
if (arrayPoolArray is not null)

Просмотреть файл

@ -0,0 +1,51 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using System.Linq;
namespace Microsoft.ML.Tokenizers
{
/// <summary>
/// Map between utf-8 byte to unicode with avoiding mapping to whitespace/control characters.
/// </summary>
internal sealed class ByteToUnicodeEncoding
{
public static ByteToUnicodeEncoding Instance { get; } = new ByteToUnicodeEncoding();
public ByteToUnicodeEncoding()
{
var byteToUnicodeMapping = Enumerable.Range('!', '~' - '!' + 1)
.Concat(Enumerable.Range('¡', '¬' - '¡' + 1))
.Concat(Enumerable.Range('®', 'ÿ' - '®' + 1))
.ToDictionary(b => (char)b, b => (char)b);
const int numChars = 256;
var n = 0;
foreach (var b in Enumerable.Range(0, numChars))
{
if (!byteToUnicodeMapping.ContainsKey((char)b))
{
byteToUnicodeMapping.Add((char)b, (char)(numChars + n));
++n;
}
}
ByteToUnicode = byteToUnicodeMapping;
UnicodeToByte = ByteToUnicode.ToDictionary(kv => kv.Value, kv => kv.Key);
int count = numChars + n;
CharToString = new string[count];
for (char c = (char)0; c < (char)count; c++)
{
CharToString[c] = c.ToString();
}
}
public IReadOnlyDictionary<char, char> ByteToUnicode { get; }
public IReadOnlyDictionary<char, char> UnicodeToByte { get; }
public string[] CharToString { get; }
}
}

Просмотреть файл

@ -4,6 +4,8 @@
using System;
using System.Buffers;
using System.Diagnostics;
using System.Text;
namespace Microsoft.ML.Tokenizers
{
@ -16,5 +18,114 @@ namespace Microsoft.ML.Tokenizers
ArrayPool<T>.Shared.Return(arrayPoolArray);
arrayPoolArray = tmp;
}
internal static int EncodeToUtf8(ReadOnlySpan<char> text, Span<byte> destination, Span<int> indexMapping)
{
Debug.Assert(!text.IsEmpty);
Debug.Assert(Encoding.UTF8.GetMaxByteCount(text.Length) <= destination.Length);
Debug.Assert(indexMapping.Length >= destination.Length);
int targetIndex = 0;
for (int i = 0; i < text.Length; i++)
{
uint c = (uint)text[i];
if (c <= 0x7Fu)
{
destination[targetIndex] = (byte)c;
indexMapping[targetIndex] = i;
targetIndex++;
continue;
}
if (c <= 0x7FFu)
{
// Scalar 00000yyy yyxxxxxx -> bytes [ 110yyyyy 10xxxxxx ]
destination[targetIndex] = (byte)((c + (0b110u << 11)) >> 6);
destination[targetIndex + 1] = (byte)((c & 0x3Fu) + 0x80u);
indexMapping[targetIndex] = indexMapping[targetIndex + 1] = i;
targetIndex += 2;
continue;
}
if (i < text.Length - 1 && char.IsSurrogatePair((char)c, text[i + 1]))
{
// Scalar 000uuuuu zzzzyyyy yyxxxxxx -> bytes [ 11110uuu 10uuzzzz 10yyyyyy 10xxxxxx ]
uint value = (uint)char.ConvertToUtf32((char)c, text[i + 1]);
destination[targetIndex] = (byte)((value + (0b11110 << 21)) >> 18);
destination[targetIndex + 1] = (byte)(((value & (0x3Fu << 12)) >> 12) + 0x80u);
destination[targetIndex + 2] = (byte)(((value & (0x3Fu << 6)) >> 6) + 0x80u);
destination[targetIndex + 3] = (byte)((value & 0x3Fu) + 0x80u);
indexMapping[targetIndex] = indexMapping[targetIndex + 1] = indexMapping[targetIndex + 2] = indexMapping[targetIndex + 3] = i;
i++;
targetIndex += 4;
continue;
}
// Scalar zzzzyyyy yyxxxxxx -> bytes [ 1110zzzz 10yyyyyy 10xxxxxx ]
destination[targetIndex] = (byte)((c + (0b1110 << 16)) >> 12);
destination[targetIndex + 1] = (byte)(((c & (0x3Fu << 6)) >> 6) + 0x80u);
destination[targetIndex + 2] = (byte)((c & 0x3Fu) + 0x80u);
indexMapping[targetIndex] = indexMapping[targetIndex + 1] = indexMapping[targetIndex + 2] = i;
targetIndex += 3;
}
return targetIndex;
}
internal static int EncodeToUtf8AndTransform(ReadOnlySpan<char> text, Span<char> destination, Span<int> indexMapping)
{
Debug.Assert(!text.IsEmpty);
Debug.Assert(Encoding.UTF8.GetMaxByteCount(text.Length) <= destination.Length);
Debug.Assert(indexMapping.Length >= destination.Length);
ByteToUnicodeEncoding byteToUnicodeEncoder = ByteToUnicodeEncoding.Instance;
int targetIndex = 0;
for (int i = 0; i < text.Length; i++)
{
uint c = (uint)text[i];
if (c <= 0x7Fu)
{
destination[targetIndex] = byteToUnicodeEncoder.ByteToUnicode[(char)c];
indexMapping[targetIndex] = i;
targetIndex++;
continue;
}
if (c <= 0x7FFu)
{
// Scalar 00000yyy yyxxxxxx -> bytes [ 110yyyyy 10xxxxxx ]
destination[targetIndex] = byteToUnicodeEncoder.ByteToUnicode[(char)((c + (0b110u << 11)) >> 6)];
destination[targetIndex + 1] = byteToUnicodeEncoder.ByteToUnicode[(char)((c & 0x3Fu) + 0x80u)];
indexMapping[targetIndex] = indexMapping[targetIndex + 1] = i;
targetIndex += 2;
continue;
}
if (i < text.Length - 1 && char.IsSurrogatePair((char)c, text[i + 1]))
{
// Scalar 000uuuuu zzzzyyyy yyxxxxxx -> bytes [ 11110uuu 10uuzzzz 10yyyyyy 10xxxxxx ]
uint value = (uint)char.ConvertToUtf32((char)c, text[i + 1]);
destination[targetIndex] = byteToUnicodeEncoder.ByteToUnicode[(char)((value + (0b11110 << 21)) >> 18)];
destination[targetIndex + 1] = byteToUnicodeEncoder.ByteToUnicode[(char)(((value & (0x3Fu << 12)) >> 12) + 0x80u)];
destination[targetIndex + 2] = byteToUnicodeEncoder.ByteToUnicode[(char)(((value & (0x3Fu << 6)) >> 6) + 0x80u)];
destination[targetIndex + 3] = byteToUnicodeEncoder.ByteToUnicode[(char)((value & 0x3Fu) + 0x80u)];
indexMapping[targetIndex] = indexMapping[targetIndex + 1] = indexMapping[targetIndex + 2] = indexMapping[targetIndex + 3] = i;
i++;
targetIndex += 4;
continue;
}
// Scalar zzzzyyyy yyxxxxxx -> bytes [ 1110zzzz 10yyyyyy 10xxxxxx ]
destination[targetIndex] = byteToUnicodeEncoder.ByteToUnicode[(char)((c + (0b1110 << 16)) >> 12)];
destination[targetIndex + 1] = byteToUnicodeEncoder.ByteToUnicode[(char)(((c & (0x3Fu << 6)) >> 6) + 0x80u)];
destination[targetIndex + 2] = byteToUnicodeEncoder.ByteToUnicode[(char)((c & 0x3Fu) + 0x80u)];
indexMapping[targetIndex] = indexMapping[targetIndex + 1] = indexMapping[targetIndex + 2] = i;
targetIndex += 3;
}
return targetIndex;
}
}
}

Просмотреть файл

@ -61,5 +61,32 @@ namespace Microsoft.ML.Tokenizers
=> Encoding.UTF8.GetChars(bytes, chars);
internal static void Replace(Span<char> span, char oldValue, char newValue) => span.Replace(oldValue, newValue);
/// <summary>
/// Encode the next code point in the text to UTF-8.
/// </summary>
/// <param name="text">The text to encode the first code point from.</param>
/// <param name="textIndex">The index of the first code point to encode.</param>
/// <param name="destination">The buffer to write the UTF-8 bytes to.</param>
/// <param name="bytesIndex">The index in the buffer to write the UTF-8 encoded bytes to.</param>
/// <returns>The number of characters consumed from the text.</returns>
internal static int EncodeCodePointToUtf8(ReadOnlySpan<char> text, int textIndex, ref byte[] destination, ref int bytesIndex)
{
Debug.Assert(textIndex < text.Length);
Rune.DecodeFromUtf16(text.Slice(textIndex), out Rune rune, out int charsConsumed);
Span<byte> buffer = stackalloc byte[4]; // max number of bytes for a single code point utf-8 encoding.
int bytesWritten = rune.EncodeToUtf8(buffer);
if (bytesIndex + bytesWritten > destination.Length)
{
Helpers.ArrayPoolGrow(ref destination, destination.Length * 2);
}
buffer.Slice(0, bytesWritten).CopyTo(destination.AsSpan(bytesIndex));
bytesIndex += bytesWritten;
return charsConsumed;
}
}
}

Просмотреть файл

@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.
using System;
using System.Diagnostics;
using System.IO;
using System.Net.Http;
using System.Text;
@ -111,6 +112,73 @@ namespace Microsoft.ML.Tokenizers
if (span[i] == oldValue)
span[i] = newValue;
}
/// <summary>
/// Encode the next code point in the text to UTF-8.
/// </summary>
/// <param name="text">The text to encode the first code point from.</param>
/// <param name="textIndex">The index of the first code point to encode.</param>
/// <param name="destination">The buffer to write the UTF-8 bytes to.</param>
/// <param name="bytesIndex">The index in the buffer to write the UTF-8 encoded bytes to.</param>
/// <returns>The number of characters consumed from the text.</returns>
internal static int EncodeCodePointToUtf8(ReadOnlySpan<char> text, int textIndex, ref byte[] destination, ref int bytesIndex)
{
Debug.Assert(textIndex < text.Length);
uint c = (uint)text[textIndex];
if (c <= 0x7Fu)
{
if (bytesIndex + 1 > destination.Length)
{
Helpers.ArrayPoolGrow(ref destination, destination.Length * 2);
}
destination[bytesIndex] = (byte)c;
bytesIndex++;
return 1;
}
if (c <= 0x7FFu)
{
// Scalar 00000yyy yyxxxxxx -> bytes [ 110yyyyy 10xxxxxx ]
if (bytesIndex + 2 > destination.Length)
{
Helpers.ArrayPoolGrow(ref destination, destination.Length * 2);
}
destination[bytesIndex] = (byte)((c + (0b110u << 11)) >> 6);
destination[bytesIndex + 1] = (byte)((c & 0x3Fu) + 0x80u);
bytesIndex += 2;
return 1;
}
if (textIndex < text.Length - 1 && char.IsSurrogatePair((char)c, text[textIndex + 1]))
{
// Scalar 000uuuuu zzzzyyyy yyxxxxxx -> bytes [ 11110uuu 10uuzzzz 10yyyyyy 10xxxxxx ]
if (bytesIndex + 4 > destination.Length)
{
Helpers.ArrayPoolGrow(ref destination, Math.Max(destination.Length, 4) * 2);
}
uint value = (uint)char.ConvertToUtf32((char)c, text[textIndex + 1]);
destination[bytesIndex] = (byte)((value + (0b11110 << 21)) >> 18);
destination[bytesIndex + 1] = (byte)(((value & (0x3Fu << 12)) >> 12) + 0x80u);
destination[bytesIndex + 2] = (byte)(((value & (0x3Fu << 6)) >> 6) + 0x80u);
destination[bytesIndex + 3] = (byte)((value & 0x3Fu) + 0x80u);
bytesIndex += 4;
return 2;
}
if (bytesIndex + 3 > destination.Length)
{
Helpers.ArrayPoolGrow(ref destination, Math.Max(destination.Length, 3) * 2);
}
// Scalar zzzzyyyy yyxxxxxx -> bytes [ 1110zzzz 10yyyyyy 10xxxxxx ]
destination[bytesIndex] = (byte)((c + (0b1110 << 16)) >> 12);
destination[bytesIndex + 1] = (byte)(((c & (0x3Fu << 6)) >> 6) + 0x80u);
destination[bytesIndex + 2] = (byte)((c & 0x3Fu) + 0x80u);
bytesIndex += 3;
return 1;
}
}
}

Просмотреть файл

@ -45,6 +45,31 @@ namespace Microsoft.ML.Tokenizers
public override int GetHashCode() => Helpers.GetHashCode(Span);
}
internal unsafe readonly struct StringSpanOrdinalKeyPair : IEquatable<StringSpanOrdinalKeyPair>
{
private readonly StringSpanOrdinalKey _left;
private readonly StringSpanOrdinalKey _right;
public StringSpanOrdinalKeyPair(char* ptr1, int length1, char* ptr2, int length2)
{
_left = new StringSpanOrdinalKey(ptr1, length1);
_right = new StringSpanOrdinalKey(ptr2, length2);
}
public StringSpanOrdinalKeyPair(string data1, string data2)
{
_left = new StringSpanOrdinalKey(data1);
_right = new StringSpanOrdinalKey(data2);
}
public override bool Equals(object? obj) =>
obj is StringSpanOrdinalKeyPair wrapper && wrapper._left.Equals(_left) && wrapper._right.Equals(_right);
public bool Equals(StringSpanOrdinalKeyPair other) => other._left.Equals(_left) && other._right.Equals(_right);
public override int GetHashCode() => HashCode.Combine(_left.GetHashCode(), _right.GetHashCode());
}
internal sealed class StringSpanOrdinalKeyCache<TValue>
{
private readonly int _capacity;
@ -115,6 +140,34 @@ namespace Microsoft.ML.Tokenizers
public override void Write(Utf8JsonWriter writer, StringSpanOrdinalKey value, JsonSerializerOptions options) => writer.WriteStringValue(value.Data!);
}
internal class StringSpanOrdinalKeyCustomConverter : JsonConverter<Dictionary<StringSpanOrdinalKey, (int, string)>>
{
public static StringSpanOrdinalKeyCustomConverter Instance { get; } = new StringSpanOrdinalKeyCustomConverter();
public override Dictionary<StringSpanOrdinalKey, (int, string)> Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var dictionary = new Dictionary<StringSpanOrdinalKey, (int, string)>();
while (reader.Read())
{
if (reader.TokenType == JsonTokenType.EndObject)
{
return dictionary;
}
if (reader.TokenType == JsonTokenType.PropertyName)
{
var key = reader.GetString();
reader.Read();
var value = reader.GetInt32();
dictionary.Add(new StringSpanOrdinalKey(key!), (value, key!));
}
}
throw new JsonException("Invalid JSON.");
}
public override void Write(Utf8JsonWriter writer, Dictionary<StringSpanOrdinalKey, (int, string)> value, JsonSerializerOptions options) => throw new NotImplementedException();
}
/// <summary>
/// Extension methods for <see cref="StringSpanOrdinalKey"/>.
/// </summary>
@ -130,5 +183,17 @@ namespace Microsoft.ML.Tokenizers
public static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKey, TValue> map, string key, out TValue value) =>
map.TryGetValue(new StringSpanOrdinalKey(key), out value!);
public unsafe static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKeyPair, TValue> map, ReadOnlySpan<char> key1, ReadOnlySpan<char> key2, out TValue value)
{
fixed (char* ptr1 = key1)
fixed (char* ptr2 = key2)
{
return map.TryGetValue(new StringSpanOrdinalKeyPair(ptr1, key1.Length, ptr2, key2.Length), out value!);
}
}
public static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKeyPair, TValue> map, string key1, string key2, out TValue value) =>
map.TryGetValue(new StringSpanOrdinalKeyPair(key1, key2), out value!);
}
}

Просмотреть файл

@ -30,7 +30,7 @@ namespace Microsoft.ML.TorchSharp.Extensions
assembly.GetManifestResourceStream("encoder.json"),
assembly.GetManifestResourceStream("vocab.bpe"),
assembly.GetManifestResourceStream("dict.txt"),
new RobertaPreTokenizer());
RobertaPreTokenizer.Instance);
(_instance as EnglishRoberta).AddMaskSymbol();
}

Просмотреть файл

@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Tokenizers;
using System;
using System.Collections.Generic;
using System.Net.Http;

Просмотреть файл

@ -0,0 +1,957 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text.RegularExpressions;
using Xunit;
namespace Microsoft.ML.Tokenizers.Tests
{
public class CodeGenTests
{
private static Tokenizer _codegen350MMonoTokenizer = CreateCodegen350MMonoTokenizer();
private static Tokenizer _codegen350MMonoTokenizerWithSpace = CreateCodegen350MMonoTokenizer(addPrefixSpace: true);
private static Tokenizer _codegen350MMonoTokenizerWithBeginningOfSentence = CreateCodegen350MMonoTokenizer(bos: true);
private static Tokenizer _codegen350MMonoTokenizerWithEndOfSentence = CreateCodegen350MMonoTokenizer(eos: true);
private static Tokenizer _codegen350MMonoTokenizerWithBeginningAndEndOfSentence = CreateCodegen350MMonoTokenizer(bos: true, eos: true);
private static Tokenizer CreateCodegen350MMonoTokenizer(bool addPrefixSpace = false, bool bos = false, bool eos = false)
{
// @"https://huggingface.co/Salesforce/codegen-350M-mono/resolve/main/vocab.json?download=true";
// @"https://huggingface.co/Salesforce/codegen-350M-mono/resolve/main/merges.txt?download=true";
using Stream vocabStream = File.OpenRead(Path.Combine(@"Codegen-350M-mono", "vocab.json"));
using Stream mergesStream = File.OpenRead(Path.Combine(@"Codegen-350M-mono", "merges.txt"));
return Tokenizer.CreateCodeGen(vocabStream, mergesStream, addPrefixSpace, bos, eos);
}
private static Tokenizer CreateCodegenPhi2Tokenizer()
{
// https://huggingface.co/microsoft/phi-2/resolve/main/vocab.json?download=true
// https://huggingface.co/microsoft/phi-2/resolve/main/merges.txt?download=true
using Stream vocabStream = File.OpenRead(Path.Combine(@"Phi-2", "vocab.json"));
using Stream mergesStream = File.OpenRead(Path.Combine(@"Phi-2", "merges.txt"));
return Tokenizer.CreateCodeGen(vocabStream, mergesStream);
}
public static IEnumerable<object?[]> CodeGenTestData
{
get
{
// string to tokenize,
// produced tokens,
// the token offsets,
// the tokens ids, produced tokens when AddPrefixSpace is enabled,
// the token offsets when AddPrefixSpace is enabled,
// the tokens ids when AddPrefixSpace is enabled
yield return new object?[]
{
"Hello World",
new string[] { "Hello", "ĠWorld" },
new (int Index, int Length)[] { (0, 5), (5, 6) },
new int[] { 15496, 2159 },
new string[] { "ĠHello", "ĠWorld" },
new (int Index, int Length)[] { (0, 5), (5, 6) },
new int[] { 18435, 2159 },
};
yield return new object?[]
{
" Hello World", // with space prefix this depends on the AddedTokens
new string[] { "ĠHello", "ĠWorld" },
new (int Index, int Length)[] { (0, 6), (6, 6) },
new int[] { 18435, 2159 },
new string[] { " ", "Hello", "ĠWorld" },
new (int Index, int Length)[] { (0, 1), (1, 5), (6, 6) },
new int[] { 50286, 15496, 2159 },
};
yield return new object?[]
{
"the brown fox jumped over the lazy dog!\r\n", // text in range 0 ~ FF
new string[] { "the", "Ġbrown", "Ġfox", "Ġjumped", "Ġover", "Ġthe", "Ġlazy", "Ġdog", "!", "č", "Ċ" },
new (int Index, int Length)[] { (0, 3), (3, 6), (9, 4), (13, 7), (20, 5), (25, 4), (29, 5), (34, 4), (38, 1), (39, 1), (40, 1) },
new int[] { 1169, 7586, 21831, 11687, 625, 262, 16931, 3290, 0, 201, 198 },
new string[] { "Ġthe", "Ġbrown", "Ġfox", "Ġjumped", "Ġover", "Ġthe", "Ġlazy", "Ġdog", "!", "č", "Ċ" },
new (int Index, int Length)[] { (0, 3), (3, 6), (9, 4), (13, 7), (20, 5), (25, 4), (29, 5), (34, 4), (38, 1), (39, 1), (40, 1) },
new int[] { 262, 7586, 21831, 11687, 625, 262, 16931, 3290, 0, 201, 198 }
};
yield return new object?[]
{
"\u0924\u1009\u1129\u1241\uE860\u3438.", // text greater than 7FF Devanagari, Myanmar, Hangul, Ethiopic, Palmyrene, CJK तဉᄩቁ㐸.
new string[] { "à¤", "¤", "á", "Ģ", "ī", "á", "Ħ", "©", "á", "ī", "ģ", "î", "¡", "ł", "ã", "IJ", "¸", "." },
new (int Index, int Length)[] { (0, 0), (0, 1), (1, 0), (1, 0), (1, 1), (2, 0), (2, 0), (2, 1), (3, 0), (3, 0), (3, 1), (4, 0), (4, 0), (4, 1), (5, 0), (5, 0), (5, 1), (6, 1) },
new int[] { 11976, 97, 157, 222, 231, 157, 226, 102, 157, 231, 223, 170, 94, 254, 159, 238, 116, 13 },
new string[] { "Ġà¤", "¤", "á", "Ģ", "ī", "á", "Ħ", "©", "á", "ī", "ģ", "î", "¡", "ł", "ã", "IJ", "¸", "." },
new (int Index, int Length)[] { (0, 0), (0, 1), (1, 0), (1, 0), (1, 1), (2, 0), (2, 0), (2, 1), (3, 0), (3, 0), (3, 1), (4, 0), (4, 0), (4, 1), (5, 0), (5, 0), (5, 1), (6, 1) },
new int[] { 28225, 97, 157, 222, 231, 157, 226, 102, 157, 231, 223, 170, 94, 254, 159, 238, 116, 13 }
};
yield return new object?[]
{
"Some Greek letters ΣΦΩ αβγδε.", // text in range 100 ~ 7FF
new string[] { "Some", "ĠGreek", "Ġletters", "ĠÎ", "£", "Î", "¦", "Î", "©", "Ġα", "β", "γ", "Î", "´", "ε", "." },
new (int Index, int Length)[] { (0, 4), (4, 6), (10, 8), (18, 1), (19, 1), (20, 0), (20, 1), (21, 0), (21, 1), (22, 2), (24, 1), (25, 1), (26, 0), (26, 1), (27, 1), (28, 1) },
new int[] { 4366, 8312, 7475, 7377, 96, 138, 99, 138, 102, 26367, 26638, 42063, 138, 112, 30950, 13 },
new string[] { "ĠSome", "ĠGreek", "Ġletters", "ĠÎ", "£", "Î", "¦", "Î", "©", "Ġα", "β", "γ", "Î", "´", "ε", "." },
new (int Index, int Length)[] { (0, 4), (4, 6), (10, 8), (18, 1), (19, 1), (20, 0), (20, 1), (21, 0), (21, 1), (22, 2), (24, 1), (25, 1), (26, 0), (26, 1), (27, 1), (28, 1) },
new int[] { 2773, 8312, 7475, 7377, 96, 138, 99, 138, 102, 26367, 26638, 42063, 138, 112, 30950, 13 }
};
yield return new object?[]
{
"αβγδε", // no spaces
new string[] { "α", "β", "γ", "Î", "´", "ε" },
new (int Index, int Length)[] { (0, 1), (1, 1), (2, 1), (3, 0), (3, 1), (4, 1) },
new int[] { 17394, 26638, 42063, 138, 112, 30950 },
new string[] { "Ġα", "β", "γ", "Î", "´", "ε" },
new (int Index, int Length)[] { (0, 1), (1, 1), (2, 1), (3, 0), (3, 1), (4, 1) },
new int[] { 26367, 26638, 42063, 138, 112, 30950 }
};
yield return new object?[]
{
"Surrogates: 😀😂😍😘",
new string[] { "Sur", "rog", "ates", ":", "ĠðŁĺ", "Ģ", "ðŁĺ", "Ĥ", "ðŁĺ", "į", "ðŁĺ", "ĺ" },
new (int Index, int Length)[] { (0, 3), (3, 3), (6, 4), (10, 1), (11, 1), (12, 2), (14, 0), (14, 2), (16, 0), (16, 2), (18, 0), (18, 2) },
new int[] { 14214, 3828, 689, 25, 30325, 222, 47249, 224, 47249, 235, 47249, 246 },
new string[] { "ĠSur", "rog", "ates", ":", "ĠðŁĺ", "Ģ", "ðŁĺ", "Ĥ", "ðŁĺ", "į", "ðŁĺ", "ĺ" },
new (int Index, int Length)[] { (0, 3), (3, 3), (6, 4), (10, 1), (11, 1), (12, 2), (14, 0), (14, 2), (16, 0), (16, 2), (18, 0), (18, 2) },
new int[] { 4198, 3828, 689, 25, 30325, 222, 47249, 224, 47249, 235, 47249, 246 }
};
yield return new object?[]
{
"Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides " +
"general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet...) for Natural " +
"Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained " +
"models in 100+ languages and deep interoperability between Jax, PyTorch and TensorFlow.",
new string[] { "Transform", "ers", "Ġ(", "formerly", "Ġknown", "Ġas", "Ġpy", "tor", "ch", "-", "transform", "ers", "Ġand", "Ġpy", "tor", "ch", "-",
"pret", "rained", "-", "bert", ")", "Ġprovides", "Ġgeneral", "-", "purpose", "Ġarchitectures", "Ġ(", "BER", "T", ",", "ĠG", "PT",
"-", "2", ",", "ĠRo", "BER", "Ta", ",", "ĠXL", "M", ",", "ĠDist", "il", "B", "ert", ",", "ĠXL", "Net", "...)", "Ġfor", "ĠNatural",
"ĠLanguage", "ĠUnderstanding", "Ġ(", "NL", "U", ")", "Ġand", "ĠNatural", "ĠLanguage", "ĠGeneration", "Ġ(", "NL", "G", ")", "Ġwith",
"Ġover", "Ġ32", "+", "Ġpret", "rained", "Ġmodels", "Ġin", "Ġ100", "+", "Ġlanguages", "Ġand", "Ġdeep", "Ġinteroper", "ability",
"Ġbetween", "ĠJ", "ax", ",", "ĠPy", "Tor", "ch", "Ġand", "ĠT", "ensor", "Flow", "." },
new (int Index, int Length)[] { (0, 9), (9, 3), (12, 2), (14, 8), (22, 6), (28, 3), (31, 3), (34, 3), (37, 2), (39, 1), (40, 9), (49, 3), (52, 4),
(56, 3), (59, 3), (62, 2), (64, 1), (65, 4), (69, 6), (75, 1), (76, 4), (80, 1), (81, 9), (90, 8), (98, 1), (99, 7), (106, 14),
(120, 2), (122, 3), (125, 1), (126, 1), (127, 2), (129, 2), (131, 1), (132, 1), (133, 1), (134, 3), (137, 3), (140, 2), (142, 1),
(143, 3), (146, 1), (147, 1), (148, 5), (153, 2), (155, 1), (156, 3), (159, 1), (160, 3), (163, 3), (166, 4), (170, 4), (174, 8),
(182, 9), (191, 14), (205, 2), (207, 2), (209, 1), (210, 1), (211, 4), (215, 8), (223, 9), (232, 11), (243, 2), (245, 2), (247, 1),
(248, 1), (249, 5), (254, 5), (259, 3), (262, 1), (263, 5), (268, 6), (274, 7), (281, 3), (284, 4), (288, 1), (289, 10), (299, 4),
(303, 5), (308, 10), (318, 7), (325, 8), (333, 2), (335, 2), (337, 1), (338, 3), (341, 3), (344, 2), (346, 4), (350, 2), (352, 5),
(357, 4), (361, 1) },
new int[] { 41762, 364, 357, 36234, 1900, 355, 12972, 13165, 354, 12, 35636, 364, 290, 12972, 13165, 354, 12, 5310, 13363, 12, 4835, 8, 3769, 2276,
12, 29983, 45619, 357, 13246, 51, 11, 402, 11571, 12, 17, 11, 5564, 13246, 38586, 11, 16276, 44, 11, 4307, 346, 33, 861, 11, 16276,
7934, 23029, 329, 12068, 15417, 28491, 357, 32572, 52, 8, 290, 12068, 15417, 16588, 357, 32572, 38, 8, 351, 625, 3933, 10, 2181, 13363,
4981, 287, 1802, 10, 8950, 290, 2769, 48817, 1799, 1022, 449, 897, 11, 9485, 15884, 354, 290, 309, 22854, 37535, 13 },
new string[] { "ĠTransformers", "Ġ(", "formerly", "Ġknown", "Ġas", "Ġpy", "tor", "ch", "-", "transform", "ers", "Ġand", "Ġpy", "tor", "ch", "-",
"pret", "rained", "-", "bert", ")", "Ġprovides", "Ġgeneral", "-", "purpose", "Ġarchitectures", "Ġ(", "BER", "T", ",", "ĠG", "PT",
"-", "2", ",", "ĠRo", "BER", "Ta", ",", "ĠXL", "M", ",", "ĠDist", "il", "B", "ert", ",", "ĠXL", "Net", "...)", "Ġfor", "ĠNatural",
"ĠLanguage", "ĠUnderstanding", "Ġ(", "NL", "U", ")", "Ġand", "ĠNatural", "ĠLanguage", "ĠGeneration", "Ġ(", "NL", "G", ")", "Ġwith",
"Ġover", "Ġ32", "+", "Ġpret", "rained", "Ġmodels", "Ġin", "Ġ100", "+", "Ġlanguages", "Ġand", "Ġdeep", "Ġinteroper", "ability",
"Ġbetween", "ĠJ", "ax", ",", "ĠPy", "Tor", "ch", "Ġand", "ĠT", "ensor", "Flow", "." },
new (int Index, int Length)[] { (0, 12), (12, 2), (14, 8), (22, 6), (28, 3), (31, 3), (34, 3), (37, 2), (39, 1), (40, 9), (49, 3), (52, 4),
(56, 3), (59, 3), (62, 2), (64, 1), (65, 4), (69, 6), (75, 1), (76, 4), (80, 1), (81, 9), (90, 8), (98, 1), (99, 7), (106, 14),
(120, 2), (122, 3), (125, 1), (126, 1), (127, 2), (129, 2), (131, 1), (132, 1), (133, 1), (134, 3), (137, 3), (140, 2), (142, 1),
(143, 3), (146, 1), (147, 1), (148, 5), (153, 2), (155, 1), (156, 3), (159, 1), (160, 3), (163, 3), (166, 4), (170, 4), (174, 8),
(182, 9), (191, 14), (205, 2), (207, 2), (209, 1), (210, 1), (211, 4), (215, 8), (223, 9), (232, 11), (243, 2), (245, 2), (247, 1),
(248, 1), (249, 5), (254, 5), (259, 3), (262, 1), (263, 5), (268, 6), (274, 7), (281, 3), (284, 4), (288, 1), (289, 10), (299, 4),
(303, 5), (308, 10), (318, 7), (325, 8), (333, 2), (335, 2), (337, 1), (338, 3), (341, 3), (344, 2), (346, 4), (350, 2), (352, 5),
(357, 4), (361, 1) },
new int[] { 39185, 357, 36234, 1900, 355, 12972, 13165, 354, 12, 35636, 364, 290, 12972, 13165, 354, 12, 5310, 13363, 12, 4835, 8, 3769, 2276,
12, 29983, 45619, 357, 13246, 51, 11, 402, 11571, 12, 17, 11, 5564, 13246, 38586, 11, 16276, 44, 11, 4307, 346, 33, 861, 11, 16276,
7934, 23029, 329, 12068, 15417, 28491, 357, 32572, 52, 8, 290, 12068, 15417, 16588, 357, 32572, 38, 8, 351, 625, 3933, 10, 2181, 13363,
4981, 287, 1802, 10, 8950, 290, 2769, 48817, 1799, 1022, 449, 897, 11, 9485, 15884, 354, 290, 309, 22854, 37535, 13 }
};
yield return new object?[]
{
"BERT is designed to pre-train deep bidirectional representations from unlabeled text by jointly " +
"conditioning on both left and right context in all layers.",
new string[] { "BER", "T", "Ġis", "Ġdesigned", "Ġto", "Ġpre", "-", "train", "Ġdeep", "Ġbid", "irection", "al", "Ġrepresentations", "Ġfrom", "Ġunl",
"abel", "ed", "Ġtext", "Ġby", "Ġjointly", "Ġconditioning", "Ġon", "Ġboth", "Ġleft", "Ġand", "Ġright", "Ġcontext", "Ġin", "Ġall",
"Ġlayers", "." },
new (int Index, int Length)[] { (0, 3), (3, 1), (4, 3), (7, 9), (16, 3), (19, 4), (23, 1), (24, 5), (29, 5), (34, 4), (38, 8), (46, 2), (48, 16),
(64, 5), (69, 4), (73, 4), (77, 2), (79, 5), (84, 3), (87, 8), (95, 13), (108, 3), (111, 5), (116, 5), (121, 4), (125, 6), (131, 8),
(139, 3), (142, 4), (146, 7), (153, 1) },
new int[] { 13246, 51, 318, 3562, 284, 662, 12, 27432, 2769, 8406, 4154, 282, 24612, 422, 9642, 9608, 276, 2420, 416, 26913, 21143, 319, 1111, 1364, 290, 826, 4732, 287, 477, 11685, 13 },
new string[] { "ĠB", "ERT", "Ġis", "Ġdesigned", "Ġto", "Ġpre", "-", "train", "Ġdeep", "Ġbid", "irection", "al", "Ġrepresentations", "Ġfrom", "Ġunl",
"abel", "ed", "Ġtext", "Ġby", "Ġjointly", "Ġconditioning", "Ġon", "Ġboth", "Ġleft", "Ġand", "Ġright", "Ġcontext", "Ġin", "Ġall",
"Ġlayers", "." },
new (int Index, int Length)[] { (0, 1), (1, 3), (4, 3), (7, 9), (16, 3), (19, 4), (23, 1), (24, 5), (29, 5), (34, 4), (38, 8), (46, 2), (48, 16),
(64, 5), (69, 4), (73, 4), (77, 2), (79, 5), (84, 3), (87, 8), (95, 13), (108, 3), (111, 5), (116, 5), (121, 4), (125, 6), (131, 8),
(139, 3), (142, 4), (146, 7), (153, 1) },
new int[] { 347, 17395, 318, 3562, 284, 662, 12, 27432, 2769, 8406, 4154, 282, 24612, 422, 9642, 9608, 276, 2420, 416, 26913, 21143, 319, 1111, 1364, 290, 826, 4732, 287, 477, 11685, 13 }
};
yield return new object?[]
{
"The quick brown fox jumps over the lazy dog.",
new string[] { "The", "Ġquick", "Ġbrown", "Ġfox", "Ġjumps", "Ġover", "Ġthe", "Ġlazy", "Ġdog", "." },
new (int Index, int Length)[] { (0, 3), (3, 6), (9, 6), (15, 4), (19, 6), (25, 5), (30, 4), (34, 5), (39, 4), (43, 1) },
new int[] { 464, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290, 13 },
new string[] { "ĠThe", "Ġquick", "Ġbrown", "Ġfox", "Ġjumps", "Ġover", "Ġthe", "Ġlazy", "Ġdog", "." },
new (int Index, int Length)[] { (0, 3), (3, 6), (9, 6), (15, 4), (19, 6), (25, 5), (30, 4), (34, 5), (39, 4), (43, 1) },
new int[] { 383, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290, 13 }
};
}
}
[Theory]
[MemberData(nameof(CodeGenTestData))]
public void TestTokenizerEncoding(
string text,
string[] expectedTokens,
(int Index, int Length)[] expectedOffsets,
int[] expectedIds,
string[] expectedTokensWithSpace,
(int Index, int Length)[] expectedOffsetsWithSpace,
int[] expectedIdsWithSpace)
{
TestTokenizer(_codegen350MMonoTokenizer, text, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
TestTokenizer(_codegen350MMonoTokenizerWithSpace, text, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
Tokenizer phi2Tokenizer = CreateCodegenPhi2Tokenizer();
TestTokenizer(phi2Tokenizer, text, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
TestDecoding(_codegen350MMonoTokenizer, text);
TestDecoding(_codegen350MMonoTokenizerWithSpace, text);
TestDecoding(phi2Tokenizer, text);
}
private void ValidateEncoding(IReadOnlyList<Token> encoding, bool addPrefixSpace, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds,
string[] expectedTokensWithSpace, (int Index, int Length)[] expectedOffsetsWithSpace, int[] expectedIdsWithSpace)
{
if (addPrefixSpace)
{
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
Assert.Equal(expectedOffsetsWithSpace, encoding.Select(t => t.Offset).ToArray());
}
else
{
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray());
}
}
private void TestDecoding(Tokenizer tokenizer, string text)
{
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
Assert.Equal(text, tokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
encoding = tokenizer.Encode(text.AsSpan(), out _);
Assert.Equal(text, tokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
CodeGen codeGenTokenizer = (tokenizer as CodeGen)!;
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: true, addEndOfSentence: false, out _);
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: true, addEndOfSentence: false, out _);
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: false, addEndOfSentence: true, out _);
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: false, addEndOfSentence: true, out _);
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray(), hasPrefixSpace: true, considerSpecialTokens: false));
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray(), hasPrefixSpace: true, considerSpecialTokens: false));
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal($"{codeGenTokenizer.BeginningOfSentenceToken}{text}{codeGenTokenizer.EndOfSentenceToken}", codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray(), hasPrefixSpace: true, considerSpecialTokens: true));
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal($"{codeGenTokenizer.BeginningOfSentenceToken}{text}{codeGenTokenizer.EndOfSentenceToken}", codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray(), hasPrefixSpace: true, considerSpecialTokens: true));
}
private void TestTokenizer(
Tokenizer tokenizer,
string text,
string[] expectedTokens,
(int Index, int Length)[] expectedOffsets,
int[] expectedIds,
string[] expectedTokensWithSpace,
(int Index, int Length)[] expectedOffsetsWithSpace,
int[] expectedIdsWithSpace)
{
CodeGen codeGenTokenizer = (tokenizer as CodeGen)!;
//
// Full Encoding
//
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
ValidateEncoding(encoding, codeGenTokenizer.AddPrefixSpace, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
encoding = tokenizer.Encode(text.AsSpan(), out _);
ValidateEncoding(encoding, codeGenTokenizer.AddPrefixSpace, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
ValidateEncoding(encoding, addPrefixSpace: false, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
ValidateEncoding(encoding, addPrefixSpace: false, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
ValidateEncoding(encoding, addPrefixSpace: true, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
ValidateEncoding(encoding, addPrefixSpace: true, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
//
// Encode To Ids
//
var ids = codeGenTokenizer.AddPrefixSpace ? expectedIdsWithSpace : expectedIds;
Assert.Equal(ids, tokenizer.EncodeToIds(text));
Assert.Equal(ids, tokenizer.EncodeToIds(text.AsSpan()));
Assert.Equal(expectedIdsWithSpace, codeGenTokenizer.EncodeToIds(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false));
Assert.Equal(expectedIdsWithSpace, codeGenTokenizer.EncodeToIds(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false));
Assert.Equal(expectedIds, codeGenTokenizer.EncodeToIds(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false));
Assert.Equal(expectedIds, codeGenTokenizer.EncodeToIds(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false));
Assert.Equal(ids, codeGenTokenizer.EncodeToIds(text, ids.Length, out string? normalizedString, out int length));
Assert.Null(normalizedString);
Assert.Equal(text.Length, length);
Assert.Equal(ids, codeGenTokenizer.EncodeToIds(text.AsSpan(), ids.Length, out normalizedString, out length));
Assert.Null(normalizedString);
Assert.Equal(text.Length, length);
Assert.Equal(expectedIds, codeGenTokenizer.EncodeToIds(text, expectedIds.Length, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out length));
Assert.Null(normalizedString);
Assert.Equal(text.Length, length);
Assert.Equal(expectedIds, codeGenTokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out length));
Assert.Null(normalizedString);
Assert.Equal(text.Length, length);
Assert.Equal(expectedIdsWithSpace, codeGenTokenizer.EncodeToIds(text, expectedIdsWithSpace.Length, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out length));
Assert.Null(normalizedString);
Assert.Equal(text.Length, length);
Assert.Equal(expectedIdsWithSpace, codeGenTokenizer.EncodeToIds(text.AsSpan(), expectedIdsWithSpace.Length, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out length));
Assert.Null(normalizedString);
Assert.Equal(text.Length, length);
int expectedTokensToExclude = expectedOffsets.Length > 1 && expectedOffsets[expectedOffsets.Length - 1].Index == expectedOffsets[expectedOffsets.Length - 2].Index ? 2 : 1;
Assert.Equal(ids.Take(ids.Length - expectedTokensToExclude), codeGenTokenizer.EncodeToIds(text, ids.Length - 1, out normalizedString, out length));
Assert.Null(normalizedString);
var offsets = codeGenTokenizer.AddPrefixSpace ? expectedOffsetsWithSpace : expectedOffsets;
int expectedLength = offsets.Length > expectedTokensToExclude ? offsets[offsets.Length - expectedTokensToExclude - 1].Index + offsets[offsets.Length - expectedTokensToExclude - 1].Length : 0;
Assert.Equal(expectedLength, length);
Assert.Equal(ids.Take(ids.Length - expectedTokensToExclude), codeGenTokenizer.EncodeToIds(text.AsSpan(), ids.Length - 1, out normalizedString, out length));
Assert.Null(normalizedString);
Assert.Equal(expectedLength, length);
Assert.Equal(expectedIds.Take(expectedIds.Length - expectedTokensToExclude), codeGenTokenizer.EncodeToIds(text, expectedIds.Length - 1, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out length));
Assert.Null(normalizedString);
Assert.Equal(expectedLength, length);
Assert.Equal(expectedIds.Take(expectedIds.Length - expectedTokensToExclude), codeGenTokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length - 1, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out length));
Assert.Null(normalizedString);
Assert.Equal(expectedLength, length);
Assert.Equal(expectedIdsWithSpace.Take(expectedIdsWithSpace.Length - expectedTokensToExclude), codeGenTokenizer.EncodeToIds(text, expectedIdsWithSpace.Length - 1, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out length));
Assert.Null(normalizedString);
Assert.Equal(expectedLength, length);
Assert.Equal(expectedIdsWithSpace.Take(expectedIdsWithSpace.Length - expectedTokensToExclude), codeGenTokenizer.EncodeToIds(text.AsSpan(), expectedIdsWithSpace.Length - 1, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out length));
Assert.Null(normalizedString);
Assert.Equal(expectedLength, length);
//
// CountTokens
//
Assert.Equal(ids.Length, codeGenTokenizer.CountTokens(text));
Assert.Equal(ids.Length, codeGenTokenizer.CountTokens(text.AsSpan()));
Assert.Equal(expectedIds.Length, codeGenTokenizer.CountTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false));
Assert.Equal(expectedIds.Length, codeGenTokenizer.CountTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false));
Assert.Equal(expectedIdsWithSpace.Length, codeGenTokenizer.CountTokens(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false));
Assert.Equal(expectedIdsWithSpace.Length, codeGenTokenizer.CountTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false));
//
// IndexOf
//
offsets = codeGenTokenizer.AddPrefixSpace ? expectedOffsetsWithSpace : expectedOffsets;
Assert.Equal(offsets[offsets.Length - 1].Index + offsets[offsets.Length - 1].Length, codeGenTokenizer.IndexOfTokenCount(text, ids.Length, out normalizedString, out int tokenCount));
Assert.Null(normalizedString);
Assert.Equal(ids.Length, tokenCount);
Assert.Equal(offsets[offsets.Length - 1].Index + offsets[offsets.Length - 1].Length, codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), ids.Length, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(ids.Length, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 1].Index + expectedOffsets[expectedOffsets.Length - 1].Length, codeGenTokenizer.IndexOfTokenCount(text, expectedIds.Length, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIds.Length, tokenCount);
Assert.Equal(expectedOffsets[expectedOffsets.Length - 1].Index + expectedOffsets[expectedOffsets.Length - 1].Length, codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), expectedIds.Length, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIds.Length, tokenCount);
Assert.Equal(expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Index + expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Length, codeGenTokenizer.IndexOfTokenCount(text, expectedIdsWithSpace.Length, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIdsWithSpace.Length, tokenCount);
Assert.Equal(expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Index + expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Length, codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), expectedIdsWithSpace.Length, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedIdsWithSpace.Length, tokenCount);
//
// LastIndexOf
//
int expectedIndex = offsets.Length > 1 && offsets[offsets.Length - 1].Index == offsets[offsets.Length - 2].Index ? text.Length : offsets[offsets.Length - 1].Index;
int expectedTokenCount = expectedIndex == text.Length ? 0 : 1;
Assert.Equal(expectedIndex, codeGenTokenizer.LastIndexOfTokenCount(text, 1, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedTokenCount, tokenCount);
Assert.Equal(expectedIndex, codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), 1, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedTokenCount, tokenCount);
Assert.Equal(expectedIndex, codeGenTokenizer.LastIndexOfTokenCount(text, 1, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedTokenCount, tokenCount);
Assert.Equal(expectedIndex, codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), 1, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedTokenCount, tokenCount);
expectedIndex = offsets.Length > 1 && expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Index == expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 2].Index ? text.Length : expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Index;
expectedTokenCount = expectedIndex == text.Length ? 0 : 1;
Assert.Equal(expectedIndex, codeGenTokenizer.LastIndexOfTokenCount(text, 1, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedTokenCount, tokenCount);
Assert.Equal(expectedIndex, codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), 1, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
Assert.Null(normalizedString);
Assert.Equal(expectedTokenCount, tokenCount);
//
// Id to Token and Token to Id mapping
//
var tokens = codeGenTokenizer.AddPrefixSpace ? expectedTokensWithSpace : expectedTokens;
for (int i = 0; i < tokens.Length; i++)
{
Assert.Equal(tokens[i], codeGenTokenizer.MapIdToToken(ids[i]));
Assert.Equal(ids[i], codeGenTokenizer.MapTokenToId(tokens[i]));
}
}
[Theory]
[MemberData(nameof(CodeGenTestData))]
public void TestBegginingAndEndOfSentenceEncoding(
string text,
string[] expectedTokens,
(int Index, int Length)[] expectedOffsets,
int[] expectedIds,
string[] expectedTokensWithSpace,
(int Index, int Length)[] expectedOffsetsWithSpace,
int[] expectedIdsWithSpace)
{
Assert.NotNull(expectedOffsets);
Assert.NotNull(expectedOffsetsWithSpace);
//
// Beginning of Sentence
//
CodeGen codeGenTokenizer = (_codegen350MMonoTokenizerWithBeginningOfSentence as CodeGen)!;
IReadOnlyList<Token> encoding = codeGenTokenizer.Encode(text, out _);
Assert.True(codeGenTokenizer.BeginningOfSentenceToken is not null);
Assert.True(codeGenTokenizer.BeginningOfSentenceId.HasValue);
var idList = new List<int>(expectedIds);
idList.Insert(0, codeGenTokenizer.BeginningOfSentenceId!.Value);
var tokensList = new List<string>(expectedTokens);
tokensList.Insert(0, codeGenTokenizer.BeginningOfSentenceToken!);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
idList = new List<int>(expectedIdsWithSpace);
idList.Insert(0, codeGenTokenizer.BeginningOfSentenceId!.Value);
tokensList = new List<string>(expectedTokensWithSpace);
tokensList.Insert(0, codeGenTokenizer.BeginningOfSentenceToken!);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: false, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: false, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
IReadOnlyList<int> ids = codeGenTokenizer.EncodeToIds(text);
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan());
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
ids = codeGenTokenizer.EncodeToIds(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false);
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false);
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
ids = codeGenTokenizer.EncodeToIds(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
Assert.NotEqual(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
Assert.NotEqual(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
ids = codeGenTokenizer.EncodeToIds(text, maxTokenCount: 5, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out string? normalizedString, out int textLength);
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), maxTokenCount: 5, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out textLength);
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
int tokenCount = codeGenTokenizer.CountTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
int count = codeGenTokenizer.CountTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
Assert.Equal(tokenCount, count);
count = codeGenTokenizer.CountTokens(text);
Assert.Equal(tokenCount + 1, count);
count = codeGenTokenizer.CountTokens(text.AsSpan());
Assert.Equal(tokenCount + 1, count);
count = codeGenTokenizer.CountTokens(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false);
Assert.Equal(tokenCount + 1, count);
count = codeGenTokenizer.CountTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false);
Assert.Equal(tokenCount + 1, count);
int length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(text.Length, length);
int index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(0, index);
//
// End of Sentence
//
codeGenTokenizer = (_codegen350MMonoTokenizerWithEndOfSentence as CodeGen)!;
encoding = codeGenTokenizer.Encode(text, out _);
Assert.True(codeGenTokenizer.EndOfSentenceToken is not null);
Assert.True(codeGenTokenizer.EndOfSentenceId.HasValue);
idList = new List<int>(expectedIds);
idList.Add(codeGenTokenizer.EndOfSentenceId!.Value);
tokensList = new List<string>(expectedTokens);
tokensList.Add(codeGenTokenizer.EndOfSentenceToken!);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
idList = new List<int>(expectedIdsWithSpace);
idList.Add(codeGenTokenizer.EndOfSentenceId!.Value);
tokensList = new List<string>(expectedTokensWithSpace);
tokensList.Add(codeGenTokenizer.EndOfSentenceToken!);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
ids = codeGenTokenizer.EncodeToIds(text);
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan());
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
ids = codeGenTokenizer.EncodeToIds(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true);
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true);
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
ids = codeGenTokenizer.EncodeToIds(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
Assert.NotEqual(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
Assert.NotEqual(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
ids = codeGenTokenizer.EncodeToIds(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out textLength);
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out textLength);
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
tokenCount = codeGenTokenizer.CountTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
count = codeGenTokenizer.CountTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
Assert.Equal(tokenCount, count);
count = codeGenTokenizer.CountTokens(text);
Assert.Equal(tokenCount + 1, count);
count = codeGenTokenizer.CountTokens(text.AsSpan());
Assert.Equal(tokenCount + 1, count);
count = codeGenTokenizer.CountTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true);
Assert.Equal(tokenCount + 1, count);
count = codeGenTokenizer.CountTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true);
Assert.Equal(tokenCount + 1, count);
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(text.Length, length);
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out count);
Assert.Equal(tokenCount + 1, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(0, index);
//
// Beginning & End of Sentence
//
codeGenTokenizer = (_codegen350MMonoTokenizerWithBeginningAndEndOfSentence as CodeGen)!;
encoding = codeGenTokenizer.Encode(text, out _);
Assert.True(codeGenTokenizer.BeginningOfSentenceToken is not null);
Assert.True(codeGenTokenizer.BeginningOfSentenceId.HasValue);
idList = new List<int>(expectedIds);
idList.Insert(0, codeGenTokenizer.BeginningOfSentenceId!.Value);
idList.Add(codeGenTokenizer.EndOfSentenceId!.Value);
tokensList = new List<string>(expectedTokens);
tokensList.Insert(0, codeGenTokenizer.BeginningOfSentenceToken!);
tokensList.Add(codeGenTokenizer.EndOfSentenceToken!);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
idList = new List<int>(expectedIdsWithSpace);
idList.Insert(0, codeGenTokenizer.BeginningOfSentenceId!.Value);
idList.Add(codeGenTokenizer.EndOfSentenceId!.Value);
tokensList = new List<string>(expectedTokensWithSpace);
tokensList.Insert(0, codeGenTokenizer.BeginningOfSentenceToken!);
tokensList.Add(codeGenTokenizer.EndOfSentenceToken!);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
Assert.Equal((0, 0), encoding[0].Offset);
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
ids = codeGenTokenizer.EncodeToIds(text);
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan());
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
ids = codeGenTokenizer.EncodeToIds(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true);
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true);
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
ids = codeGenTokenizer.EncodeToIds(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
Assert.NotEqual(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
Assert.NotEqual(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
Assert.NotEqual(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
Assert.NotEqual(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
ids = codeGenTokenizer.EncodeToIds(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out textLength);
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out textLength);
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
tokenCount = codeGenTokenizer.CountTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
count = codeGenTokenizer.CountTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
Assert.Equal(tokenCount, count);
count = codeGenTokenizer.CountTokens(text);
Assert.Equal(tokenCount + 2, count);
count = codeGenTokenizer.CountTokens(text.AsSpan());
Assert.Equal(tokenCount + 2, count);
count = codeGenTokenizer.CountTokens(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true);
Assert.Equal(tokenCount + 2, count);
count = codeGenTokenizer.CountTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true);
Assert.Equal(tokenCount + 2, count);
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 2, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 2, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out count);
Assert.Equal(tokenCount + 2, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out count);
Assert.Equal(tokenCount + 2, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(text.Length, length);
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(text.Length, length);
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 2, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
Assert.Equal(tokenCount + 2, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out count);
Assert.Equal(tokenCount + 2, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out count);
Assert.Equal(tokenCount + 2, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(0, index);
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
Assert.Equal(tokenCount, count);
Assert.Equal(0, index);
}
private const string DefaultSpecialToken = "<|endoftext|>";
[Fact]
public void TestDefaultValues()
{
CodeGen codeGenTokenizer = (_codegen350MMonoTokenizer as CodeGen)!;
Assert.False(codeGenTokenizer.AddPrefixSpace);
Assert.False(codeGenTokenizer.AddBeginningOfSentence);
Assert.False(codeGenTokenizer.AddEndOfSentence);
Assert.Equal(codeGenTokenizer.MapTokenToId(DefaultSpecialToken), codeGenTokenizer.BeginningOfSentenceId!.Value);
Assert.Equal(codeGenTokenizer.MapTokenToId(DefaultSpecialToken), codeGenTokenizer.EndOfSentenceId!.Value);
Assert.Equal(codeGenTokenizer.MapTokenToId(DefaultSpecialToken), codeGenTokenizer.UnknownTokenId!.Value);
Assert.Equal(DefaultSpecialToken, codeGenTokenizer.BeginningOfSentenceToken);
Assert.Equal(DefaultSpecialToken, codeGenTokenizer.EndOfSentenceToken);
Assert.Equal(DefaultSpecialToken, codeGenTokenizer.UnknownToken);
}
[Theory]
[InlineData(1, 0, 0, 0, 3)]
[InlineData(2, 2, 1, 2, 1)]
[InlineData(3, 2, 1, 2, 1)]
[InlineData(4, 4, 3, 4, 0)]
[InlineData(5, 4, 3, 4, 0)]
public void TestTokenLimits(int maxTokenCount, int expectedTokenCount, int expectedTextLength, int expectedTokenCountFromEnd, int expectedTextIndexFromEnd)
{
// cannot split between the first two tokens nor last two tokens
string input = "δ😀";
int[] encodingIds = [138, 112, 47249, 222];
(int Index, int Length)[] offsets = [(0, 0), (0, 1), (1, 0), (1, 2)];
int calculatedLengthUsingOffsets = expectedTokenCount > 0 ? offsets[expectedTokenCount - 1].Index + offsets[expectedTokenCount - 1].Length : 0;
IReadOnlyList<int> ids = _codegen350MMonoTokenizer.EncodeToIds(input, maxTokenCount, out _, out int textLength);
Assert.Equal(expectedTokenCount, ids.Count);
Assert.Equal(expectedTextLength, textLength);
Assert.Equal(encodingIds.Take(expectedTokenCount), ids);
Assert.Equal(calculatedLengthUsingOffsets, textLength);
ids = _codegen350MMonoTokenizer.EncodeToIds(input.AsSpan(), maxTokenCount, out _, out textLength);
Assert.Equal(expectedTokenCount, ids.Count);
Assert.Equal(expectedTextLength, textLength);
Assert.Equal(encodingIds.Take(expectedTokenCount), ids);
Assert.Equal(calculatedLengthUsingOffsets, textLength);
textLength = _codegen350MMonoTokenizer.IndexOfTokenCount(input, maxTokenCount, out _, out int tokenCount);
Assert.Equal(expectedTokenCount, tokenCount);
Assert.Equal(expectedTextLength, textLength);
Assert.Equal(calculatedLengthUsingOffsets, textLength);
textLength = _codegen350MMonoTokenizer.IndexOfTokenCount(input.AsSpan(), maxTokenCount, out _, out tokenCount);
Assert.Equal(expectedTokenCount, tokenCount);
Assert.Equal(expectedTextLength, textLength);
Assert.Equal(calculatedLengthUsingOffsets, textLength);
calculatedLengthUsingOffsets = expectedTokenCountFromEnd > 0 ? offsets[offsets.Length - expectedTokenCountFromEnd].Index : input.Length;
textLength = _codegen350MMonoTokenizer.LastIndexOfTokenCount(input, maxTokenCount, out _, out tokenCount);
Assert.Equal(expectedTokenCountFromEnd, tokenCount);
Assert.Equal(expectedTextIndexFromEnd, textLength);
Assert.Equal(calculatedLengthUsingOffsets, textLength);
textLength = _codegen350MMonoTokenizer.LastIndexOfTokenCount(input.AsSpan(), maxTokenCount, out _, out tokenCount);
Assert.Equal(expectedTokenCountFromEnd, tokenCount);
Assert.Equal(expectedTextIndexFromEnd, textLength);
Assert.Equal(calculatedLengthUsingOffsets, textLength);
}
}
}

Просмотреть файл

@ -6,20 +6,13 @@ using System;
using System.IO;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using Xunit;
using System.Diagnostics;
using System.Threading.Tasks;
namespace Microsoft.ML.Tokenizers.Tests
{
public class EnglishRobertaTests
{
private static readonly string _vocabUrl = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json";
private static readonly string _mergeUrl = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe";
private static readonly string _dictUrl = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt";
public static IEnumerable<object[]> BertaData
{
get
@ -83,82 +76,62 @@ namespace Microsoft.ML.Tokenizers.Tests
}
private static Tokenizer? _robertaTokenizer = null;
private async static Task<Tokenizer> GetRobertaTokenizer()
private static Tokenizer GetRobertaTokenizer()
{
if (_robertaTokenizer is null)
{
string vocabFile = Utils.CreateTemporaryFile("json");
string mergeFile = Utils.CreateTemporaryFile("txt");
string translationFile = Utils.CreateTemporaryFile("txt");
// encoder.json is same as vocab.json
// vocab.bpe is same as merges.txt
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json";
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe";
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt";
try
{
await Utils.DownloadFile(_vocabUrl, vocabFile);
await Utils.DownloadFile(_mergeUrl, mergeFile);
await Utils.DownloadFile(_dictUrl, translationFile);
_robertaTokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
}
finally
{
Utils.DeleteFile(vocabFile);
Utils.DeleteFile(mergeFile);
Utils.DeleteFile(translationFile);
}
_robertaTokenizer = new EnglishRoberta(
Path.Combine(@"Gpt-2", "vocab.json"),
Path.Combine(@"Gpt-2", "merges.txt"),
Path.Combine(@"Gpt-2", "dict.txt"),
RobertaPreTokenizer.Instance);
}
return _robertaTokenizer;
}
[Fact]
public async void TokenizationTest()
public void TokenizationTest()
{
string vocabFile = Utils.CreateTemporaryFile("json");
string mergeFile = Utils.CreateTemporaryFile("txt");
string translationFile = Utils.CreateTemporaryFile("txt");
string[]? paths = null; ;
// encoder.json is same as vocab.json
// vocab.bpe is same as merges.txt
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json";
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe";
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt";
try
string vocabFile = Path.Combine(@"Gpt-2", "vocab.json");
string mergeFile = Path.Combine(@"Gpt-2", "merges.txt");
string translationFile = Path.Combine(@"Gpt-2", "dict.txt");
Tokenizer tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer);
TokenizerTests.TestTokenLimits(tokenizer);
tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance, filterUnsupportedChars: false);
TestTokenizer(tokenizer);
using Stream vocabStream = File.OpenRead(vocabFile);
using Stream mergeStream = File.OpenRead(mergeFile);
using Stream translationStream = File.OpenRead(translationFile);
tokenizer = new EnglishRoberta(vocabStream, mergeStream, translationStream, RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer);
// Ensure caching works regardless of which method is called first.
for (CallingOrder order = CallingOrder.Encode; order <= CallingOrder.CountTokens; order++)
{
await Utils.DownloadFile(_vocabUrl, vocabFile);
await Utils.DownloadFile(_mergeUrl, mergeFile);
await Utils.DownloadFile(_dictUrl, translationFile);
Tokenizer tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer);
TokenizerTests.TestTokenLimits(tokenizer);
tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer, order);
tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance, filterUnsupportedChars: false);
TestTokenizer(tokenizer);
using Stream vocabStream = File.OpenRead(vocabFile);
using Stream mergeStream = File.OpenRead(mergeFile);
using Stream translationStream = File.OpenRead(translationFile);
tokenizer = new EnglishRoberta(vocabStream, mergeStream, translationStream, RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer);
// Ensure caching works regardless of which method is called first.
for (CallingOrder order = CallingOrder.Encode; order <= CallingOrder.CountTokens; order++)
{
tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer, order);
tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance, filterUnsupportedChars: false);
TestTokenizer(tokenizer, order);
}
}
finally
{
Utils.DeleteFile(vocabFile);
Utils.DeleteFile(mergeFile);
Utils.DeleteFile(translationFile);
if (paths is not null)
{
Utils.DeleteFile(paths[0]);
Utils.DeleteFile(paths[1]);
Utils.DeleteFile(paths[2]);
}
TestTokenizer(tokenizer, order);
}
}
@ -200,9 +173,9 @@ namespace Microsoft.ML.Tokenizers.Tests
[Theory]
[MemberData(nameof(RobertaTestData))]
public async void TestTokenizerEncoding(string text, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds)
public void TestTokenizerEncoding(string text, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds)
{
Tokenizer tokenizer = await GetRobertaTokenizer();
Tokenizer tokenizer = GetRobertaTokenizer();
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
IReadOnlyList<Token> encoding1 = tokenizer.Encode(text.AsSpan(), out _);

Просмотреть файл

@ -238,8 +238,8 @@ namespace Microsoft.ML.Tokenizers.Tests
IReadOnlyList<Token> result = GPT4.Encode(text, out string? normalizedString);
Assert.Equal(encoded, result.Select(token => token.Id).ToArray());
Assert.Equal(encoded.Count, idsCount);
Assert.Equal(new string[] { "<|im_start|>", "Hello", " ⭐", "", " World", "<|im_end|>" }, result.Select(token => token.Value).ToArray());
Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 2), (19, 0), (19, 6), (25, 10) }, result.Select(token => token.Offset).ToArray());
Assert.Equal(new string[] { "<|im_start|>", "Hello", " ⭐", "", " World", "<|im_end|>" }, result.Select(token => token.Value).ToArray());
Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 2), (18, 1), (19, 6), (25, 10) }, result.Select(token => token.Offset).ToArray());
}
[Fact]
@ -457,8 +457,8 @@ namespace Microsoft.ML.Tokenizers.Tests
yield return new object?[]
{
"Hello, y'all! How are you 😁 ?",
new string[] { "Hello", ",", " y", "'all", "!", " How", " are", " you", " 😁", "", " ?" },
new (int Index, int Length)[] { (0, 5), (5, 1), (6, 2), (8, 4), (12, 1), (13, 4), (17, 4), (21, 4), (25, 3), (28, 0), (28, 2) },
new string[] { "Hello", ",", " y", "'all", "!", " How", " are", " you", " 😁", "😁", " ?" },
new (int Index, int Length)[] { (0, 5), (5, 1), (6, 2), (8, 4), (12, 1), (13, 4), (17, 4), (21, 4), (25, 3), (26, 2), (28, 2) },
new int[] { 9906, 11, 379, 65948, 0, 2650, 527, 499, 27623, 223, 949 }
};
}
@ -533,6 +533,106 @@ namespace Microsoft.ML.Tokenizers.Tests
}
return sb.ToString();
}
public static IEnumerable<object?[]> TokenizerLimitsTestData
{
get
{
// string to tokenize, produced tokens, the token offsets, the token ids
yield return new object?[]
{
"Hello ⭐ World",
new string[] { "Hello", " ⭐", "⭐", " World" },
new (int Index, int Length)[] { (0, 5), (5, 2), (6, 1), (7, 6) },
new int[] { 9906, 2928, 99834, 4435 }
};
yield return new object?[]
{
"⭐", // encoded to multiple tokens
new string[] { "⭐", "⭐" },
new (int Index, int Length)[] { (0, 1), (0, 1) },
new int[] { 158, 99834 }
};
yield return new object?[]
{
"Hi 😀", // Surrogates
new string[] { "Hi", " 😀" },
new (int Index, int Length)[] { (0, 2), (2, 3) },
new int[] { 13347, 91416 }
};
yield return new object?[]
{
"⭐😀", // character encoded to multiple tokens and surrogates
new string[] { "⭐", "⭐", "😀", "😀" },
new (int Index, int Length)[] { (0, 1), (0, 1), (1, 2), (1, 2) },
new int[] { 158, 99834, 76460, 222 }
};
yield return new object?[]
{
"From: Adele Vance\nSubject: TestSubject\nTestBodyContent",
new string[] { "From", ":", " Ade", "le", " Vance", "\n", "Subject", ":", " Test", "Subject", "\n", "Test", "Body", "Content" },
new (int Index, int Length)[] { (0, 4), (4, 1), (5, 4), (9, 2), (11, 6), (17, 1), (18, 7), (25, 1), (26, 5), (31, 7), (38, 1), (39, 4), (43, 4), (47, 7)},
new int[] { 3915, 25, 63140, 273, 92368, 198, 13317, 25, 3475, 13317, 198, 2323, 5561, 2831 }
};
}
}
[Theory]
[MemberData(nameof(TokenizerLimitsTestData))]
public void TestPreciseTokenLimits(string text, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds)
{
IReadOnlyList<Token> result = GPT4.Encode(text, out _);
int[] ids = result.Select(r => r.Id).ToArray();
(int Index, int Length)[] offsets = result.Select(r => r.Offset).ToArray();
Assert.Equal(expectedTokens, result.Select(r => r.Value));
Assert.Equal(expectedIds, ids);
Assert.Equal(expectedOffsets, offsets);
Assert.Equal(expectedIds, GPT4.EncodeToIds(text));
Assert.Equal(expectedIds.Length, GPT4.CountTokens(text));
for (int tokenCount = 1; tokenCount <= ids.Length; tokenCount++)
{
int length = GPT4.IndexOfTokenCount(text, tokenCount, out _, out int count);
Assert.True(count <= ids.Length);
if (count < tokenCount)
{
Assert.True(count < ids.Length - 1);
Assert.True(offsets[count + 1].Index < offsets[count].Index + offsets[count].Length);
}
if (count > 0)
{
Assert.Equal(offsets[count - 1].Index + offsets[count - 1].Length, length);
}
else
{
Assert.Equal(0, length);
}
int index = GPT4.LastIndexOfTokenCount(text, tokenCount, out _, out count);
Assert.True(count <= ids.Length);
if (count < tokenCount)
{
Assert.True(ids.Length - tokenCount > 0);
Assert.True(offsets[offsets.Length - tokenCount].Index < offsets[offsets.Length - tokenCount - 1].Index + offsets[offsets.Length - tokenCount - 1].Length);
}
if (count > 0)
{
Assert.Equal(offsets[offsets.Length - count].Index, index);
}
else
{
Assert.Equal(text.Length, index);
}
}
}
}
}