Add more required Tokenizer APIs (#7114)

* Add more required Tokenizer APIs

* Address the feedback
This commit is contained in:
Tarek Mahmoud Sayed 2024-04-03 11:09:24 -07:00 коммит произвёл GitHub
Родитель c96aac79e4
Коммит c99f7e3abb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
4 изменённых файлов: 149 добавлений и 6 удалений

Просмотреть файл

@ -759,10 +759,10 @@ namespace Microsoft.ML.Tokenizers
return encoder;
}
internal static (Dictionary<string, int> SpecialTokens, Regex Regex, string VocabFile) GetTiktokenConfigurations(string modelName)
{
ModelEncoding modelEncoding = GetModelEncoding(modelName);
internal static (Dictionary<string, int> SpecialTokens, Regex Regex, string VocabFile) GetTiktokenConfigurations(string modelName) => GetTiktokenConfigurations(GetModelEncoding(modelName), modelName);
internal static (Dictionary<string, int> SpecialTokens, Regex Regex, string VocabFile) GetTiktokenConfigurations(ModelEncoding modelEncoding, string? modelName = null)
{
switch (modelEncoding)
{
case ModelEncoding.Cl100kBase:
@ -783,7 +783,7 @@ namespace Microsoft.ML.Tokenizers
return (new Dictionary<string, int> { { EndOfText, 50256 }, }, P50kBaseRegex(), GPT2File);
default:
throw new NotSupportedException($"The model '{modelName}' is not supported.");
throw new NotSupportedException($"The model '{modelName ?? modelEncoding.ToString()}' is not supported.");
}
}
@ -797,6 +797,10 @@ namespace Microsoft.ML.Tokenizers
private const string R50RanksFile = "r50k_base.tiktoken.deflate"; // "https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken"
private const string GPT2File = "gpt2.tiktoken.deflate"; // "https://pythia.blob.core.windows.net/public/encoding/gpt2.tiktoken"
internal const string Cl100kBaseEncodingName = "cl100k_base";
internal const string P50kBaseEncodingName = "p50k_base";
internal const string P50kEditEncodingName = "p50k_edit";
internal const string R50kBaseEncodingName = "r50k_base";
#if NET7_0_OR_GREATER
[GeneratedRegex(Cl100kBaseRegexPattern)]
@ -824,7 +828,16 @@ namespace Microsoft.ML.Tokenizers
throw new ArgumentNullException(nameof(modelName));
}
(Dictionary<string, int> SpecialTokens, Regex Regex, string VocabFile) tiktokenConfiguration = Tiktoken.GetTiktokenConfigurations(modelName);
return CreateTokenizerForModel(GetModelEncoding(modelName), modelName, extraSpecialTokens, normalizer);
}
internal static Tokenizer CreateTokenizerForModel(
ModelEncoding modelEncoding,
string? modelName = null,
IReadOnlyDictionary<string, int>? extraSpecialTokens = null,
Normalizer? normalizer = null)
{
(Dictionary<string, int> SpecialTokens, Regex Regex, string VocabFile) tiktokenConfiguration = Tiktoken.GetTiktokenConfigurations(modelEncoding, modelName);
if (extraSpecialTokens is not null)
{

Просмотреть файл

@ -89,7 +89,7 @@ namespace Microsoft.ML.Tokenizers
/// Encodes input text to tokens Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <returns>The tokenization result includes the tokens list, tokens Ids, tokens offset mapping.</returns>
/// <returns>The list of encoded Ids.</returns>
public IReadOnlyList<int> EncodeToIds(string text)
{
if (text is null)
@ -108,6 +108,48 @@ namespace Microsoft.ML.Tokenizers
return idsList;
}
/// <summary>
/// Encodes input text to tokens Ids up to maximum number of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <param name="processedText">If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text.</param>
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
/// <returns>The list of encoded Ids.</returns>
public IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string processedText, out int textLength)
{
processedText = text;
textLength = 0;
if (text is null)
{
throw new ArgumentNullException(nameof(text));
}
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than 0.");
}
if (Normalizer is not null)
{
processedText = Normalizer.Normalize(text);
}
List<int> idsList = new();
foreach (Split split in PreTokenizer.PreTokenize(processedText))
{
Model.EncodeToIds(split.TokenSpan, idsList, out int length, maxTokenCount - idsList.Count);
if (length < split.Offset.Length || idsList.Count >= maxTokenCount)
{
break;
}
}
return idsList;
}
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
@ -324,6 +366,45 @@ namespace Microsoft.ML.Tokenizers
public static Tokenizer CreateTiktokenForModel(string modelName, IReadOnlyDictionary<string, int>? extraSpecialTokens = null, Normalizer? normalizer = null)
=> Tiktoken.CreateTokenizerForModel(modelName, extraSpecialTokens, normalizer);
/// <summary>
/// Create tokenizer based on encoding name
/// </summary>
/// <param name="encodingName">Encoding name</param>
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the encoding</param>
/// <param name="normalizer">To normalize the text before tokenization</param>
/// <returns>The tokenizer</returns>
public static Tokenizer CreateTiktokenForEncoding(string encodingName, IReadOnlyDictionary<string, int>? extraSpecialTokens = null, Normalizer? normalizer = null)
{
if (string.IsNullOrEmpty(encodingName))
{
throw new ArgumentNullException(nameof(encodingName));
}
Tiktoken.ModelEncoding modelEncoding;
if (encodingName.Equals(Tiktoken.Cl100kBaseEncodingName, StringComparison.OrdinalIgnoreCase))
{
modelEncoding = Tiktoken.ModelEncoding.Cl100kBase;
}
else if (encodingName.Equals(Tiktoken.P50kBaseEncodingName, StringComparison.OrdinalIgnoreCase))
{
modelEncoding = Tiktoken.ModelEncoding.P50kBase;
}
else if (encodingName.Equals(Tiktoken.P50kEditEncodingName, StringComparison.OrdinalIgnoreCase))
{
modelEncoding = Tiktoken.ModelEncoding.P50kEdit;
}
else if (encodingName.Equals(Tiktoken.R50kBaseEncodingName, StringComparison.OrdinalIgnoreCase))
{
modelEncoding = Tiktoken.ModelEncoding.R50kBase;
}
else
{
throw new ArgumentException($"The encoding name '{encodingName}' is not supported. The only supported encoding names are: {Tiktoken.Cl100kBaseEncodingName}, {Tiktoken.P50kBaseEncodingName}, {Tiktoken.P50kEditEncodingName}, and {Tiktoken.R50kBaseEncodingName}.", nameof(encodingName));
}
return Tiktoken.CreateTokenizerForModel(modelEncoding, modelName: null, extraSpecialTokens, normalizer);
}
/// <summary>
/// Create a SentencePieceBpe tokenizer from the given model stream. The model stream should contain the SentencePiece Bpe model according to
/// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto specification.

Просмотреть файл

@ -355,6 +355,49 @@ namespace Microsoft.ML.Tokenizers.Tests
Assert.NotNull(tokenizer.PreTokenizer);
}
[Theory]
[InlineData("r50k_base")]
[InlineData("p50k_base")]
[InlineData("p50k_edit")]
[InlineData("cl100k_base")]
public void TestAllSupportedEncodingNames(string encodingName)
{
Tokenizer tokenizer = Tokenizer.CreateTiktokenForEncoding(encodingName);
Assert.NotNull(tokenizer.Model);
Assert.NotNull(tokenizer.PreTokenizer);
string modelName = encodingName.ToLowerInvariant() switch
{
"r50k_base" => "text-davinci-001",
"p50k_base" => "text-davinci-003",
"p50k_edit" => "text-davinci-edit-001",
"cl100k_base" => "gpt-4",
_ => throw new ArgumentException("Invalid encoding name"),
};
Tokenizer tokenizer1 = Tokenizer.CreateTiktokenForModel(modelName);
Tiktoken? model1 = tokenizer.Model as Tiktoken;
Tiktoken? model2 = tokenizer1.Model as Tiktoken;
Assert.NotNull(model1);
Assert.NotNull(model2);
Assert.Equal(model2.Encoder, model1.Encoder);
Assert.Equal(model2.Decoder, model1.Decoder);
Assert.Equal(model2.SpecialTokens, model1.SpecialTokens);
Assert.Equal(model2.Vocab, model1.Vocab);
}
[Fact]
public void TestEncodingNamesNegativeCases()
{
Assert.Throws<ArgumentNullException>(() => Tokenizer.CreateTiktokenForEncoding(null!));
Assert.Throws<ArgumentException>(() => Tokenizer.CreateTiktokenForEncoding("r50k_base_"));
Assert.Throws<ArgumentException>(() => Tokenizer.CreateTiktokenForEncoding("p50k_base_"));
Assert.Throws<ArgumentException>(() => Tokenizer.CreateTiktokenForEncoding("p50k_edit_"));
Assert.Throws<ArgumentException>(() => Tokenizer.CreateTiktokenForEncoding("cl100k_base_"));
}
[InlineData("gpt-4")]
[InlineData("text-davinci-003")]
[InlineData("text-curie-001")]

Просмотреть файл

@ -28,6 +28,12 @@ namespace Microsoft.ML.Tokenizers.Tests
{
int index1 = tokenizer.IndexOfTokenCount(input, maxTokenCount: i, out string processedText1, out int tokenCount1);
int index2 = tokenizer.LastIndexOfTokenCount(input, maxTokenCount: i, out string processedText2, out int tokenCount2);
IReadOnlyList<int> partialIdsList = tokenizer.EncodeToIds(input, maxTokenCount: i, out string processedText, out int textLength);
Assert.True(textLength <= processedText.Length);
Assert.True(tokenizer.Normalizer is not null || processedText == input);
Assert.Equal(fullIdsList.Take(partialIdsList.Count), partialIdsList);
IReadOnlyList<int>? prefixIds = null;
IReadOnlyList<int>? suffixIds = null;