Reduce Tiktoken Creation Memory Allocation (#7202)

This commit is contained in:
Tarek Mahmoud Sayed 2024-07-28 15:55:03 -07:00 коммит произвёл GitHub
Родитель 34eb579d4b
Коммит f72c9d22fc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
3 изменённых файлов: 60 добавлений и 12 удалений

Просмотреть файл

@ -32,17 +32,37 @@
<Files ParameterType="Microsoft.Build.Framework.ITaskItem[]" Required="true" />
</ParameterGroup>
<Task>
<Using Namespace="System.Globalization" />
<Using Namespace="System.IO" />
<Using Namespace="System.IO.Compression" />
<Code Type="Fragment" Language="cs">
<![CDATA[
foreach(var file in Files)
foreach (var file in Files)
{
using var sourceStream = File.OpenRead(file.GetMetadata("FullPath"));
string fileName = file.GetMetadata("FullPath");
string fileContent = File.ReadAllText(fileName);
int capacity = 1;
int eolIndex = 0;
do
{
if ((eolIndex = fileContent.IndexOf('\n', eolIndex)) >= 0)
{
eolIndex++;
capacity++;
}
else
{
break;
}
} while (eolIndex < fileContent.Length);
using var sourceStream = File.OpenRead(fileName);
using var reader = new StreamReader(sourceStream);
using var destStream = new DeflateStream(File.Create(file.GetMetadata("Destination")), CompressionLevel.Optimal);
using var streamWriter = new StreamWriter(destStream);
streamWriter.WriteLine($"Capacity: {capacity.ToString(CultureInfo.InvariantCulture)}");
string line;
int destLineNumber = 0;

Просмотреть файл

@ -156,21 +156,37 @@ namespace Microsoft.ML.Tokenizers
internal static async ValueTask<(Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<StringSpanOrdinalKey, (int Id, string Token)>, Dictionary<int, ReadOnlyMemory<byte>>)> LoadTiktokenBpeAsync(
Stream vocabStream, bool useAsync, CancellationToken cancellationToken = default)
{
var encoder = new Dictionary<ReadOnlyMemory<byte>, int>(ReadOnlyMemoryByteComparer.Instance);
var vocab = new Dictionary<StringSpanOrdinalKey, (int Id, string Token)>();
var decoder = new Dictionary<int, ReadOnlyMemory<byte>>();
Dictionary<ReadOnlyMemory<byte>, int> encoder;
Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab;
Dictionary<int, ReadOnlyMemory<byte>> decoder;
try
{
// Don't dispose the reader as it will dispose the underlying stream vocabStream. The caller is responsible for disposing the stream.
StreamReader reader = new StreamReader(vocabStream);
string? line;
do
string? line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine();
const string capacity = "Capacity: ";
int suggestedCapacity = 0; // default capacity
if (line is not null && line.StartsWith(capacity, StringComparison.Ordinal))
{
line = useAsync ?
await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) :
reader.ReadLine();
} while (line is not null && line.Length == 0);
if (!Helpers.TryParseInt32(line, capacity.Length, out suggestedCapacity))
{
throw new FormatException($"Invalid format in the BPE vocab file stream");
}
line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine();
}
encoder = new Dictionary<ReadOnlyMemory<byte>, int>(suggestedCapacity, ReadOnlyMemoryByteComparer.Instance);
vocab = new Dictionary<StringSpanOrdinalKey, (int Id, string Token)>(suggestedCapacity);
decoder = new Dictionary<int, ReadOnlyMemory<byte>>(suggestedCapacity);
// skip empty lines
while (line is not null && line.Length == 0)
{
line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine();
}
if (line is not null && line.IndexOf(' ') < 0)
{

Просмотреть файл

@ -12,7 +12,6 @@ using System.Linq;
using System.Reflection;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
using Xunit;
@ -501,9 +500,22 @@ namespace Microsoft.ML.Tokenizers.Tests
{
RemoteExecutor.Invoke(static (name) =>
{
#if NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER
long allocation = GC.GetAllocatedBytesForCurrentThread();
#endif // NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER
Tokenizer tokenizer = TiktokenTokenizer.CreateForModel(name);
Assert.True(tokenizer is TiktokenTokenizer);
Assert.NotNull(tokenizer.PreTokenizer);
#if NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER
int entriesCount = GetEncoder((tokenizer as TiktokenTokenizer)!)!.Count;
allocation = GC.GetAllocatedBytesForCurrentThread() - allocation;
// entriesCount * 260 is average memory allocation during the initialization for the the models we carry data files for.
// this allocation is not the size of the cache but it include all temporary allocations during the initialization.
Assert.True((entriesCount * 260) > allocation, $"Memory allocation of {entriesCount} entries for {name}: {allocation} bytes");
#endif // NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER
}, modelName).Dispose();
}