Embed the Tokenizer data files inside the assembly (#6403)

This commit is contained in:
Tarek Mahmoud Sayed 2022-10-24 10:30:48 -07:00 коммит произвёл GitHub
Родитель 1903fa5eda
Коммит c69acbeb97
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
9 изменённых файлов: 100380 добавлений и 101 удалений

Просмотреть файл

@ -36,14 +36,71 @@ namespace Microsoft.ML.Tokenizers
/// <param name="highestOccurrenceMappingPath">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
public EnglishRoberta(string vocabularyPath, string mergePath, string highestOccurrenceMappingPath)
{
if (vocabularyPath is null)
{
throw new ArgumentNullException(nameof(vocabularyPath));
}
if (mergePath is null)
{
throw new ArgumentNullException(nameof(mergePath));
}
if (highestOccurrenceMappingPath is null)
{
throw new ArgumentNullException(nameof(highestOccurrenceMappingPath));
}
using Stream vocabularyStream = File.OpenRead(vocabularyPath);
using Stream mergeStream = File.OpenRead(mergePath);
using Stream highestOccurrenceMappingStream = File.OpenRead(highestOccurrenceMappingPath);
// vocabularyPath like encoder.json
// merge file like vocab.bpe
// highestOccurrenceMappingPath like dict.txt
_vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingPath);
_vocab = GetVocabulary(vocabularyPath);
_vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingStream);
_vocab = GetVocabulary(vocabularyStream);
_vocabReverse = _vocab.ReverseSorted();
_mergeRanks = GetMergeRanks(mergePath);
_mergeRanks = GetMergeRanks(mergeStream);
int maxCharValue = GetByteToUnicode(out _byteToUnicode);
_charToString = new string[maxCharValue];
for (char c = (char)0; c < (char)maxCharValue; c++)
{
_charToString[c] = c.ToString();
}
_unicodeToByte = _byteToUnicode.Reverse();
_cache = new Cache<string, IReadOnlyList<Token>>();
}
/// <summary>
/// Construct tokenizer object to use with the English Robert model.
/// </summary>
/// <param name="vocabularyStream">The stream of a JSON file containing the dictionary of string keys and their ids.</param>
/// <param name="mergeStream">The stream of a file containing the tokens's pairs list.</param>
/// <param name="highestOccurrenceMappingStream">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream)
{
if (vocabularyStream is null)
{
throw new ArgumentNullException(nameof(vocabularyStream));
}
if (mergeStream is null)
{
throw new ArgumentNullException(nameof(mergeStream));
}
if (highestOccurrenceMappingStream is null)
{
throw new ArgumentNullException(nameof(highestOccurrenceMappingStream));
}
_vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingStream);
_vocab = GetVocabulary(vocabularyStream);
_vocabReverse = _vocab.ReverseSorted();
_mergeRanks = GetMergeRanks(mergeStream);
int maxCharValue = GetByteToUnicode(out _byteToUnicode);
_charToString = new string[maxCharValue];
for (char c = (char)0; c < (char)maxCharValue; c++)
@ -298,28 +355,24 @@ namespace Microsoft.ML.Tokenizers
return tokens;
}
private static HighestOccurrenceMapping GetHighestOccurrenceMapping(string highestOccurrenceMappingPath) =>
HighestOccurrenceMapping.Load(highestOccurrenceMappingPath);
private static HighestOccurrenceMapping GetHighestOccurrenceMapping(Stream highestOccurrenceMappingStream) =>
HighestOccurrenceMapping.Load(highestOccurrenceMappingStream);
private Dictionary<string, int> GetVocabulary(string vocabularyPath)
private Dictionary<string, int> GetVocabulary(Stream vocabularyStream)
{
Dictionary<string, int>? vocab;
try
{
using (Stream stream = File.OpenRead(vocabularyPath))
{
vocab = JsonSerializer.Deserialize<Dictionary<string, int>>(stream) as Dictionary<string, int>;
}
vocab = JsonSerializer.Deserialize<Dictionary<string, int>>(vocabularyStream) as Dictionary<string, int>;
}
catch (Exception e)
{
throw new ArgumentException($"Problems met when parsing JSON object in {vocabularyPath}.{Environment.NewLine}Error message: {e.Message}");
throw new ArgumentException($"Problems met when parsing JSON vocabulary object.{Environment.NewLine}Error message: {e.Message}");
}
if (vocab is null)
{
throw new ArgumentException($"Failed to read the vocabulary file '{vocabularyPath}'");
throw new ArgumentException($"Failed to read the vocabulary file.");
}
if (_vocabIdToHighestOccurrence.BosWord is not null)
@ -345,28 +398,28 @@ namespace Microsoft.ML.Tokenizers
return vocab;
}
private Dictionary<(string, string), int> GetMergeRanks(string mergePath)
private Dictionary<(string, string), int> GetMergeRanks(Stream mergeStream)
{
string[] splitContents;
List<string> splitContents = new();
try
{
splitContents = File.ReadAllLines(mergePath);
using StreamReader reader = new StreamReader(mergeStream);
while (reader.Peek() >= 0)
{
splitContents.Add(reader.ReadLine());
}
}
catch (Exception e)
{
throw new IOException($"Cannot read the file '{mergePath}'.{Environment.NewLine}Error message: {e.Message}", e);
throw new IOException($"Cannot read the file Merge file.{Environment.NewLine}Error message: {e.Message}", e);
}
var mergeRanks = new Dictionary<(string, string), int>();
for (int i = 0; i < splitContents.Length; i++)
// We ignore the first and last line in the file
for (int i = 1; i < splitContents.Count - 1; i++)
{
if (i == 0 || i == splitContents.Length - 1)
{
continue;
}
var split = splitContents[i].Split(' ');
if (split.Length != 2 || string.IsNullOrEmpty(split[0]) || string.IsNullOrEmpty(split[1]))
{
@ -664,22 +717,25 @@ namespace Microsoft.ML.Tokenizers
/// 284 432911125
/// ...
/// </summary>
public static HighestOccurrenceMapping Load(string fileName)
public static HighestOccurrenceMapping Load(Stream stream)
{
var mapping = new HighestOccurrenceMapping();
mapping.AddFromFile(fileName);
mapping.AddFromStream(stream);
return mapping;
}
/// <summary>
/// Loads a pre-existing vocabulary from a text file and adds its symbols to this instance.
/// Loads a pre-existing vocabulary from a text stream and adds its symbols to this instance.
/// </summary>
public void AddFromFile(string fileName)
public void AddFromStream(Stream stream)
{
var lines = File.ReadAllLines(fileName, Encoding.UTF8);
Debug.Assert(stream is not null);
using StreamReader reader = new StreamReader(stream);
foreach (var line in lines)
while (reader.Peek() >= 0)
{
string line = reader.ReadLine();
var splitLine = line.Trim().Split(' ');
if (splitLine.Length != 2)
{

Просмотреть файл

@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Reflection;
using System.Runtime.CompilerServices;
using Microsoft.ML.Runtime;
using Microsoft.ML.Tokenizers;
@ -13,25 +14,22 @@ namespace Microsoft.ML.TorchSharp.Extensions
{
internal static class TokenizerExtensions
{
private const string EncoderJsonName = "encoder.json";
private const string MergeName = "vocab.bpe";
private const string DictName = "dict.txt";
private static readonly Uri _encoderJsonUrl = new Uri("https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json");
private static readonly Uri _mergeUrl = new Uri("https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe");
private static readonly Uri _dictUrl = new Uri("https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt");
private static Tokenizer _instance;
internal static Tokenizer GetInstance(IChannel ch)
{
if (_instance is null)
{
FileUtils.LoadFromFileOrDownloadFromWeb(string.Empty, EncoderJsonName, _encoderJsonUrl, ch);
FileUtils.LoadFromFileOrDownloadFromWeb(string.Empty, MergeName, _mergeUrl, ch);
FileUtils.LoadFromFileOrDownloadFromWeb(string.Empty, DictName, _dictUrl, ch);
// encoder.json, vocab.bpe, and dict.txt are picked up from the following source:
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt"
Assembly assembly = typeof(TokenizerExtensions).Assembly;
EnglishRoberta model = new EnglishRoberta(EncoderJsonName, MergeName, DictName);
EnglishRoberta model = new EnglishRoberta(
assembly.GetManifestResourceStream("encoder.json"),
assembly.GetManifestResourceStream("vocab.bpe"),
assembly.GetManifestResourceStream("dict.txt"));
model.AddMaskSymbol();
_instance = new Tokenizer(model, new RobertaPreTokenizer());
}

Просмотреть файл

@ -10,9 +10,9 @@
<ItemGroup>
<PackageReference Include="TorchSharp" Version="$(TorchSharpVersion)" />
<PackageReference Include="libtorch-cpu-win-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Windows'))" PrivateAssets="all"/>
<PackageReference Include="libtorch-cpu-linux-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Linux'))" PrivateAssets="all"/>
<PackageReference Include="libtorch-cpu-osx-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('OSX'))" PrivateAssets="all"/>
<PackageReference Include="libtorch-cpu-win-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Windows'))" PrivateAssets="all" />
<PackageReference Include="libtorch-cpu-linux-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Linux'))" PrivateAssets="all" />
<PackageReference Include="libtorch-cpu-osx-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('OSX'))" PrivateAssets="all" />
</ItemGroup>
<ItemGroup>
@ -24,4 +24,16 @@
</ProjectReference>
</ItemGroup>
<ItemGroup>
<EmbeddedResource Include="Resources\dict.txt">
<LogicalName>dict.txt</LogicalName>
</EmbeddedResource>
<EmbeddedResource Include="Resources\encoder.json">
<LogicalName>encoder.json</LogicalName>
</EmbeddedResource>
<EmbeddedResource Include="Resources\vocab.bpe">
<LogicalName>vocab.bpe</LogicalName>
</EmbeddedResource>
</ItemGroup>
</Project>

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -28,54 +28,6 @@ namespace Microsoft.ML.TorchSharp.Utils
typeof(double),
};
public static string LoadFromFileOrDownloadFromWeb(string path, string fileName, Uri url, IChannel ch)
{
Contracts.AssertNonWhiteSpace(fileName, "Filename can't be empty");
var contents = "";
var filePath = Path.Combine(path, fileName);
if (!File.Exists(filePath))
{
try
{
using var webClient = new WebClient();
contents = webClient.DownloadString(url);
}
catch (WebException e)
{
throw new WebException($"File {fileName} not found and cannot be downloaded from {url}.\n" +
$"Error message: {e.Message}", e);
}
try
{
File.WriteAllText(filePath, contents);
ch.Info($"File {fileName} successfully downloaded from {url} and saved to {path}.");
}
catch (Exception e)
{
ch.Warning($"{DateTime.Now} - WARNING: File {fileName} successfully downloaded from {url}, " +
$"but error occurs when saving file {fileName} into {path}.\n" +
$"Error message: {e.Message}");
}
}
else
{
try
{
contents = File.ReadAllText(filePath);
}
catch (Exception e)
{
throw new IOException($"Problems met when reading {filePath}.\n" +
$"Error message: {e.Message}", e);
}
}
return contents;
}
/// <summary>
/// Load a continuous segment of bytes from stream and parse them into a number array.
/// NOTE: this function is only for little-endian storage!

Просмотреть файл

@ -255,20 +255,13 @@ namespace Microsoft.ML.Tests
var transformedData = transformer.Transform(dataView).Preview();
Assert.NotNull(transformedData);
#if NET461
Assert.Equal("Class One", transformedData.ColumnView[4].Values[0].ToString());
Assert.Equal("Class Two", transformedData.ColumnView[4].Values[1].ToString());
Assert.Equal("Class Three", transformedData.ColumnView[4].Values[2].ToString());
Assert.Equal("Class Three", transformedData.ColumnView[4].Values[4].ToString());
Assert.Equal("Class One", transformedData.ColumnView[4].Values[6].ToString());
#else
Assert.Equal("Class One", transformedData.ColumnView[4].Values[0].ToString());
Assert.Equal("Class Two", transformedData.ColumnView[4].Values[1].ToString());
Assert.Equal("Class Three", transformedData.ColumnView[4].Values[2].ToString());
Assert.Equal("Class One", transformedData.ColumnView[4].Values[4].ToString());
Assert.Equal("Class One", transformedData.ColumnView[4].Values[6].ToString());
#endif
Assert.Equal("Class One", transformedData.ColumnView[4].Values[3].ToString());
Assert.Equal("Class Two", transformedData.ColumnView[4].Values[7].ToString());
}

Просмотреть файл

@ -81,7 +81,7 @@ namespace Microsoft.ML.Tokenizers.Tests
}
[Fact]
public async void ToknizationTest()
public async void TokenizationTest()
{
string vocabFile = Utils.CreateTemporaryFile("json");
string mergeFile = Utils.CreateTemporaryFile("txt");
@ -100,6 +100,12 @@ namespace Microsoft.ML.Tokenizers.Tests
paths = tokenizer.Model.Save(Path.GetTempPath(), "roberta");
Tokenizer tokenizer1 = new Tokenizer(new EnglishRoberta(paths[0], paths[1], paths[2]), RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer1);
using Stream vocabStream = File.OpenRead(vocabFile);
using Stream mergeStream = File.OpenRead(mergeFile);
using Stream translationStream = File.OpenRead(translationFile);
tokenizer = new Tokenizer(new EnglishRoberta(vocabStream, mergeStream, translationStream), RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer);
}
finally
{