Reduce Tiktoken Creation Memory Allocation (#7202)
This commit is contained in:
Родитель
34eb579d4b
Коммит
f72c9d22fc
|
@ -32,17 +32,37 @@
|
|||
<Files ParameterType="Microsoft.Build.Framework.ITaskItem[]" Required="true" />
|
||||
</ParameterGroup>
|
||||
<Task>
|
||||
<Using Namespace="System.Globalization" />
|
||||
<Using Namespace="System.IO" />
|
||||
<Using Namespace="System.IO.Compression" />
|
||||
<Code Type="Fragment" Language="cs">
|
||||
<![CDATA[
|
||||
foreach(var file in Files)
|
||||
foreach (var file in Files)
|
||||
{
|
||||
using var sourceStream = File.OpenRead(file.GetMetadata("FullPath"));
|
||||
string fileName = file.GetMetadata("FullPath");
|
||||
string fileContent = File.ReadAllText(fileName);
|
||||
int capacity = 1;
|
||||
int eolIndex = 0;
|
||||
do
|
||||
{
|
||||
if ((eolIndex = fileContent.IndexOf('\n', eolIndex)) >= 0)
|
||||
{
|
||||
eolIndex++;
|
||||
capacity++;
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
} while (eolIndex < fileContent.Length);
|
||||
|
||||
using var sourceStream = File.OpenRead(fileName);
|
||||
using var reader = new StreamReader(sourceStream);
|
||||
using var destStream = new DeflateStream(File.Create(file.GetMetadata("Destination")), CompressionLevel.Optimal);
|
||||
using var streamWriter = new StreamWriter(destStream);
|
||||
|
||||
streamWriter.WriteLine($"Capacity: {capacity.ToString(CultureInfo.InvariantCulture)}");
|
||||
|
||||
string line;
|
||||
int destLineNumber = 0;
|
||||
|
||||
|
|
|
@ -156,21 +156,37 @@ namespace Microsoft.ML.Tokenizers
|
|||
internal static async ValueTask<(Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<StringSpanOrdinalKey, (int Id, string Token)>, Dictionary<int, ReadOnlyMemory<byte>>)> LoadTiktokenBpeAsync(
|
||||
Stream vocabStream, bool useAsync, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var encoder = new Dictionary<ReadOnlyMemory<byte>, int>(ReadOnlyMemoryByteComparer.Instance);
|
||||
var vocab = new Dictionary<StringSpanOrdinalKey, (int Id, string Token)>();
|
||||
var decoder = new Dictionary<int, ReadOnlyMemory<byte>>();
|
||||
Dictionary<ReadOnlyMemory<byte>, int> encoder;
|
||||
Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab;
|
||||
Dictionary<int, ReadOnlyMemory<byte>> decoder;
|
||||
|
||||
try
|
||||
{
|
||||
// Don't dispose the reader as it will dispose the underlying stream vocabStream. The caller is responsible for disposing the stream.
|
||||
StreamReader reader = new StreamReader(vocabStream);
|
||||
string? line;
|
||||
do
|
||||
string? line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine();
|
||||
|
||||
const string capacity = "Capacity: ";
|
||||
int suggestedCapacity = 0; // default capacity
|
||||
if (line is not null && line.StartsWith(capacity, StringComparison.Ordinal))
|
||||
{
|
||||
line = useAsync ?
|
||||
await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) :
|
||||
reader.ReadLine();
|
||||
} while (line is not null && line.Length == 0);
|
||||
if (!Helpers.TryParseInt32(line, capacity.Length, out suggestedCapacity))
|
||||
{
|
||||
throw new FormatException($"Invalid format in the BPE vocab file stream");
|
||||
}
|
||||
|
||||
line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine();
|
||||
}
|
||||
|
||||
encoder = new Dictionary<ReadOnlyMemory<byte>, int>(suggestedCapacity, ReadOnlyMemoryByteComparer.Instance);
|
||||
vocab = new Dictionary<StringSpanOrdinalKey, (int Id, string Token)>(suggestedCapacity);
|
||||
decoder = new Dictionary<int, ReadOnlyMemory<byte>>(suggestedCapacity);
|
||||
|
||||
// skip empty lines
|
||||
while (line is not null && line.Length == 0)
|
||||
{
|
||||
line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine();
|
||||
}
|
||||
|
||||
if (line is not null && line.IndexOf(' ') < 0)
|
||||
{
|
||||
|
|
|
@ -12,7 +12,6 @@ using System.Linq;
|
|||
using System.Reflection;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
using System.Threading.Tasks;
|
||||
using Xunit;
|
||||
|
||||
|
@ -501,9 +500,22 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
{
|
||||
RemoteExecutor.Invoke(static (name) =>
|
||||
{
|
||||
#if NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER
|
||||
long allocation = GC.GetAllocatedBytesForCurrentThread();
|
||||
#endif // NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER
|
||||
|
||||
Tokenizer tokenizer = TiktokenTokenizer.CreateForModel(name);
|
||||
Assert.True(tokenizer is TiktokenTokenizer);
|
||||
Assert.NotNull(tokenizer.PreTokenizer);
|
||||
|
||||
#if NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER
|
||||
int entriesCount = GetEncoder((tokenizer as TiktokenTokenizer)!)!.Count;
|
||||
allocation = GC.GetAllocatedBytesForCurrentThread() - allocation;
|
||||
|
||||
// entriesCount * 260 is average memory allocation during the initialization for the the models we carry data files for.
|
||||
// this allocation is not the size of the cache but it include all temporary allocations during the initialization.
|
||||
Assert.True((entriesCount * 260) > allocation, $"Memory allocation of {entriesCount} entries for {name}: {allocation} bytes");
|
||||
#endif // NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER
|
||||
}, modelName).Dispose();
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче