Introducing CodeGen Tokenizer (#7139)
* Introducing CodeGen Tokenizer * Mark a method as private. It was not intended to be public * Init vocab atomically. * Prevent returning tokens that are only partially mapped to a code point. * Ensure Tiktoken precise token's count with IndexOf & LastIndexOf. Ensure accurate offsets too. * Address the feedback
This commit is contained in:
Родитель
72cfdf611a
Коммит
e9097ce6d6
|
@ -133,6 +133,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.!
|
||||
|
||||
License notice for CodeGen Tokenizer
|
||||
--------------------------------------------
|
||||
|
||||
https://github.com/huggingface/transformers/blob/8c12690cecbb97e187861e386f7a0ac790e4236c/src/transformers/models/codegen/tokenization_codegen.py#L2
|
||||
|
||||
Copyright 2022 The Salesforce authors, The Open AI Team Authors and The HuggingFace Inc. team.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
License notice for BitUtility
|
||||
------------------------------------------
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@
|
|||
<ApacheArrowVersion>11.0.0</ApacheArrowVersion>
|
||||
<GoogleProtobufVersion>3.19.6</GoogleProtobufVersion>
|
||||
<LightGBMVersion>3.3.5</LightGBMVersion>
|
||||
<MicrosoftBclHashCodeVersion>1.1.1</MicrosoftBclHashCodeVersion>
|
||||
<MicrosoftCodeAnalysisAnalyzersVersion>3.3.0</MicrosoftCodeAnalysisAnalyzersVersion>
|
||||
<MicrosoftCodeAnalysisCSharpVersion>3.9.0</MicrosoftCodeAnalysisCSharpVersion>
|
||||
<MicrosoftDotNetInteractiveVersion>1.0.0-beta.23509.3</MicrosoftDotNetInteractiveVersion>
|
||||
|
@ -87,7 +88,7 @@
|
|||
<MicrosoftMLTensorFlowTestModelsVersion>0.0.13-test</MicrosoftMLTensorFlowTestModelsVersion>
|
||||
<MicrosoftMLTestDatabasesVersion>0.0.6-test</MicrosoftMLTestDatabasesVersion>
|
||||
<MicrosoftMLTestModelsVersion>0.0.7-test</MicrosoftMLTestModelsVersion>
|
||||
<MicrosoftMLTestTokenizersVersion>2.0.0-beta.24218.2</MicrosoftMLTestTokenizersVersion>
|
||||
<MicrosoftMLTestTokenizersVersion>2.0.0-beta.24219.1</MicrosoftMLTestTokenizersVersion>
|
||||
<SystemDataSqlClientVersion>4.8.6</SystemDataSqlClientVersion>
|
||||
<SystemDataSQLiteCoreVersion>1.0.118</SystemDataSQLiteCoreVersion>
|
||||
<XunitCombinatorialVersion>1.2.7</XunitCombinatorialVersion>
|
||||
|
|
|
@ -21,6 +21,10 @@
|
|||
<PackageReference Include="System.Text.Json" Version="$(SystemTextJsonVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
|
||||
<PackageReference Include="Microsoft.Bcl.HashCode" Version="$(MicrosoftBclHashCodeVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
<UsingTask TaskName="CompressFile"
|
||||
TaskFactory="RoslynCodeTaskFactory"
|
||||
AssemblyFile="$(MSBuildToolsPath)\Microsoft.Build.Tasks.Core.dll" >
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -10,6 +10,7 @@ using System.IO;
|
|||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Threading;
|
||||
|
||||
namespace Microsoft.ML.Tokenizers
|
||||
{
|
||||
|
@ -23,9 +24,6 @@ namespace Microsoft.ML.Tokenizers
|
|||
private Dictionary<string, int>? _vocabOriginal;
|
||||
private readonly SortedDictionary<int, StringSpanOrdinalKey> _vocabReverse;
|
||||
private readonly Cache<(string, string), int> _mergeRanks;
|
||||
private readonly IReadOnlyDictionary<char, char> _byteToUnicode;
|
||||
private readonly IReadOnlyDictionary<char, char> _unicodeToByte;
|
||||
private readonly string[] _charToString;
|
||||
private readonly StringSpanOrdinalKeyCache<List<Token>> _cache;
|
||||
private readonly PreTokenizer? _preTokenizer;
|
||||
private readonly Normalizer? _normalizer;
|
||||
|
@ -95,14 +93,6 @@ namespace Microsoft.ML.Tokenizers
|
|||
_vocab = GetVocabulary(vocabularyStream);
|
||||
_vocabReverse = _vocab.ReverseSorted();
|
||||
_mergeRanks = GetMergeRanks(mergeStream);
|
||||
int maxCharValue = GetByteToUnicode(out _byteToUnicode);
|
||||
_charToString = new string[maxCharValue];
|
||||
for (char c = (char)0; c < (char)maxCharValue; c++)
|
||||
{
|
||||
_charToString[c] = c.ToString();
|
||||
}
|
||||
|
||||
_unicodeToByte = _byteToUnicode.Reverse();
|
||||
_cache = new StringSpanOrdinalKeyCache<List<Token>>();
|
||||
|
||||
if (disposeStream)
|
||||
|
@ -113,6 +103,80 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
}
|
||||
|
||||
private static Dictionary<StringSpanOrdinalKey, int> GetVocabulary(Stream vocabularyStream)
|
||||
{
|
||||
Dictionary<StringSpanOrdinalKey, int>? vocab;
|
||||
try
|
||||
{
|
||||
JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } };
|
||||
vocab = JsonSerializer.Deserialize<Dictionary<StringSpanOrdinalKey, int>>(vocabularyStream, options) as Dictionary<StringSpanOrdinalKey, int>;
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
throw new ArgumentException($"Problems met when parsing JSON vocabulary object.{Environment.NewLine}Error message: {e.Message}");
|
||||
}
|
||||
|
||||
if (vocab is null)
|
||||
{
|
||||
throw new ArgumentException($"Failed to read the vocabulary file.");
|
||||
}
|
||||
|
||||
return vocab;
|
||||
}
|
||||
|
||||
private static Cache<(string, string), int> GetMergeRanks(Stream mergeStream)
|
||||
{
|
||||
var mergeRanks = new Cache<(string, string), int>(60_000);
|
||||
try
|
||||
{
|
||||
using StreamReader reader = new StreamReader(mergeStream);
|
||||
|
||||
// We ignore the first and last line in the file
|
||||
if (reader.Peek() >= 0)
|
||||
{
|
||||
string ignored = reader.ReadLine()!;
|
||||
}
|
||||
|
||||
int rank = 1;
|
||||
while (reader.Peek() >= 0)
|
||||
{
|
||||
string line = reader.ReadLine()!;
|
||||
int index = line.IndexOf(' ');
|
||||
if (index < 1 || index == line.Length - 1 || line.IndexOf(' ', index + 1) != -1)
|
||||
{
|
||||
throw new FormatException($"Invalid format of merge file at line: \"{line}\"");
|
||||
}
|
||||
|
||||
mergeRanks.Set((line.Substring(0, index), line.Substring(index + 1)), rank++);
|
||||
}
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
// Report any issues encountered while consuming a data file as IOExceptions.
|
||||
throw new IOException($"Cannot read the file Merge file.{Environment.NewLine}Error message: {e.Message}", e);
|
||||
}
|
||||
|
||||
return mergeRanks;
|
||||
}
|
||||
|
||||
private Dictionary<string, int> GetVocab()
|
||||
{
|
||||
Dictionary<string, int>? publicVocab = Volatile.Read(ref _vocabOriginal);
|
||||
if (publicVocab is null)
|
||||
{
|
||||
var vocab = new Dictionary<string, int>();
|
||||
foreach (var item in _vocab)
|
||||
{
|
||||
vocab.Add(item.Key.ToString(), item.Value);
|
||||
}
|
||||
|
||||
Interlocked.CompareExchange(ref _vocabOriginal, vocab, null);
|
||||
publicVocab = _vocabOriginal;
|
||||
}
|
||||
|
||||
return publicVocab;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the PreTokenizer used by the Tokenizer.
|
||||
/// </summary>
|
||||
|
@ -126,7 +190,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <summary>
|
||||
/// Gets the dictionary mapping tokens to Ids.
|
||||
/// </summary>
|
||||
public IReadOnlyDictionary<string, int> Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
|
||||
public IReadOnlyDictionary<string, int> Vocab => GetVocab();
|
||||
|
||||
//
|
||||
// Public Model interfaces implementation
|
||||
|
@ -147,9 +211,10 @@ namespace Microsoft.ML.Tokenizers
|
|||
char[] buffer = ArrayPool<char>.Shared.Rent(v.Length);
|
||||
int i = 0;
|
||||
|
||||
IReadOnlyDictionary<char, char> unicodeToByte = ByteToUnicodeEncoding.Instance.UnicodeToByte;
|
||||
for (int j = 0; j < v.Length; j++)
|
||||
{
|
||||
if (_unicodeToByte.TryGetValue(v[j], out var c))
|
||||
if (unicodeToByte.TryGetValue(v[j], out var c))
|
||||
{
|
||||
buffer[i++] = c;
|
||||
}
|
||||
|
@ -232,9 +297,11 @@ namespace Microsoft.ML.Tokenizers
|
|||
int[] indexMapping = ArrayPool<int>.Shared.Rent(text.Length);
|
||||
|
||||
int newTokenIndex = 0;
|
||||
IReadOnlyDictionary<char, char> byteToUnicode = ByteToUnicodeEncoding.Instance.ByteToUnicode;
|
||||
|
||||
for (int i = 0; i < text.Length; i++)
|
||||
{
|
||||
if (_byteToUnicode.TryGetValue(text[i], out var value))
|
||||
if (byteToUnicode.TryGetValue(text[i], out var value))
|
||||
{
|
||||
token[newTokenIndex] = value;
|
||||
indexMapping[newTokenIndex] = i;
|
||||
|
@ -607,9 +674,11 @@ namespace Microsoft.ML.Tokenizers
|
|||
int[] indexMapping = ArrayPool<int>.Shared.Rent(text.Length);
|
||||
|
||||
int newTokenIndex = 0;
|
||||
IReadOnlyDictionary<char, char> byteToUnicode = ByteToUnicodeEncoding.Instance.ByteToUnicode;
|
||||
|
||||
for (int i = 0; i < text.Length; i++)
|
||||
{
|
||||
if (_byteToUnicode.TryGetValue(text[i], out var value))
|
||||
if (byteToUnicode.TryGetValue(text[i], out var value))
|
||||
{
|
||||
token[newTokenIndex] = value;
|
||||
indexMapping[newTokenIndex] = i;
|
||||
|
@ -650,9 +719,11 @@ namespace Microsoft.ML.Tokenizers
|
|||
int[] indexMapping = ArrayPool<int>.Shared.Rent(text.Length);
|
||||
|
||||
int newTokenIndex = 0;
|
||||
IReadOnlyDictionary<char, char> byteToUnicode = ByteToUnicodeEncoding.Instance.ByteToUnicode;
|
||||
|
||||
for (int i = 0; i < text.Length; i++)
|
||||
{
|
||||
if (_byteToUnicode.TryGetValue(text[i], out var value))
|
||||
if (byteToUnicode.TryGetValue(text[i], out var value))
|
||||
{
|
||||
token[newTokenIndex] = value;
|
||||
indexMapping[newTokenIndex] = i;
|
||||
|
@ -829,107 +900,6 @@ namespace Microsoft.ML.Tokenizers
|
|||
private static HighestOccurrenceMapping GetHighestOccurrenceMapping(Stream highestOccurrenceMappingStream) =>
|
||||
HighestOccurrenceMapping.Load(highestOccurrenceMappingStream);
|
||||
|
||||
private Dictionary<StringSpanOrdinalKey, int> GetVocabulary(Stream vocabularyStream)
|
||||
{
|
||||
Dictionary<StringSpanOrdinalKey, int>? vocab;
|
||||
try
|
||||
{
|
||||
JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } };
|
||||
vocab = JsonSerializer.Deserialize<Dictionary<StringSpanOrdinalKey, int>>(vocabularyStream, options) as Dictionary<StringSpanOrdinalKey, int>;
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
throw new ArgumentException($"Problems met when parsing JSON vocabulary object.{Environment.NewLine}Error message: {e.Message}");
|
||||
}
|
||||
|
||||
if (vocab is null)
|
||||
{
|
||||
throw new ArgumentException($"Failed to read the vocabulary file.");
|
||||
}
|
||||
|
||||
if (_vocabIdToHighestOccurrence.BosWord is not null)
|
||||
{
|
||||
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.BosWord)] = -_vocabIdToHighestOccurrence.BosIndex;
|
||||
}
|
||||
|
||||
if (_vocabIdToHighestOccurrence.EosWord is not null)
|
||||
{
|
||||
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.EosWord)] = -_vocabIdToHighestOccurrence.EosIndex;
|
||||
}
|
||||
|
||||
if (_vocabIdToHighestOccurrence.UnkWord is not null)
|
||||
{
|
||||
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.UnkWord)] = -_vocabIdToHighestOccurrence.UnkIndex;
|
||||
}
|
||||
|
||||
if (_vocabIdToHighestOccurrence.PadWord is not null)
|
||||
{
|
||||
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.PadWord)] = -_vocabIdToHighestOccurrence.PadIndex;
|
||||
}
|
||||
|
||||
return vocab;
|
||||
}
|
||||
|
||||
private Cache<(string, string), int> GetMergeRanks(Stream mergeStream)
|
||||
{
|
||||
var mergeRanks = new Cache<(string, string), int>(60_000);
|
||||
try
|
||||
{
|
||||
using StreamReader reader = new StreamReader(mergeStream);
|
||||
|
||||
// We ignore the first and last line in the file
|
||||
if (reader.Peek() >= 0)
|
||||
{
|
||||
string ignored = reader.ReadLine()!;
|
||||
}
|
||||
|
||||
int rank = 1;
|
||||
while (reader.Peek() >= 0)
|
||||
{
|
||||
string line = reader.ReadLine()!;
|
||||
int index = line.IndexOf(' ');
|
||||
if (index < 1 || index == line.Length - 1 || line.IndexOf(' ', index + 1) != -1)
|
||||
{
|
||||
throw new Exception($"Invalid format of merge file: \"{line}\"");
|
||||
}
|
||||
|
||||
mergeRanks.Set((line.Substring(0, index), line.Substring(index + 1)), rank++);
|
||||
}
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
throw new IOException($"Cannot read the file Merge file.{Environment.NewLine}Error message: {e.Message}", e);
|
||||
}
|
||||
|
||||
return mergeRanks;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns list of utf-8 bytes and a corresponding list of unicode chars.
|
||||
/// This mapping is to make unseen characters (such as control characters) displayable.
|
||||
/// </summary>
|
||||
private static int GetByteToUnicode(out IReadOnlyDictionary<char, char> byteToUnicode)
|
||||
{
|
||||
var byteToUnicodeMapping = Enumerable.Range('!', '~' - '!' + 1)
|
||||
.Concat(Enumerable.Range('¡', '¬' - '¡' + 1))
|
||||
.Concat(Enumerable.Range('®', 'ÿ' - '®' + 1))
|
||||
.ToDictionary(b => (char)b, b => (char)b);
|
||||
|
||||
const int numChars = 256;
|
||||
var n = 0;
|
||||
foreach (var b in Enumerable.Range(0, numChars))
|
||||
{
|
||||
if (!byteToUnicodeMapping.ContainsKey((char)b))
|
||||
{
|
||||
byteToUnicodeMapping.Add((char)b, (char)(numChars + n));
|
||||
++n;
|
||||
}
|
||||
}
|
||||
|
||||
byteToUnicode = byteToUnicodeMapping;
|
||||
return numChars + n;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Encode a token into BPE-ed sub-tokens. E.g., "playing" into ["play", "ing"].
|
||||
/// </summary>
|
||||
|
@ -940,17 +910,20 @@ namespace Microsoft.ML.Tokenizers
|
|||
return [];
|
||||
}
|
||||
|
||||
string[] charToString = ByteToUnicodeEncoding.Instance.CharToString;
|
||||
|
||||
if (token.Length == 1)
|
||||
{
|
||||
string tokenValue = _charToString[token[0]];
|
||||
Debug.Assert(token[0] < charToString.Length);
|
||||
string tokenValue = charToString[token[0]];
|
||||
return new List<Token> { new Token(_vocab[new StringSpanOrdinalKey(tokenValue)], tokenValue, (indexMapping[0], 1)) };
|
||||
}
|
||||
|
||||
List<string> word = new(token.Length);
|
||||
foreach (char c in token)
|
||||
{
|
||||
Debug.Assert(c < _charToString.Length);
|
||||
word.Add(_charToString[c]);
|
||||
Debug.Assert(c < charToString.Length);
|
||||
word.Add(charToString[c]);
|
||||
}
|
||||
|
||||
HashSet<(string, string)> pairs = new();
|
||||
|
@ -1065,10 +1038,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// </summary>
|
||||
/// <param name="ch">The character to check.</param>
|
||||
/// <returns>True if the character is supported, otherwise false.</returns>
|
||||
public bool IsSupportedChar(char ch)
|
||||
{
|
||||
return _byteToUnicode.ContainsKey(ch);
|
||||
}
|
||||
public bool IsSupportedChar(char ch) => ByteToUnicodeEncoding.Instance.ByteToUnicode.ContainsKey(ch);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -25,7 +25,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
{
|
||||
private readonly Dictionary<ReadOnlyMemory<byte>, int> _encoder;
|
||||
private readonly Dictionary<int, ReadOnlyMemory<byte>> _decoder;
|
||||
private readonly LruCache<(int[] Bytes, string Token)> _cache;
|
||||
private readonly LruCache<(int Id, int TokenIndex, int TokenLength)[]> _cache;
|
||||
private readonly Dictionary<StringSpanOrdinalKey, (int Id, string Token)> _vocab;
|
||||
private IReadOnlyDictionary<string, int>? _vocabOriginal;
|
||||
private const int MaxWordLengthToCache = 15;
|
||||
|
@ -92,7 +92,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
_preTokenizer = preTokenizer;
|
||||
_normalizer = normalizer;
|
||||
|
||||
_cache = new LruCache<(int[] Bytes, string Token)>(cacheSize);
|
||||
_cache = new LruCache<(int Id, int TokenIndex, int TokenLength)[]>(cacheSize);
|
||||
|
||||
SpecialTokens = specialTokens;
|
||||
CacheSpecialTokensEncoding(specialTokens);
|
||||
|
@ -102,7 +102,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
{
|
||||
try
|
||||
{
|
||||
_cache = new LruCache<(int[] Bytes, string Token)>(cacheSize);
|
||||
_cache = new LruCache<(int Id, int TokenIndex, int TokenLength)[]>(cacheSize);
|
||||
(_encoder, _vocab, _decoder) = LoadTiktokenBpeAsync(vocabStream, useAsync: false).GetAwaiter().GetResult();
|
||||
|
||||
_preTokenizer = preTokenizer;
|
||||
|
@ -140,7 +140,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
foreach (KeyValuePair<string, int> specialToken in specialTokens)
|
||||
{
|
||||
_decoder![specialToken.Value] = Encoding.UTF8.GetBytes(specialToken.Key);
|
||||
_cache!.Add(specialToken.Key, (new[] { specialToken.Value }, specialToken.Key));
|
||||
_cache!.Add(specialToken.Key, new[] { (Id: specialToken.Value, TokenIndex0: 0, TokenLength: specialToken.Key.Length) });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -290,13 +290,14 @@ namespace Microsoft.ML.Tokenizers
|
|||
{
|
||||
Debug.Assert(!text.IsEmpty);
|
||||
|
||||
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
|
||||
if (_cache.TryGetValue(text, out (int Id, int TokenIndex, int TokenLength)[] value))
|
||||
{
|
||||
tokens.Add(new Token(value.Ids[0], value.Token, (offset, value.Token.Length)));
|
||||
for (int i = 1; i < value.Ids.Length; i++)
|
||||
for (int i = 0; i < value.Length; i++)
|
||||
{
|
||||
// One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width.
|
||||
tokens.Add(new Token(value.Ids[i], "", (offset + text.Length, 0)));
|
||||
tokens.Add(new Token(
|
||||
value[i].Id,
|
||||
value[i].TokenLength == 0 ? string.Empty : text.Slice(value[i].TokenIndex, value[i].TokenLength).ToString(),
|
||||
(value[i].TokenIndex + offset, value[i].TokenLength)));
|
||||
}
|
||||
|
||||
return;
|
||||
|
@ -309,25 +310,35 @@ namespace Microsoft.ML.Tokenizers
|
|||
return;
|
||||
}
|
||||
|
||||
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
|
||||
int encodedLength = Helpers.GetUtf8Bytes(text, arrayPoolArray);
|
||||
int utf8Length = Encoding.UTF8.GetMaxByteCount(text.Length);
|
||||
byte[] arrayPoolArray = arrayPoolArray = ArrayPool<byte>.Shared.Rent(utf8Length);
|
||||
int[]? indexMappingArray = null;
|
||||
Span<int> indexMappingSpan = utf8Length + 1 <= 128 ? stackalloc int[128] : (indexMappingArray = ArrayPool<int>.Shared.Rent(utf8Length + 1));
|
||||
int encodedLength = Helpers.EncodeToUtf8(text, arrayPoolArray, indexMappingSpan);
|
||||
Debug.Assert(encodedLength < indexMappingSpan.Length);
|
||||
indexMappingSpan[encodedLength] = text.Length;
|
||||
|
||||
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
|
||||
(int Id, int TokenIndex, int TokenLength)[] encodedTokens = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder, indexMappingSpan.Slice(0, encodedLength + 1));
|
||||
ArrayPool<byte>.Shared.Return(arrayPoolArray);
|
||||
if (indexMappingArray is not null)
|
||||
{
|
||||
ArrayPool<int>.Shared.Return(indexMappingArray);
|
||||
}
|
||||
|
||||
Debug.Assert(encodedIds.Length > 0);
|
||||
Debug.Assert(encodedTokens.Length > 0);
|
||||
string textAsString = text.ToString();
|
||||
|
||||
if (text.Length <= MaxWordLengthToCache)
|
||||
{
|
||||
_cache.Add(textAsString, (encodedIds, textAsString));
|
||||
_cache.Add(textAsString, encodedTokens);
|
||||
}
|
||||
|
||||
tokens.Add(new Token(encodedIds[0], textAsString, (offset, text.Length)));
|
||||
for (int i = 1; i < encodedIds.Length; i++)
|
||||
for (int i = 0; i < encodedTokens.Length; i++)
|
||||
{
|
||||
// One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width.
|
||||
tokens.Add(new Token(encodedIds[i], "", (offset + text.Length, 0)));
|
||||
tokens.Add(new Token(
|
||||
encodedTokens[i].Id,
|
||||
encodedTokens[i].TokenLength == 0 ? string.Empty : text.Slice(encodedTokens[i].TokenIndex, encodedTokens[i].TokenLength).ToString(),
|
||||
(encodedTokens[i].TokenIndex + offset, encodedTokens[i].TokenLength)));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -435,17 +446,9 @@ namespace Microsoft.ML.Tokenizers
|
|||
return 0;
|
||||
}
|
||||
|
||||
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
|
||||
if (_cache.TryGetValue(text, out (int Id, int TokenIndex, int TokenLength)[] value))
|
||||
{
|
||||
if (value.Ids.Length <= maxTokenCount)
|
||||
{
|
||||
accumulatedIds.AddRange(value.Ids);
|
||||
textLength = text.Length;
|
||||
return value.Ids.Length;
|
||||
}
|
||||
|
||||
textLength = 0;
|
||||
return 0;
|
||||
return EncodeToIdsResult(value, accumulatedIds, maxTokenCount, text.Length, out textLength);
|
||||
}
|
||||
|
||||
if (_vocab.TryGetValue(text, out (int Id, string Token) mappedId))
|
||||
|
@ -455,32 +458,85 @@ namespace Microsoft.ML.Tokenizers
|
|||
return 1;
|
||||
}
|
||||
|
||||
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
|
||||
int encodedLength = Helpers.GetUtf8Bytes(text, arrayPoolArray);
|
||||
int utf8Length = Encoding.UTF8.GetMaxByteCount(text.Length);
|
||||
byte[] arrayPoolArray = arrayPoolArray = ArrayPool<byte>.Shared.Rent(utf8Length);
|
||||
int[]? indexMappingArray = null;
|
||||
Span<int> indexMappingSpan = utf8Length + 1 <= 128 ? stackalloc int[128] : (indexMappingArray = ArrayPool<int>.Shared.Rent(utf8Length + 1));
|
||||
int encodedLength = Helpers.EncodeToUtf8(text, arrayPoolArray, indexMappingSpan);
|
||||
Debug.Assert(encodedLength < indexMappingSpan.Length);
|
||||
indexMappingSpan[encodedLength] = text.Length;
|
||||
|
||||
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
|
||||
(int Id, int TokenIndex, int TokenLength)[] encodedTokens = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder, indexMappingSpan.Slice(0, encodedLength + 1));
|
||||
ArrayPool<byte>.Shared.Return(arrayPoolArray);
|
||||
if (indexMappingArray is not null)
|
||||
{
|
||||
ArrayPool<int>.Shared.Return(indexMappingArray);
|
||||
}
|
||||
|
||||
if (text.Length <= MaxWordLengthToCache)
|
||||
{
|
||||
string textAsString = text.ToString();
|
||||
_cache.Add(textAsString, (encodedIds, textAsString));
|
||||
_cache.Add(textAsString, encodedTokens);
|
||||
}
|
||||
|
||||
int result;
|
||||
if (encodedIds.Length <= maxTokenCount)
|
||||
return EncodeToIdsResult(encodedTokens, accumulatedIds, maxTokenCount, text.Length, out textLength);
|
||||
}
|
||||
|
||||
private int EncodeToIdsResult((int Id, int TokenIndex, int TokenLength)[] tokens, IList<int>? accumulatedIds, int maxTokens, int fullTextLength, out int textLength)
|
||||
{
|
||||
textLength = 0;
|
||||
|
||||
if (tokens.Length <= maxTokens)
|
||||
{
|
||||
accumulatedIds.AddRange(encodedIds);
|
||||
textLength = text.Length;
|
||||
result = encodedIds.Length;
|
||||
}
|
||||
else
|
||||
{
|
||||
textLength = 0;
|
||||
result = 0;
|
||||
if (accumulatedIds is not null)
|
||||
{
|
||||
foreach (var t in tokens)
|
||||
{
|
||||
accumulatedIds.Add(t.Id);
|
||||
}
|
||||
}
|
||||
|
||||
textLength = fullTextLength;
|
||||
return tokens.Length;
|
||||
}
|
||||
|
||||
ArrayPool<byte>.Shared.Return(arrayPoolArray);
|
||||
return result;
|
||||
int tokenCount;
|
||||
for (tokenCount = 0; tokenCount < maxTokens; tokenCount++)
|
||||
{
|
||||
int overlapIndex = tokens[tokenCount].TokenIndex + tokens[tokenCount].TokenLength;
|
||||
// maxTokens is less than tokens.Count, so it is safe to index maxTokens.
|
||||
if (tokens[tokenCount + 1].TokenIndex < overlapIndex)
|
||||
{
|
||||
// Ensure we'll not break the text in the middle of a code-point
|
||||
int j = tokenCount + 2;
|
||||
while (j < tokens.Length && tokens[j].TokenIndex < overlapIndex)
|
||||
{
|
||||
j++;
|
||||
}
|
||||
|
||||
if (j <= maxTokens)
|
||||
{
|
||||
// append encountered tokens to the accumulatedIds
|
||||
for (int k = tokenCount; k < j; k++)
|
||||
{
|
||||
accumulatedIds?.Add(tokens[k].Id);
|
||||
}
|
||||
tokenCount = j - 1;
|
||||
textLength = tokens[tokenCount].TokenIndex + tokens[tokenCount].TokenLength;
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
accumulatedIds?.Add(tokens[tokenCount].Id);
|
||||
textLength = tokens[tokenCount].TokenIndex + tokens[tokenCount].TokenLength;
|
||||
}
|
||||
}
|
||||
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -598,16 +654,9 @@ namespace Microsoft.ML.Tokenizers
|
|||
return 0;
|
||||
}
|
||||
|
||||
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
|
||||
if (_cache.TryGetValue(text, out (int Id, int TokenIndex, int TokenLength)[] value))
|
||||
{
|
||||
if (value.Ids.Length <= maxTokens)
|
||||
{
|
||||
textLength = text.Length;
|
||||
return value.Ids.Length;
|
||||
}
|
||||
|
||||
textLength = 0;
|
||||
return 0;
|
||||
return EncodeToIdsResult(value, accumulatedIds: null, maxTokens, text.Length, out textLength);
|
||||
}
|
||||
|
||||
if (_vocab.TryGetValue(text, out _))
|
||||
|
@ -616,30 +665,28 @@ namespace Microsoft.ML.Tokenizers
|
|||
return 1;
|
||||
}
|
||||
|
||||
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
|
||||
int encodedLength = Helpers.GetUtf8Bytes(text, arrayPoolArray);
|
||||
int utf8Length = Encoding.UTF8.GetMaxByteCount(text.Length);
|
||||
byte[] arrayPoolArray = arrayPoolArray = ArrayPool<byte>.Shared.Rent(utf8Length);
|
||||
int[]? indexMappingArray = null;
|
||||
Span<int> indexMappingSpan = utf8Length + 1 <= 128 ? stackalloc int[128] : (indexMappingArray = ArrayPool<int>.Shared.Rent(utf8Length + 1));
|
||||
int encodedLength = Helpers.EncodeToUtf8(text, arrayPoolArray, indexMappingSpan);
|
||||
Debug.Assert(encodedLength < indexMappingSpan.Length);
|
||||
indexMappingSpan[encodedLength] = text.Length;
|
||||
|
||||
(int Id, int TokenIndex, int TokenLength)[] encodedTokens = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder, indexMappingSpan.Slice(0, encodedLength + 1));
|
||||
ArrayPool<byte>.Shared.Return(arrayPoolArray);
|
||||
if (indexMappingArray is not null)
|
||||
{
|
||||
ArrayPool<int>.Shared.Return(indexMappingArray);
|
||||
}
|
||||
|
||||
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
|
||||
if (text.Length <= MaxWordLengthToCache)
|
||||
{
|
||||
string textAsString = text.ToString();
|
||||
_cache.Add(textAsString, (encodedIds, textAsString));
|
||||
_cache.Add(textAsString, encodedTokens);
|
||||
}
|
||||
|
||||
int result;
|
||||
if (encodedIds.Length <= maxTokens)
|
||||
{
|
||||
textLength = text.Length;
|
||||
result = encodedIds.Length;
|
||||
}
|
||||
else
|
||||
{
|
||||
textLength = 0;
|
||||
result = 0;
|
||||
}
|
||||
|
||||
ArrayPool<byte>.Shared.Return(arrayPoolArray);
|
||||
return result;
|
||||
return EncodeToIdsResult(encodedTokens, accumulatedIds: null, maxTokens, text.Length, out textLength);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -729,16 +776,9 @@ namespace Microsoft.ML.Tokenizers
|
|||
return 0;
|
||||
}
|
||||
|
||||
if (_cache.TryGetValue(text, out (int[] Ids, string Token) value))
|
||||
if (_cache.TryGetValue(text, out (int Id, int TokenIndex, int TokenLength)[] value))
|
||||
{
|
||||
if (value.Ids.Length <= maxTokens)
|
||||
{
|
||||
textIndex = 0;
|
||||
return value.Ids.Length;
|
||||
}
|
||||
|
||||
textIndex = text.Length;
|
||||
return 0;
|
||||
return EncodeToIdsFromEndResult(value, accumulatedIds: null, maxTokens, text.Length, out textIndex);
|
||||
}
|
||||
|
||||
if (_vocab.TryGetValue(text, out _))
|
||||
|
@ -747,31 +787,63 @@ namespace Microsoft.ML.Tokenizers
|
|||
return 1;
|
||||
}
|
||||
|
||||
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
|
||||
int encodedLength = Helpers.GetUtf8Bytes(text, arrayPoolArray);
|
||||
int utf8Length = Encoding.UTF8.GetMaxByteCount(text.Length);
|
||||
byte[] arrayPoolArray = arrayPoolArray = ArrayPool<byte>.Shared.Rent(utf8Length);
|
||||
int[]? indexMappingArray = null;
|
||||
Span<int> indexMappingSpan = utf8Length + 1 <= 128 ? stackalloc int[128] : (indexMappingArray = ArrayPool<int>.Shared.Rent(utf8Length + 1));
|
||||
int encodedLength = Helpers.EncodeToUtf8(text, arrayPoolArray, indexMappingSpan);
|
||||
Debug.Assert(encodedLength < indexMappingSpan.Length);
|
||||
indexMappingSpan[encodedLength] = text.Length;
|
||||
|
||||
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
|
||||
(int Id, int TokenIndex, int TokenLength)[] encodedTokens = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder, indexMappingSpan.Slice(0, encodedLength + 1));
|
||||
ArrayPool<byte>.Shared.Return(arrayPoolArray);
|
||||
if (indexMappingArray is not null)
|
||||
{
|
||||
ArrayPool<int>.Shared.Return(indexMappingArray);
|
||||
}
|
||||
|
||||
if (text.Length <= MaxWordLengthToCache)
|
||||
{
|
||||
string textAsString = text.ToString();
|
||||
_cache.Add(textAsString, (encodedIds, textAsString));
|
||||
_cache.Add(textAsString, encodedTokens);
|
||||
}
|
||||
|
||||
int result;
|
||||
if (encodedIds.Length <= maxTokens)
|
||||
return EncodeToIdsFromEndResult(encodedTokens, accumulatedIds: null, maxTokens, text.Length, out textIndex);
|
||||
}
|
||||
|
||||
private int EncodeToIdsFromEndResult((int Id, int TokenIndex, int TokenLength)[] tokens, IList<int>? accumulatedIds, int maxTokens, int fullTextLength, out int textIndex)
|
||||
{
|
||||
textIndex = fullTextLength;
|
||||
|
||||
if (tokens.Length <= maxTokens)
|
||||
{
|
||||
if (accumulatedIds is not null)
|
||||
{
|
||||
foreach (var t in tokens)
|
||||
{
|
||||
accumulatedIds.Add(t.Id);
|
||||
}
|
||||
}
|
||||
|
||||
textIndex = 0;
|
||||
result = encodedIds.Length;
|
||||
}
|
||||
else
|
||||
{
|
||||
textIndex = text.Length;
|
||||
result = 0;
|
||||
return tokens.Length;
|
||||
}
|
||||
|
||||
ArrayPool<byte>.Shared.Return(arrayPoolArray);
|
||||
return result;
|
||||
int index = tokens.Length - maxTokens;
|
||||
|
||||
// avoid breaking the text in the middle of a code-point
|
||||
while (index < tokens.Length && tokens[index].TokenIndex < tokens[index - 1].TokenIndex + tokens[index - 1].TokenLength)
|
||||
{
|
||||
index++;
|
||||
}
|
||||
|
||||
for (int i = index; i < tokens.Length; i++)
|
||||
{
|
||||
accumulatedIds?.Add(tokens[i].Id);
|
||||
}
|
||||
|
||||
textIndex = index >= tokens.Length ? fullTextLength : tokens[index].TokenIndex;
|
||||
return tokens.Length - index;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -786,11 +858,11 @@ namespace Microsoft.ML.Tokenizers
|
|||
return null;
|
||||
}
|
||||
|
||||
if (_cache.TryGetValue(token, out (int[] Ids, string Token) value))
|
||||
if (_cache.TryGetValue(token, out (int Id, int TokenIndex, int TokenLength)[] value))
|
||||
{
|
||||
if (value.Ids.Length == 1)
|
||||
if (value.Length == 1)
|
||||
{
|
||||
return value.Ids[0];
|
||||
return value[0].Id;
|
||||
}
|
||||
|
||||
return null;
|
||||
|
@ -801,30 +873,34 @@ namespace Microsoft.ML.Tokenizers
|
|||
return id.Id;
|
||||
}
|
||||
|
||||
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(token.Length));
|
||||
try
|
||||
int utf8Length = Encoding.UTF8.GetMaxByteCount(token.Length);
|
||||
byte[] arrayPoolArray = arrayPoolArray = ArrayPool<byte>.Shared.Rent(utf8Length);
|
||||
int[]? indexMappingArray = null;
|
||||
Span<int> indexMappingSpan = utf8Length + 1 <= 128 ? stackalloc int[128] : (indexMappingArray = ArrayPool<int>.Shared.Rent(utf8Length + 1));
|
||||
int encodedLength = Helpers.EncodeToUtf8(token, arrayPoolArray, indexMappingSpan);
|
||||
Debug.Assert(encodedLength < indexMappingSpan.Length);
|
||||
indexMappingSpan[encodedLength] = token.Length;
|
||||
|
||||
(int Id, int TokenIndex, int TokenLength)[] encodedTokens = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder, indexMappingSpan.Slice(0, encodedLength + 1));
|
||||
|
||||
ArrayPool<byte>.Shared.Return(arrayPoolArray);
|
||||
if (indexMappingArray is not null)
|
||||
{
|
||||
int encodedLength = Helpers.GetUtf8Bytes(token, arrayPoolArray);
|
||||
|
||||
int[] idsToCache = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
|
||||
|
||||
if (token.Length <= MaxWordLengthToCache)
|
||||
{
|
||||
string tokenAsString = token.ToString();
|
||||
_cache.Add(tokenAsString, (idsToCache, tokenAsString));
|
||||
}
|
||||
|
||||
if (idsToCache.Length == 1)
|
||||
{
|
||||
return idsToCache[0];
|
||||
}
|
||||
|
||||
return null;
|
||||
ArrayPool<int>.Shared.Return(indexMappingArray);
|
||||
}
|
||||
finally
|
||||
|
||||
if (token.Length <= MaxWordLengthToCache)
|
||||
{
|
||||
ArrayPool<byte>.Shared.Return(arrayPoolArray);
|
||||
string tokenAsString = token.ToString();
|
||||
_cache.Add(tokenAsString, encodedTokens);
|
||||
}
|
||||
|
||||
if (encodedTokens.Length == 1)
|
||||
{
|
||||
return encodedTokens[0].Id;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -455,6 +455,74 @@ namespace Microsoft.ML.Tokenizers
|
|||
return new SentencePieceBpe(modelProto, addBeginOfSentence, addEndOfSentence);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a CodeGen tokenizer from the given vocab and merges streams.
|
||||
/// </summary>
|
||||
/// <param name="vocabStream">The stream containing the vocab file.</param>
|
||||
/// <param name="mergesStream">The stream containing the merges file.</param>
|
||||
/// <param name="addPrefixSpace">Indicate whether to add a space before the token.</param>
|
||||
/// <param name="addBeginOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
|
||||
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||
/// <returns>The CodeGen tokenizer object.</returns>
|
||||
/// <remarks>
|
||||
/// The tokenizer will be created according to the configuration specified in https://huggingface.co/Salesforce/codegen-350M-mono/raw/main/tokenizer.json.
|
||||
/// It is important to provide the similar vocab and merges files to the ones used in the training of the model.
|
||||
/// The vocab and merges files can be downloaded from the following links:
|
||||
/// https://huggingface.co/Salesforce/codegen-350M-mono/resolve/main/vocab.json?download=true
|
||||
/// https://huggingface.co/Salesforce/codegen-350M-mono/resolve/main/merges.txt?download=true
|
||||
/// </remarks>
|
||||
public static Tokenizer CreateCodeGen(
|
||||
Stream vocabStream,
|
||||
Stream mergesStream,
|
||||
bool addPrefixSpace = false,
|
||||
bool addBeginOfSentence = false,
|
||||
bool addEndOfSentence = false)
|
||||
{
|
||||
if (vocabStream is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(vocabStream));
|
||||
}
|
||||
|
||||
if (mergesStream is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(mergesStream));
|
||||
}
|
||||
|
||||
return new CodeGen(
|
||||
vocabStream,
|
||||
mergesStream,
|
||||
new TiktokenPreTokenizer(Tiktoken.P50kBaseRegex(), CodeGen.CodeGenAddedTokens),
|
||||
normalizer: null,
|
||||
CodeGen.CodeGenAddedTokens,
|
||||
addPrefixSpace: addPrefixSpace,
|
||||
addBeginningOfSentence: addBeginOfSentence,
|
||||
addEndOfSentence: addEndOfSentence);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a CodeGen Phi2 tokenizer from the given vocab and merges streams.
|
||||
/// </summary>
|
||||
/// <param name="vocabStream">The stream containing the vocab file.</param>
|
||||
/// <param name="mergesStream">The stream containing the merges file.</param>
|
||||
/// <param name="addPrefixSpace">Indicate whether to add a space before the token.</param>
|
||||
/// <param name="addBeginOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
|
||||
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
|
||||
/// <returns>The CodeGen tokenizer object.</returns>
|
||||
/// <remarks>
|
||||
/// The tokenizer will be created according to the configuration specified in https://huggingface.co/microsoft/phi-2/raw/main/tokenizer.json.
|
||||
/// It is important to provide the similar vocab and merges files to the ones used in the training of the model.
|
||||
/// The vocab and merges files can be downloaded from the following links:
|
||||
/// https://huggingface.co/microsoft/phi-2/resolve/main/vocab.json?download=true
|
||||
/// https://huggingface.co/microsoft/phi-2/resolve/main/merges.txt?download=true
|
||||
/// </remarks>
|
||||
public static Tokenizer CreatePhi2(
|
||||
Stream vocabStream,
|
||||
Stream mergesStream,
|
||||
bool addPrefixSpace = false,
|
||||
bool addBeginOfSentence = false,
|
||||
bool addEndOfSentence = false)
|
||||
=> CreateCodeGen(vocabStream, mergesStream, addPrefixSpace, addBeginOfSentence, addEndOfSentence);
|
||||
|
||||
internal static IEnumerable<(int Offset, int Length)>? InitializeForEncoding(
|
||||
string? text,
|
||||
ReadOnlySpan<char> textSpan,
|
||||
|
|
|
@ -13,11 +13,11 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// </summary>
|
||||
internal static class BytePairEncoder
|
||||
{
|
||||
public static int[] BytePairEncode(ReadOnlyMemory<byte> mergingBytes, IReadOnlyDictionary<ReadOnlyMemory<byte>, int> ranks)
|
||||
public static (int Id, int TokenIndex, int TokenLength)[] BytePairEncode(ReadOnlyMemory<byte> mergingBytes, IReadOnlyDictionary<ReadOnlyMemory<byte>, int> ranks, ReadOnlySpan<int> indexMappingSpan)
|
||||
{
|
||||
if (mergingBytes.Length == 1)
|
||||
{
|
||||
return [ranks[mergingBytes]];
|
||||
return [(ranks[mergingBytes], 0, 1)];
|
||||
}
|
||||
|
||||
(int Index, int Rank)[]? arrayPoolArray = null;
|
||||
|
@ -84,10 +84,28 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
}
|
||||
|
||||
var result = new int[byteIndicesAndRanks.Length - 1];
|
||||
var result = new (int Id, int TokenIndex, int TokenLength)[byteIndicesAndRanks.Length - 1];
|
||||
for (int i = 0; i < result.Length; i++)
|
||||
{
|
||||
result[i] = ranks[mergingBytes.SliceStartEnd(byteIndicesAndRanks[i].Index, byteIndicesAndRanks[i + 1].Index)];
|
||||
int startIndex = byteIndicesAndRanks[i].Index;
|
||||
int endIndex = byteIndicesAndRanks[i + 1].Index;
|
||||
|
||||
int mappedStartIndex = indexMappingSpan[startIndex];
|
||||
int mappedEndIndex = indexMappingSpan[endIndex];
|
||||
|
||||
int finalEndIndex = endIndex;
|
||||
|
||||
if (finalEndIndex > 0 && indexMappingSpan[finalEndIndex - 1] == mappedEndIndex)
|
||||
{
|
||||
// The partial character/element should be included in the current token.
|
||||
finalEndIndex++;
|
||||
while (finalEndIndex < indexMappingSpan.Length && indexMappingSpan[finalEndIndex] == mappedEndIndex)
|
||||
{
|
||||
finalEndIndex++;
|
||||
}
|
||||
}
|
||||
|
||||
result[i] = (ranks[mergingBytes.SliceStartEnd(startIndex, endIndex)], mappedStartIndex, indexMappingSpan[finalEndIndex] - mappedStartIndex);
|
||||
}
|
||||
|
||||
if (arrayPoolArray is not null)
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
|
||||
namespace Microsoft.ML.Tokenizers
|
||||
{
|
||||
/// <summary>
|
||||
/// Map between utf-8 byte to unicode with avoiding mapping to whitespace/control characters.
|
||||
/// </summary>
|
||||
internal sealed class ByteToUnicodeEncoding
|
||||
{
|
||||
public static ByteToUnicodeEncoding Instance { get; } = new ByteToUnicodeEncoding();
|
||||
|
||||
public ByteToUnicodeEncoding()
|
||||
{
|
||||
var byteToUnicodeMapping = Enumerable.Range('!', '~' - '!' + 1)
|
||||
.Concat(Enumerable.Range('¡', '¬' - '¡' + 1))
|
||||
.Concat(Enumerable.Range('®', 'ÿ' - '®' + 1))
|
||||
.ToDictionary(b => (char)b, b => (char)b);
|
||||
|
||||
const int numChars = 256;
|
||||
var n = 0;
|
||||
foreach (var b in Enumerable.Range(0, numChars))
|
||||
{
|
||||
if (!byteToUnicodeMapping.ContainsKey((char)b))
|
||||
{
|
||||
byteToUnicodeMapping.Add((char)b, (char)(numChars + n));
|
||||
++n;
|
||||
}
|
||||
}
|
||||
|
||||
ByteToUnicode = byteToUnicodeMapping;
|
||||
UnicodeToByte = ByteToUnicode.ToDictionary(kv => kv.Value, kv => kv.Key);
|
||||
|
||||
int count = numChars + n;
|
||||
|
||||
CharToString = new string[count];
|
||||
for (char c = (char)0; c < (char)count; c++)
|
||||
{
|
||||
CharToString[c] = c.ToString();
|
||||
}
|
||||
}
|
||||
|
||||
public IReadOnlyDictionary<char, char> ByteToUnicode { get; }
|
||||
public IReadOnlyDictionary<char, char> UnicodeToByte { get; }
|
||||
public string[] CharToString { get; }
|
||||
}
|
||||
}
|
|
@ -4,6 +4,8 @@
|
|||
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.Diagnostics;
|
||||
using System.Text;
|
||||
|
||||
namespace Microsoft.ML.Tokenizers
|
||||
{
|
||||
|
@ -16,5 +18,114 @@ namespace Microsoft.ML.Tokenizers
|
|||
ArrayPool<T>.Shared.Return(arrayPoolArray);
|
||||
arrayPoolArray = tmp;
|
||||
}
|
||||
|
||||
internal static int EncodeToUtf8(ReadOnlySpan<char> text, Span<byte> destination, Span<int> indexMapping)
|
||||
{
|
||||
Debug.Assert(!text.IsEmpty);
|
||||
Debug.Assert(Encoding.UTF8.GetMaxByteCount(text.Length) <= destination.Length);
|
||||
Debug.Assert(indexMapping.Length >= destination.Length);
|
||||
|
||||
int targetIndex = 0;
|
||||
|
||||
for (int i = 0; i < text.Length; i++)
|
||||
{
|
||||
uint c = (uint)text[i];
|
||||
if (c <= 0x7Fu)
|
||||
{
|
||||
destination[targetIndex] = (byte)c;
|
||||
indexMapping[targetIndex] = i;
|
||||
targetIndex++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (c <= 0x7FFu)
|
||||
{
|
||||
// Scalar 00000yyy yyxxxxxx -> bytes [ 110yyyyy 10xxxxxx ]
|
||||
destination[targetIndex] = (byte)((c + (0b110u << 11)) >> 6);
|
||||
destination[targetIndex + 1] = (byte)((c & 0x3Fu) + 0x80u);
|
||||
indexMapping[targetIndex] = indexMapping[targetIndex + 1] = i;
|
||||
targetIndex += 2;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (i < text.Length - 1 && char.IsSurrogatePair((char)c, text[i + 1]))
|
||||
{
|
||||
// Scalar 000uuuuu zzzzyyyy yyxxxxxx -> bytes [ 11110uuu 10uuzzzz 10yyyyyy 10xxxxxx ]
|
||||
uint value = (uint)char.ConvertToUtf32((char)c, text[i + 1]);
|
||||
destination[targetIndex] = (byte)((value + (0b11110 << 21)) >> 18);
|
||||
destination[targetIndex + 1] = (byte)(((value & (0x3Fu << 12)) >> 12) + 0x80u);
|
||||
destination[targetIndex + 2] = (byte)(((value & (0x3Fu << 6)) >> 6) + 0x80u);
|
||||
destination[targetIndex + 3] = (byte)((value & 0x3Fu) + 0x80u);
|
||||
indexMapping[targetIndex] = indexMapping[targetIndex + 1] = indexMapping[targetIndex + 2] = indexMapping[targetIndex + 3] = i;
|
||||
i++;
|
||||
targetIndex += 4;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Scalar zzzzyyyy yyxxxxxx -> bytes [ 1110zzzz 10yyyyyy 10xxxxxx ]
|
||||
destination[targetIndex] = (byte)((c + (0b1110 << 16)) >> 12);
|
||||
destination[targetIndex + 1] = (byte)(((c & (0x3Fu << 6)) >> 6) + 0x80u);
|
||||
destination[targetIndex + 2] = (byte)((c & 0x3Fu) + 0x80u);
|
||||
indexMapping[targetIndex] = indexMapping[targetIndex + 1] = indexMapping[targetIndex + 2] = i;
|
||||
targetIndex += 3;
|
||||
}
|
||||
|
||||
return targetIndex;
|
||||
}
|
||||
|
||||
internal static int EncodeToUtf8AndTransform(ReadOnlySpan<char> text, Span<char> destination, Span<int> indexMapping)
|
||||
{
|
||||
Debug.Assert(!text.IsEmpty);
|
||||
Debug.Assert(Encoding.UTF8.GetMaxByteCount(text.Length) <= destination.Length);
|
||||
Debug.Assert(indexMapping.Length >= destination.Length);
|
||||
|
||||
ByteToUnicodeEncoding byteToUnicodeEncoder = ByteToUnicodeEncoding.Instance;
|
||||
int targetIndex = 0;
|
||||
|
||||
for (int i = 0; i < text.Length; i++)
|
||||
{
|
||||
uint c = (uint)text[i];
|
||||
if (c <= 0x7Fu)
|
||||
{
|
||||
destination[targetIndex] = byteToUnicodeEncoder.ByteToUnicode[(char)c];
|
||||
indexMapping[targetIndex] = i;
|
||||
targetIndex++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (c <= 0x7FFu)
|
||||
{
|
||||
// Scalar 00000yyy yyxxxxxx -> bytes [ 110yyyyy 10xxxxxx ]
|
||||
destination[targetIndex] = byteToUnicodeEncoder.ByteToUnicode[(char)((c + (0b110u << 11)) >> 6)];
|
||||
destination[targetIndex + 1] = byteToUnicodeEncoder.ByteToUnicode[(char)((c & 0x3Fu) + 0x80u)];
|
||||
indexMapping[targetIndex] = indexMapping[targetIndex + 1] = i;
|
||||
targetIndex += 2;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (i < text.Length - 1 && char.IsSurrogatePair((char)c, text[i + 1]))
|
||||
{
|
||||
// Scalar 000uuuuu zzzzyyyy yyxxxxxx -> bytes [ 11110uuu 10uuzzzz 10yyyyyy 10xxxxxx ]
|
||||
uint value = (uint)char.ConvertToUtf32((char)c, text[i + 1]);
|
||||
destination[targetIndex] = byteToUnicodeEncoder.ByteToUnicode[(char)((value + (0b11110 << 21)) >> 18)];
|
||||
destination[targetIndex + 1] = byteToUnicodeEncoder.ByteToUnicode[(char)(((value & (0x3Fu << 12)) >> 12) + 0x80u)];
|
||||
destination[targetIndex + 2] = byteToUnicodeEncoder.ByteToUnicode[(char)(((value & (0x3Fu << 6)) >> 6) + 0x80u)];
|
||||
destination[targetIndex + 3] = byteToUnicodeEncoder.ByteToUnicode[(char)((value & 0x3Fu) + 0x80u)];
|
||||
indexMapping[targetIndex] = indexMapping[targetIndex + 1] = indexMapping[targetIndex + 2] = indexMapping[targetIndex + 3] = i;
|
||||
i++;
|
||||
targetIndex += 4;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Scalar zzzzyyyy yyxxxxxx -> bytes [ 1110zzzz 10yyyyyy 10xxxxxx ]
|
||||
destination[targetIndex] = byteToUnicodeEncoder.ByteToUnicode[(char)((c + (0b1110 << 16)) >> 12)];
|
||||
destination[targetIndex + 1] = byteToUnicodeEncoder.ByteToUnicode[(char)(((c & (0x3Fu << 6)) >> 6) + 0x80u)];
|
||||
destination[targetIndex + 2] = byteToUnicodeEncoder.ByteToUnicode[(char)((c & 0x3Fu) + 0x80u)];
|
||||
indexMapping[targetIndex] = indexMapping[targetIndex + 1] = indexMapping[targetIndex + 2] = i;
|
||||
targetIndex += 3;
|
||||
}
|
||||
|
||||
return targetIndex;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -61,5 +61,32 @@ namespace Microsoft.ML.Tokenizers
|
|||
=> Encoding.UTF8.GetChars(bytes, chars);
|
||||
|
||||
internal static void Replace(Span<char> span, char oldValue, char newValue) => span.Replace(oldValue, newValue);
|
||||
|
||||
/// <summary>
|
||||
/// Encode the next code point in the text to UTF-8.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode the first code point from.</param>
|
||||
/// <param name="textIndex">The index of the first code point to encode.</param>
|
||||
/// <param name="destination">The buffer to write the UTF-8 bytes to.</param>
|
||||
/// <param name="bytesIndex">The index in the buffer to write the UTF-8 encoded bytes to.</param>
|
||||
/// <returns>The number of characters consumed from the text.</returns>
|
||||
internal static int EncodeCodePointToUtf8(ReadOnlySpan<char> text, int textIndex, ref byte[] destination, ref int bytesIndex)
|
||||
{
|
||||
Debug.Assert(textIndex < text.Length);
|
||||
|
||||
Rune.DecodeFromUtf16(text.Slice(textIndex), out Rune rune, out int charsConsumed);
|
||||
|
||||
Span<byte> buffer = stackalloc byte[4]; // max number of bytes for a single code point utf-8 encoding.
|
||||
int bytesWritten = rune.EncodeToUtf8(buffer);
|
||||
if (bytesIndex + bytesWritten > destination.Length)
|
||||
{
|
||||
Helpers.ArrayPoolGrow(ref destination, destination.Length * 2);
|
||||
}
|
||||
|
||||
buffer.Slice(0, bytesWritten).CopyTo(destination.AsSpan(bytesIndex));
|
||||
bytesIndex += bytesWritten;
|
||||
|
||||
return charsConsumed;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
using System.Net.Http;
|
||||
using System.Text;
|
||||
|
@ -111,6 +112,73 @@ namespace Microsoft.ML.Tokenizers
|
|||
if (span[i] == oldValue)
|
||||
span[i] = newValue;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Encode the next code point in the text to UTF-8.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode the first code point from.</param>
|
||||
/// <param name="textIndex">The index of the first code point to encode.</param>
|
||||
/// <param name="destination">The buffer to write the UTF-8 bytes to.</param>
|
||||
/// <param name="bytesIndex">The index in the buffer to write the UTF-8 encoded bytes to.</param>
|
||||
/// <returns>The number of characters consumed from the text.</returns>
|
||||
internal static int EncodeCodePointToUtf8(ReadOnlySpan<char> text, int textIndex, ref byte[] destination, ref int bytesIndex)
|
||||
{
|
||||
Debug.Assert(textIndex < text.Length);
|
||||
|
||||
uint c = (uint)text[textIndex];
|
||||
if (c <= 0x7Fu)
|
||||
{
|
||||
if (bytesIndex + 1 > destination.Length)
|
||||
{
|
||||
Helpers.ArrayPoolGrow(ref destination, destination.Length * 2);
|
||||
}
|
||||
destination[bytesIndex] = (byte)c;
|
||||
bytesIndex++;
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (c <= 0x7FFu)
|
||||
{
|
||||
// Scalar 00000yyy yyxxxxxx -> bytes [ 110yyyyy 10xxxxxx ]
|
||||
if (bytesIndex + 2 > destination.Length)
|
||||
{
|
||||
Helpers.ArrayPoolGrow(ref destination, destination.Length * 2);
|
||||
}
|
||||
destination[bytesIndex] = (byte)((c + (0b110u << 11)) >> 6);
|
||||
destination[bytesIndex + 1] = (byte)((c & 0x3Fu) + 0x80u);
|
||||
bytesIndex += 2;
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (textIndex < text.Length - 1 && char.IsSurrogatePair((char)c, text[textIndex + 1]))
|
||||
{
|
||||
// Scalar 000uuuuu zzzzyyyy yyxxxxxx -> bytes [ 11110uuu 10uuzzzz 10yyyyyy 10xxxxxx ]
|
||||
if (bytesIndex + 4 > destination.Length)
|
||||
{
|
||||
Helpers.ArrayPoolGrow(ref destination, Math.Max(destination.Length, 4) * 2);
|
||||
}
|
||||
|
||||
uint value = (uint)char.ConvertToUtf32((char)c, text[textIndex + 1]);
|
||||
destination[bytesIndex] = (byte)((value + (0b11110 << 21)) >> 18);
|
||||
destination[bytesIndex + 1] = (byte)(((value & (0x3Fu << 12)) >> 12) + 0x80u);
|
||||
destination[bytesIndex + 2] = (byte)(((value & (0x3Fu << 6)) >> 6) + 0x80u);
|
||||
destination[bytesIndex + 3] = (byte)((value & 0x3Fu) + 0x80u);
|
||||
bytesIndex += 4;
|
||||
return 2;
|
||||
}
|
||||
|
||||
if (bytesIndex + 3 > destination.Length)
|
||||
{
|
||||
Helpers.ArrayPoolGrow(ref destination, Math.Max(destination.Length, 3) * 2);
|
||||
}
|
||||
|
||||
// Scalar zzzzyyyy yyxxxxxx -> bytes [ 1110zzzz 10yyyyyy 10xxxxxx ]
|
||||
destination[bytesIndex] = (byte)((c + (0b1110 << 16)) >> 12);
|
||||
destination[bytesIndex + 1] = (byte)(((c & (0x3Fu << 6)) >> 6) + 0x80u);
|
||||
destination[bytesIndex + 2] = (byte)((c & 0x3Fu) + 0x80u);
|
||||
bytesIndex += 3;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -45,6 +45,31 @@ namespace Microsoft.ML.Tokenizers
|
|||
public override int GetHashCode() => Helpers.GetHashCode(Span);
|
||||
}
|
||||
|
||||
internal unsafe readonly struct StringSpanOrdinalKeyPair : IEquatable<StringSpanOrdinalKeyPair>
|
||||
{
|
||||
private readonly StringSpanOrdinalKey _left;
|
||||
private readonly StringSpanOrdinalKey _right;
|
||||
|
||||
public StringSpanOrdinalKeyPair(char* ptr1, int length1, char* ptr2, int length2)
|
||||
{
|
||||
_left = new StringSpanOrdinalKey(ptr1, length1);
|
||||
_right = new StringSpanOrdinalKey(ptr2, length2);
|
||||
}
|
||||
|
||||
public StringSpanOrdinalKeyPair(string data1, string data2)
|
||||
{
|
||||
_left = new StringSpanOrdinalKey(data1);
|
||||
_right = new StringSpanOrdinalKey(data2);
|
||||
}
|
||||
public override bool Equals(object? obj) =>
|
||||
obj is StringSpanOrdinalKeyPair wrapper && wrapper._left.Equals(_left) && wrapper._right.Equals(_right);
|
||||
|
||||
public bool Equals(StringSpanOrdinalKeyPair other) => other._left.Equals(_left) && other._right.Equals(_right);
|
||||
|
||||
public override int GetHashCode() => HashCode.Combine(_left.GetHashCode(), _right.GetHashCode());
|
||||
}
|
||||
|
||||
|
||||
internal sealed class StringSpanOrdinalKeyCache<TValue>
|
||||
{
|
||||
private readonly int _capacity;
|
||||
|
@ -115,6 +140,34 @@ namespace Microsoft.ML.Tokenizers
|
|||
public override void Write(Utf8JsonWriter writer, StringSpanOrdinalKey value, JsonSerializerOptions options) => writer.WriteStringValue(value.Data!);
|
||||
}
|
||||
|
||||
internal class StringSpanOrdinalKeyCustomConverter : JsonConverter<Dictionary<StringSpanOrdinalKey, (int, string)>>
|
||||
{
|
||||
public static StringSpanOrdinalKeyCustomConverter Instance { get; } = new StringSpanOrdinalKeyCustomConverter();
|
||||
|
||||
public override Dictionary<StringSpanOrdinalKey, (int, string)> Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
|
||||
{
|
||||
var dictionary = new Dictionary<StringSpanOrdinalKey, (int, string)>();
|
||||
while (reader.Read())
|
||||
{
|
||||
if (reader.TokenType == JsonTokenType.EndObject)
|
||||
{
|
||||
return dictionary;
|
||||
}
|
||||
|
||||
if (reader.TokenType == JsonTokenType.PropertyName)
|
||||
{
|
||||
var key = reader.GetString();
|
||||
reader.Read();
|
||||
var value = reader.GetInt32();
|
||||
dictionary.Add(new StringSpanOrdinalKey(key!), (value, key!));
|
||||
}
|
||||
}
|
||||
throw new JsonException("Invalid JSON.");
|
||||
}
|
||||
|
||||
public override void Write(Utf8JsonWriter writer, Dictionary<StringSpanOrdinalKey, (int, string)> value, JsonSerializerOptions options) => throw new NotImplementedException();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods for <see cref="StringSpanOrdinalKey"/>.
|
||||
/// </summary>
|
||||
|
@ -130,5 +183,17 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
public static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKey, TValue> map, string key, out TValue value) =>
|
||||
map.TryGetValue(new StringSpanOrdinalKey(key), out value!);
|
||||
|
||||
public unsafe static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKeyPair, TValue> map, ReadOnlySpan<char> key1, ReadOnlySpan<char> key2, out TValue value)
|
||||
{
|
||||
fixed (char* ptr1 = key1)
|
||||
fixed (char* ptr2 = key2)
|
||||
{
|
||||
return map.TryGetValue(new StringSpanOrdinalKeyPair(ptr1, key1.Length, ptr2, key2.Length), out value!);
|
||||
}
|
||||
}
|
||||
|
||||
public static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKeyPair, TValue> map, string key1, string key2, out TValue value) =>
|
||||
map.TryGetValue(new StringSpanOrdinalKeyPair(key1, key2), out value!);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace Microsoft.ML.TorchSharp.Extensions
|
|||
assembly.GetManifestResourceStream("encoder.json"),
|
||||
assembly.GetManifestResourceStream("vocab.bpe"),
|
||||
assembly.GetManifestResourceStream("dict.txt"),
|
||||
new RobertaPreTokenizer());
|
||||
RobertaPreTokenizer.Instance);
|
||||
(_instance as EnglishRoberta).AddMaskSymbol();
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using Microsoft.ML.Tokenizers;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Net.Http;
|
||||
|
|
|
@ -0,0 +1,957 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text.RegularExpressions;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.ML.Tokenizers.Tests
|
||||
{
|
||||
public class CodeGenTests
|
||||
{
|
||||
private static Tokenizer _codegen350MMonoTokenizer = CreateCodegen350MMonoTokenizer();
|
||||
private static Tokenizer _codegen350MMonoTokenizerWithSpace = CreateCodegen350MMonoTokenizer(addPrefixSpace: true);
|
||||
private static Tokenizer _codegen350MMonoTokenizerWithBeginningOfSentence = CreateCodegen350MMonoTokenizer(bos: true);
|
||||
private static Tokenizer _codegen350MMonoTokenizerWithEndOfSentence = CreateCodegen350MMonoTokenizer(eos: true);
|
||||
private static Tokenizer _codegen350MMonoTokenizerWithBeginningAndEndOfSentence = CreateCodegen350MMonoTokenizer(bos: true, eos: true);
|
||||
|
||||
private static Tokenizer CreateCodegen350MMonoTokenizer(bool addPrefixSpace = false, bool bos = false, bool eos = false)
|
||||
{
|
||||
// @"https://huggingface.co/Salesforce/codegen-350M-mono/resolve/main/vocab.json?download=true";
|
||||
// @"https://huggingface.co/Salesforce/codegen-350M-mono/resolve/main/merges.txt?download=true";
|
||||
|
||||
using Stream vocabStream = File.OpenRead(Path.Combine(@"Codegen-350M-mono", "vocab.json"));
|
||||
using Stream mergesStream = File.OpenRead(Path.Combine(@"Codegen-350M-mono", "merges.txt"));
|
||||
|
||||
return Tokenizer.CreateCodeGen(vocabStream, mergesStream, addPrefixSpace, bos, eos);
|
||||
}
|
||||
|
||||
private static Tokenizer CreateCodegenPhi2Tokenizer()
|
||||
{
|
||||
// https://huggingface.co/microsoft/phi-2/resolve/main/vocab.json?download=true
|
||||
// https://huggingface.co/microsoft/phi-2/resolve/main/merges.txt?download=true
|
||||
|
||||
using Stream vocabStream = File.OpenRead(Path.Combine(@"Phi-2", "vocab.json"));
|
||||
using Stream mergesStream = File.OpenRead(Path.Combine(@"Phi-2", "merges.txt"));
|
||||
|
||||
return Tokenizer.CreateCodeGen(vocabStream, mergesStream);
|
||||
}
|
||||
|
||||
public static IEnumerable<object?[]> CodeGenTestData
|
||||
{
|
||||
get
|
||||
{
|
||||
// string to tokenize,
|
||||
// produced tokens,
|
||||
// the token offsets,
|
||||
// the tokens ids, produced tokens when AddPrefixSpace is enabled,
|
||||
// the token offsets when AddPrefixSpace is enabled,
|
||||
// the tokens ids when AddPrefixSpace is enabled
|
||||
yield return new object?[]
|
||||
{
|
||||
"Hello World",
|
||||
new string[] { "Hello", "ĠWorld" },
|
||||
new (int Index, int Length)[] { (0, 5), (5, 6) },
|
||||
new int[] { 15496, 2159 },
|
||||
new string[] { "ĠHello", "ĠWorld" },
|
||||
new (int Index, int Length)[] { (0, 5), (5, 6) },
|
||||
new int[] { 18435, 2159 },
|
||||
};
|
||||
|
||||
yield return new object?[]
|
||||
{
|
||||
" Hello World", // with space prefix this depends on the AddedTokens
|
||||
new string[] { "ĠHello", "ĠWorld" },
|
||||
new (int Index, int Length)[] { (0, 6), (6, 6) },
|
||||
new int[] { 18435, 2159 },
|
||||
new string[] { " ", "Hello", "ĠWorld" },
|
||||
new (int Index, int Length)[] { (0, 1), (1, 5), (6, 6) },
|
||||
new int[] { 50286, 15496, 2159 },
|
||||
};
|
||||
|
||||
yield return new object?[]
|
||||
{
|
||||
"the brown fox jumped over the lazy dog!\r\n", // text in range 0 ~ FF
|
||||
new string[] { "the", "Ġbrown", "Ġfox", "Ġjumped", "Ġover", "Ġthe", "Ġlazy", "Ġdog", "!", "č", "Ċ" },
|
||||
new (int Index, int Length)[] { (0, 3), (3, 6), (9, 4), (13, 7), (20, 5), (25, 4), (29, 5), (34, 4), (38, 1), (39, 1), (40, 1) },
|
||||
new int[] { 1169, 7586, 21831, 11687, 625, 262, 16931, 3290, 0, 201, 198 },
|
||||
new string[] { "Ġthe", "Ġbrown", "Ġfox", "Ġjumped", "Ġover", "Ġthe", "Ġlazy", "Ġdog", "!", "č", "Ċ" },
|
||||
new (int Index, int Length)[] { (0, 3), (3, 6), (9, 4), (13, 7), (20, 5), (25, 4), (29, 5), (34, 4), (38, 1), (39, 1), (40, 1) },
|
||||
new int[] { 262, 7586, 21831, 11687, 625, 262, 16931, 3290, 0, 201, 198 }
|
||||
};
|
||||
|
||||
yield return new object?[]
|
||||
{
|
||||
"\u0924\u1009\u1129\u1241\uE860\u3438.", // text greater than 7FF Devanagari, Myanmar, Hangul, Ethiopic, Palmyrene, CJK तဉᄩቁ㐸.
|
||||
new string[] { "à¤", "¤", "á", "Ģ", "ī", "á", "Ħ", "©", "á", "ī", "ģ", "î", "¡", "ł", "ã", "IJ", "¸", "." },
|
||||
new (int Index, int Length)[] { (0, 0), (0, 1), (1, 0), (1, 0), (1, 1), (2, 0), (2, 0), (2, 1), (3, 0), (3, 0), (3, 1), (4, 0), (4, 0), (4, 1), (5, 0), (5, 0), (5, 1), (6, 1) },
|
||||
new int[] { 11976, 97, 157, 222, 231, 157, 226, 102, 157, 231, 223, 170, 94, 254, 159, 238, 116, 13 },
|
||||
new string[] { "Ġà¤", "¤", "á", "Ģ", "ī", "á", "Ħ", "©", "á", "ī", "ģ", "î", "¡", "ł", "ã", "IJ", "¸", "." },
|
||||
new (int Index, int Length)[] { (0, 0), (0, 1), (1, 0), (1, 0), (1, 1), (2, 0), (2, 0), (2, 1), (3, 0), (3, 0), (3, 1), (4, 0), (4, 0), (4, 1), (5, 0), (5, 0), (5, 1), (6, 1) },
|
||||
new int[] { 28225, 97, 157, 222, 231, 157, 226, 102, 157, 231, 223, 170, 94, 254, 159, 238, 116, 13 }
|
||||
};
|
||||
|
||||
yield return new object?[]
|
||||
{
|
||||
"Some Greek letters ΣΦΩ αβγδε.", // text in range 100 ~ 7FF
|
||||
new string[] { "Some", "ĠGreek", "Ġletters", "ĠÎ", "£", "Î", "¦", "Î", "©", "Ġα", "β", "γ", "Î", "´", "ε", "." },
|
||||
new (int Index, int Length)[] { (0, 4), (4, 6), (10, 8), (18, 1), (19, 1), (20, 0), (20, 1), (21, 0), (21, 1), (22, 2), (24, 1), (25, 1), (26, 0), (26, 1), (27, 1), (28, 1) },
|
||||
new int[] { 4366, 8312, 7475, 7377, 96, 138, 99, 138, 102, 26367, 26638, 42063, 138, 112, 30950, 13 },
|
||||
new string[] { "ĠSome", "ĠGreek", "Ġletters", "ĠÎ", "£", "Î", "¦", "Î", "©", "Ġα", "β", "γ", "Î", "´", "ε", "." },
|
||||
new (int Index, int Length)[] { (0, 4), (4, 6), (10, 8), (18, 1), (19, 1), (20, 0), (20, 1), (21, 0), (21, 1), (22, 2), (24, 1), (25, 1), (26, 0), (26, 1), (27, 1), (28, 1) },
|
||||
new int[] { 2773, 8312, 7475, 7377, 96, 138, 99, 138, 102, 26367, 26638, 42063, 138, 112, 30950, 13 }
|
||||
};
|
||||
|
||||
yield return new object?[]
|
||||
{
|
||||
"αβγδε", // no spaces
|
||||
new string[] { "α", "β", "γ", "Î", "´", "ε" },
|
||||
new (int Index, int Length)[] { (0, 1), (1, 1), (2, 1), (3, 0), (3, 1), (4, 1) },
|
||||
new int[] { 17394, 26638, 42063, 138, 112, 30950 },
|
||||
new string[] { "Ġα", "β", "γ", "Î", "´", "ε" },
|
||||
new (int Index, int Length)[] { (0, 1), (1, 1), (2, 1), (3, 0), (3, 1), (4, 1) },
|
||||
new int[] { 26367, 26638, 42063, 138, 112, 30950 }
|
||||
};
|
||||
|
||||
yield return new object?[]
|
||||
{
|
||||
"Surrogates: 😀😂😍😘",
|
||||
new string[] { "Sur", "rog", "ates", ":", "ĠðŁĺ", "Ģ", "ðŁĺ", "Ĥ", "ðŁĺ", "į", "ðŁĺ", "ĺ" },
|
||||
new (int Index, int Length)[] { (0, 3), (3, 3), (6, 4), (10, 1), (11, 1), (12, 2), (14, 0), (14, 2), (16, 0), (16, 2), (18, 0), (18, 2) },
|
||||
new int[] { 14214, 3828, 689, 25, 30325, 222, 47249, 224, 47249, 235, 47249, 246 },
|
||||
new string[] { "ĠSur", "rog", "ates", ":", "ĠðŁĺ", "Ģ", "ðŁĺ", "Ĥ", "ðŁĺ", "į", "ðŁĺ", "ĺ" },
|
||||
new (int Index, int Length)[] { (0, 3), (3, 3), (6, 4), (10, 1), (11, 1), (12, 2), (14, 0), (14, 2), (16, 0), (16, 2), (18, 0), (18, 2) },
|
||||
new int[] { 4198, 3828, 689, 25, 30325, 222, 47249, 224, 47249, 235, 47249, 246 }
|
||||
};
|
||||
|
||||
yield return new object?[]
|
||||
{
|
||||
"Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides " +
|
||||
"general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet...) for Natural " +
|
||||
"Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained " +
|
||||
"models in 100+ languages and deep interoperability between Jax, PyTorch and TensorFlow.",
|
||||
new string[] { "Transform", "ers", "Ġ(", "formerly", "Ġknown", "Ġas", "Ġpy", "tor", "ch", "-", "transform", "ers", "Ġand", "Ġpy", "tor", "ch", "-",
|
||||
"pret", "rained", "-", "bert", ")", "Ġprovides", "Ġgeneral", "-", "purpose", "Ġarchitectures", "Ġ(", "BER", "T", ",", "ĠG", "PT",
|
||||
"-", "2", ",", "ĠRo", "BER", "Ta", ",", "ĠXL", "M", ",", "ĠDist", "il", "B", "ert", ",", "ĠXL", "Net", "...)", "Ġfor", "ĠNatural",
|
||||
"ĠLanguage", "ĠUnderstanding", "Ġ(", "NL", "U", ")", "Ġand", "ĠNatural", "ĠLanguage", "ĠGeneration", "Ġ(", "NL", "G", ")", "Ġwith",
|
||||
"Ġover", "Ġ32", "+", "Ġpret", "rained", "Ġmodels", "Ġin", "Ġ100", "+", "Ġlanguages", "Ġand", "Ġdeep", "Ġinteroper", "ability",
|
||||
"Ġbetween", "ĠJ", "ax", ",", "ĠPy", "Tor", "ch", "Ġand", "ĠT", "ensor", "Flow", "." },
|
||||
new (int Index, int Length)[] { (0, 9), (9, 3), (12, 2), (14, 8), (22, 6), (28, 3), (31, 3), (34, 3), (37, 2), (39, 1), (40, 9), (49, 3), (52, 4),
|
||||
(56, 3), (59, 3), (62, 2), (64, 1), (65, 4), (69, 6), (75, 1), (76, 4), (80, 1), (81, 9), (90, 8), (98, 1), (99, 7), (106, 14),
|
||||
(120, 2), (122, 3), (125, 1), (126, 1), (127, 2), (129, 2), (131, 1), (132, 1), (133, 1), (134, 3), (137, 3), (140, 2), (142, 1),
|
||||
(143, 3), (146, 1), (147, 1), (148, 5), (153, 2), (155, 1), (156, 3), (159, 1), (160, 3), (163, 3), (166, 4), (170, 4), (174, 8),
|
||||
(182, 9), (191, 14), (205, 2), (207, 2), (209, 1), (210, 1), (211, 4), (215, 8), (223, 9), (232, 11), (243, 2), (245, 2), (247, 1),
|
||||
(248, 1), (249, 5), (254, 5), (259, 3), (262, 1), (263, 5), (268, 6), (274, 7), (281, 3), (284, 4), (288, 1), (289, 10), (299, 4),
|
||||
(303, 5), (308, 10), (318, 7), (325, 8), (333, 2), (335, 2), (337, 1), (338, 3), (341, 3), (344, 2), (346, 4), (350, 2), (352, 5),
|
||||
(357, 4), (361, 1) },
|
||||
new int[] { 41762, 364, 357, 36234, 1900, 355, 12972, 13165, 354, 12, 35636, 364, 290, 12972, 13165, 354, 12, 5310, 13363, 12, 4835, 8, 3769, 2276,
|
||||
12, 29983, 45619, 357, 13246, 51, 11, 402, 11571, 12, 17, 11, 5564, 13246, 38586, 11, 16276, 44, 11, 4307, 346, 33, 861, 11, 16276,
|
||||
7934, 23029, 329, 12068, 15417, 28491, 357, 32572, 52, 8, 290, 12068, 15417, 16588, 357, 32572, 38, 8, 351, 625, 3933, 10, 2181, 13363,
|
||||
4981, 287, 1802, 10, 8950, 290, 2769, 48817, 1799, 1022, 449, 897, 11, 9485, 15884, 354, 290, 309, 22854, 37535, 13 },
|
||||
new string[] { "ĠTransformers", "Ġ(", "formerly", "Ġknown", "Ġas", "Ġpy", "tor", "ch", "-", "transform", "ers", "Ġand", "Ġpy", "tor", "ch", "-",
|
||||
"pret", "rained", "-", "bert", ")", "Ġprovides", "Ġgeneral", "-", "purpose", "Ġarchitectures", "Ġ(", "BER", "T", ",", "ĠG", "PT",
|
||||
"-", "2", ",", "ĠRo", "BER", "Ta", ",", "ĠXL", "M", ",", "ĠDist", "il", "B", "ert", ",", "ĠXL", "Net", "...)", "Ġfor", "ĠNatural",
|
||||
"ĠLanguage", "ĠUnderstanding", "Ġ(", "NL", "U", ")", "Ġand", "ĠNatural", "ĠLanguage", "ĠGeneration", "Ġ(", "NL", "G", ")", "Ġwith",
|
||||
"Ġover", "Ġ32", "+", "Ġpret", "rained", "Ġmodels", "Ġin", "Ġ100", "+", "Ġlanguages", "Ġand", "Ġdeep", "Ġinteroper", "ability",
|
||||
"Ġbetween", "ĠJ", "ax", ",", "ĠPy", "Tor", "ch", "Ġand", "ĠT", "ensor", "Flow", "." },
|
||||
new (int Index, int Length)[] { (0, 12), (12, 2), (14, 8), (22, 6), (28, 3), (31, 3), (34, 3), (37, 2), (39, 1), (40, 9), (49, 3), (52, 4),
|
||||
(56, 3), (59, 3), (62, 2), (64, 1), (65, 4), (69, 6), (75, 1), (76, 4), (80, 1), (81, 9), (90, 8), (98, 1), (99, 7), (106, 14),
|
||||
(120, 2), (122, 3), (125, 1), (126, 1), (127, 2), (129, 2), (131, 1), (132, 1), (133, 1), (134, 3), (137, 3), (140, 2), (142, 1),
|
||||
(143, 3), (146, 1), (147, 1), (148, 5), (153, 2), (155, 1), (156, 3), (159, 1), (160, 3), (163, 3), (166, 4), (170, 4), (174, 8),
|
||||
(182, 9), (191, 14), (205, 2), (207, 2), (209, 1), (210, 1), (211, 4), (215, 8), (223, 9), (232, 11), (243, 2), (245, 2), (247, 1),
|
||||
(248, 1), (249, 5), (254, 5), (259, 3), (262, 1), (263, 5), (268, 6), (274, 7), (281, 3), (284, 4), (288, 1), (289, 10), (299, 4),
|
||||
(303, 5), (308, 10), (318, 7), (325, 8), (333, 2), (335, 2), (337, 1), (338, 3), (341, 3), (344, 2), (346, 4), (350, 2), (352, 5),
|
||||
(357, 4), (361, 1) },
|
||||
new int[] { 39185, 357, 36234, 1900, 355, 12972, 13165, 354, 12, 35636, 364, 290, 12972, 13165, 354, 12, 5310, 13363, 12, 4835, 8, 3769, 2276,
|
||||
12, 29983, 45619, 357, 13246, 51, 11, 402, 11571, 12, 17, 11, 5564, 13246, 38586, 11, 16276, 44, 11, 4307, 346, 33, 861, 11, 16276,
|
||||
7934, 23029, 329, 12068, 15417, 28491, 357, 32572, 52, 8, 290, 12068, 15417, 16588, 357, 32572, 38, 8, 351, 625, 3933, 10, 2181, 13363,
|
||||
4981, 287, 1802, 10, 8950, 290, 2769, 48817, 1799, 1022, 449, 897, 11, 9485, 15884, 354, 290, 309, 22854, 37535, 13 }
|
||||
};
|
||||
|
||||
yield return new object?[]
|
||||
{
|
||||
"BERT is designed to pre-train deep bidirectional representations from unlabeled text by jointly " +
|
||||
"conditioning on both left and right context in all layers.",
|
||||
new string[] { "BER", "T", "Ġis", "Ġdesigned", "Ġto", "Ġpre", "-", "train", "Ġdeep", "Ġbid", "irection", "al", "Ġrepresentations", "Ġfrom", "Ġunl",
|
||||
"abel", "ed", "Ġtext", "Ġby", "Ġjointly", "Ġconditioning", "Ġon", "Ġboth", "Ġleft", "Ġand", "Ġright", "Ġcontext", "Ġin", "Ġall",
|
||||
"Ġlayers", "." },
|
||||
new (int Index, int Length)[] { (0, 3), (3, 1), (4, 3), (7, 9), (16, 3), (19, 4), (23, 1), (24, 5), (29, 5), (34, 4), (38, 8), (46, 2), (48, 16),
|
||||
(64, 5), (69, 4), (73, 4), (77, 2), (79, 5), (84, 3), (87, 8), (95, 13), (108, 3), (111, 5), (116, 5), (121, 4), (125, 6), (131, 8),
|
||||
(139, 3), (142, 4), (146, 7), (153, 1) },
|
||||
new int[] { 13246, 51, 318, 3562, 284, 662, 12, 27432, 2769, 8406, 4154, 282, 24612, 422, 9642, 9608, 276, 2420, 416, 26913, 21143, 319, 1111, 1364, 290, 826, 4732, 287, 477, 11685, 13 },
|
||||
new string[] { "ĠB", "ERT", "Ġis", "Ġdesigned", "Ġto", "Ġpre", "-", "train", "Ġdeep", "Ġbid", "irection", "al", "Ġrepresentations", "Ġfrom", "Ġunl",
|
||||
"abel", "ed", "Ġtext", "Ġby", "Ġjointly", "Ġconditioning", "Ġon", "Ġboth", "Ġleft", "Ġand", "Ġright", "Ġcontext", "Ġin", "Ġall",
|
||||
"Ġlayers", "." },
|
||||
new (int Index, int Length)[] { (0, 1), (1, 3), (4, 3), (7, 9), (16, 3), (19, 4), (23, 1), (24, 5), (29, 5), (34, 4), (38, 8), (46, 2), (48, 16),
|
||||
(64, 5), (69, 4), (73, 4), (77, 2), (79, 5), (84, 3), (87, 8), (95, 13), (108, 3), (111, 5), (116, 5), (121, 4), (125, 6), (131, 8),
|
||||
(139, 3), (142, 4), (146, 7), (153, 1) },
|
||||
new int[] { 347, 17395, 318, 3562, 284, 662, 12, 27432, 2769, 8406, 4154, 282, 24612, 422, 9642, 9608, 276, 2420, 416, 26913, 21143, 319, 1111, 1364, 290, 826, 4732, 287, 477, 11685, 13 }
|
||||
};
|
||||
|
||||
yield return new object?[]
|
||||
{
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
new string[] { "The", "Ġquick", "Ġbrown", "Ġfox", "Ġjumps", "Ġover", "Ġthe", "Ġlazy", "Ġdog", "." },
|
||||
new (int Index, int Length)[] { (0, 3), (3, 6), (9, 6), (15, 4), (19, 6), (25, 5), (30, 4), (34, 5), (39, 4), (43, 1) },
|
||||
new int[] { 464, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290, 13 },
|
||||
new string[] { "ĠThe", "Ġquick", "Ġbrown", "Ġfox", "Ġjumps", "Ġover", "Ġthe", "Ġlazy", "Ġdog", "." },
|
||||
new (int Index, int Length)[] { (0, 3), (3, 6), (9, 6), (15, 4), (19, 6), (25, 5), (30, 4), (34, 5), (39, 4), (43, 1) },
|
||||
new int[] { 383, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290, 13 }
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[MemberData(nameof(CodeGenTestData))]
|
||||
public void TestTokenizerEncoding(
|
||||
string text,
|
||||
string[] expectedTokens,
|
||||
(int Index, int Length)[] expectedOffsets,
|
||||
int[] expectedIds,
|
||||
string[] expectedTokensWithSpace,
|
||||
(int Index, int Length)[] expectedOffsetsWithSpace,
|
||||
int[] expectedIdsWithSpace)
|
||||
{
|
||||
TestTokenizer(_codegen350MMonoTokenizer, text, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
|
||||
TestTokenizer(_codegen350MMonoTokenizerWithSpace, text, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
|
||||
|
||||
Tokenizer phi2Tokenizer = CreateCodegenPhi2Tokenizer();
|
||||
TestTokenizer(phi2Tokenizer, text, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
|
||||
|
||||
TestDecoding(_codegen350MMonoTokenizer, text);
|
||||
TestDecoding(_codegen350MMonoTokenizerWithSpace, text);
|
||||
TestDecoding(phi2Tokenizer, text);
|
||||
}
|
||||
|
||||
private void ValidateEncoding(IReadOnlyList<Token> encoding, bool addPrefixSpace, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds,
|
||||
string[] expectedTokensWithSpace, (int Index, int Length)[] expectedOffsetsWithSpace, int[] expectedIdsWithSpace)
|
||||
{
|
||||
if (addPrefixSpace)
|
||||
{
|
||||
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal(expectedOffsetsWithSpace, encoding.Select(t => t.Offset).ToArray());
|
||||
}
|
||||
else
|
||||
{
|
||||
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray());
|
||||
}
|
||||
}
|
||||
|
||||
private void TestDecoding(Tokenizer tokenizer, string text)
|
||||
{
|
||||
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
|
||||
Assert.Equal(text, tokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
|
||||
|
||||
encoding = tokenizer.Encode(text.AsSpan(), out _);
|
||||
Assert.Equal(text, tokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
|
||||
|
||||
CodeGen codeGenTokenizer = (tokenizer as CodeGen)!;
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: true, addEndOfSentence: false, out _);
|
||||
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: true, addEndOfSentence: false, out _);
|
||||
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: false, addEndOfSentence: true, out _);
|
||||
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: false, addEndOfSentence: true, out _);
|
||||
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: true, addEndOfSentence: true, out _);
|
||||
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: codeGenTokenizer.AddPrefixSpace, addBeginningOfSentence: true, addEndOfSentence: true, out _);
|
||||
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray()));
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
|
||||
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray(), hasPrefixSpace: true, considerSpecialTokens: false));
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
|
||||
Assert.Equal(text, codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray(), hasPrefixSpace: true, considerSpecialTokens: false));
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
|
||||
Assert.Equal($"{codeGenTokenizer.BeginningOfSentenceToken}{text}{codeGenTokenizer.EndOfSentenceToken}", codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray(), hasPrefixSpace: true, considerSpecialTokens: true));
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
|
||||
Assert.Equal($"{codeGenTokenizer.BeginningOfSentenceToken}{text}{codeGenTokenizer.EndOfSentenceToken}", codeGenTokenizer.Decode(encoding.Select(t => t.Id).ToArray(), hasPrefixSpace: true, considerSpecialTokens: true));
|
||||
}
|
||||
|
||||
private void TestTokenizer(
|
||||
Tokenizer tokenizer,
|
||||
string text,
|
||||
string[] expectedTokens,
|
||||
(int Index, int Length)[] expectedOffsets,
|
||||
int[] expectedIds,
|
||||
string[] expectedTokensWithSpace,
|
||||
(int Index, int Length)[] expectedOffsetsWithSpace,
|
||||
int[] expectedIdsWithSpace)
|
||||
{
|
||||
CodeGen codeGenTokenizer = (tokenizer as CodeGen)!;
|
||||
|
||||
//
|
||||
// Full Encoding
|
||||
//
|
||||
|
||||
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
|
||||
ValidateEncoding(encoding, codeGenTokenizer.AddPrefixSpace, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
|
||||
|
||||
encoding = tokenizer.Encode(text.AsSpan(), out _);
|
||||
ValidateEncoding(encoding, codeGenTokenizer.AddPrefixSpace, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
|
||||
ValidateEncoding(encoding, addPrefixSpace: false, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
|
||||
ValidateEncoding(encoding, addPrefixSpace: false, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
|
||||
ValidateEncoding(encoding, addPrefixSpace: true, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
|
||||
ValidateEncoding(encoding, addPrefixSpace: true, expectedTokens, expectedOffsets, expectedIds, expectedTokensWithSpace, expectedOffsetsWithSpace, expectedIdsWithSpace);
|
||||
|
||||
//
|
||||
// Encode To Ids
|
||||
//
|
||||
|
||||
var ids = codeGenTokenizer.AddPrefixSpace ? expectedIdsWithSpace : expectedIds;
|
||||
|
||||
Assert.Equal(ids, tokenizer.EncodeToIds(text));
|
||||
Assert.Equal(ids, tokenizer.EncodeToIds(text.AsSpan()));
|
||||
|
||||
Assert.Equal(expectedIdsWithSpace, codeGenTokenizer.EncodeToIds(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false));
|
||||
Assert.Equal(expectedIdsWithSpace, codeGenTokenizer.EncodeToIds(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false));
|
||||
Assert.Equal(expectedIds, codeGenTokenizer.EncodeToIds(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false));
|
||||
Assert.Equal(expectedIds, codeGenTokenizer.EncodeToIds(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false));
|
||||
|
||||
Assert.Equal(ids, codeGenTokenizer.EncodeToIds(text, ids.Length, out string? normalizedString, out int length));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(text.Length, length);
|
||||
Assert.Equal(ids, codeGenTokenizer.EncodeToIds(text.AsSpan(), ids.Length, out normalizedString, out length));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(text.Length, length);
|
||||
|
||||
Assert.Equal(expectedIds, codeGenTokenizer.EncodeToIds(text, expectedIds.Length, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out length));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(text.Length, length);
|
||||
Assert.Equal(expectedIds, codeGenTokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out length));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(text.Length, length);
|
||||
|
||||
Assert.Equal(expectedIdsWithSpace, codeGenTokenizer.EncodeToIds(text, expectedIdsWithSpace.Length, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out length));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(text.Length, length);
|
||||
Assert.Equal(expectedIdsWithSpace, codeGenTokenizer.EncodeToIds(text.AsSpan(), expectedIdsWithSpace.Length, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out length));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(text.Length, length);
|
||||
|
||||
int expectedTokensToExclude = expectedOffsets.Length > 1 && expectedOffsets[expectedOffsets.Length - 1].Index == expectedOffsets[expectedOffsets.Length - 2].Index ? 2 : 1;
|
||||
Assert.Equal(ids.Take(ids.Length - expectedTokensToExclude), codeGenTokenizer.EncodeToIds(text, ids.Length - 1, out normalizedString, out length));
|
||||
Assert.Null(normalizedString);
|
||||
var offsets = codeGenTokenizer.AddPrefixSpace ? expectedOffsetsWithSpace : expectedOffsets;
|
||||
int expectedLength = offsets.Length > expectedTokensToExclude ? offsets[offsets.Length - expectedTokensToExclude - 1].Index + offsets[offsets.Length - expectedTokensToExclude - 1].Length : 0;
|
||||
Assert.Equal(expectedLength, length);
|
||||
Assert.Equal(ids.Take(ids.Length - expectedTokensToExclude), codeGenTokenizer.EncodeToIds(text.AsSpan(), ids.Length - 1, out normalizedString, out length));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(expectedLength, length);
|
||||
|
||||
Assert.Equal(expectedIds.Take(expectedIds.Length - expectedTokensToExclude), codeGenTokenizer.EncodeToIds(text, expectedIds.Length - 1, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out length));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(expectedLength, length);
|
||||
Assert.Equal(expectedIds.Take(expectedIds.Length - expectedTokensToExclude), codeGenTokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length - 1, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out length));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(expectedLength, length);
|
||||
|
||||
Assert.Equal(expectedIdsWithSpace.Take(expectedIdsWithSpace.Length - expectedTokensToExclude), codeGenTokenizer.EncodeToIds(text, expectedIdsWithSpace.Length - 1, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out length));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(expectedLength, length);
|
||||
Assert.Equal(expectedIdsWithSpace.Take(expectedIdsWithSpace.Length - expectedTokensToExclude), codeGenTokenizer.EncodeToIds(text.AsSpan(), expectedIdsWithSpace.Length - 1, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out length));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(expectedLength, length);
|
||||
|
||||
//
|
||||
// CountTokens
|
||||
//
|
||||
|
||||
Assert.Equal(ids.Length, codeGenTokenizer.CountTokens(text));
|
||||
Assert.Equal(ids.Length, codeGenTokenizer.CountTokens(text.AsSpan()));
|
||||
|
||||
Assert.Equal(expectedIds.Length, codeGenTokenizer.CountTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false));
|
||||
Assert.Equal(expectedIds.Length, codeGenTokenizer.CountTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false));
|
||||
|
||||
Assert.Equal(expectedIdsWithSpace.Length, codeGenTokenizer.CountTokens(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false));
|
||||
Assert.Equal(expectedIdsWithSpace.Length, codeGenTokenizer.CountTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false));
|
||||
|
||||
//
|
||||
// IndexOf
|
||||
//
|
||||
|
||||
offsets = codeGenTokenizer.AddPrefixSpace ? expectedOffsetsWithSpace : expectedOffsets;
|
||||
|
||||
Assert.Equal(offsets[offsets.Length - 1].Index + offsets[offsets.Length - 1].Length, codeGenTokenizer.IndexOfTokenCount(text, ids.Length, out normalizedString, out int tokenCount));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(ids.Length, tokenCount);
|
||||
Assert.Equal(offsets[offsets.Length - 1].Index + offsets[offsets.Length - 1].Length, codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), ids.Length, out normalizedString, out tokenCount));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(ids.Length, tokenCount);
|
||||
|
||||
Assert.Equal(expectedOffsets[expectedOffsets.Length - 1].Index + expectedOffsets[expectedOffsets.Length - 1].Length, codeGenTokenizer.IndexOfTokenCount(text, expectedIds.Length, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(expectedIds.Length, tokenCount);
|
||||
Assert.Equal(expectedOffsets[expectedOffsets.Length - 1].Index + expectedOffsets[expectedOffsets.Length - 1].Length, codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), expectedIds.Length, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(expectedIds.Length, tokenCount);
|
||||
|
||||
Assert.Equal(expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Index + expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Length, codeGenTokenizer.IndexOfTokenCount(text, expectedIdsWithSpace.Length, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(expectedIdsWithSpace.Length, tokenCount);
|
||||
Assert.Equal(expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Index + expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Length, codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), expectedIdsWithSpace.Length, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(expectedIdsWithSpace.Length, tokenCount);
|
||||
|
||||
//
|
||||
// LastIndexOf
|
||||
//
|
||||
|
||||
int expectedIndex = offsets.Length > 1 && offsets[offsets.Length - 1].Index == offsets[offsets.Length - 2].Index ? text.Length : offsets[offsets.Length - 1].Index;
|
||||
int expectedTokenCount = expectedIndex == text.Length ? 0 : 1;
|
||||
Assert.Equal(expectedIndex, codeGenTokenizer.LastIndexOfTokenCount(text, 1, out normalizedString, out tokenCount));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(expectedTokenCount, tokenCount);
|
||||
Assert.Equal(expectedIndex, codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), 1, out normalizedString, out tokenCount));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(expectedTokenCount, tokenCount);
|
||||
|
||||
Assert.Equal(expectedIndex, codeGenTokenizer.LastIndexOfTokenCount(text, 1, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(expectedTokenCount, tokenCount);
|
||||
Assert.Equal(expectedIndex, codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), 1, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(expectedTokenCount, tokenCount);
|
||||
|
||||
expectedIndex = offsets.Length > 1 && expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Index == expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 2].Index ? text.Length : expectedOffsetsWithSpace[expectedOffsetsWithSpace.Length - 1].Index;
|
||||
expectedTokenCount = expectedIndex == text.Length ? 0 : 1;
|
||||
Assert.Equal(expectedIndex, codeGenTokenizer.LastIndexOfTokenCount(text, 1, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(expectedTokenCount, tokenCount);
|
||||
Assert.Equal(expectedIndex, codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), 1, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out tokenCount));
|
||||
Assert.Null(normalizedString);
|
||||
Assert.Equal(expectedTokenCount, tokenCount);
|
||||
|
||||
//
|
||||
// Id to Token and Token to Id mapping
|
||||
//
|
||||
var tokens = codeGenTokenizer.AddPrefixSpace ? expectedTokensWithSpace : expectedTokens;
|
||||
|
||||
for (int i = 0; i < tokens.Length; i++)
|
||||
{
|
||||
Assert.Equal(tokens[i], codeGenTokenizer.MapIdToToken(ids[i]));
|
||||
Assert.Equal(ids[i], codeGenTokenizer.MapTokenToId(tokens[i]));
|
||||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[MemberData(nameof(CodeGenTestData))]
|
||||
public void TestBegginingAndEndOfSentenceEncoding(
|
||||
string text,
|
||||
string[] expectedTokens,
|
||||
(int Index, int Length)[] expectedOffsets,
|
||||
int[] expectedIds,
|
||||
string[] expectedTokensWithSpace,
|
||||
(int Index, int Length)[] expectedOffsetsWithSpace,
|
||||
int[] expectedIdsWithSpace)
|
||||
|
||||
{
|
||||
Assert.NotNull(expectedOffsets);
|
||||
Assert.NotNull(expectedOffsetsWithSpace);
|
||||
|
||||
//
|
||||
// Beginning of Sentence
|
||||
//
|
||||
|
||||
CodeGen codeGenTokenizer = (_codegen350MMonoTokenizerWithBeginningOfSentence as CodeGen)!;
|
||||
|
||||
IReadOnlyList<Token> encoding = codeGenTokenizer.Encode(text, out _);
|
||||
Assert.True(codeGenTokenizer.BeginningOfSentenceToken is not null);
|
||||
Assert.True(codeGenTokenizer.BeginningOfSentenceId.HasValue);
|
||||
var idList = new List<int>(expectedIds);
|
||||
idList.Insert(0, codeGenTokenizer.BeginningOfSentenceId!.Value);
|
||||
var tokensList = new List<string>(expectedTokens);
|
||||
tokensList.Insert(0, codeGenTokenizer.BeginningOfSentenceToken!);
|
||||
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal((0, 0), encoding[0].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), out _);
|
||||
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal((0, 0), encoding[0].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out _);
|
||||
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal((0, 0), encoding[0].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out _);
|
||||
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal((0, 0), encoding[0].Offset);
|
||||
|
||||
idList = new List<int>(expectedIdsWithSpace);
|
||||
idList.Insert(0, codeGenTokenizer.BeginningOfSentenceId!.Value);
|
||||
tokensList = new List<string>(expectedTokensWithSpace);
|
||||
tokensList.Insert(0, codeGenTokenizer.BeginningOfSentenceToken!);
|
||||
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: false, out _);
|
||||
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal((0, 0), encoding[0].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: false, out _);
|
||||
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal((0, 0), encoding[0].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
|
||||
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
|
||||
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
|
||||
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
|
||||
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
|
||||
|
||||
IReadOnlyList<int> ids = codeGenTokenizer.EncodeToIds(text);
|
||||
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text.AsSpan());
|
||||
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false);
|
||||
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false);
|
||||
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
|
||||
Assert.NotEqual(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
|
||||
Assert.NotEqual(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text, maxTokenCount: 5, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out string? normalizedString, out int textLength);
|
||||
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), maxTokenCount: 5, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out textLength);
|
||||
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
|
||||
|
||||
int tokenCount = codeGenTokenizer.CountTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
|
||||
int count = codeGenTokenizer.CountTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
|
||||
Assert.Equal(tokenCount, count);
|
||||
count = codeGenTokenizer.CountTokens(text);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
count = codeGenTokenizer.CountTokens(text.AsSpan());
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
count = codeGenTokenizer.CountTokens(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
count = codeGenTokenizer.CountTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
|
||||
int length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
Assert.Equal(text.Length, length);
|
||||
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
Assert.Equal(text.Length, length);
|
||||
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
Assert.Equal(text.Length, length);
|
||||
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
Assert.Equal(text.Length, length);
|
||||
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount, count);
|
||||
Assert.Equal(text.Length, length);
|
||||
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount, count);
|
||||
Assert.Equal(text.Length, length);
|
||||
|
||||
int index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
Assert.Equal(0, index);
|
||||
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
Assert.Equal(0, index);
|
||||
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
Assert.Equal(0, index);
|
||||
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
Assert.Equal(0, index);
|
||||
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount, count);
|
||||
Assert.Equal(0, index);
|
||||
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount, count);
|
||||
Assert.Equal(0, index);
|
||||
|
||||
//
|
||||
// End of Sentence
|
||||
//
|
||||
|
||||
codeGenTokenizer = (_codegen350MMonoTokenizerWithEndOfSentence as CodeGen)!;
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text, out _);
|
||||
Assert.True(codeGenTokenizer.EndOfSentenceToken is not null);
|
||||
Assert.True(codeGenTokenizer.EndOfSentenceId.HasValue);
|
||||
idList = new List<int>(expectedIds);
|
||||
idList.Add(codeGenTokenizer.EndOfSentenceId!.Value);
|
||||
tokensList = new List<string>(expectedTokens);
|
||||
tokensList.Add(codeGenTokenizer.EndOfSentenceToken!);
|
||||
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), out _);
|
||||
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out _);
|
||||
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out _);
|
||||
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
idList = new List<int>(expectedIdsWithSpace);
|
||||
idList.Add(codeGenTokenizer.EndOfSentenceId!.Value);
|
||||
tokensList = new List<string>(expectedTokensWithSpace);
|
||||
tokensList.Add(codeGenTokenizer.EndOfSentenceToken!);
|
||||
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: true, out _);
|
||||
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: true, out _);
|
||||
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
|
||||
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
|
||||
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
|
||||
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
|
||||
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
ids = codeGenTokenizer.EncodeToIds(text);
|
||||
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text.AsSpan());
|
||||
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true);
|
||||
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true);
|
||||
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
|
||||
Assert.NotEqual(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
|
||||
Assert.NotEqual(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out textLength);
|
||||
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out textLength);
|
||||
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
|
||||
|
||||
tokenCount = codeGenTokenizer.CountTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
|
||||
count = codeGenTokenizer.CountTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
|
||||
Assert.Equal(tokenCount, count);
|
||||
count = codeGenTokenizer.CountTokens(text);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
count = codeGenTokenizer.CountTokens(text.AsSpan());
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
count = codeGenTokenizer.CountTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
count = codeGenTokenizer.CountTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
|
||||
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
Assert.Equal(text.Length, length);
|
||||
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
Assert.Equal(text.Length, length);
|
||||
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
Assert.Equal(text.Length, length);
|
||||
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
Assert.Equal(text.Length, length);
|
||||
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount, count);
|
||||
Assert.Equal(text.Length, length);
|
||||
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount, count);
|
||||
Assert.Equal(text.Length, length);
|
||||
|
||||
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
Assert.Equal(0, index);
|
||||
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
Assert.Equal(0, index);
|
||||
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
Assert.Equal(0, index);
|
||||
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 1, count);
|
||||
Assert.Equal(0, index);
|
||||
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount, count);
|
||||
Assert.Equal(0, index);
|
||||
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount, count);
|
||||
Assert.Equal(0, index);
|
||||
|
||||
//
|
||||
// Beginning & End of Sentence
|
||||
//
|
||||
|
||||
codeGenTokenizer = (_codegen350MMonoTokenizerWithBeginningAndEndOfSentence as CodeGen)!;
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text, out _);
|
||||
Assert.True(codeGenTokenizer.BeginningOfSentenceToken is not null);
|
||||
Assert.True(codeGenTokenizer.BeginningOfSentenceId.HasValue);
|
||||
idList = new List<int>(expectedIds);
|
||||
idList.Insert(0, codeGenTokenizer.BeginningOfSentenceId!.Value);
|
||||
idList.Add(codeGenTokenizer.EndOfSentenceId!.Value);
|
||||
tokensList = new List<string>(expectedTokens);
|
||||
tokensList.Insert(0, codeGenTokenizer.BeginningOfSentenceToken!);
|
||||
tokensList.Add(codeGenTokenizer.EndOfSentenceToken!);
|
||||
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal((0, 0), encoding[0].Offset);
|
||||
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), out _);
|
||||
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal((0, 0), encoding[0].Offset);
|
||||
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out _);
|
||||
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal((0, 0), encoding[0].Offset);
|
||||
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out _);
|
||||
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal((0, 0), encoding[0].Offset);
|
||||
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
idList = new List<int>(expectedIdsWithSpace);
|
||||
idList.Insert(0, codeGenTokenizer.BeginningOfSentenceId!.Value);
|
||||
idList.Add(codeGenTokenizer.EndOfSentenceId!.Value);
|
||||
tokensList = new List<string>(expectedTokensWithSpace);
|
||||
tokensList.Insert(0, codeGenTokenizer.BeginningOfSentenceToken!);
|
||||
tokensList.Add(codeGenTokenizer.EndOfSentenceToken!);
|
||||
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
|
||||
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal((0, 0), encoding[0].Offset);
|
||||
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
|
||||
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.Equal((0, 0), encoding[0].Offset);
|
||||
Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
|
||||
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
|
||||
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
|
||||
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
|
||||
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
|
||||
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
|
||||
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
encoding = codeGenTokenizer.Encode(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
|
||||
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
|
||||
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
|
||||
Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
|
||||
Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
|
||||
|
||||
ids = codeGenTokenizer.EncodeToIds(text);
|
||||
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
|
||||
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text.AsSpan());
|
||||
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
|
||||
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true);
|
||||
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
|
||||
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true);
|
||||
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
|
||||
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
|
||||
Assert.NotEqual(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
|
||||
Assert.NotEqual(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
|
||||
Assert.NotEqual(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
|
||||
Assert.NotEqual(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out textLength);
|
||||
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
|
||||
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
|
||||
ids = codeGenTokenizer.EncodeToIds(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out textLength);
|
||||
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
|
||||
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
|
||||
|
||||
tokenCount = codeGenTokenizer.CountTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
|
||||
count = codeGenTokenizer.CountTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false);
|
||||
Assert.Equal(tokenCount, count);
|
||||
count = codeGenTokenizer.CountTokens(text);
|
||||
Assert.Equal(tokenCount + 2, count);
|
||||
count = codeGenTokenizer.CountTokens(text.AsSpan());
|
||||
Assert.Equal(tokenCount + 2, count);
|
||||
count = codeGenTokenizer.CountTokens(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true);
|
||||
Assert.Equal(tokenCount + 2, count);
|
||||
count = codeGenTokenizer.CountTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true);
|
||||
Assert.Equal(tokenCount + 2, count);
|
||||
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 2, count);
|
||||
Assert.Equal(text.Length, length);
|
||||
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 2, count);
|
||||
Assert.Equal(text.Length, length);
|
||||
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 2, count);
|
||||
Assert.Equal(text.Length, length);
|
||||
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 2, count);
|
||||
Assert.Equal(text.Length, length);
|
||||
length = codeGenTokenizer.IndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount, count);
|
||||
Assert.Equal(text.Length, length);
|
||||
length = codeGenTokenizer.IndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount, count);
|
||||
Assert.Equal(text.Length, length);
|
||||
|
||||
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 2, count);
|
||||
Assert.Equal(0, index);
|
||||
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 2, count);
|
||||
Assert.Equal(0, index);
|
||||
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 2, count);
|
||||
Assert.Equal(0, index);
|
||||
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount + 2, count);
|
||||
Assert.Equal(0, index);
|
||||
index = codeGenTokenizer.LastIndexOfTokenCount(text, maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount, count);
|
||||
Assert.Equal(0, index);
|
||||
index = codeGenTokenizer.LastIndexOfTokenCount(text.AsSpan(), maxTokenCount: 500, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out normalizedString, out count);
|
||||
Assert.Equal(tokenCount, count);
|
||||
Assert.Equal(0, index);
|
||||
}
|
||||
|
||||
private const string DefaultSpecialToken = "<|endoftext|>";
|
||||
|
||||
[Fact]
|
||||
public void TestDefaultValues()
|
||||
{
|
||||
CodeGen codeGenTokenizer = (_codegen350MMonoTokenizer as CodeGen)!;
|
||||
Assert.False(codeGenTokenizer.AddPrefixSpace);
|
||||
Assert.False(codeGenTokenizer.AddBeginningOfSentence);
|
||||
Assert.False(codeGenTokenizer.AddEndOfSentence);
|
||||
|
||||
Assert.Equal(codeGenTokenizer.MapTokenToId(DefaultSpecialToken), codeGenTokenizer.BeginningOfSentenceId!.Value);
|
||||
Assert.Equal(codeGenTokenizer.MapTokenToId(DefaultSpecialToken), codeGenTokenizer.EndOfSentenceId!.Value);
|
||||
Assert.Equal(codeGenTokenizer.MapTokenToId(DefaultSpecialToken), codeGenTokenizer.UnknownTokenId!.Value);
|
||||
|
||||
Assert.Equal(DefaultSpecialToken, codeGenTokenizer.BeginningOfSentenceToken);
|
||||
Assert.Equal(DefaultSpecialToken, codeGenTokenizer.EndOfSentenceToken);
|
||||
Assert.Equal(DefaultSpecialToken, codeGenTokenizer.UnknownToken);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(1, 0, 0, 0, 3)]
|
||||
[InlineData(2, 2, 1, 2, 1)]
|
||||
[InlineData(3, 2, 1, 2, 1)]
|
||||
[InlineData(4, 4, 3, 4, 0)]
|
||||
[InlineData(5, 4, 3, 4, 0)]
|
||||
public void TestTokenLimits(int maxTokenCount, int expectedTokenCount, int expectedTextLength, int expectedTokenCountFromEnd, int expectedTextIndexFromEnd)
|
||||
{
|
||||
// cannot split between the first two tokens nor last two tokens
|
||||
string input = "δ😀";
|
||||
int[] encodingIds = [138, 112, 47249, 222];
|
||||
(int Index, int Length)[] offsets = [(0, 0), (0, 1), (1, 0), (1, 2)];
|
||||
int calculatedLengthUsingOffsets = expectedTokenCount > 0 ? offsets[expectedTokenCount - 1].Index + offsets[expectedTokenCount - 1].Length : 0;
|
||||
|
||||
IReadOnlyList<int> ids = _codegen350MMonoTokenizer.EncodeToIds(input, maxTokenCount, out _, out int textLength);
|
||||
Assert.Equal(expectedTokenCount, ids.Count);
|
||||
Assert.Equal(expectedTextLength, textLength);
|
||||
Assert.Equal(encodingIds.Take(expectedTokenCount), ids);
|
||||
Assert.Equal(calculatedLengthUsingOffsets, textLength);
|
||||
ids = _codegen350MMonoTokenizer.EncodeToIds(input.AsSpan(), maxTokenCount, out _, out textLength);
|
||||
Assert.Equal(expectedTokenCount, ids.Count);
|
||||
Assert.Equal(expectedTextLength, textLength);
|
||||
Assert.Equal(encodingIds.Take(expectedTokenCount), ids);
|
||||
Assert.Equal(calculatedLengthUsingOffsets, textLength);
|
||||
|
||||
textLength = _codegen350MMonoTokenizer.IndexOfTokenCount(input, maxTokenCount, out _, out int tokenCount);
|
||||
Assert.Equal(expectedTokenCount, tokenCount);
|
||||
Assert.Equal(expectedTextLength, textLength);
|
||||
Assert.Equal(calculatedLengthUsingOffsets, textLength);
|
||||
textLength = _codegen350MMonoTokenizer.IndexOfTokenCount(input.AsSpan(), maxTokenCount, out _, out tokenCount);
|
||||
Assert.Equal(expectedTokenCount, tokenCount);
|
||||
Assert.Equal(expectedTextLength, textLength);
|
||||
Assert.Equal(calculatedLengthUsingOffsets, textLength);
|
||||
|
||||
calculatedLengthUsingOffsets = expectedTokenCountFromEnd > 0 ? offsets[offsets.Length - expectedTokenCountFromEnd].Index : input.Length;
|
||||
textLength = _codegen350MMonoTokenizer.LastIndexOfTokenCount(input, maxTokenCount, out _, out tokenCount);
|
||||
Assert.Equal(expectedTokenCountFromEnd, tokenCount);
|
||||
Assert.Equal(expectedTextIndexFromEnd, textLength);
|
||||
Assert.Equal(calculatedLengthUsingOffsets, textLength);
|
||||
textLength = _codegen350MMonoTokenizer.LastIndexOfTokenCount(input.AsSpan(), maxTokenCount, out _, out tokenCount);
|
||||
Assert.Equal(expectedTokenCountFromEnd, tokenCount);
|
||||
Assert.Equal(expectedTextIndexFromEnd, textLength);
|
||||
Assert.Equal(calculatedLengthUsingOffsets, textLength);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -6,20 +6,13 @@ using System;
|
|||
using System.IO;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Net.Http;
|
||||
|
||||
using Xunit;
|
||||
using System.Diagnostics;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.ML.Tokenizers.Tests
|
||||
{
|
||||
public class EnglishRobertaTests
|
||||
{
|
||||
private static readonly string _vocabUrl = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json";
|
||||
private static readonly string _mergeUrl = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe";
|
||||
private static readonly string _dictUrl = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt";
|
||||
|
||||
public static IEnumerable<object[]> BertaData
|
||||
{
|
||||
get
|
||||
|
@ -83,82 +76,62 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
}
|
||||
|
||||
private static Tokenizer? _robertaTokenizer = null;
|
||||
private async static Task<Tokenizer> GetRobertaTokenizer()
|
||||
private static Tokenizer GetRobertaTokenizer()
|
||||
{
|
||||
if (_robertaTokenizer is null)
|
||||
{
|
||||
string vocabFile = Utils.CreateTemporaryFile("json");
|
||||
string mergeFile = Utils.CreateTemporaryFile("txt");
|
||||
string translationFile = Utils.CreateTemporaryFile("txt");
|
||||
// encoder.json is same as vocab.json
|
||||
// vocab.bpe is same as merges.txt
|
||||
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json";
|
||||
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe";
|
||||
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt";
|
||||
|
||||
try
|
||||
{
|
||||
await Utils.DownloadFile(_vocabUrl, vocabFile);
|
||||
await Utils.DownloadFile(_mergeUrl, mergeFile);
|
||||
await Utils.DownloadFile(_dictUrl, translationFile);
|
||||
|
||||
_robertaTokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
|
||||
}
|
||||
finally
|
||||
{
|
||||
Utils.DeleteFile(vocabFile);
|
||||
Utils.DeleteFile(mergeFile);
|
||||
Utils.DeleteFile(translationFile);
|
||||
}
|
||||
_robertaTokenizer = new EnglishRoberta(
|
||||
Path.Combine(@"Gpt-2", "vocab.json"),
|
||||
Path.Combine(@"Gpt-2", "merges.txt"),
|
||||
Path.Combine(@"Gpt-2", "dict.txt"),
|
||||
RobertaPreTokenizer.Instance);
|
||||
}
|
||||
|
||||
return _robertaTokenizer;
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async void TokenizationTest()
|
||||
public void TokenizationTest()
|
||||
{
|
||||
string vocabFile = Utils.CreateTemporaryFile("json");
|
||||
string mergeFile = Utils.CreateTemporaryFile("txt");
|
||||
string translationFile = Utils.CreateTemporaryFile("txt");
|
||||
string[]? paths = null; ;
|
||||
// encoder.json is same as vocab.json
|
||||
// vocab.bpe is same as merges.txt
|
||||
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json";
|
||||
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe";
|
||||
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt";
|
||||
|
||||
try
|
||||
string vocabFile = Path.Combine(@"Gpt-2", "vocab.json");
|
||||
string mergeFile = Path.Combine(@"Gpt-2", "merges.txt");
|
||||
string translationFile = Path.Combine(@"Gpt-2", "dict.txt");
|
||||
|
||||
Tokenizer tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
|
||||
|
||||
TestTokenizer(tokenizer);
|
||||
TokenizerTests.TestTokenLimits(tokenizer);
|
||||
|
||||
tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance, filterUnsupportedChars: false);
|
||||
|
||||
TestTokenizer(tokenizer);
|
||||
|
||||
using Stream vocabStream = File.OpenRead(vocabFile);
|
||||
using Stream mergeStream = File.OpenRead(mergeFile);
|
||||
using Stream translationStream = File.OpenRead(translationFile);
|
||||
tokenizer = new EnglishRoberta(vocabStream, mergeStream, translationStream, RobertaPreTokenizer.Instance);
|
||||
TestTokenizer(tokenizer);
|
||||
|
||||
// Ensure caching works regardless of which method is called first.
|
||||
for (CallingOrder order = CallingOrder.Encode; order <= CallingOrder.CountTokens; order++)
|
||||
{
|
||||
await Utils.DownloadFile(_vocabUrl, vocabFile);
|
||||
await Utils.DownloadFile(_mergeUrl, mergeFile);
|
||||
await Utils.DownloadFile(_dictUrl, translationFile);
|
||||
|
||||
Tokenizer tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
|
||||
TestTokenizer(tokenizer);
|
||||
TokenizerTests.TestTokenLimits(tokenizer);
|
||||
tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
|
||||
TestTokenizer(tokenizer, order);
|
||||
|
||||
tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance, filterUnsupportedChars: false);
|
||||
TestTokenizer(tokenizer);
|
||||
|
||||
using Stream vocabStream = File.OpenRead(vocabFile);
|
||||
using Stream mergeStream = File.OpenRead(mergeFile);
|
||||
using Stream translationStream = File.OpenRead(translationFile);
|
||||
tokenizer = new EnglishRoberta(vocabStream, mergeStream, translationStream, RobertaPreTokenizer.Instance);
|
||||
TestTokenizer(tokenizer);
|
||||
|
||||
// Ensure caching works regardless of which method is called first.
|
||||
for (CallingOrder order = CallingOrder.Encode; order <= CallingOrder.CountTokens; order++)
|
||||
{
|
||||
tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance);
|
||||
TestTokenizer(tokenizer, order);
|
||||
|
||||
tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance, filterUnsupportedChars: false);
|
||||
TestTokenizer(tokenizer, order);
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
Utils.DeleteFile(vocabFile);
|
||||
Utils.DeleteFile(mergeFile);
|
||||
Utils.DeleteFile(translationFile);
|
||||
|
||||
if (paths is not null)
|
||||
{
|
||||
Utils.DeleteFile(paths[0]);
|
||||
Utils.DeleteFile(paths[1]);
|
||||
Utils.DeleteFile(paths[2]);
|
||||
}
|
||||
TestTokenizer(tokenizer, order);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -200,9 +173,9 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
|
||||
[Theory]
|
||||
[MemberData(nameof(RobertaTestData))]
|
||||
public async void TestTokenizerEncoding(string text, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds)
|
||||
public void TestTokenizerEncoding(string text, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds)
|
||||
{
|
||||
Tokenizer tokenizer = await GetRobertaTokenizer();
|
||||
Tokenizer tokenizer = GetRobertaTokenizer();
|
||||
|
||||
IReadOnlyList<Token> encoding = tokenizer.Encode(text, out _);
|
||||
IReadOnlyList<Token> encoding1 = tokenizer.Encode(text.AsSpan(), out _);
|
||||
|
|
|
@ -238,8 +238,8 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
IReadOnlyList<Token> result = GPT4.Encode(text, out string? normalizedString);
|
||||
Assert.Equal(encoded, result.Select(token => token.Id).ToArray());
|
||||
Assert.Equal(encoded.Count, idsCount);
|
||||
Assert.Equal(new string[] { "<|im_start|>", "Hello", " ⭐", "", " World", "<|im_end|>" }, result.Select(token => token.Value).ToArray());
|
||||
Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 2), (19, 0), (19, 6), (25, 10) }, result.Select(token => token.Offset).ToArray());
|
||||
Assert.Equal(new string[] { "<|im_start|>", "Hello", " ⭐", "⭐", " World", "<|im_end|>" }, result.Select(token => token.Value).ToArray());
|
||||
Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 2), (18, 1), (19, 6), (25, 10) }, result.Select(token => token.Offset).ToArray());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
|
@ -457,8 +457,8 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
yield return new object?[]
|
||||
{
|
||||
"Hello, y'all! How are you 😁 ?",
|
||||
new string[] { "Hello", ",", " y", "'all", "!", " How", " are", " you", " 😁", "", " ?" },
|
||||
new (int Index, int Length)[] { (0, 5), (5, 1), (6, 2), (8, 4), (12, 1), (13, 4), (17, 4), (21, 4), (25, 3), (28, 0), (28, 2) },
|
||||
new string[] { "Hello", ",", " y", "'all", "!", " How", " are", " you", " 😁", "😁", " ?" },
|
||||
new (int Index, int Length)[] { (0, 5), (5, 1), (6, 2), (8, 4), (12, 1), (13, 4), (17, 4), (21, 4), (25, 3), (26, 2), (28, 2) },
|
||||
new int[] { 9906, 11, 379, 65948, 0, 2650, 527, 499, 27623, 223, 949 }
|
||||
};
|
||||
}
|
||||
|
@ -533,6 +533,106 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
}
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
public static IEnumerable<object?[]> TokenizerLimitsTestData
|
||||
{
|
||||
get
|
||||
{
|
||||
// string to tokenize, produced tokens, the token offsets, the token ids
|
||||
yield return new object?[]
|
||||
{
|
||||
"Hello ⭐ World",
|
||||
new string[] { "Hello", " ⭐", "⭐", " World" },
|
||||
new (int Index, int Length)[] { (0, 5), (5, 2), (6, 1), (7, 6) },
|
||||
new int[] { 9906, 2928, 99834, 4435 }
|
||||
};
|
||||
|
||||
yield return new object?[]
|
||||
{
|
||||
"⭐", // encoded to multiple tokens
|
||||
new string[] { "⭐", "⭐" },
|
||||
new (int Index, int Length)[] { (0, 1), (0, 1) },
|
||||
new int[] { 158, 99834 }
|
||||
};
|
||||
|
||||
yield return new object?[]
|
||||
{
|
||||
"Hi 😀", // Surrogates
|
||||
new string[] { "Hi", " 😀" },
|
||||
new (int Index, int Length)[] { (0, 2), (2, 3) },
|
||||
new int[] { 13347, 91416 }
|
||||
};
|
||||
|
||||
yield return new object?[]
|
||||
{
|
||||
"⭐😀", // character encoded to multiple tokens and surrogates
|
||||
new string[] { "⭐", "⭐", "😀", "😀" },
|
||||
new (int Index, int Length)[] { (0, 1), (0, 1), (1, 2), (1, 2) },
|
||||
new int[] { 158, 99834, 76460, 222 }
|
||||
};
|
||||
|
||||
yield return new object?[]
|
||||
{
|
||||
"From: Adele Vance\nSubject: TestSubject\nTestBodyContent",
|
||||
new string[] { "From", ":", " Ade", "le", " Vance", "\n", "Subject", ":", " Test", "Subject", "\n", "Test", "Body", "Content" },
|
||||
new (int Index, int Length)[] { (0, 4), (4, 1), (5, 4), (9, 2), (11, 6), (17, 1), (18, 7), (25, 1), (26, 5), (31, 7), (38, 1), (39, 4), (43, 4), (47, 7)},
|
||||
new int[] { 3915, 25, 63140, 273, 92368, 198, 13317, 25, 3475, 13317, 198, 2323, 5561, 2831 }
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[MemberData(nameof(TokenizerLimitsTestData))]
|
||||
public void TestPreciseTokenLimits(string text, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds)
|
||||
{
|
||||
IReadOnlyList<Token> result = GPT4.Encode(text, out _);
|
||||
int[] ids = result.Select(r => r.Id).ToArray();
|
||||
(int Index, int Length)[] offsets = result.Select(r => r.Offset).ToArray();
|
||||
Assert.Equal(expectedTokens, result.Select(r => r.Value));
|
||||
Assert.Equal(expectedIds, ids);
|
||||
Assert.Equal(expectedOffsets, offsets);
|
||||
Assert.Equal(expectedIds, GPT4.EncodeToIds(text));
|
||||
Assert.Equal(expectedIds.Length, GPT4.CountTokens(text));
|
||||
|
||||
for (int tokenCount = 1; tokenCount <= ids.Length; tokenCount++)
|
||||
{
|
||||
int length = GPT4.IndexOfTokenCount(text, tokenCount, out _, out int count);
|
||||
Assert.True(count <= ids.Length);
|
||||
|
||||
if (count < tokenCount)
|
||||
{
|
||||
Assert.True(count < ids.Length - 1);
|
||||
Assert.True(offsets[count + 1].Index < offsets[count].Index + offsets[count].Length);
|
||||
}
|
||||
|
||||
if (count > 0)
|
||||
{
|
||||
Assert.Equal(offsets[count - 1].Index + offsets[count - 1].Length, length);
|
||||
}
|
||||
else
|
||||
{
|
||||
Assert.Equal(0, length);
|
||||
}
|
||||
|
||||
int index = GPT4.LastIndexOfTokenCount(text, tokenCount, out _, out count);
|
||||
Assert.True(count <= ids.Length);
|
||||
|
||||
if (count < tokenCount)
|
||||
{
|
||||
Assert.True(ids.Length - tokenCount > 0);
|
||||
Assert.True(offsets[offsets.Length - tokenCount].Index < offsets[offsets.Length - tokenCount - 1].Index + offsets[offsets.Length - tokenCount - 1].Length);
|
||||
}
|
||||
|
||||
if (count > 0)
|
||||
{
|
||||
Assert.Equal(offsets[offsets.Length - count].Index, index);
|
||||
}
|
||||
else
|
||||
{
|
||||
Assert.Equal(text.Length, index);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче