Add Span support in tokenizer's Model abstraction (#7035)

* Add Span support in tokenizer's Model abstraction

* Address the feedback

* Use stackalloc instead of the ArrayPool
This commit is contained in:
Tarek Mahmoud Sayed 2024-02-29 18:01:13 -08:00 коммит произвёл GitHub
Родитель c6f5397963
Коммит 99c620ad96
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
12 изменённых файлов: 433 добавлений и 317 удалений

Просмотреть файл

@ -3,8 +3,10 @@
// See the LICENSE file in the project root for more information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Serialization;
@ -34,20 +36,21 @@ namespace Microsoft.ML.Tokenizers
{
_unknownToken = value;
if (value is null)
if (VocabReverse.TryGetValue(0, out string? v))
{
if (VocabReverse.TryGetValue(0, out string? v))
if (v == value)
{
VocabReverse.Remove(0);
if (_vocab.TryGetValue(v, out int id))
{
_vocab.Remove(v);
}
return;
}
VocabReverse.Remove(0);
_vocab.Remove(new StringSpanOrdinalKey(v));
}
else
if (value is not null)
{
_vocab[value] = 0;
_vocab[new StringSpanOrdinalKey(value)] = 0;
VocabReverse[0] = value;
}
}
@ -68,7 +71,6 @@ namespace Microsoft.ML.Tokenizers
/// </summary>
public bool FuseUnknownTokens { get; }
/// <summary>
/// Construct a new Bpe model object to use for text encoding.
/// </summary>
@ -111,23 +113,19 @@ namespace Microsoft.ML.Tokenizers
ContinuingSubwordPrefix = continuingSubwordPrefix;
EndOfWordSuffix = endOfWordSuffix;
(Dictionary<string, int>? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream);
_vocab = vocab1 ?? new Dictionary<string, int>();
Cache = new Cache<string, Word>();
(Dictionary<StringSpanOrdinalKey, int>? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream);
_vocab = vocab1 ?? new Dictionary<StringSpanOrdinalKey, int>();
Cache = new StringSpanOrdinalKeyCache<Word>();
VocabReverse = new();
foreach (KeyValuePair<string, int> kvp in Vocab)
foreach (KeyValuePair<StringSpanOrdinalKey, int> kvp in _vocab)
{
VocabReverse.Add(kvp.Value, kvp.Key);
VocabReverse.Add(kvp.Value, kvp.Key.Data!);
}
if (unknownToken is null && VocabReverse.TryGetValue(0, out string? unkToken))
{
unknownToken = unkToken;
}
UnknownToken = unknownToken;
UnknownToken = unknownToken ?? (VocabReverse.TryGetValue(0, out string? unkToken) ? unkToken : null);
int prefixLen = ContinuingSubwordPrefix is null ? 0 : ContinuingSubwordPrefix.Length;
@ -197,7 +195,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="text">The text to split.</param>
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
public override void EncodeToIds(string text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds);
public override void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds);
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
@ -205,7 +203,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="text">The text to encode.</param>
/// <param name="isSpecialToken">Indicate if the token is special token.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int CountTokens(string text, bool isSpecialToken) => EncodeToIdsWithCache(text, null);
public override int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken) => EncodeToIdsWithCache(text, null);
/// <summary>
/// Map the token to encoded Id.
@ -213,15 +211,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="token">The token to map to the Id.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
/// <returns>The mapped Id of the token.</returns>
public override int? MapTokenToId(string token, bool considerSpecialTokens = true)
{
if (_vocab.TryGetValue(token, out int value))
{
return value;
}
return null;
}
public override int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out int value) ? value : null;
/// <summary>
/// Map the encoded Id to the token.
@ -242,24 +232,27 @@ namespace Microsoft.ML.Tokenizers
/// <summary>
/// Gets the dictionary mapping tokens to Ids.
/// </summary>
public IReadOnlyDictionary<string, int> Vocab => _vocab;
public IReadOnlyDictionary<string, int> Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
/// Read the given files to extract the vocab and merges
internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges)
internal static (Dictionary<StringSpanOrdinalKey, int>?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges)
{
Dictionary<string, int>? dic = JsonSerializer.Deserialize<Dictionary<string, int>>(vocab) as Dictionary<string, int>;
JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } };
Dictionary<StringSpanOrdinalKey, int>? dic = JsonSerializer.Deserialize<Dictionary<StringSpanOrdinalKey, int>>(vocab, options) as Dictionary<StringSpanOrdinalKey, int>;
return (dic, ConvertMergesToHashmap(merges));
}
/// The vocabulary assigns a number to each token.
private readonly Dictionary<string, int> _vocab;
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;
private Dictionary<string, int>? _vocabOriginal;
/// Contains the mapping between Pairs and their (rank, newId).
internal Dictionary<Pair<int>, (int, int)> Merges { get; }
/// Contains the cache for optimizing the encoding step.
internal Cache<string, Word>? Cache { get; }
internal StringSpanOrdinalKeyCache<Word>? Cache { get; }
internal static readonly int DefaultCacheCapacity = 10_000;
@ -309,9 +302,6 @@ namespace Microsoft.ML.Tokenizers
return merges;
}
/// Reset the cache.
internal void ClearCache() => Cache?.Clear();
private readonly Dictionary<char, string> _charToString = new Dictionary<char, string>();
[MethodImpl(MethodImplOptions.AggressiveInlining)]
@ -327,38 +317,68 @@ namespace Microsoft.ML.Tokenizers
return s;
}
internal Word MergeWord(string w)
internal Word MergeWord(ReadOnlySpan<char> w)
{
Word word = Word.WithCapacity(w.Length);
(int Id, int Len)? unk = null;
int i = 0;
Span<char> buffer = stackalloc char[256];
scoped ReadOnlySpan<char> s;
while (i < w.Length)
{
int length;
string s;
if (Char.IsHighSurrogate(w[i]) && i < w.Length - 1 && Char.IsLowSurrogate(w[i + 1]))
{
length = 2;
s = w.Substring(i, length);
s = w.Slice(i, 2);
}
else
{
length = 1;
s = CharToString(w[i]);
s = w.Slice(i, 1);
}
// Add the `continuing_subword_prefix` if relevant
if (i > 0 && ContinuingSubwordPrefix is not null)
{
s = $"{ContinuingSubwordPrefix}{s}";
if (ContinuingSubwordPrefix.Length + s.Length <= buffer.Length)
{
ContinuingSubwordPrefix.AsSpan().CopyTo(buffer);
s.CopyTo(buffer.Slice(ContinuingSubwordPrefix.Length));
s = buffer.Slice(0, ContinuingSubwordPrefix.Length + s.Length);
}
else
{
#if NETCOREAPP
s = $"{ContinuingSubwordPrefix}{s}".AsSpan();
#else
string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString();
s = $"{ContinuingSubwordPrefix}{s1}".AsSpan();
#endif
}
}
// Add the `end_of_word_suffix` if relevant
if (i + length >= w.Length && EndOfWordSuffix is not null)
{
s = $"{s}{EndOfWordSuffix}";
if (s.Length + EndOfWordSuffix.Length <= buffer.Length)
{
s.CopyTo(buffer);
EndOfWordSuffix.AsSpan().CopyTo(buffer.Slice(s.Length));
s = buffer.Slice(0, s.Length + EndOfWordSuffix.Length);
}
else
{
#if NETCOREAPP
s = $"{s}{EndOfWordSuffix}".AsSpan();
#else
string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString();
s = $"{s1}{EndOfWordSuffix}".AsSpan();
#endif
}
}
if (_vocab.TryGetValue(s, out int id))
@ -419,17 +439,17 @@ namespace Microsoft.ML.Tokenizers
Word word;
if (Cache is not null)
{
if (Cache.TryGet(text, out word))
if (Cache.TryGetValue(text, out word))
{
return WordToTokens(ref word);
}
word = MergeWord(text);
word = MergeWord(text.AsSpan());
Cache.Set(text, word);
}
else
{
word = MergeWord(text);
word = MergeWord(text.AsSpan());
}
return WordToTokens(ref word);
@ -445,19 +465,19 @@ namespace Microsoft.ML.Tokenizers
return word.SymbolsCount;
}
internal int EncodeToIdsWithCache(string text, IList<int>? accumulatedIds)
internal int EncodeToIdsWithCache(ReadOnlySpan<char> text, IList<int>? accumulatedIds)
{
Word word;
if (Cache is not null)
{
if (Cache.TryGet(text, out Word hit))
if (Cache.TryGetValue(text, out Word hit))
{
return WordToIds(ref hit, accumulatedIds);
}
word = MergeWord(text);
Cache.Set(text, word);
Cache.Set(text.ToString(), word);
}
else
{

Просмотреть файл

@ -4,113 +4,54 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
namespace Microsoft.ML.Tokenizers
{
internal sealed class Cache<TKey, TValue> where TKey : notnull where TValue : notnull
{
private readonly int _capacity;
private readonly Dictionary<TKey, TValue> _map;
private object SyncObj => _map;
internal Cache() : this(Bpe.DefaultCacheCapacity) { }
internal Cache(int capacity)
{
Capacity = capacity;
Map = new Dictionary<TKey, TValue>(Capacity);
_capacity = capacity;
_map = new Dictionary<TKey, TValue>(capacity);
}
private readonly object _lock = new();
internal Dictionary<TKey, TValue> Map { get; set; }
internal int Capacity { get; set; }
internal void Fresh() => Map = new Dictionary<TKey, TValue>(Capacity);
internal void Clear()
internal bool TryGetValue(TKey key, out TValue value)
{
lock (_lock)
lock (SyncObj)
{
Map.Clear();
}
}
internal List<TValue> GetValues(IEnumerable<TKey> keys)
{
List<TValue> values = new();
lock (_lock)
{
foreach (TKey key in keys)
{
if (Map.TryGetValue(key, out TValue? value))
{
values.Add(value);
}
}
}
return values;
}
internal bool TryGet(TKey key, out TValue value)
{
lock (_lock)
{
return Map.TryGetValue(key, out value!);
}
}
internal void SetValues(IEnumerable<(TKey, TValue)> entries)
{
lock (_lock)
{
foreach ((TKey, TValue) entry in entries)
{
if (Capacity <= Map.Count)
{
break;
}
Map[entry.Item1] = entry.Item2;
}
}
}
internal void Set(TKey k, TValue v)
{
lock (_lock)
{
if (Capacity > Map.Count)
{
Map[k] = v;
}
}
}
internal KeyValuePair<TKey, TValue>[] ToArray()
{
lock (_lock)
{
return Map.ToArray();
return _map.TryGetValue(key, out value!);
}
}
internal TValue GetOrAdd(TKey key, TValue value)
{
lock (_lock)
lock (SyncObj)
{
if (Map.TryGetValue(key, out TValue? v))
if (_map.TryGetValue(key, out TValue? v))
{
return v;
}
if (Capacity > Map.Count)
{
Map[key] = value;
return v!;
}
_map[key] = value;
return value;
}
}
internal void Set(TKey key, TValue value)
{
lock (SyncObj)
{
if (_map.Count < _capacity)
{
_map[key] = value;
}
}
}
}
}

Просмотреть файл

@ -18,13 +18,14 @@ namespace Microsoft.ML.Tokenizers
public sealed class EnglishRoberta : Model
{
private readonly HighestOccurrenceMapping _vocabIdToHighestOccurrence;
private readonly IReadOnlyDictionary<string, int> _vocab;
private readonly SortedDictionary<int, string> _vocabReverse;
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;
private Dictionary<string, int>? _vocabOriginal;
private readonly SortedDictionary<int, StringSpanOrdinalKey> _vocabReverse;
private readonly Cache<(string, string), int> _mergeRanks;
private readonly IReadOnlyDictionary<char, char> _byteToUnicode;
private readonly IReadOnlyDictionary<char, char> _unicodeToByte;
private readonly string[] _charToString;
private readonly Cache<string, List<Token>> _cache;
private readonly StringSpanOrdinalKeyCache<List<Token>> _cache;
/// <summary>
/// Indicate if want to filter the unsupported characters during the decoding.
@ -77,7 +78,7 @@ namespace Microsoft.ML.Tokenizers
}
_unicodeToByte = _byteToUnicode.Reverse();
_cache = new Cache<string, List<Token>>();
_cache = new StringSpanOrdinalKeyCache<List<Token>>();
}
/// <summary>
@ -118,13 +119,13 @@ namespace Microsoft.ML.Tokenizers
}
_unicodeToByte = _byteToUnicode.Reverse();
_cache = new Cache<string, List<Token>>();
_cache = new StringSpanOrdinalKeyCache<List<Token>>();
}
/// <summary>
/// Gets the dictionary mapping tokens to Ids.
/// </summary>
public IReadOnlyDictionary<string, int> Vocab => _vocab;
public IReadOnlyDictionary<string, int> Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
//
// Public Model interfaces implementation
@ -145,14 +146,15 @@ namespace Microsoft.ML.Tokenizers
if (_vocabReverse.TryGetValue(id, out var value))
{
string v = value.Data!;
if (FilterUnsupportedChars)
{
char[] buffer = ArrayPool<char>.Shared.Rent(value.Length);
char[] buffer = ArrayPool<char>.Shared.Rent(v.Length);
int i = 0;
for (int j = 0; j < value.Length; j++)
for (int j = 0; j < v.Length; j++)
{
if (_unicodeToByte.TryGetValue(value[j], out var c))
if (_unicodeToByte.TryGetValue(v[j], out var c))
{
buffer[i++] = c;
}
@ -164,7 +166,7 @@ namespace Microsoft.ML.Tokenizers
}
else
{
return value;
return v;
}
}
@ -205,7 +207,7 @@ namespace Microsoft.ML.Tokenizers
return Array.Empty<Token>();
}
if (_cache.TryGet(text, out List<Token>? hit))
if (_cache.TryGetValue(text, out List<Token>? hit))
{
ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping);
@ -225,7 +227,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="text">The text to split.</param>
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
public override void EncodeToIds(string text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIds(text, accumulatedIds);
public override void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIds(text, accumulatedIds);
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
@ -233,16 +235,16 @@ namespace Microsoft.ML.Tokenizers
/// <param name="text">The text to encode.</param>
/// <param name="isSpecialToken">Indicate if the token is special token.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int CountTokens(string text, bool isSpecialToken) => EncodeToIds(text, null);
public override int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken) => EncodeToIds(text, null);
private int EncodeToIds(string text, IList<int>? accumulatedIds)
private int EncodeToIds(ReadOnlySpan<char> text, IList<int>? accumulatedIds)
{
if (string.IsNullOrEmpty(text))
if (text.IsEmpty)
{
return 0;
}
if (_cache.TryGet(text, out List<Token>? hit))
if (_cache.TryGetValue(text, out List<Token>? hit))
{
if (accumulatedIds is not null)
{
@ -255,17 +257,41 @@ namespace Microsoft.ML.Tokenizers
return hit.Count;
}
// If the cache doesn't have the text, then encode it and add it to the cache
IReadOnlyList<Token> tokens = Encode(text);
char[] token = ArrayPool<char>.Shared.Rent(text.Length);
int[] indexMapping = ArrayPool<int>.Shared.Rent(text.Length);
int newTokenIndex = 0;
for (int i = 0; i < text.Length; i++)
{
if (_byteToUnicode.TryGetValue(text[i], out var value))
{
token[newTokenIndex] = value;
indexMapping[newTokenIndex] = i;
newTokenIndex++;
}
}
if (newTokenIndex == 0)
{
ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping);
return 0;
}
List<Token> result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
_cache.Set(text.ToString(), result);
ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping);
if (accumulatedIds is not null)
{
foreach (var t in tokens)
foreach (var t in result)
{
accumulatedIds.Add(t.Id);
}
}
return tokens.Count;
return result.Count;
}
/// <summary>
@ -274,7 +300,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="token">The token to map to the Id.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
/// <returns>The mapped Id of the token.</returns>
public override int? MapTokenToId(string token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out var value) ? value : null;
public override int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out int value) ? value : null;
/// <summary>
/// Convert a list of tokens Ids to highest occurrence rankings.
@ -397,12 +423,13 @@ namespace Microsoft.ML.Tokenizers
private static HighestOccurrenceMapping GetHighestOccurrenceMapping(Stream highestOccurrenceMappingStream) =>
HighestOccurrenceMapping.Load(highestOccurrenceMappingStream);
private Dictionary<string, int> GetVocabulary(Stream vocabularyStream)
private Dictionary<StringSpanOrdinalKey, int> GetVocabulary(Stream vocabularyStream)
{
Dictionary<string, int>? vocab;
Dictionary<StringSpanOrdinalKey, int>? vocab;
try
{
vocab = JsonSerializer.Deserialize<Dictionary<string, int>>(vocabularyStream) as Dictionary<string, int>;
JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } };
vocab = JsonSerializer.Deserialize<Dictionary<StringSpanOrdinalKey, int>>(vocabularyStream, options) as Dictionary<StringSpanOrdinalKey, int>;
}
catch (Exception e)
{
@ -416,22 +443,22 @@ namespace Microsoft.ML.Tokenizers
if (_vocabIdToHighestOccurrence.BosWord is not null)
{
vocab[_vocabIdToHighestOccurrence.BosWord] = -_vocabIdToHighestOccurrence.BosIndex;
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.BosWord)] = -_vocabIdToHighestOccurrence.BosIndex;
}
if (_vocabIdToHighestOccurrence.EosWord is not null)
{
vocab[_vocabIdToHighestOccurrence.EosWord] = -_vocabIdToHighestOccurrence.EosIndex;
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.EosWord)] = -_vocabIdToHighestOccurrence.EosIndex;
}
if (_vocabIdToHighestOccurrence.UnkWord is not null)
{
vocab[_vocabIdToHighestOccurrence.UnkWord] = -_vocabIdToHighestOccurrence.UnkIndex;
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.UnkWord)] = -_vocabIdToHighestOccurrence.UnkIndex;
}
if (_vocabIdToHighestOccurrence.PadWord is not null)
{
vocab[_vocabIdToHighestOccurrence.PadWord] = -_vocabIdToHighestOccurrence.PadIndex;
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.PadWord)] = -_vocabIdToHighestOccurrence.PadIndex;
}
return vocab;
@ -510,7 +537,7 @@ namespace Microsoft.ML.Tokenizers
if (token.Length == 1)
{
string tokenValue = _charToString[token[0]];
return new List<Token> { new Token(_vocab[tokenValue], tokenValue, (indexMapping[0], 1)) };
return new List<Token> { new Token(_vocab[new StringSpanOrdinalKey(tokenValue)], tokenValue, (indexMapping[0], 1)) };
}
List<string> word = new(token.Length);
@ -539,7 +566,7 @@ namespace Microsoft.ML.Tokenizers
// get the most frequent bi-gram pair
var (first, second) = pairs.ArgMin(pair => _mergeRanks.GetOrAdd(pair, int.MaxValue));
if (!_mergeRanks.TryGet((first, second), out int _))
if (!_mergeRanks.TryGetValue((first, second), out int _))
{
break;
}
@ -599,7 +626,7 @@ namespace Microsoft.ML.Tokenizers
foreach (string w in word)
{
tokens.Add(new Token(_vocab[w], w, (indexMapping[index], w.Length)));
tokens.Add(new Token(_vocab[new StringSpanOrdinalKey(w)], w, (indexMapping[index], w.Length)));
index += w.Length;
}

Просмотреть файл

@ -31,14 +31,15 @@ namespace Microsoft.ML.Tokenizers
/// This method does the default implementation that uses the Encode method to get the token's Ids.
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
/// </remarks>
public virtual void EncodeToIds(string text, bool isSpecialToken, IList<int> accumulatedIds)
public virtual void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds)
{
if (accumulatedIds is null)
{
throw new ArgumentNullException(nameof(accumulatedIds));
}
var tokens = Encode(text);
// Default implementation is not optimized for memory allocation. It is recommended to override this method for the sake of the performance.
var tokens = Encode(text.ToString());
foreach (var token in tokens)
{
accumulatedIds.Add(token.Id);
@ -55,7 +56,7 @@ namespace Microsoft.ML.Tokenizers
/// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids.
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
/// </remarks>
public virtual int CountTokens(string text, bool isSpecialToken)
public virtual int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken)
{
var ids = new List<int>();
EncodeToIds(text, isSpecialToken, ids);
@ -68,7 +69,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="token">The token to map to Id</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
/// <returns>The mapped Id of the token.</returns>
public abstract int? MapTokenToId(string token, bool considerSpecialTokens = true);
public abstract int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true);
/// <summary>
/// Map the encoded Id to the token.

Просмотреть файл

@ -19,12 +19,14 @@ namespace Microsoft.ML.Tokenizers
/// </summary>
public sealed class Tiktoken : Model
{
private readonly Dictionary<ReadOnlyMemory<byte>, int> _encoder = null!;
private readonly Dictionary<int, ReadOnlyMemory<byte>> _decoder = null!;
private readonly LruCache<string, int[]>? _cache;
private readonly IReadOnlyDictionary<string, int>? _specialTokensEncoder;
private readonly Dictionary<ReadOnlyMemory<byte>, int> _encoder;
private readonly Dictionary<int, ReadOnlyMemory<byte>> _decoder;
private readonly LruCache<int[]> _cache;
private readonly Dictionary<StringSpanOrdinalKey, int>? _specialTokensEncoder;
private Dictionary<string, int>? _specialTokensEncoderOriginal;
private readonly Dictionary<int, string>? _specialTokensDecoder;
private readonly Dictionary<string, int> _vocab = null!;
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;
private IReadOnlyDictionary<string, int>? _vocabOriginal;
/// <summary>
/// Create a new Tiktoken tokenizer's model object.
@ -34,7 +36,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="cacheSize">The size of the cache to use.</param>
/// <exception cref="ArgumentNullException">Thrown when <paramref name="vocabFilePath"/> is null or empty.</exception>
/// <exception cref="InvalidOperationException">Thrown when failed to load the BPE vocab file.</exception>
public Tiktoken(string vocabFilePath, IReadOnlyDictionary<string, int>? specialTokens = null, int cacheSize = LruCache<string, int[]>.DefaultCacheSize) :
public Tiktoken(string vocabFilePath, IReadOnlyDictionary<string, int>? specialTokens = null, int cacheSize = LruCache<int[]>.DefaultCacheSize) :
this(string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath), specialTokens, cacheSize, disposeStream: true)
{
}
@ -47,7 +49,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="cacheSize">The size of the cache to use.</param>
/// <exception cref="ArgumentNullException">Thrown when <paramref name="vocabStream"/> is null or empty.</exception>
/// <exception cref="InvalidOperationException">Thrown when failed to load the BPE vocab file.</exception>
public Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTokens = null, int cacheSize = LruCache<string, int[]>.DefaultCacheSize) :
public Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTokens = null, int cacheSize = LruCache<int[]>.DefaultCacheSize) :
this(vocabStream ?? throw new ArgumentNullException(nameof(vocabStream)), specialTokens, cacheSize, disposeStream: false)
{
}
@ -63,9 +65,9 @@ namespace Microsoft.ML.Tokenizers
internal Tiktoken(
Dictionary<ReadOnlyMemory<byte>, int> encoder,
Dictionary<int, ReadOnlyMemory<byte>> decoder,
Dictionary<string, int> vocab,
Dictionary<StringSpanOrdinalKey, int> vocab,
IReadOnlyDictionary<string, int>? specialTokens,
int cacheSize = LruCache<string, int[]>.DefaultCacheSize) : this(cacheSize)
int cacheSize = LruCache<int[]>.DefaultCacheSize)
{
_encoder = encoder ?? throw new ArgumentNullException(nameof(encoder));
_decoder = decoder ?? throw new ArgumentNullException(nameof(decoder));
@ -73,24 +75,21 @@ namespace Microsoft.ML.Tokenizers
Debug.Assert(encoder.Count == decoder.Count);
_specialTokensEncoder = specialTokens;
if (_specialTokensEncoder is not null)
{
_specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
}
_encoder = encoder!;
_decoder = decoder!;
_vocab = vocab!;
_cache = new LruCache<int[]>(cacheSize);
(_specialTokensEncoder, _specialTokensDecoder) = CreateEncoderDecoder(specialTokens);
}
private Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTokens, int cacheSize, bool disposeStream) : this(cacheSize)
private Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTokens, int cacheSize, bool disposeStream)
{
try
{
_cache = new LruCache<int[]>(cacheSize);
(_encoder, _vocab, _decoder) = LoadTikTokenBpeAsync(vocabStream, useAsync: false).GetAwaiter().GetResult();
_specialTokensEncoder = specialTokens;
if (_specialTokensEncoder is not null)
{
_specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
}
(_specialTokensEncoder, _specialTokensDecoder) = CreateEncoderDecoder(specialTokens);
}
finally
{
@ -101,17 +100,15 @@ namespace Microsoft.ML.Tokenizers
}
}
private Tiktoken(int cacheSize)
private static (Dictionary<StringSpanOrdinalKey, int>?, Dictionary<int, string>?) CreateEncoderDecoder(IReadOnlyDictionary<string, int>? specialTokens)
{
if (cacheSize < 0)
if (specialTokens is not null)
{
throw new ArgumentOutOfRangeException(nameof(cacheSize));
var encoder = specialTokens.ToDictionary(e => new StringSpanOrdinalKey(e.Key), e => e.Value);
return (encoder, encoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key.Data!));
}
if (cacheSize > 0)
{
_cache = new LruCache<string, int[]>(cacheSize);
}
return (null, null);
}
/// <summary>
@ -125,7 +122,7 @@ namespace Microsoft.ML.Tokenizers
public static async Task<Tiktoken> CreateAsync(
Stream vocabStream,
IReadOnlyDictionary<string, int>? specialTokens = null,
int cacheSize = LruCache<string, int[]>.DefaultCacheSize,
int cacheSize = LruCache<int[]>.DefaultCacheSize,
CancellationToken cancellationToken = default)
{
if (vocabStream is null)
@ -133,7 +130,7 @@ namespace Microsoft.ML.Tokenizers
throw new ArgumentNullException(nameof(vocabStream));
}
(Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<string, int> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) =
(Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) =
await LoadTikTokenBpeAsync(vocabStream, useAsync: true, cancellationToken).ConfigureAwait(false);
return new Tiktoken(encoder, decoder, vocab, specialTokens, cacheSize);
@ -150,7 +147,7 @@ namespace Microsoft.ML.Tokenizers
public static async Task<Tiktoken> CreateAsync(
string vocabFilePath,
IReadOnlyDictionary<string, int>? specialTokensEncoder = null,
int cacheSize = LruCache<string, int[]>.DefaultCacheSize,
int cacheSize = LruCache<int[]>.DefaultCacheSize,
CancellationToken cancellationToken = default)
{
if (vocabFilePath is null)
@ -170,11 +167,11 @@ namespace Microsoft.ML.Tokenizers
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>Map of byte[] to integer token id</returns>
/// <exception cref="InvalidOperationException"></exception>
internal static async ValueTask<(Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<string, int>, Dictionary<int, ReadOnlyMemory<byte>>)> LoadTikTokenBpeAsync(
internal static async ValueTask<(Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<StringSpanOrdinalKey, int>, Dictionary<int, ReadOnlyMemory<byte>>)> LoadTikTokenBpeAsync(
Stream vocabStream, bool useAsync, CancellationToken cancellationToken = default)
{
var encoder = new Dictionary<ReadOnlyMemory<byte>, int>(ReadOnlyMemoryByteComparer.Instance);
var vocab = new Dictionary<string, int>();
var vocab = new Dictionary<StringSpanOrdinalKey, int>();
var decoder = new Dictionary<int, ReadOnlyMemory<byte>>();
try
@ -212,7 +209,7 @@ namespace Microsoft.ML.Tokenizers
if (decodedToken.IndexOf('\uFFFD') < 0)
{
vocab[decodedToken] = rank;
vocab[new StringSpanOrdinalKey(decodedToken)] = rank;
}
}
else
@ -230,12 +227,6 @@ namespace Microsoft.ML.Tokenizers
return (encoder, vocab, decoder);
}
/// <summary>
/// Gets the dictionary mapping special tokens to Ids.
/// </summary>
/// <returns>The dictionary mapping special tokens to Ids.</returns>
public IReadOnlyDictionary<string, int>? SpecialTokensEncoder => _specialTokensEncoder;
/// <summary>
/// Encode a split text string to a list of tokens.
/// </summary>
@ -253,12 +244,7 @@ namespace Microsoft.ML.Tokenizers
if (isSpecialToken)
{
if (_specialTokensEncoder is null)
{
throw new InvalidOperationException($"The tokenizer doesn't have special tokens");
}
if (_specialTokensEncoder.TryGetValue(text, out int id))
if (_specialTokensEncoder?.TryGetValue(text, out int id) is true)
{
return new List<Token> { new(id, text, (0, text.Length)) };
}
@ -266,7 +252,7 @@ namespace Microsoft.ML.Tokenizers
throw new InvalidOperationException($"The special token {text} doesn't exist in the tokenizer");
}
if (_cache?.Lookup(text, out int[] ids) is true)
if (_cache.TryGetValue(text, out int[]? ids))
{
tokens = new Token[ids.Length];
tokens[0] = new Token(ids[0], text, (0, text.Length));
@ -290,7 +276,7 @@ namespace Microsoft.ML.Tokenizers
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
Debug.Assert(encodedIds.Length > 0);
_cache?.Add(text, encodedIds);
_cache.Add(text, encodedIds);
tokens = new Token[encodedIds.Length];
tokens[0] = new Token(encodedIds[0], text, (0, text.Length));
@ -305,21 +291,21 @@ namespace Microsoft.ML.Tokenizers
}
/// <summary>
/// Encode a split text string to a list of Ids.
/// Encode text to a list of Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
/// <param name="accumulatedIds">The list of accumulated Ids.</param>
public override void EncodeToIds(string text, bool isSpecialToken, IList<int> accumulatedIds)
public override void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds)
{
if (string.IsNullOrEmpty(text))
if (text.IsEmpty)
{
return;
}
if (isSpecialToken)
{
if (_specialTokensEncoder is not null && _specialTokensEncoder.TryGetValue(text, out int id))
if (_specialTokensEncoder?.TryGetValue(text, out int id) is true)
{
accumulatedIds.Add(id);
}
@ -327,7 +313,7 @@ namespace Microsoft.ML.Tokenizers
return;
}
if (_cache?.Lookup(text, out int[] tokenIds) is true)
if (_cache.TryGetValue(text, out int[]? tokenIds))
{
accumulatedIds.AddRange(tokenIds);
return;
@ -340,10 +326,10 @@ namespace Microsoft.ML.Tokenizers
}
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
int encodedLength = GetUtf8Bytes(text.AsSpan(), arrayPoolArray);
int encodedLength = GetUtf8Bytes(text, arrayPoolArray);
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
_cache?.Add(text, encodedIds);
_cache.Add(text.ToString(), encodedIds);
accumulatedIds.AddRange(encodedIds);
@ -354,12 +340,12 @@ namespace Microsoft.ML.Tokenizers
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="text">The text to tokenize.</param>
/// <param name="isSpecialToken">Indicate if the token is special token.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int CountTokens(string text, bool isSpecialToken)
public override int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken)
{
if (string.IsNullOrEmpty(text))
if (text.IsEmpty)
{
return 0;
}
@ -369,7 +355,7 @@ namespace Microsoft.ML.Tokenizers
return _specialTokensEncoder.TryGetValue(text, out _) ? 1 : 0;
}
if (_cache?.Lookup(text, out int[] ids) is true)
if (_cache.TryGetValue(text, out int[] ids))
{
return ids.Length;
}
@ -380,10 +366,10 @@ namespace Microsoft.ML.Tokenizers
}
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
int encodedLength = GetUtf8Bytes(text.AsSpan(), arrayPoolArray);
int encodedLength = GetUtf8Bytes(text, arrayPoolArray);
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
_cache?.Add(text, encodedIds);
_cache.Add(text.ToString(), encodedIds);
ArrayPool<byte>.Shared.Return(arrayPoolArray);
return encodedIds.Length;
@ -395,19 +381,22 @@ namespace Microsoft.ML.Tokenizers
/// <param name="token">The token to map to the Id.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
/// <returns>The mapped Id of the token.</returns>
public override int? MapTokenToId(string token, bool considerSpecialTokens = true)
public override int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true)
{
if (string.IsNullOrEmpty(token))
if (token.IsEmpty)
{
return 0;
}
if (considerSpecialTokens && _specialTokensEncoder is not null && _specialTokensEncoder.TryGetValue(token, out int specialTokenId))
if (considerSpecialTokens && _specialTokensEncoder is not null)
{
return specialTokenId;
if (_specialTokensEncoder.TryGetValue(token, out int specialTokenId))
{
return specialTokenId;
}
}
if (_cache?.Lookup(token, out int[] ids) is true)
if (_cache.TryGetValue(token, out int[] ids))
{
if (ids.Length == 1)
{
@ -425,10 +414,10 @@ namespace Microsoft.ML.Tokenizers
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(token.Length));
try
{
int encodedLength = GetUtf8Bytes(token.AsSpan(), arrayPoolArray);
int encodedLength = GetUtf8Bytes(token, arrayPoolArray);
int[] idsToCache = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
_cache?.Add(token, idsToCache);
_cache.Add(token.ToString(), idsToCache);
if (idsToCache.Length == 1)
{
@ -550,7 +539,12 @@ namespace Microsoft.ML.Tokenizers
/// Gets the dictionary mapping tokens to Ids.
/// </summary>
/// <remarks>This may not contain the full set of vocabulary tokens, use Encoder to get the full set of vocabulary.</remarks>
public IReadOnlyDictionary<string, int> Vocab => _vocab;
public IReadOnlyDictionary<string, int> Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
/// <summary>
/// Gets the dictionary mapping special tokens to Ids.
/// </summary>
public IReadOnlyDictionary<string, int>? SpecialTokensEncoder => _specialTokensEncoderOriginal ??= _specialTokensEncoder?.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
/// <summary>
/// Gets the dictionary mapping token bytes to Ids.

Просмотреть файл

@ -104,7 +104,7 @@ namespace Microsoft.ML.Tokenizers
foreach (Split split in PreTokenizer.PreTokenize(normalized, considerSpecialTokens))
{
Model.EncodeToIds(split.TokenString, split.IsSpecialToken, idsList);
Model.EncodeToIds(split.TokenSpan, split.IsSpecialToken, idsList);
}
return idsList;
@ -130,7 +130,7 @@ namespace Microsoft.ML.Tokenizers
int idsCount = 0;
foreach (Split split in PreTokenizer.PreTokenize(normalized, considerSpecialTokens))
{
idsCount += Model.CountTokens(split.TokenString, split.IsSpecialToken);
idsCount += Model.CountTokens(split.TokenSpan, split.IsSpecialToken);
}
return idsCount;
@ -343,7 +343,7 @@ namespace Microsoft.ML.Tokenizers
}
}
private static readonly ConcurrentDictionary<string, (Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<string, int>, Dictionary<int, ReadOnlyMemory<byte>>)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase);
private static readonly ConcurrentDictionary<string, (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase);
/// <summary>
/// Create tokenizer based on regex pattern, BPE rank file and special tokens
@ -371,7 +371,7 @@ namespace Microsoft.ML.Tokenizers
}
}
if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<string, int> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) cache))
if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) cache))
{
using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false))
{

Просмотреть файл

@ -37,5 +37,7 @@ namespace Microsoft.ML.Tokenizers
internal static bool TryParseInt32(string s, int offset, out int result)
=> int.TryParse(s.AsSpan().Slice(offset), NumberStyles.None, CultureInfo.InvariantCulture, out result);
internal static int GetHashCode(ReadOnlySpan<char> span) => string.GetHashCode(span);
}
}

Просмотреть файл

@ -48,6 +48,17 @@ namespace Microsoft.ML.Tokenizers
return true;
}
internal static int GetHashCode(ReadOnlySpan<char> span)
{
int hash = 17;
foreach (char c in span)
{
hash = hash * 31 + c;
}
return hash;
}
}
}

Просмотреть файл

@ -2,47 +2,37 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
namespace Microsoft.ML.Tokenizers
{
internal class LruCache<TKey, TValue> where TKey : notnull where TValue : notnull
internal sealed class LruCache<TValue>
{
/// <summary>
/// The default LRU cache size.
/// </summary>
public const int DefaultCacheSize = 8192; // 4096;
public const int DefaultCacheSize = 8192;
private readonly object _lockObject = new object();
private class CacheItem
{
public readonly TKey Key;
public TValue Value;
public CacheItem(TKey key, TValue value)
{
Key = key;
Value = value;
}
}
private readonly Dictionary<TKey, LinkedListNode<CacheItem>> _cache;
private readonly LinkedList<CacheItem> _lruList;
private readonly Dictionary<StringSpanOrdinalKey, LinkedListNode<KeyValuePair<string, TValue>>> _cache = new();
private readonly LinkedList<KeyValuePair<string, TValue>> _lruList = new();
private readonly int _cacheSize;
private object SyncObj => _cache;
/// <summary>
/// Constructs an <see cref="LruCache{TKey,TValue}" /> object.
/// Constructs an <see cref="LruCache{TValue}" /> object.
/// </summary>
/// <param name="cacheSize">
/// The maximum number of <typeparamref name="TKey" /> to <typeparamref name="TValue" /> mappings
/// that can be cached. This defaults to <see cref="DefaultCacheSize" />, which is set to
/// <value>4096</value>.
/// The maximum number of mappings that can be cached. This defaults to <see cref="DefaultCacheSize" />, which is set to <value>8192</value>.
/// </param>
public LruCache(int cacheSize = DefaultCacheSize)
{
_cache = new Dictionary<TKey, LinkedListNode<CacheItem>>();
_lruList = new LinkedList<CacheItem>();
if (cacheSize <= 0)
{
throw new ArgumentOutOfRangeException(nameof(cacheSize), "Cache size must be a positive number.");
}
_cacheSize = cacheSize;
}
@ -54,11 +44,11 @@ namespace Microsoft.ML.Tokenizers
/// <returns>
/// true if the cache contains a mapping for key, false otherwise.
/// </returns>
public bool Lookup(TKey key, out TValue value)
public bool TryGetValue(string key, out TValue value)
{
lock (_lockObject)
lock (SyncObj)
{
if (_cache.TryGetValue(key, out LinkedListNode<CacheItem>? cached))
if (_cache.TryGetValue(new StringSpanOrdinalKey(key), out LinkedListNode<KeyValuePair<string, TValue>>? cached))
{
_lruList.Remove(cached);
_lruList.AddFirst(cached);
@ -71,16 +61,31 @@ namespace Microsoft.ML.Tokenizers
}
}
protected virtual void OnEviction(TValue evictedValue) { }
private void EvictIfNeeded()
/// <summary>
/// Retrieves the value associated with the specified key /> object.
/// </summary>
/// <param name="key">The object to be used as a key.</param>
/// <param name="value">An out parameter that is set to the value of the key if key contains a mapping in the cache.</param>
/// <returns>
/// true if the cache contains a mapping for key, false otherwise.
/// </returns>
public unsafe bool TryGetValue(ReadOnlySpan<char> key, out TValue value)
{
while (_cache.Count >= _cacheSize)
lock (SyncObj)
{
LinkedListNode<CacheItem>? nodeToEvict = _lruList.Last;
_lruList.RemoveLast();
_cache.Remove(nodeToEvict!.Value.Key);
OnEviction(nodeToEvict.Value.Value);
fixed (char* ptr = key)
{
if (_cache.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out LinkedListNode<KeyValuePair<string, TValue>>? cached))
{
_lruList.Remove(cached);
_lruList.AddFirst(cached);
value = cached.Value.Value;
return true;
}
}
value = default!;
return false;
}
}
@ -89,46 +94,29 @@ namespace Microsoft.ML.Tokenizers
/// </summary>
/// <param name="key">The key whose mapped <paramref name="value" /> is to be created or replaced.</param>
/// <param name="value">The new value to be mapped to the <paramref name="key" />.</param>
public void Add(TKey key, TValue value) => Replace(key, value, out _);
public bool Replace(TKey key, TValue value, out TValue oldValue)
public void Add(string key, TValue value)
{
lock (_lockObject)
lock (SyncObj)
{
return ReplaceInternal(key, value, out oldValue);
if (_cache.TryGetValue(new StringSpanOrdinalKey(key), out LinkedListNode<KeyValuePair<string, TValue>>? cached))
{
cached.Value = new KeyValuePair<string, TValue>(key, value);
_lruList.Remove(cached);
_lruList.AddFirst(cached);
return;
}
while (_cache.Count >= _cacheSize)
{
LinkedListNode<KeyValuePair<string, TValue>>? nodeToEvict = _lruList.Last;
_lruList.RemoveLast();
_cache.Remove(new StringSpanOrdinalKey(nodeToEvict!.Value.Key));
}
var node = new LinkedListNode<KeyValuePair<string, TValue>>(new KeyValuePair<string, TValue>(key, value));
_cache[new StringSpanOrdinalKey(key)] = node;
_lruList.AddFirst(node);
}
}
private bool ReplaceInternal(TKey key, TValue value, out TValue oldValue)
{
if (_cache.TryGetValue(key, out LinkedListNode<CacheItem>? cached))
{
oldValue = cached.Value.Value;
cached.Value.Value = value;
_lruList.Remove(cached);
_lruList.AddFirst(cached);
return true;
}
EvictIfNeeded();
var node = new LinkedListNode<CacheItem>(new CacheItem(key, value));
_cache[key] = node;
_lruList.AddFirst(node);
oldValue = default!;
return false;
}
/// <summary>
/// The number of entries currently present in the cache.
/// </summary>
public int Count => _cache.Count;
/// <summary>
/// Clears the contents of this cache.
/// </summary>
public void Clear()
{
_cache.Clear();
_lruList.Clear();
}
}
}
}

Просмотреть файл

@ -0,0 +1,132 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace Microsoft.ML.Tokenizers
{
/// <summary>Used as a key in a dictionary to enable querying with either a string or a span.</summary>
/// <remarks>
/// This should only be used with a Ptr/Length for querying. For storing in a dictionary, this should
/// always be used with a string.
/// </remarks>
internal unsafe readonly struct StringSpanOrdinalKey : IEquatable<StringSpanOrdinalKey>
{
public readonly char* Ptr;
public readonly int Length;
public readonly string? Data;
public StringSpanOrdinalKey(char* ptr, int length)
{
Ptr = ptr;
Length = length;
}
public StringSpanOrdinalKey(string data) =>
Data = data;
private ReadOnlySpan<char> Span => Ptr is not null ?
new ReadOnlySpan<char>(Ptr, Length) :
Data.AsSpan();
public override bool Equals(object? obj) =>
obj is StringSpanOrdinalKey wrapper && Equals(wrapper);
public bool Equals(StringSpanOrdinalKey other) =>
Span.SequenceEqual(other.Span);
public override int GetHashCode() => Helpers.GetHashCode(Span);
}
internal sealed class StringSpanOrdinalKeyCache<TValue>
{
private readonly int _capacity;
private readonly Dictionary<StringSpanOrdinalKey, TValue> _map;
private object SyncObj => _map;
internal StringSpanOrdinalKeyCache() : this(Bpe.DefaultCacheCapacity) { }
internal StringSpanOrdinalKeyCache(int capacity)
{
_capacity = capacity;
_map = new Dictionary<StringSpanOrdinalKey, TValue>(capacity);
}
internal bool TryGetValue(string key, out TValue value)
{
lock (SyncObj)
{
return _map.TryGetValue(new StringSpanOrdinalKey(key), out value!);
}
}
internal unsafe bool TryGetValue(ReadOnlySpan<char> key, out TValue value)
{
lock (SyncObj)
{
fixed (char* ptr = key)
{
return _map.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out value!);
}
}
}
internal void Remove(string key)
{
lock (SyncObj)
{
_map.Remove(new StringSpanOrdinalKey(key));
}
}
internal void Set(string k, TValue v)
{
lock (SyncObj)
{
if (_map.Count < _capacity)
{
_map[new StringSpanOrdinalKey(k)] = v;
}
}
}
}
/// <summary>
/// Custom JSON converter for <see cref="StringSpanOrdinalKey"/>.
/// </summary>
internal sealed class StringSpanOrdinalKeyConverter : JsonConverter<StringSpanOrdinalKey>
{
public static StringSpanOrdinalKeyConverter Instance { get; } = new StringSpanOrdinalKeyConverter();
public override StringSpanOrdinalKey ReadAsPropertyName(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) =>
new StringSpanOrdinalKey(reader.GetString()!);
public override void WriteAsPropertyName(Utf8JsonWriter writer, StringSpanOrdinalKey value, JsonSerializerOptions options) =>
writer.WriteStringValue(value.Data!);
public override StringSpanOrdinalKey Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) => new StringSpanOrdinalKey(reader.GetString()!);
public override void Write(Utf8JsonWriter writer, StringSpanOrdinalKey value, JsonSerializerOptions options) => writer.WriteStringValue(value.Data!);
}
/// <summary>
/// Extension methods for <see cref="StringSpanOrdinalKey"/>.
/// </summary>
internal static class StringSpanOrdinalKeyExtensions
{
public unsafe static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKey, TValue> map, ReadOnlySpan<char> key, out TValue value)
{
fixed (char* ptr = key)
{
return map.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out value!);
}
}
public static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKey, TValue> map, string key, out TValue value) =>
map.TryGetValue(new StringSpanOrdinalKey(key), out value!);
}
}

Просмотреть файл

@ -156,7 +156,7 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(ids[i], encoding.Ids[i]);
Assert.Equal(ids[i], idsList[i]);
Assert.Equal(encoding.Tokens[i], tokenizer.Model.MapIdToToken(encoding.Ids[i]));
Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i]));
Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i].AsSpan()));
Assert.Equal(encoding.Tokens[i], tokenizer.Decode(encoding.Ids[i]));
}
}

Просмотреть файл

@ -201,7 +201,7 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.Equal(unfilteredToken![i], tokenizer.Model.MapIdToToken(encoding.Ids[i], considerSpecialTokens: false));
}
Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i]));
Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i].AsSpan()));
}
}
}