diff --git a/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj b/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj
index fbff32071..8294d9954 100644
--- a/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj
+++ b/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj
@@ -32,17 +32,37 @@
+
= 0)
+ {
+ eolIndex++;
+ capacity++;
+ }
+ else
+ {
+ break;
+ }
+ } while (eolIndex < fileContent.Length);
+
+ using var sourceStream = File.OpenRead(fileName);
using var reader = new StreamReader(sourceStream);
using var destStream = new DeflateStream(File.Create(file.GetMetadata("Destination")), CompressionLevel.Optimal);
using var streamWriter = new StreamWriter(destStream);
+ streamWriter.WriteLine($"Capacity: {capacity.ToString(CultureInfo.InvariantCulture)}");
+
string line;
int destLineNumber = 0;
diff --git a/src/Microsoft.ML.Tokenizers/Model/TiktokenTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/TiktokenTokenizer.cs
index cda85b118..08bbf5763 100644
--- a/src/Microsoft.ML.Tokenizers/Model/TiktokenTokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/TiktokenTokenizer.cs
@@ -156,21 +156,37 @@ namespace Microsoft.ML.Tokenizers
internal static async ValueTask<(Dictionary, int>, Dictionary, Dictionary>)> LoadTiktokenBpeAsync(
Stream vocabStream, bool useAsync, CancellationToken cancellationToken = default)
{
- var encoder = new Dictionary, int>(ReadOnlyMemoryByteComparer.Instance);
- var vocab = new Dictionary();
- var decoder = new Dictionary>();
+ Dictionary, int> encoder;
+ Dictionary vocab;
+ Dictionary> decoder;
try
{
// Don't dispose the reader as it will dispose the underlying stream vocabStream. The caller is responsible for disposing the stream.
StreamReader reader = new StreamReader(vocabStream);
- string? line;
- do
+ string? line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine();
+
+ const string capacity = "Capacity: ";
+ int suggestedCapacity = 0; // default capacity
+ if (line is not null && line.StartsWith(capacity, StringComparison.Ordinal))
{
- line = useAsync ?
- await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) :
- reader.ReadLine();
- } while (line is not null && line.Length == 0);
+ if (!Helpers.TryParseInt32(line, capacity.Length, out suggestedCapacity))
+ {
+ throw new FormatException($"Invalid format in the BPE vocab file stream");
+ }
+
+ line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine();
+ }
+
+ encoder = new Dictionary, int>(suggestedCapacity, ReadOnlyMemoryByteComparer.Instance);
+ vocab = new Dictionary(suggestedCapacity);
+ decoder = new Dictionary>(suggestedCapacity);
+
+ // skip empty lines
+ while (line is not null && line.Length == 0)
+ {
+ line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine();
+ }
if (line is not null && line.IndexOf(' ') < 0)
{
diff --git a/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs
index 231a8b22e..791e24527 100644
--- a/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs
+++ b/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs
@@ -12,7 +12,6 @@ using System.Linq;
using System.Reflection;
using System.Text;
using System.Text.Json;
-using System.Text.Json.Serialization;
using System.Threading.Tasks;
using Xunit;
@@ -501,9 +500,22 @@ namespace Microsoft.ML.Tokenizers.Tests
{
RemoteExecutor.Invoke(static (name) =>
{
+#if NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER
+ long allocation = GC.GetAllocatedBytesForCurrentThread();
+#endif // NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER
+
Tokenizer tokenizer = TiktokenTokenizer.CreateForModel(name);
Assert.True(tokenizer is TiktokenTokenizer);
Assert.NotNull(tokenizer.PreTokenizer);
+
+#if NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER
+ int entriesCount = GetEncoder((tokenizer as TiktokenTokenizer)!)!.Count;
+ allocation = GC.GetAllocatedBytesForCurrentThread() - allocation;
+
+ // entriesCount * 260 is average memory allocation during the initialization for the the models we carry data files for.
+ // this allocation is not the size of the cache but it include all temporary allocations during the initialization.
+ Assert.True((entriesCount * 260) > allocation, $"Memory allocation of {entriesCount} entries for {name}: {allocation} bytes");
+#endif // NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER
}, modelName).Dispose();
}