Embed the Tokenizer data files inside the assembly (#6403)
This commit is contained in:
Родитель
1903fa5eda
Коммит
c69acbeb97
|
@ -36,14 +36,71 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="highestOccurrenceMappingPath">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
|
||||
public EnglishRoberta(string vocabularyPath, string mergePath, string highestOccurrenceMappingPath)
|
||||
{
|
||||
if (vocabularyPath is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(vocabularyPath));
|
||||
}
|
||||
|
||||
if (mergePath is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(mergePath));
|
||||
}
|
||||
|
||||
if (highestOccurrenceMappingPath is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(highestOccurrenceMappingPath));
|
||||
}
|
||||
|
||||
using Stream vocabularyStream = File.OpenRead(vocabularyPath);
|
||||
using Stream mergeStream = File.OpenRead(mergePath);
|
||||
using Stream highestOccurrenceMappingStream = File.OpenRead(highestOccurrenceMappingPath);
|
||||
|
||||
// vocabularyPath like encoder.json
|
||||
// merge file like vocab.bpe
|
||||
// highestOccurrenceMappingPath like dict.txt
|
||||
|
||||
_vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingPath);
|
||||
_vocab = GetVocabulary(vocabularyPath);
|
||||
_vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingStream);
|
||||
_vocab = GetVocabulary(vocabularyStream);
|
||||
_vocabReverse = _vocab.ReverseSorted();
|
||||
_mergeRanks = GetMergeRanks(mergePath);
|
||||
_mergeRanks = GetMergeRanks(mergeStream);
|
||||
int maxCharValue = GetByteToUnicode(out _byteToUnicode);
|
||||
_charToString = new string[maxCharValue];
|
||||
for (char c = (char)0; c < (char)maxCharValue; c++)
|
||||
{
|
||||
_charToString[c] = c.ToString();
|
||||
}
|
||||
|
||||
_unicodeToByte = _byteToUnicode.Reverse();
|
||||
_cache = new Cache<string, IReadOnlyList<Token>>();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Construct tokenizer object to use with the English Robert model.
|
||||
/// </summary>
|
||||
/// <param name="vocabularyStream">The stream of a JSON file containing the dictionary of string keys and their ids.</param>
|
||||
/// <param name="mergeStream">The stream of a file containing the tokens's pairs list.</param>
|
||||
/// <param name="highestOccurrenceMappingStream">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
|
||||
public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream)
|
||||
{
|
||||
if (vocabularyStream is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(vocabularyStream));
|
||||
}
|
||||
|
||||
if (mergeStream is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(mergeStream));
|
||||
}
|
||||
|
||||
if (highestOccurrenceMappingStream is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(highestOccurrenceMappingStream));
|
||||
}
|
||||
|
||||
_vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingStream);
|
||||
_vocab = GetVocabulary(vocabularyStream);
|
||||
_vocabReverse = _vocab.ReverseSorted();
|
||||
_mergeRanks = GetMergeRanks(mergeStream);
|
||||
int maxCharValue = GetByteToUnicode(out _byteToUnicode);
|
||||
_charToString = new string[maxCharValue];
|
||||
for (char c = (char)0; c < (char)maxCharValue; c++)
|
||||
|
@ -298,28 +355,24 @@ namespace Microsoft.ML.Tokenizers
|
|||
return tokens;
|
||||
}
|
||||
|
||||
private static HighestOccurrenceMapping GetHighestOccurrenceMapping(string highestOccurrenceMappingPath) =>
|
||||
HighestOccurrenceMapping.Load(highestOccurrenceMappingPath);
|
||||
private static HighestOccurrenceMapping GetHighestOccurrenceMapping(Stream highestOccurrenceMappingStream) =>
|
||||
HighestOccurrenceMapping.Load(highestOccurrenceMappingStream);
|
||||
|
||||
private Dictionary<string, int> GetVocabulary(string vocabularyPath)
|
||||
private Dictionary<string, int> GetVocabulary(Stream vocabularyStream)
|
||||
{
|
||||
Dictionary<string, int>? vocab;
|
||||
try
|
||||
{
|
||||
using (Stream stream = File.OpenRead(vocabularyPath))
|
||||
{
|
||||
vocab = JsonSerializer.Deserialize<Dictionary<string, int>>(stream) as Dictionary<string, int>;
|
||||
|
||||
}
|
||||
vocab = JsonSerializer.Deserialize<Dictionary<string, int>>(vocabularyStream) as Dictionary<string, int>;
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
throw new ArgumentException($"Problems met when parsing JSON object in {vocabularyPath}.{Environment.NewLine}Error message: {e.Message}");
|
||||
throw new ArgumentException($"Problems met when parsing JSON vocabulary object.{Environment.NewLine}Error message: {e.Message}");
|
||||
}
|
||||
|
||||
if (vocab is null)
|
||||
{
|
||||
throw new ArgumentException($"Failed to read the vocabulary file '{vocabularyPath}'");
|
||||
throw new ArgumentException($"Failed to read the vocabulary file.");
|
||||
}
|
||||
|
||||
if (_vocabIdToHighestOccurrence.BosWord is not null)
|
||||
|
@ -345,28 +398,28 @@ namespace Microsoft.ML.Tokenizers
|
|||
return vocab;
|
||||
}
|
||||
|
||||
private Dictionary<(string, string), int> GetMergeRanks(string mergePath)
|
||||
private Dictionary<(string, string), int> GetMergeRanks(Stream mergeStream)
|
||||
{
|
||||
string[] splitContents;
|
||||
List<string> splitContents = new();
|
||||
|
||||
try
|
||||
{
|
||||
splitContents = File.ReadAllLines(mergePath);
|
||||
using StreamReader reader = new StreamReader(mergeStream);
|
||||
while (reader.Peek() >= 0)
|
||||
{
|
||||
splitContents.Add(reader.ReadLine());
|
||||
}
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
throw new IOException($"Cannot read the file '{mergePath}'.{Environment.NewLine}Error message: {e.Message}", e);
|
||||
throw new IOException($"Cannot read the file Merge file.{Environment.NewLine}Error message: {e.Message}", e);
|
||||
}
|
||||
|
||||
var mergeRanks = new Dictionary<(string, string), int>();
|
||||
|
||||
for (int i = 0; i < splitContents.Length; i++)
|
||||
// We ignore the first and last line in the file
|
||||
for (int i = 1; i < splitContents.Count - 1; i++)
|
||||
{
|
||||
if (i == 0 || i == splitContents.Length - 1)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
var split = splitContents[i].Split(' ');
|
||||
if (split.Length != 2 || string.IsNullOrEmpty(split[0]) || string.IsNullOrEmpty(split[1]))
|
||||
{
|
||||
|
@ -664,22 +717,25 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// 284 432911125
|
||||
/// ...
|
||||
/// </summary>
|
||||
public static HighestOccurrenceMapping Load(string fileName)
|
||||
public static HighestOccurrenceMapping Load(Stream stream)
|
||||
{
|
||||
var mapping = new HighestOccurrenceMapping();
|
||||
mapping.AddFromFile(fileName);
|
||||
mapping.AddFromStream(stream);
|
||||
return mapping;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Loads a pre-existing vocabulary from a text file and adds its symbols to this instance.
|
||||
/// Loads a pre-existing vocabulary from a text stream and adds its symbols to this instance.
|
||||
/// </summary>
|
||||
public void AddFromFile(string fileName)
|
||||
public void AddFromStream(Stream stream)
|
||||
{
|
||||
var lines = File.ReadAllLines(fileName, Encoding.UTF8);
|
||||
Debug.Assert(stream is not null);
|
||||
using StreamReader reader = new StreamReader(stream);
|
||||
|
||||
foreach (var line in lines)
|
||||
while (reader.Peek() >= 0)
|
||||
{
|
||||
string line = reader.ReadLine();
|
||||
|
||||
var splitLine = line.Trim().Split(' ');
|
||||
if (splitLine.Length != 2)
|
||||
{
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Reflection;
|
||||
using System.Runtime.CompilerServices;
|
||||
using Microsoft.ML.Runtime;
|
||||
using Microsoft.ML.Tokenizers;
|
||||
|
@ -13,25 +14,22 @@ namespace Microsoft.ML.TorchSharp.Extensions
|
|||
{
|
||||
internal static class TokenizerExtensions
|
||||
{
|
||||
private const string EncoderJsonName = "encoder.json";
|
||||
private const string MergeName = "vocab.bpe";
|
||||
private const string DictName = "dict.txt";
|
||||
|
||||
private static readonly Uri _encoderJsonUrl = new Uri("https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json");
|
||||
private static readonly Uri _mergeUrl = new Uri("https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe");
|
||||
private static readonly Uri _dictUrl = new Uri("https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt");
|
||||
|
||||
private static Tokenizer _instance;
|
||||
|
||||
internal static Tokenizer GetInstance(IChannel ch)
|
||||
{
|
||||
if (_instance is null)
|
||||
{
|
||||
FileUtils.LoadFromFileOrDownloadFromWeb(string.Empty, EncoderJsonName, _encoderJsonUrl, ch);
|
||||
FileUtils.LoadFromFileOrDownloadFromWeb(string.Empty, MergeName, _mergeUrl, ch);
|
||||
FileUtils.LoadFromFileOrDownloadFromWeb(string.Empty, DictName, _dictUrl, ch);
|
||||
// encoder.json, vocab.bpe, and dict.txt are picked up from the following source:
|
||||
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
|
||||
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
|
||||
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt"
|
||||
Assembly assembly = typeof(TokenizerExtensions).Assembly;
|
||||
|
||||
EnglishRoberta model = new EnglishRoberta(EncoderJsonName, MergeName, DictName);
|
||||
EnglishRoberta model = new EnglishRoberta(
|
||||
assembly.GetManifestResourceStream("encoder.json"),
|
||||
assembly.GetManifestResourceStream("vocab.bpe"),
|
||||
assembly.GetManifestResourceStream("dict.txt"));
|
||||
model.AddMaskSymbol();
|
||||
_instance = new Tokenizer(model, new RobertaPreTokenizer());
|
||||
}
|
||||
|
|
|
@ -10,9 +10,9 @@
|
|||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="TorchSharp" Version="$(TorchSharpVersion)" />
|
||||
<PackageReference Include="libtorch-cpu-win-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Windows'))" PrivateAssets="all"/>
|
||||
<PackageReference Include="libtorch-cpu-linux-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Linux'))" PrivateAssets="all"/>
|
||||
<PackageReference Include="libtorch-cpu-osx-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('OSX'))" PrivateAssets="all"/>
|
||||
<PackageReference Include="libtorch-cpu-win-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Windows'))" PrivateAssets="all" />
|
||||
<PackageReference Include="libtorch-cpu-linux-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Linux'))" PrivateAssets="all" />
|
||||
<PackageReference Include="libtorch-cpu-osx-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('OSX'))" PrivateAssets="all" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
|
@ -24,4 +24,16 @@
|
|||
</ProjectReference>
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<EmbeddedResource Include="Resources\dict.txt">
|
||||
<LogicalName>dict.txt</LogicalName>
|
||||
</EmbeddedResource>
|
||||
<EmbeddedResource Include="Resources\encoder.json">
|
||||
<LogicalName>encoder.json</LogicalName>
|
||||
</EmbeddedResource>
|
||||
<EmbeddedResource Include="Resources\vocab.bpe">
|
||||
<LogicalName>vocab.bpe</LogicalName>
|
||||
</EmbeddedResource>
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -28,54 +28,6 @@ namespace Microsoft.ML.TorchSharp.Utils
|
|||
typeof(double),
|
||||
};
|
||||
|
||||
public static string LoadFromFileOrDownloadFromWeb(string path, string fileName, Uri url, IChannel ch)
|
||||
{
|
||||
Contracts.AssertNonWhiteSpace(fileName, "Filename can't be empty");
|
||||
|
||||
var contents = "";
|
||||
var filePath = Path.Combine(path, fileName);
|
||||
if (!File.Exists(filePath))
|
||||
{
|
||||
try
|
||||
{
|
||||
using var webClient = new WebClient();
|
||||
contents = webClient.DownloadString(url);
|
||||
|
||||
}
|
||||
catch (WebException e)
|
||||
{
|
||||
throw new WebException($"File {fileName} not found and cannot be downloaded from {url}.\n" +
|
||||
$"Error message: {e.Message}", e);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
File.WriteAllText(filePath, contents);
|
||||
ch.Info($"File {fileName} successfully downloaded from {url} and saved to {path}.");
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
ch.Warning($"{DateTime.Now} - WARNING: File {fileName} successfully downloaded from {url}, " +
|
||||
$"but error occurs when saving file {fileName} into {path}.\n" +
|
||||
$"Error message: {e.Message}");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
try
|
||||
{
|
||||
contents = File.ReadAllText(filePath);
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
throw new IOException($"Problems met when reading {filePath}.\n" +
|
||||
$"Error message: {e.Message}", e);
|
||||
}
|
||||
}
|
||||
|
||||
return contents;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load a continuous segment of bytes from stream and parse them into a number array.
|
||||
/// NOTE: this function is only for little-endian storage!
|
||||
|
|
|
@ -255,20 +255,13 @@ namespace Microsoft.ML.Tests
|
|||
var transformedData = transformer.Transform(dataView).Preview();
|
||||
|
||||
Assert.NotNull(transformedData);
|
||||
#if NET461
|
||||
Assert.Equal("Class One", transformedData.ColumnView[4].Values[0].ToString());
|
||||
Assert.Equal("Class Two", transformedData.ColumnView[4].Values[1].ToString());
|
||||
Assert.Equal("Class Three", transformedData.ColumnView[4].Values[2].ToString());
|
||||
Assert.Equal("Class Three", transformedData.ColumnView[4].Values[4].ToString());
|
||||
Assert.Equal("Class One", transformedData.ColumnView[4].Values[6].ToString());
|
||||
#else
|
||||
|
||||
Assert.Equal("Class One", transformedData.ColumnView[4].Values[0].ToString());
|
||||
Assert.Equal("Class Two", transformedData.ColumnView[4].Values[1].ToString());
|
||||
Assert.Equal("Class Three", transformedData.ColumnView[4].Values[2].ToString());
|
||||
Assert.Equal("Class One", transformedData.ColumnView[4].Values[4].ToString());
|
||||
Assert.Equal("Class One", transformedData.ColumnView[4].Values[6].ToString());
|
||||
|
||||
#endif
|
||||
Assert.Equal("Class One", transformedData.ColumnView[4].Values[3].ToString());
|
||||
Assert.Equal("Class Two", transformedData.ColumnView[4].Values[7].ToString());
|
||||
}
|
||||
|
|
|
@ -81,7 +81,7 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
}
|
||||
|
||||
[Fact]
|
||||
public async void ToknizationTest()
|
||||
public async void TokenizationTest()
|
||||
{
|
||||
string vocabFile = Utils.CreateTemporaryFile("json");
|
||||
string mergeFile = Utils.CreateTemporaryFile("txt");
|
||||
|
@ -100,6 +100,12 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
paths = tokenizer.Model.Save(Path.GetTempPath(), "roberta");
|
||||
Tokenizer tokenizer1 = new Tokenizer(new EnglishRoberta(paths[0], paths[1], paths[2]), RobertaPreTokenizer.Instance);
|
||||
TestTokenizer(tokenizer1);
|
||||
|
||||
using Stream vocabStream = File.OpenRead(vocabFile);
|
||||
using Stream mergeStream = File.OpenRead(mergeFile);
|
||||
using Stream translationStream = File.OpenRead(translationFile);
|
||||
tokenizer = new Tokenizer(new EnglishRoberta(vocabStream, mergeStream, translationStream), RobertaPreTokenizer.Instance);
|
||||
TestTokenizer(tokenizer);
|
||||
}
|
||||
finally
|
||||
{
|
||||
|
|
Загрузка…
Ссылка в новой задаче