diff --git a/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj b/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj index fbff32071..8294d9954 100644 --- a/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj +++ b/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj @@ -32,17 +32,37 @@ + = 0) + { + eolIndex++; + capacity++; + } + else + { + break; + } + } while (eolIndex < fileContent.Length); + + using var sourceStream = File.OpenRead(fileName); using var reader = new StreamReader(sourceStream); using var destStream = new DeflateStream(File.Create(file.GetMetadata("Destination")), CompressionLevel.Optimal); using var streamWriter = new StreamWriter(destStream); + streamWriter.WriteLine($"Capacity: {capacity.ToString(CultureInfo.InvariantCulture)}"); + string line; int destLineNumber = 0; diff --git a/src/Microsoft.ML.Tokenizers/Model/TiktokenTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/TiktokenTokenizer.cs index cda85b118..08bbf5763 100644 --- a/src/Microsoft.ML.Tokenizers/Model/TiktokenTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/TiktokenTokenizer.cs @@ -156,21 +156,37 @@ namespace Microsoft.ML.Tokenizers internal static async ValueTask<(Dictionary, int>, Dictionary, Dictionary>)> LoadTiktokenBpeAsync( Stream vocabStream, bool useAsync, CancellationToken cancellationToken = default) { - var encoder = new Dictionary, int>(ReadOnlyMemoryByteComparer.Instance); - var vocab = new Dictionary(); - var decoder = new Dictionary>(); + Dictionary, int> encoder; + Dictionary vocab; + Dictionary> decoder; try { // Don't dispose the reader as it will dispose the underlying stream vocabStream. The caller is responsible for disposing the stream. StreamReader reader = new StreamReader(vocabStream); - string? line; - do + string? line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine(); + + const string capacity = "Capacity: "; + int suggestedCapacity = 0; // default capacity + if (line is not null && line.StartsWith(capacity, StringComparison.Ordinal)) { - line = useAsync ? - await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : - reader.ReadLine(); - } while (line is not null && line.Length == 0); + if (!Helpers.TryParseInt32(line, capacity.Length, out suggestedCapacity)) + { + throw new FormatException($"Invalid format in the BPE vocab file stream"); + } + + line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine(); + } + + encoder = new Dictionary, int>(suggestedCapacity, ReadOnlyMemoryByteComparer.Instance); + vocab = new Dictionary(suggestedCapacity); + decoder = new Dictionary>(suggestedCapacity); + + // skip empty lines + while (line is not null && line.Length == 0) + { + line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine(); + } if (line is not null && line.IndexOf(' ') < 0) { diff --git a/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs index 231a8b22e..791e24527 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs @@ -12,7 +12,6 @@ using System.Linq; using System.Reflection; using System.Text; using System.Text.Json; -using System.Text.Json.Serialization; using System.Threading.Tasks; using Xunit; @@ -501,9 +500,22 @@ namespace Microsoft.ML.Tokenizers.Tests { RemoteExecutor.Invoke(static (name) => { +#if NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER + long allocation = GC.GetAllocatedBytesForCurrentThread(); +#endif // NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER + Tokenizer tokenizer = TiktokenTokenizer.CreateForModel(name); Assert.True(tokenizer is TiktokenTokenizer); Assert.NotNull(tokenizer.PreTokenizer); + +#if NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER + int entriesCount = GetEncoder((tokenizer as TiktokenTokenizer)!)!.Count; + allocation = GC.GetAllocatedBytesForCurrentThread() - allocation; + + // entriesCount * 260 is average memory allocation during the initialization for the the models we carry data files for. + // this allocation is not the size of the cache but it include all temporary allocations during the initialization. + Assert.True((entriesCount * 260) > allocation, $"Memory allocation of {entriesCount} entries for {name}: {allocation} bytes"); +#endif // NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER }, modelName).Dispose(); }