Add Span support in tokenizer's Model abstraction (#7035)
* Add Span support in tokenizer's Model abstraction * Address the feedback * Use stackalloc instead of the ArrayPool
This commit is contained in:
Родитель
c6f5397963
Коммит
99c620ad96
|
@ -3,8 +3,10 @@
|
|||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
|
@ -34,20 +36,21 @@ namespace Microsoft.ML.Tokenizers
|
|||
{
|
||||
_unknownToken = value;
|
||||
|
||||
if (value is null)
|
||||
if (VocabReverse.TryGetValue(0, out string? v))
|
||||
{
|
||||
if (VocabReverse.TryGetValue(0, out string? v))
|
||||
if (v == value)
|
||||
{
|
||||
VocabReverse.Remove(0);
|
||||
if (_vocab.TryGetValue(v, out int id))
|
||||
{
|
||||
_vocab.Remove(v);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
VocabReverse.Remove(0);
|
||||
_vocab.Remove(new StringSpanOrdinalKey(v));
|
||||
}
|
||||
else
|
||||
|
||||
|
||||
if (value is not null)
|
||||
{
|
||||
_vocab[value] = 0;
|
||||
_vocab[new StringSpanOrdinalKey(value)] = 0;
|
||||
VocabReverse[0] = value;
|
||||
}
|
||||
}
|
||||
|
@ -68,7 +71,6 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// </summary>
|
||||
public bool FuseUnknownTokens { get; }
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Construct a new Bpe model object to use for text encoding.
|
||||
/// </summary>
|
||||
|
@ -111,23 +113,19 @@ namespace Microsoft.ML.Tokenizers
|
|||
ContinuingSubwordPrefix = continuingSubwordPrefix;
|
||||
EndOfWordSuffix = endOfWordSuffix;
|
||||
|
||||
(Dictionary<string, int>? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream);
|
||||
_vocab = vocab1 ?? new Dictionary<string, int>();
|
||||
Cache = new Cache<string, Word>();
|
||||
(Dictionary<StringSpanOrdinalKey, int>? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream);
|
||||
_vocab = vocab1 ?? new Dictionary<StringSpanOrdinalKey, int>();
|
||||
Cache = new StringSpanOrdinalKeyCache<Word>();
|
||||
|
||||
VocabReverse = new();
|
||||
|
||||
foreach (KeyValuePair<string, int> kvp in Vocab)
|
||||
foreach (KeyValuePair<StringSpanOrdinalKey, int> kvp in _vocab)
|
||||
{
|
||||
VocabReverse.Add(kvp.Value, kvp.Key);
|
||||
VocabReverse.Add(kvp.Value, kvp.Key.Data!);
|
||||
}
|
||||
|
||||
if (unknownToken is null && VocabReverse.TryGetValue(0, out string? unkToken))
|
||||
{
|
||||
unknownToken = unkToken;
|
||||
}
|
||||
|
||||
UnknownToken = unknownToken;
|
||||
UnknownToken = unknownToken ?? (VocabReverse.TryGetValue(0, out string? unkToken) ? unkToken : null);
|
||||
|
||||
int prefixLen = ContinuingSubwordPrefix is null ? 0 : ContinuingSubwordPrefix.Length;
|
||||
|
||||
|
@ -197,7 +195,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="text">The text to split.</param>
|
||||
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
|
||||
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
|
||||
public override void EncodeToIds(string text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds);
|
||||
public override void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds);
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
|
@ -205,7 +203,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="isSpecialToken">Indicate if the token is special token.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
public override int CountTokens(string text, bool isSpecialToken) => EncodeToIdsWithCache(text, null);
|
||||
public override int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken) => EncodeToIdsWithCache(text, null);
|
||||
|
||||
/// <summary>
|
||||
/// Map the token to encoded Id.
|
||||
|
@ -213,15 +211,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="token">The token to map to the Id.</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
|
||||
/// <returns>The mapped Id of the token.</returns>
|
||||
public override int? MapTokenToId(string token, bool considerSpecialTokens = true)
|
||||
{
|
||||
if (_vocab.TryGetValue(token, out int value))
|
||||
{
|
||||
return value;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
public override int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out int value) ? value : null;
|
||||
|
||||
/// <summary>
|
||||
/// Map the encoded Id to the token.
|
||||
|
@ -242,24 +232,27 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <summary>
|
||||
/// Gets the dictionary mapping tokens to Ids.
|
||||
/// </summary>
|
||||
public IReadOnlyDictionary<string, int> Vocab => _vocab;
|
||||
public IReadOnlyDictionary<string, int> Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
|
||||
|
||||
/// Read the given files to extract the vocab and merges
|
||||
internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges)
|
||||
internal static (Dictionary<StringSpanOrdinalKey, int>?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges)
|
||||
{
|
||||
Dictionary<string, int>? dic = JsonSerializer.Deserialize<Dictionary<string, int>>(vocab) as Dictionary<string, int>;
|
||||
JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } };
|
||||
Dictionary<StringSpanOrdinalKey, int>? dic = JsonSerializer.Deserialize<Dictionary<StringSpanOrdinalKey, int>>(vocab, options) as Dictionary<StringSpanOrdinalKey, int>;
|
||||
|
||||
return (dic, ConvertMergesToHashmap(merges));
|
||||
}
|
||||
|
||||
/// The vocabulary assigns a number to each token.
|
||||
private readonly Dictionary<string, int> _vocab;
|
||||
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;
|
||||
|
||||
private Dictionary<string, int>? _vocabOriginal;
|
||||
|
||||
/// Contains the mapping between Pairs and their (rank, newId).
|
||||
internal Dictionary<Pair<int>, (int, int)> Merges { get; }
|
||||
|
||||
/// Contains the cache for optimizing the encoding step.
|
||||
internal Cache<string, Word>? Cache { get; }
|
||||
internal StringSpanOrdinalKeyCache<Word>? Cache { get; }
|
||||
|
||||
internal static readonly int DefaultCacheCapacity = 10_000;
|
||||
|
||||
|
@ -309,9 +302,6 @@ namespace Microsoft.ML.Tokenizers
|
|||
return merges;
|
||||
}
|
||||
|
||||
/// Reset the cache.
|
||||
internal void ClearCache() => Cache?.Clear();
|
||||
|
||||
private readonly Dictionary<char, string> _charToString = new Dictionary<char, string>();
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
|
@ -327,38 +317,68 @@ namespace Microsoft.ML.Tokenizers
|
|||
return s;
|
||||
}
|
||||
|
||||
internal Word MergeWord(string w)
|
||||
internal Word MergeWord(ReadOnlySpan<char> w)
|
||||
{
|
||||
Word word = Word.WithCapacity(w.Length);
|
||||
(int Id, int Len)? unk = null;
|
||||
int i = 0;
|
||||
|
||||
Span<char> buffer = stackalloc char[256];
|
||||
scoped ReadOnlySpan<char> s;
|
||||
|
||||
while (i < w.Length)
|
||||
{
|
||||
int length;
|
||||
string s;
|
||||
|
||||
if (Char.IsHighSurrogate(w[i]) && i < w.Length - 1 && Char.IsLowSurrogate(w[i + 1]))
|
||||
{
|
||||
length = 2;
|
||||
s = w.Substring(i, length);
|
||||
s = w.Slice(i, 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
length = 1;
|
||||
s = CharToString(w[i]);
|
||||
s = w.Slice(i, 1);
|
||||
}
|
||||
|
||||
// Add the `continuing_subword_prefix` if relevant
|
||||
if (i > 0 && ContinuingSubwordPrefix is not null)
|
||||
{
|
||||
s = $"{ContinuingSubwordPrefix}{s}";
|
||||
if (ContinuingSubwordPrefix.Length + s.Length <= buffer.Length)
|
||||
{
|
||||
ContinuingSubwordPrefix.AsSpan().CopyTo(buffer);
|
||||
s.CopyTo(buffer.Slice(ContinuingSubwordPrefix.Length));
|
||||
s = buffer.Slice(0, ContinuingSubwordPrefix.Length + s.Length);
|
||||
}
|
||||
else
|
||||
{
|
||||
#if NETCOREAPP
|
||||
s = $"{ContinuingSubwordPrefix}{s}".AsSpan();
|
||||
#else
|
||||
string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString();
|
||||
s = $"{ContinuingSubwordPrefix}{s1}".AsSpan();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// Add the `end_of_word_suffix` if relevant
|
||||
if (i + length >= w.Length && EndOfWordSuffix is not null)
|
||||
{
|
||||
s = $"{s}{EndOfWordSuffix}";
|
||||
if (s.Length + EndOfWordSuffix.Length <= buffer.Length)
|
||||
{
|
||||
s.CopyTo(buffer);
|
||||
EndOfWordSuffix.AsSpan().CopyTo(buffer.Slice(s.Length));
|
||||
s = buffer.Slice(0, s.Length + EndOfWordSuffix.Length);
|
||||
}
|
||||
else
|
||||
{
|
||||
#if NETCOREAPP
|
||||
s = $"{s}{EndOfWordSuffix}".AsSpan();
|
||||
#else
|
||||
string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString();
|
||||
s = $"{s1}{EndOfWordSuffix}".AsSpan();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
if (_vocab.TryGetValue(s, out int id))
|
||||
|
@ -419,17 +439,17 @@ namespace Microsoft.ML.Tokenizers
|
|||
Word word;
|
||||
if (Cache is not null)
|
||||
{
|
||||
if (Cache.TryGet(text, out word))
|
||||
if (Cache.TryGetValue(text, out word))
|
||||
{
|
||||
return WordToTokens(ref word);
|
||||
}
|
||||
|
||||
word = MergeWord(text);
|
||||
word = MergeWord(text.AsSpan());
|
||||
Cache.Set(text, word);
|
||||
}
|
||||
else
|
||||
{
|
||||
word = MergeWord(text);
|
||||
word = MergeWord(text.AsSpan());
|
||||
}
|
||||
|
||||
return WordToTokens(ref word);
|
||||
|
@ -445,19 +465,19 @@ namespace Microsoft.ML.Tokenizers
|
|||
return word.SymbolsCount;
|
||||
}
|
||||
|
||||
internal int EncodeToIdsWithCache(string text, IList<int>? accumulatedIds)
|
||||
internal int EncodeToIdsWithCache(ReadOnlySpan<char> text, IList<int>? accumulatedIds)
|
||||
{
|
||||
Word word;
|
||||
|
||||
if (Cache is not null)
|
||||
{
|
||||
if (Cache.TryGet(text, out Word hit))
|
||||
if (Cache.TryGetValue(text, out Word hit))
|
||||
{
|
||||
return WordToIds(ref hit, accumulatedIds);
|
||||
}
|
||||
|
||||
word = MergeWord(text);
|
||||
Cache.Set(text, word);
|
||||
Cache.Set(text.ToString(), word);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
|
@ -4,113 +4,54 @@
|
|||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
|
||||
namespace Microsoft.ML.Tokenizers
|
||||
{
|
||||
internal sealed class Cache<TKey, TValue> where TKey : notnull where TValue : notnull
|
||||
{
|
||||
private readonly int _capacity;
|
||||
private readonly Dictionary<TKey, TValue> _map;
|
||||
private object SyncObj => _map;
|
||||
|
||||
internal Cache() : this(Bpe.DefaultCacheCapacity) { }
|
||||
|
||||
internal Cache(int capacity)
|
||||
{
|
||||
Capacity = capacity;
|
||||
Map = new Dictionary<TKey, TValue>(Capacity);
|
||||
_capacity = capacity;
|
||||
_map = new Dictionary<TKey, TValue>(capacity);
|
||||
}
|
||||
|
||||
private readonly object _lock = new();
|
||||
|
||||
internal Dictionary<TKey, TValue> Map { get; set; }
|
||||
|
||||
internal int Capacity { get; set; }
|
||||
|
||||
internal void Fresh() => Map = new Dictionary<TKey, TValue>(Capacity);
|
||||
|
||||
internal void Clear()
|
||||
internal bool TryGetValue(TKey key, out TValue value)
|
||||
{
|
||||
lock (_lock)
|
||||
lock (SyncObj)
|
||||
{
|
||||
Map.Clear();
|
||||
}
|
||||
}
|
||||
|
||||
internal List<TValue> GetValues(IEnumerable<TKey> keys)
|
||||
{
|
||||
List<TValue> values = new();
|
||||
lock (_lock)
|
||||
{
|
||||
foreach (TKey key in keys)
|
||||
{
|
||||
if (Map.TryGetValue(key, out TValue? value))
|
||||
{
|
||||
values.Add(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return values;
|
||||
}
|
||||
|
||||
internal bool TryGet(TKey key, out TValue value)
|
||||
{
|
||||
lock (_lock)
|
||||
{
|
||||
return Map.TryGetValue(key, out value!);
|
||||
}
|
||||
}
|
||||
|
||||
internal void SetValues(IEnumerable<(TKey, TValue)> entries)
|
||||
{
|
||||
lock (_lock)
|
||||
{
|
||||
foreach ((TKey, TValue) entry in entries)
|
||||
{
|
||||
if (Capacity <= Map.Count)
|
||||
{
|
||||
break;
|
||||
}
|
||||
Map[entry.Item1] = entry.Item2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal void Set(TKey k, TValue v)
|
||||
{
|
||||
lock (_lock)
|
||||
{
|
||||
if (Capacity > Map.Count)
|
||||
{
|
||||
Map[k] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal KeyValuePair<TKey, TValue>[] ToArray()
|
||||
{
|
||||
lock (_lock)
|
||||
{
|
||||
return Map.ToArray();
|
||||
return _map.TryGetValue(key, out value!);
|
||||
}
|
||||
}
|
||||
|
||||
internal TValue GetOrAdd(TKey key, TValue value)
|
||||
{
|
||||
lock (_lock)
|
||||
lock (SyncObj)
|
||||
{
|
||||
if (Map.TryGetValue(key, out TValue? v))
|
||||
if (_map.TryGetValue(key, out TValue? v))
|
||||
{
|
||||
return v;
|
||||
}
|
||||
|
||||
if (Capacity > Map.Count)
|
||||
{
|
||||
Map[key] = value;
|
||||
return v!;
|
||||
}
|
||||
|
||||
_map[key] = value;
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
internal void Set(TKey key, TValue value)
|
||||
{
|
||||
lock (SyncObj)
|
||||
{
|
||||
if (_map.Count < _capacity)
|
||||
{
|
||||
_map[key] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,13 +18,14 @@ namespace Microsoft.ML.Tokenizers
|
|||
public sealed class EnglishRoberta : Model
|
||||
{
|
||||
private readonly HighestOccurrenceMapping _vocabIdToHighestOccurrence;
|
||||
private readonly IReadOnlyDictionary<string, int> _vocab;
|
||||
private readonly SortedDictionary<int, string> _vocabReverse;
|
||||
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;
|
||||
private Dictionary<string, int>? _vocabOriginal;
|
||||
private readonly SortedDictionary<int, StringSpanOrdinalKey> _vocabReverse;
|
||||
private readonly Cache<(string, string), int> _mergeRanks;
|
||||
private readonly IReadOnlyDictionary<char, char> _byteToUnicode;
|
||||
private readonly IReadOnlyDictionary<char, char> _unicodeToByte;
|
||||
private readonly string[] _charToString;
|
||||
private readonly Cache<string, List<Token>> _cache;
|
||||
private readonly StringSpanOrdinalKeyCache<List<Token>> _cache;
|
||||
|
||||
/// <summary>
|
||||
/// Indicate if want to filter the unsupported characters during the decoding.
|
||||
|
@ -77,7 +78,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
|
||||
_unicodeToByte = _byteToUnicode.Reverse();
|
||||
_cache = new Cache<string, List<Token>>();
|
||||
_cache = new StringSpanOrdinalKeyCache<List<Token>>();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -118,13 +119,13 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
|
||||
_unicodeToByte = _byteToUnicode.Reverse();
|
||||
_cache = new Cache<string, List<Token>>();
|
||||
_cache = new StringSpanOrdinalKeyCache<List<Token>>();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the dictionary mapping tokens to Ids.
|
||||
/// </summary>
|
||||
public IReadOnlyDictionary<string, int> Vocab => _vocab;
|
||||
public IReadOnlyDictionary<string, int> Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
|
||||
|
||||
//
|
||||
// Public Model interfaces implementation
|
||||
|
@ -145,14 +146,15 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
if (_vocabReverse.TryGetValue(id, out var value))
|
||||
{
|
||||
string v = value.Data!;
|
||||
if (FilterUnsupportedChars)
|
||||
{
|
||||
char[] buffer = ArrayPool<char>.Shared.Rent(value.Length);
|
||||
char[] buffer = ArrayPool<char>.Shared.Rent(v.Length);
|
||||
int i = 0;
|
||||
|
||||
for (int j = 0; j < value.Length; j++)
|
||||
for (int j = 0; j < v.Length; j++)
|
||||
{
|
||||
if (_unicodeToByte.TryGetValue(value[j], out var c))
|
||||
if (_unicodeToByte.TryGetValue(v[j], out var c))
|
||||
{
|
||||
buffer[i++] = c;
|
||||
}
|
||||
|
@ -164,7 +166,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
else
|
||||
{
|
||||
return value;
|
||||
return v;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -205,7 +207,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
return Array.Empty<Token>();
|
||||
}
|
||||
|
||||
if (_cache.TryGet(text, out List<Token>? hit))
|
||||
if (_cache.TryGetValue(text, out List<Token>? hit))
|
||||
{
|
||||
ArrayPool<char>.Shared.Return(token);
|
||||
ArrayPool<int>.Shared.Return(indexMapping);
|
||||
|
@ -225,7 +227,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="text">The text to split.</param>
|
||||
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
|
||||
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
|
||||
public override void EncodeToIds(string text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIds(text, accumulatedIds);
|
||||
public override void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIds(text, accumulatedIds);
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
|
@ -233,16 +235,16 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="isSpecialToken">Indicate if the token is special token.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
public override int CountTokens(string text, bool isSpecialToken) => EncodeToIds(text, null);
|
||||
public override int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken) => EncodeToIds(text, null);
|
||||
|
||||
private int EncodeToIds(string text, IList<int>? accumulatedIds)
|
||||
private int EncodeToIds(ReadOnlySpan<char> text, IList<int>? accumulatedIds)
|
||||
{
|
||||
if (string.IsNullOrEmpty(text))
|
||||
if (text.IsEmpty)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (_cache.TryGet(text, out List<Token>? hit))
|
||||
if (_cache.TryGetValue(text, out List<Token>? hit))
|
||||
{
|
||||
if (accumulatedIds is not null)
|
||||
{
|
||||
|
@ -255,17 +257,41 @@ namespace Microsoft.ML.Tokenizers
|
|||
return hit.Count;
|
||||
}
|
||||
|
||||
// If the cache doesn't have the text, then encode it and add it to the cache
|
||||
IReadOnlyList<Token> tokens = Encode(text);
|
||||
char[] token = ArrayPool<char>.Shared.Rent(text.Length);
|
||||
int[] indexMapping = ArrayPool<int>.Shared.Rent(text.Length);
|
||||
|
||||
int newTokenIndex = 0;
|
||||
for (int i = 0; i < text.Length; i++)
|
||||
{
|
||||
if (_byteToUnicode.TryGetValue(text[i], out var value))
|
||||
{
|
||||
token[newTokenIndex] = value;
|
||||
indexMapping[newTokenIndex] = i;
|
||||
newTokenIndex++;
|
||||
}
|
||||
}
|
||||
|
||||
if (newTokenIndex == 0)
|
||||
{
|
||||
ArrayPool<char>.Shared.Return(token);
|
||||
ArrayPool<int>.Shared.Return(indexMapping);
|
||||
return 0;
|
||||
}
|
||||
|
||||
List<Token> result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
|
||||
_cache.Set(text.ToString(), result);
|
||||
ArrayPool<char>.Shared.Return(token);
|
||||
ArrayPool<int>.Shared.Return(indexMapping);
|
||||
|
||||
if (accumulatedIds is not null)
|
||||
{
|
||||
foreach (var t in tokens)
|
||||
foreach (var t in result)
|
||||
{
|
||||
accumulatedIds.Add(t.Id);
|
||||
}
|
||||
}
|
||||
|
||||
return tokens.Count;
|
||||
return result.Count;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -274,7 +300,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="token">The token to map to the Id.</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
|
||||
/// <returns>The mapped Id of the token.</returns>
|
||||
public override int? MapTokenToId(string token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out var value) ? value : null;
|
||||
public override int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out int value) ? value : null;
|
||||
|
||||
/// <summary>
|
||||
/// Convert a list of tokens Ids to highest occurrence rankings.
|
||||
|
@ -397,12 +423,13 @@ namespace Microsoft.ML.Tokenizers
|
|||
private static HighestOccurrenceMapping GetHighestOccurrenceMapping(Stream highestOccurrenceMappingStream) =>
|
||||
HighestOccurrenceMapping.Load(highestOccurrenceMappingStream);
|
||||
|
||||
private Dictionary<string, int> GetVocabulary(Stream vocabularyStream)
|
||||
private Dictionary<StringSpanOrdinalKey, int> GetVocabulary(Stream vocabularyStream)
|
||||
{
|
||||
Dictionary<string, int>? vocab;
|
||||
Dictionary<StringSpanOrdinalKey, int>? vocab;
|
||||
try
|
||||
{
|
||||
vocab = JsonSerializer.Deserialize<Dictionary<string, int>>(vocabularyStream) as Dictionary<string, int>;
|
||||
JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } };
|
||||
vocab = JsonSerializer.Deserialize<Dictionary<StringSpanOrdinalKey, int>>(vocabularyStream, options) as Dictionary<StringSpanOrdinalKey, int>;
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
|
@ -416,22 +443,22 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
if (_vocabIdToHighestOccurrence.BosWord is not null)
|
||||
{
|
||||
vocab[_vocabIdToHighestOccurrence.BosWord] = -_vocabIdToHighestOccurrence.BosIndex;
|
||||
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.BosWord)] = -_vocabIdToHighestOccurrence.BosIndex;
|
||||
}
|
||||
|
||||
if (_vocabIdToHighestOccurrence.EosWord is not null)
|
||||
{
|
||||
vocab[_vocabIdToHighestOccurrence.EosWord] = -_vocabIdToHighestOccurrence.EosIndex;
|
||||
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.EosWord)] = -_vocabIdToHighestOccurrence.EosIndex;
|
||||
}
|
||||
|
||||
if (_vocabIdToHighestOccurrence.UnkWord is not null)
|
||||
{
|
||||
vocab[_vocabIdToHighestOccurrence.UnkWord] = -_vocabIdToHighestOccurrence.UnkIndex;
|
||||
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.UnkWord)] = -_vocabIdToHighestOccurrence.UnkIndex;
|
||||
}
|
||||
|
||||
if (_vocabIdToHighestOccurrence.PadWord is not null)
|
||||
{
|
||||
vocab[_vocabIdToHighestOccurrence.PadWord] = -_vocabIdToHighestOccurrence.PadIndex;
|
||||
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.PadWord)] = -_vocabIdToHighestOccurrence.PadIndex;
|
||||
}
|
||||
|
||||
return vocab;
|
||||
|
@ -510,7 +537,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
if (token.Length == 1)
|
||||
{
|
||||
string tokenValue = _charToString[token[0]];
|
||||
return new List<Token> { new Token(_vocab[tokenValue], tokenValue, (indexMapping[0], 1)) };
|
||||
return new List<Token> { new Token(_vocab[new StringSpanOrdinalKey(tokenValue)], tokenValue, (indexMapping[0], 1)) };
|
||||
}
|
||||
|
||||
List<string> word = new(token.Length);
|
||||
|
@ -539,7 +566,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
// get the most frequent bi-gram pair
|
||||
var (first, second) = pairs.ArgMin(pair => _mergeRanks.GetOrAdd(pair, int.MaxValue));
|
||||
if (!_mergeRanks.TryGet((first, second), out int _))
|
||||
if (!_mergeRanks.TryGetValue((first, second), out int _))
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
@ -599,7 +626,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
foreach (string w in word)
|
||||
{
|
||||
tokens.Add(new Token(_vocab[w], w, (indexMapping[index], w.Length)));
|
||||
tokens.Add(new Token(_vocab[new StringSpanOrdinalKey(w)], w, (indexMapping[index], w.Length)));
|
||||
index += w.Length;
|
||||
}
|
||||
|
||||
|
|
|
@ -31,14 +31,15 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// This method does the default implementation that uses the Encode method to get the token's Ids.
|
||||
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
|
||||
/// </remarks>
|
||||
public virtual void EncodeToIds(string text, bool isSpecialToken, IList<int> accumulatedIds)
|
||||
public virtual void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds)
|
||||
{
|
||||
if (accumulatedIds is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(accumulatedIds));
|
||||
}
|
||||
|
||||
var tokens = Encode(text);
|
||||
// Default implementation is not optimized for memory allocation. It is recommended to override this method for the sake of the performance.
|
||||
var tokens = Encode(text.ToString());
|
||||
foreach (var token in tokens)
|
||||
{
|
||||
accumulatedIds.Add(token.Id);
|
||||
|
@ -55,7 +56,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids.
|
||||
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
|
||||
/// </remarks>
|
||||
public virtual int CountTokens(string text, bool isSpecialToken)
|
||||
public virtual int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken)
|
||||
{
|
||||
var ids = new List<int>();
|
||||
EncodeToIds(text, isSpecialToken, ids);
|
||||
|
@ -68,7 +69,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="token">The token to map to Id</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
|
||||
/// <returns>The mapped Id of the token.</returns>
|
||||
public abstract int? MapTokenToId(string token, bool considerSpecialTokens = true);
|
||||
public abstract int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true);
|
||||
|
||||
/// <summary>
|
||||
/// Map the encoded Id to the token.
|
||||
|
|
|
@ -19,12 +19,14 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// </summary>
|
||||
public sealed class Tiktoken : Model
|
||||
{
|
||||
private readonly Dictionary<ReadOnlyMemory<byte>, int> _encoder = null!;
|
||||
private readonly Dictionary<int, ReadOnlyMemory<byte>> _decoder = null!;
|
||||
private readonly LruCache<string, int[]>? _cache;
|
||||
private readonly IReadOnlyDictionary<string, int>? _specialTokensEncoder;
|
||||
private readonly Dictionary<ReadOnlyMemory<byte>, int> _encoder;
|
||||
private readonly Dictionary<int, ReadOnlyMemory<byte>> _decoder;
|
||||
private readonly LruCache<int[]> _cache;
|
||||
private readonly Dictionary<StringSpanOrdinalKey, int>? _specialTokensEncoder;
|
||||
private Dictionary<string, int>? _specialTokensEncoderOriginal;
|
||||
private readonly Dictionary<int, string>? _specialTokensDecoder;
|
||||
private readonly Dictionary<string, int> _vocab = null!;
|
||||
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;
|
||||
private IReadOnlyDictionary<string, int>? _vocabOriginal;
|
||||
|
||||
/// <summary>
|
||||
/// Create a new Tiktoken tokenizer's model object.
|
||||
|
@ -34,7 +36,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="cacheSize">The size of the cache to use.</param>
|
||||
/// <exception cref="ArgumentNullException">Thrown when <paramref name="vocabFilePath"/> is null or empty.</exception>
|
||||
/// <exception cref="InvalidOperationException">Thrown when failed to load the BPE vocab file.</exception>
|
||||
public Tiktoken(string vocabFilePath, IReadOnlyDictionary<string, int>? specialTokens = null, int cacheSize = LruCache<string, int[]>.DefaultCacheSize) :
|
||||
public Tiktoken(string vocabFilePath, IReadOnlyDictionary<string, int>? specialTokens = null, int cacheSize = LruCache<int[]>.DefaultCacheSize) :
|
||||
this(string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath), specialTokens, cacheSize, disposeStream: true)
|
||||
{
|
||||
}
|
||||
|
@ -47,7 +49,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="cacheSize">The size of the cache to use.</param>
|
||||
/// <exception cref="ArgumentNullException">Thrown when <paramref name="vocabStream"/> is null or empty.</exception>
|
||||
/// <exception cref="InvalidOperationException">Thrown when failed to load the BPE vocab file.</exception>
|
||||
public Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTokens = null, int cacheSize = LruCache<string, int[]>.DefaultCacheSize) :
|
||||
public Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTokens = null, int cacheSize = LruCache<int[]>.DefaultCacheSize) :
|
||||
this(vocabStream ?? throw new ArgumentNullException(nameof(vocabStream)), specialTokens, cacheSize, disposeStream: false)
|
||||
{
|
||||
}
|
||||
|
@ -63,9 +65,9 @@ namespace Microsoft.ML.Tokenizers
|
|||
internal Tiktoken(
|
||||
Dictionary<ReadOnlyMemory<byte>, int> encoder,
|
||||
Dictionary<int, ReadOnlyMemory<byte>> decoder,
|
||||
Dictionary<string, int> vocab,
|
||||
Dictionary<StringSpanOrdinalKey, int> vocab,
|
||||
IReadOnlyDictionary<string, int>? specialTokens,
|
||||
int cacheSize = LruCache<string, int[]>.DefaultCacheSize) : this(cacheSize)
|
||||
int cacheSize = LruCache<int[]>.DefaultCacheSize)
|
||||
{
|
||||
_encoder = encoder ?? throw new ArgumentNullException(nameof(encoder));
|
||||
_decoder = decoder ?? throw new ArgumentNullException(nameof(decoder));
|
||||
|
@ -73,24 +75,21 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
Debug.Assert(encoder.Count == decoder.Count);
|
||||
|
||||
_specialTokensEncoder = specialTokens;
|
||||
if (_specialTokensEncoder is not null)
|
||||
{
|
||||
_specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
|
||||
}
|
||||
_encoder = encoder!;
|
||||
_decoder = decoder!;
|
||||
_vocab = vocab!;
|
||||
_cache = new LruCache<int[]>(cacheSize);
|
||||
|
||||
(_specialTokensEncoder, _specialTokensDecoder) = CreateEncoderDecoder(specialTokens);
|
||||
}
|
||||
|
||||
private Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTokens, int cacheSize, bool disposeStream) : this(cacheSize)
|
||||
private Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTokens, int cacheSize, bool disposeStream)
|
||||
{
|
||||
try
|
||||
{
|
||||
_cache = new LruCache<int[]>(cacheSize);
|
||||
(_encoder, _vocab, _decoder) = LoadTikTokenBpeAsync(vocabStream, useAsync: false).GetAwaiter().GetResult();
|
||||
|
||||
_specialTokensEncoder = specialTokens;
|
||||
if (_specialTokensEncoder is not null)
|
||||
{
|
||||
_specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
|
||||
}
|
||||
(_specialTokensEncoder, _specialTokensDecoder) = CreateEncoderDecoder(specialTokens);
|
||||
}
|
||||
finally
|
||||
{
|
||||
|
@ -101,17 +100,15 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
}
|
||||
|
||||
private Tiktoken(int cacheSize)
|
||||
private static (Dictionary<StringSpanOrdinalKey, int>?, Dictionary<int, string>?) CreateEncoderDecoder(IReadOnlyDictionary<string, int>? specialTokens)
|
||||
{
|
||||
if (cacheSize < 0)
|
||||
if (specialTokens is not null)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(cacheSize));
|
||||
var encoder = specialTokens.ToDictionary(e => new StringSpanOrdinalKey(e.Key), e => e.Value);
|
||||
return (encoder, encoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key.Data!));
|
||||
}
|
||||
|
||||
if (cacheSize > 0)
|
||||
{
|
||||
_cache = new LruCache<string, int[]>(cacheSize);
|
||||
}
|
||||
return (null, null);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -125,7 +122,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
public static async Task<Tiktoken> CreateAsync(
|
||||
Stream vocabStream,
|
||||
IReadOnlyDictionary<string, int>? specialTokens = null,
|
||||
int cacheSize = LruCache<string, int[]>.DefaultCacheSize,
|
||||
int cacheSize = LruCache<int[]>.DefaultCacheSize,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (vocabStream is null)
|
||||
|
@ -133,7 +130,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
throw new ArgumentNullException(nameof(vocabStream));
|
||||
}
|
||||
|
||||
(Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<string, int> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) =
|
||||
(Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) =
|
||||
await LoadTikTokenBpeAsync(vocabStream, useAsync: true, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
return new Tiktoken(encoder, decoder, vocab, specialTokens, cacheSize);
|
||||
|
@ -150,7 +147,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
public static async Task<Tiktoken> CreateAsync(
|
||||
string vocabFilePath,
|
||||
IReadOnlyDictionary<string, int>? specialTokensEncoder = null,
|
||||
int cacheSize = LruCache<string, int[]>.DefaultCacheSize,
|
||||
int cacheSize = LruCache<int[]>.DefaultCacheSize,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (vocabFilePath is null)
|
||||
|
@ -170,11 +167,11 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
|
||||
/// <returns>Map of byte[] to integer token id</returns>
|
||||
/// <exception cref="InvalidOperationException"></exception>
|
||||
internal static async ValueTask<(Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<string, int>, Dictionary<int, ReadOnlyMemory<byte>>)> LoadTikTokenBpeAsync(
|
||||
internal static async ValueTask<(Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<StringSpanOrdinalKey, int>, Dictionary<int, ReadOnlyMemory<byte>>)> LoadTikTokenBpeAsync(
|
||||
Stream vocabStream, bool useAsync, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var encoder = new Dictionary<ReadOnlyMemory<byte>, int>(ReadOnlyMemoryByteComparer.Instance);
|
||||
var vocab = new Dictionary<string, int>();
|
||||
var vocab = new Dictionary<StringSpanOrdinalKey, int>();
|
||||
var decoder = new Dictionary<int, ReadOnlyMemory<byte>>();
|
||||
|
||||
try
|
||||
|
@ -212,7 +209,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
if (decodedToken.IndexOf('\uFFFD') < 0)
|
||||
{
|
||||
vocab[decodedToken] = rank;
|
||||
vocab[new StringSpanOrdinalKey(decodedToken)] = rank;
|
||||
}
|
||||
}
|
||||
else
|
||||
|
@ -230,12 +227,6 @@ namespace Microsoft.ML.Tokenizers
|
|||
return (encoder, vocab, decoder);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the dictionary mapping special tokens to Ids.
|
||||
/// </summary>
|
||||
/// <returns>The dictionary mapping special tokens to Ids.</returns>
|
||||
public IReadOnlyDictionary<string, int>? SpecialTokensEncoder => _specialTokensEncoder;
|
||||
|
||||
/// <summary>
|
||||
/// Encode a split text string to a list of tokens.
|
||||
/// </summary>
|
||||
|
@ -253,12 +244,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
if (isSpecialToken)
|
||||
{
|
||||
if (_specialTokensEncoder is null)
|
||||
{
|
||||
throw new InvalidOperationException($"The tokenizer doesn't have special tokens");
|
||||
}
|
||||
|
||||
if (_specialTokensEncoder.TryGetValue(text, out int id))
|
||||
if (_specialTokensEncoder?.TryGetValue(text, out int id) is true)
|
||||
{
|
||||
return new List<Token> { new(id, text, (0, text.Length)) };
|
||||
}
|
||||
|
@ -266,7 +252,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
throw new InvalidOperationException($"The special token {text} doesn't exist in the tokenizer");
|
||||
}
|
||||
|
||||
if (_cache?.Lookup(text, out int[] ids) is true)
|
||||
if (_cache.TryGetValue(text, out int[]? ids))
|
||||
{
|
||||
tokens = new Token[ids.Length];
|
||||
tokens[0] = new Token(ids[0], text, (0, text.Length));
|
||||
|
@ -290,7 +276,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
|
||||
Debug.Assert(encodedIds.Length > 0);
|
||||
_cache?.Add(text, encodedIds);
|
||||
_cache.Add(text, encodedIds);
|
||||
|
||||
tokens = new Token[encodedIds.Length];
|
||||
tokens[0] = new Token(encodedIds[0], text, (0, text.Length));
|
||||
|
@ -305,21 +291,21 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Encode a split text string to a list of Ids.
|
||||
/// Encode text to a list of Ids.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
|
||||
/// <param name="accumulatedIds">The list of accumulated Ids.</param>
|
||||
public override void EncodeToIds(string text, bool isSpecialToken, IList<int> accumulatedIds)
|
||||
public override void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds)
|
||||
{
|
||||
if (string.IsNullOrEmpty(text))
|
||||
if (text.IsEmpty)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if (isSpecialToken)
|
||||
{
|
||||
if (_specialTokensEncoder is not null && _specialTokensEncoder.TryGetValue(text, out int id))
|
||||
if (_specialTokensEncoder?.TryGetValue(text, out int id) is true)
|
||||
{
|
||||
accumulatedIds.Add(id);
|
||||
}
|
||||
|
@ -327,7 +313,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
return;
|
||||
}
|
||||
|
||||
if (_cache?.Lookup(text, out int[] tokenIds) is true)
|
||||
if (_cache.TryGetValue(text, out int[]? tokenIds))
|
||||
{
|
||||
accumulatedIds.AddRange(tokenIds);
|
||||
return;
|
||||
|
@ -340,10 +326,10 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
|
||||
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
|
||||
int encodedLength = GetUtf8Bytes(text.AsSpan(), arrayPoolArray);
|
||||
int encodedLength = GetUtf8Bytes(text, arrayPoolArray);
|
||||
|
||||
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
|
||||
_cache?.Add(text, encodedIds);
|
||||
_cache.Add(text.ToString(), encodedIds);
|
||||
|
||||
accumulatedIds.AddRange(encodedIds);
|
||||
|
||||
|
@ -354,12 +340,12 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="text">The text to tokenize.</param>
|
||||
/// <param name="isSpecialToken">Indicate if the token is special token.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
public override int CountTokens(string text, bool isSpecialToken)
|
||||
public override int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken)
|
||||
{
|
||||
if (string.IsNullOrEmpty(text))
|
||||
if (text.IsEmpty)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
@ -369,7 +355,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
return _specialTokensEncoder.TryGetValue(text, out _) ? 1 : 0;
|
||||
}
|
||||
|
||||
if (_cache?.Lookup(text, out int[] ids) is true)
|
||||
if (_cache.TryGetValue(text, out int[] ids))
|
||||
{
|
||||
return ids.Length;
|
||||
}
|
||||
|
@ -380,10 +366,10 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
|
||||
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
|
||||
int encodedLength = GetUtf8Bytes(text.AsSpan(), arrayPoolArray);
|
||||
int encodedLength = GetUtf8Bytes(text, arrayPoolArray);
|
||||
|
||||
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
|
||||
_cache?.Add(text, encodedIds);
|
||||
_cache.Add(text.ToString(), encodedIds);
|
||||
|
||||
ArrayPool<byte>.Shared.Return(arrayPoolArray);
|
||||
return encodedIds.Length;
|
||||
|
@ -395,19 +381,22 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="token">The token to map to the Id.</param>
|
||||
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
|
||||
/// <returns>The mapped Id of the token.</returns>
|
||||
public override int? MapTokenToId(string token, bool considerSpecialTokens = true)
|
||||
public override int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true)
|
||||
{
|
||||
if (string.IsNullOrEmpty(token))
|
||||
if (token.IsEmpty)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (considerSpecialTokens && _specialTokensEncoder is not null && _specialTokensEncoder.TryGetValue(token, out int specialTokenId))
|
||||
if (considerSpecialTokens && _specialTokensEncoder is not null)
|
||||
{
|
||||
return specialTokenId;
|
||||
if (_specialTokensEncoder.TryGetValue(token, out int specialTokenId))
|
||||
{
|
||||
return specialTokenId;
|
||||
}
|
||||
}
|
||||
|
||||
if (_cache?.Lookup(token, out int[] ids) is true)
|
||||
if (_cache.TryGetValue(token, out int[] ids))
|
||||
{
|
||||
if (ids.Length == 1)
|
||||
{
|
||||
|
@ -425,10 +414,10 @@ namespace Microsoft.ML.Tokenizers
|
|||
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(token.Length));
|
||||
try
|
||||
{
|
||||
int encodedLength = GetUtf8Bytes(token.AsSpan(), arrayPoolArray);
|
||||
int encodedLength = GetUtf8Bytes(token, arrayPoolArray);
|
||||
|
||||
int[] idsToCache = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
|
||||
_cache?.Add(token, idsToCache);
|
||||
_cache.Add(token.ToString(), idsToCache);
|
||||
|
||||
if (idsToCache.Length == 1)
|
||||
{
|
||||
|
@ -550,7 +539,12 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// Gets the dictionary mapping tokens to Ids.
|
||||
/// </summary>
|
||||
/// <remarks>This may not contain the full set of vocabulary tokens, use Encoder to get the full set of vocabulary.</remarks>
|
||||
public IReadOnlyDictionary<string, int> Vocab => _vocab;
|
||||
public IReadOnlyDictionary<string, int> Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
|
||||
|
||||
/// <summary>
|
||||
/// Gets the dictionary mapping special tokens to Ids.
|
||||
/// </summary>
|
||||
public IReadOnlyDictionary<string, int>? SpecialTokensEncoder => _specialTokensEncoderOriginal ??= _specialTokensEncoder?.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
|
||||
|
||||
/// <summary>
|
||||
/// Gets the dictionary mapping token bytes to Ids.
|
||||
|
|
|
@ -104,7 +104,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
foreach (Split split in PreTokenizer.PreTokenize(normalized, considerSpecialTokens))
|
||||
{
|
||||
Model.EncodeToIds(split.TokenString, split.IsSpecialToken, idsList);
|
||||
Model.EncodeToIds(split.TokenSpan, split.IsSpecialToken, idsList);
|
||||
}
|
||||
|
||||
return idsList;
|
||||
|
@ -130,7 +130,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
int idsCount = 0;
|
||||
foreach (Split split in PreTokenizer.PreTokenize(normalized, considerSpecialTokens))
|
||||
{
|
||||
idsCount += Model.CountTokens(split.TokenString, split.IsSpecialToken);
|
||||
idsCount += Model.CountTokens(split.TokenSpan, split.IsSpecialToken);
|
||||
}
|
||||
|
||||
return idsCount;
|
||||
|
@ -343,7 +343,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
}
|
||||
|
||||
private static readonly ConcurrentDictionary<string, (Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<string, int>, Dictionary<int, ReadOnlyMemory<byte>>)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase);
|
||||
private static readonly ConcurrentDictionary<string, (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase);
|
||||
|
||||
/// <summary>
|
||||
/// Create tokenizer based on regex pattern, BPE rank file and special tokens
|
||||
|
@ -371,7 +371,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
}
|
||||
|
||||
if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<string, int> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) cache))
|
||||
if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) cache))
|
||||
{
|
||||
using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false))
|
||||
{
|
||||
|
|
|
@ -37,5 +37,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
internal static bool TryParseInt32(string s, int offset, out int result)
|
||||
=> int.TryParse(s.AsSpan().Slice(offset), NumberStyles.None, CultureInfo.InvariantCulture, out result);
|
||||
|
||||
internal static int GetHashCode(ReadOnlySpan<char> span) => string.GetHashCode(span);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,6 +48,17 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
return true;
|
||||
}
|
||||
|
||||
internal static int GetHashCode(ReadOnlySpan<char> span)
|
||||
{
|
||||
int hash = 17;
|
||||
foreach (char c in span)
|
||||
{
|
||||
hash = hash * 31 + c;
|
||||
}
|
||||
|
||||
return hash;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,47 +2,37 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace Microsoft.ML.Tokenizers
|
||||
{
|
||||
internal class LruCache<TKey, TValue> where TKey : notnull where TValue : notnull
|
||||
internal sealed class LruCache<TValue>
|
||||
{
|
||||
/// <summary>
|
||||
/// The default LRU cache size.
|
||||
/// </summary>
|
||||
public const int DefaultCacheSize = 8192; // 4096;
|
||||
public const int DefaultCacheSize = 8192;
|
||||
|
||||
private readonly object _lockObject = new object();
|
||||
|
||||
private class CacheItem
|
||||
{
|
||||
public readonly TKey Key;
|
||||
public TValue Value;
|
||||
|
||||
public CacheItem(TKey key, TValue value)
|
||||
{
|
||||
Key = key;
|
||||
Value = value;
|
||||
}
|
||||
}
|
||||
|
||||
private readonly Dictionary<TKey, LinkedListNode<CacheItem>> _cache;
|
||||
private readonly LinkedList<CacheItem> _lruList;
|
||||
private readonly Dictionary<StringSpanOrdinalKey, LinkedListNode<KeyValuePair<string, TValue>>> _cache = new();
|
||||
private readonly LinkedList<KeyValuePair<string, TValue>> _lruList = new();
|
||||
private readonly int _cacheSize;
|
||||
|
||||
private object SyncObj => _cache;
|
||||
|
||||
/// <summary>
|
||||
/// Constructs an <see cref="LruCache{TKey,TValue}" /> object.
|
||||
/// Constructs an <see cref="LruCache{TValue}" /> object.
|
||||
/// </summary>
|
||||
/// <param name="cacheSize">
|
||||
/// The maximum number of <typeparamref name="TKey" /> to <typeparamref name="TValue" /> mappings
|
||||
/// that can be cached. This defaults to <see cref="DefaultCacheSize" />, which is set to
|
||||
/// <value>4096</value>.
|
||||
/// The maximum number of mappings that can be cached. This defaults to <see cref="DefaultCacheSize" />, which is set to <value>8192</value>.
|
||||
/// </param>
|
||||
public LruCache(int cacheSize = DefaultCacheSize)
|
||||
{
|
||||
_cache = new Dictionary<TKey, LinkedListNode<CacheItem>>();
|
||||
_lruList = new LinkedList<CacheItem>();
|
||||
if (cacheSize <= 0)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(cacheSize), "Cache size must be a positive number.");
|
||||
}
|
||||
|
||||
_cacheSize = cacheSize;
|
||||
}
|
||||
|
||||
|
@ -54,11 +44,11 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <returns>
|
||||
/// true if the cache contains a mapping for key, false otherwise.
|
||||
/// </returns>
|
||||
public bool Lookup(TKey key, out TValue value)
|
||||
public bool TryGetValue(string key, out TValue value)
|
||||
{
|
||||
lock (_lockObject)
|
||||
lock (SyncObj)
|
||||
{
|
||||
if (_cache.TryGetValue(key, out LinkedListNode<CacheItem>? cached))
|
||||
if (_cache.TryGetValue(new StringSpanOrdinalKey(key), out LinkedListNode<KeyValuePair<string, TValue>>? cached))
|
||||
{
|
||||
_lruList.Remove(cached);
|
||||
_lruList.AddFirst(cached);
|
||||
|
@ -71,16 +61,31 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
}
|
||||
|
||||
protected virtual void OnEviction(TValue evictedValue) { }
|
||||
|
||||
private void EvictIfNeeded()
|
||||
/// <summary>
|
||||
/// Retrieves the value associated with the specified key /> object.
|
||||
/// </summary>
|
||||
/// <param name="key">The object to be used as a key.</param>
|
||||
/// <param name="value">An out parameter that is set to the value of the key if key contains a mapping in the cache.</param>
|
||||
/// <returns>
|
||||
/// true if the cache contains a mapping for key, false otherwise.
|
||||
/// </returns>
|
||||
public unsafe bool TryGetValue(ReadOnlySpan<char> key, out TValue value)
|
||||
{
|
||||
while (_cache.Count >= _cacheSize)
|
||||
lock (SyncObj)
|
||||
{
|
||||
LinkedListNode<CacheItem>? nodeToEvict = _lruList.Last;
|
||||
_lruList.RemoveLast();
|
||||
_cache.Remove(nodeToEvict!.Value.Key);
|
||||
OnEviction(nodeToEvict.Value.Value);
|
||||
fixed (char* ptr = key)
|
||||
{
|
||||
if (_cache.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out LinkedListNode<KeyValuePair<string, TValue>>? cached))
|
||||
{
|
||||
_lruList.Remove(cached);
|
||||
_lruList.AddFirst(cached);
|
||||
value = cached.Value.Value;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
value = default!;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -89,46 +94,29 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// </summary>
|
||||
/// <param name="key">The key whose mapped <paramref name="value" /> is to be created or replaced.</param>
|
||||
/// <param name="value">The new value to be mapped to the <paramref name="key" />.</param>
|
||||
public void Add(TKey key, TValue value) => Replace(key, value, out _);
|
||||
|
||||
public bool Replace(TKey key, TValue value, out TValue oldValue)
|
||||
public void Add(string key, TValue value)
|
||||
{
|
||||
lock (_lockObject)
|
||||
lock (SyncObj)
|
||||
{
|
||||
return ReplaceInternal(key, value, out oldValue);
|
||||
if (_cache.TryGetValue(new StringSpanOrdinalKey(key), out LinkedListNode<KeyValuePair<string, TValue>>? cached))
|
||||
{
|
||||
cached.Value = new KeyValuePair<string, TValue>(key, value);
|
||||
_lruList.Remove(cached);
|
||||
_lruList.AddFirst(cached);
|
||||
return;
|
||||
}
|
||||
|
||||
while (_cache.Count >= _cacheSize)
|
||||
{
|
||||
LinkedListNode<KeyValuePair<string, TValue>>? nodeToEvict = _lruList.Last;
|
||||
_lruList.RemoveLast();
|
||||
_cache.Remove(new StringSpanOrdinalKey(nodeToEvict!.Value.Key));
|
||||
}
|
||||
|
||||
var node = new LinkedListNode<KeyValuePair<string, TValue>>(new KeyValuePair<string, TValue>(key, value));
|
||||
_cache[new StringSpanOrdinalKey(key)] = node;
|
||||
_lruList.AddFirst(node);
|
||||
}
|
||||
}
|
||||
|
||||
private bool ReplaceInternal(TKey key, TValue value, out TValue oldValue)
|
||||
{
|
||||
if (_cache.TryGetValue(key, out LinkedListNode<CacheItem>? cached))
|
||||
{
|
||||
oldValue = cached.Value.Value;
|
||||
cached.Value.Value = value;
|
||||
_lruList.Remove(cached);
|
||||
_lruList.AddFirst(cached);
|
||||
return true;
|
||||
}
|
||||
EvictIfNeeded();
|
||||
var node = new LinkedListNode<CacheItem>(new CacheItem(key, value));
|
||||
_cache[key] = node;
|
||||
_lruList.AddFirst(node);
|
||||
oldValue = default!;
|
||||
return false;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The number of entries currently present in the cache.
|
||||
/// </summary>
|
||||
public int Count => _cache.Count;
|
||||
|
||||
/// <summary>
|
||||
/// Clears the contents of this cache.
|
||||
/// </summary>
|
||||
public void Clear()
|
||||
{
|
||||
_cache.Clear();
|
||||
_lruList.Clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
|
||||
namespace Microsoft.ML.Tokenizers
|
||||
{
|
||||
/// <summary>Used as a key in a dictionary to enable querying with either a string or a span.</summary>
|
||||
/// <remarks>
|
||||
/// This should only be used with a Ptr/Length for querying. For storing in a dictionary, this should
|
||||
/// always be used with a string.
|
||||
/// </remarks>
|
||||
internal unsafe readonly struct StringSpanOrdinalKey : IEquatable<StringSpanOrdinalKey>
|
||||
{
|
||||
public readonly char* Ptr;
|
||||
public readonly int Length;
|
||||
public readonly string? Data;
|
||||
|
||||
public StringSpanOrdinalKey(char* ptr, int length)
|
||||
{
|
||||
Ptr = ptr;
|
||||
Length = length;
|
||||
}
|
||||
|
||||
public StringSpanOrdinalKey(string data) =>
|
||||
Data = data;
|
||||
|
||||
private ReadOnlySpan<char> Span => Ptr is not null ?
|
||||
new ReadOnlySpan<char>(Ptr, Length) :
|
||||
Data.AsSpan();
|
||||
|
||||
public override bool Equals(object? obj) =>
|
||||
obj is StringSpanOrdinalKey wrapper && Equals(wrapper);
|
||||
|
||||
public bool Equals(StringSpanOrdinalKey other) =>
|
||||
Span.SequenceEqual(other.Span);
|
||||
|
||||
public override int GetHashCode() => Helpers.GetHashCode(Span);
|
||||
}
|
||||
|
||||
internal sealed class StringSpanOrdinalKeyCache<TValue>
|
||||
{
|
||||
private readonly int _capacity;
|
||||
private readonly Dictionary<StringSpanOrdinalKey, TValue> _map;
|
||||
|
||||
private object SyncObj => _map;
|
||||
|
||||
internal StringSpanOrdinalKeyCache() : this(Bpe.DefaultCacheCapacity) { }
|
||||
|
||||
internal StringSpanOrdinalKeyCache(int capacity)
|
||||
{
|
||||
_capacity = capacity;
|
||||
_map = new Dictionary<StringSpanOrdinalKey, TValue>(capacity);
|
||||
}
|
||||
|
||||
internal bool TryGetValue(string key, out TValue value)
|
||||
{
|
||||
lock (SyncObj)
|
||||
{
|
||||
return _map.TryGetValue(new StringSpanOrdinalKey(key), out value!);
|
||||
}
|
||||
}
|
||||
|
||||
internal unsafe bool TryGetValue(ReadOnlySpan<char> key, out TValue value)
|
||||
{
|
||||
lock (SyncObj)
|
||||
{
|
||||
fixed (char* ptr = key)
|
||||
{
|
||||
return _map.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out value!);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal void Remove(string key)
|
||||
{
|
||||
lock (SyncObj)
|
||||
{
|
||||
_map.Remove(new StringSpanOrdinalKey(key));
|
||||
}
|
||||
}
|
||||
|
||||
internal void Set(string k, TValue v)
|
||||
{
|
||||
lock (SyncObj)
|
||||
{
|
||||
if (_map.Count < _capacity)
|
||||
{
|
||||
_map[new StringSpanOrdinalKey(k)] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Custom JSON converter for <see cref="StringSpanOrdinalKey"/>.
|
||||
/// </summary>
|
||||
internal sealed class StringSpanOrdinalKeyConverter : JsonConverter<StringSpanOrdinalKey>
|
||||
{
|
||||
public static StringSpanOrdinalKeyConverter Instance { get; } = new StringSpanOrdinalKeyConverter();
|
||||
public override StringSpanOrdinalKey ReadAsPropertyName(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) =>
|
||||
new StringSpanOrdinalKey(reader.GetString()!);
|
||||
|
||||
public override void WriteAsPropertyName(Utf8JsonWriter writer, StringSpanOrdinalKey value, JsonSerializerOptions options) =>
|
||||
writer.WriteStringValue(value.Data!);
|
||||
|
||||
public override StringSpanOrdinalKey Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) => new StringSpanOrdinalKey(reader.GetString()!);
|
||||
public override void Write(Utf8JsonWriter writer, StringSpanOrdinalKey value, JsonSerializerOptions options) => writer.WriteStringValue(value.Data!);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods for <see cref="StringSpanOrdinalKey"/>.
|
||||
/// </summary>
|
||||
internal static class StringSpanOrdinalKeyExtensions
|
||||
{
|
||||
public unsafe static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKey, TValue> map, ReadOnlySpan<char> key, out TValue value)
|
||||
{
|
||||
fixed (char* ptr = key)
|
||||
{
|
||||
return map.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out value!);
|
||||
}
|
||||
}
|
||||
|
||||
public static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKey, TValue> map, string key, out TValue value) =>
|
||||
map.TryGetValue(new StringSpanOrdinalKey(key), out value!);
|
||||
}
|
||||
}
|
|
@ -156,7 +156,7 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
Assert.Equal(ids[i], encoding.Ids[i]);
|
||||
Assert.Equal(ids[i], idsList[i]);
|
||||
Assert.Equal(encoding.Tokens[i], tokenizer.Model.MapIdToToken(encoding.Ids[i]));
|
||||
Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i]));
|
||||
Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i].AsSpan()));
|
||||
Assert.Equal(encoding.Tokens[i], tokenizer.Decode(encoding.Ids[i]));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -201,7 +201,7 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
Assert.Equal(unfilteredToken![i], tokenizer.Model.MapIdToToken(encoding.Ids[i], considerSpecialTokens: false));
|
||||
}
|
||||
|
||||
Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i]));
|
||||
Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i].AsSpan()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче