Add more required Tokenizer APIs (#7114)
* Add more required Tokenizer APIs * Address the feedback
This commit is contained in:
Родитель
c96aac79e4
Коммит
c99f7e3abb
|
@ -759,10 +759,10 @@ namespace Microsoft.ML.Tokenizers
|
|||
return encoder;
|
||||
}
|
||||
|
||||
internal static (Dictionary<string, int> SpecialTokens, Regex Regex, string VocabFile) GetTiktokenConfigurations(string modelName)
|
||||
{
|
||||
ModelEncoding modelEncoding = GetModelEncoding(modelName);
|
||||
internal static (Dictionary<string, int> SpecialTokens, Regex Regex, string VocabFile) GetTiktokenConfigurations(string modelName) => GetTiktokenConfigurations(GetModelEncoding(modelName), modelName);
|
||||
|
||||
internal static (Dictionary<string, int> SpecialTokens, Regex Regex, string VocabFile) GetTiktokenConfigurations(ModelEncoding modelEncoding, string? modelName = null)
|
||||
{
|
||||
switch (modelEncoding)
|
||||
{
|
||||
case ModelEncoding.Cl100kBase:
|
||||
|
@ -783,7 +783,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
return (new Dictionary<string, int> { { EndOfText, 50256 }, }, P50kBaseRegex(), GPT2File);
|
||||
|
||||
default:
|
||||
throw new NotSupportedException($"The model '{modelName}' is not supported.");
|
||||
throw new NotSupportedException($"The model '{modelName ?? modelEncoding.ToString()}' is not supported.");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -797,6 +797,10 @@ namespace Microsoft.ML.Tokenizers
|
|||
private const string R50RanksFile = "r50k_base.tiktoken.deflate"; // "https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken"
|
||||
private const string GPT2File = "gpt2.tiktoken.deflate"; // "https://pythia.blob.core.windows.net/public/encoding/gpt2.tiktoken"
|
||||
|
||||
internal const string Cl100kBaseEncodingName = "cl100k_base";
|
||||
internal const string P50kBaseEncodingName = "p50k_base";
|
||||
internal const string P50kEditEncodingName = "p50k_edit";
|
||||
internal const string R50kBaseEncodingName = "r50k_base";
|
||||
|
||||
#if NET7_0_OR_GREATER
|
||||
[GeneratedRegex(Cl100kBaseRegexPattern)]
|
||||
|
@ -824,7 +828,16 @@ namespace Microsoft.ML.Tokenizers
|
|||
throw new ArgumentNullException(nameof(modelName));
|
||||
}
|
||||
|
||||
(Dictionary<string, int> SpecialTokens, Regex Regex, string VocabFile) tiktokenConfiguration = Tiktoken.GetTiktokenConfigurations(modelName);
|
||||
return CreateTokenizerForModel(GetModelEncoding(modelName), modelName, extraSpecialTokens, normalizer);
|
||||
}
|
||||
|
||||
internal static Tokenizer CreateTokenizerForModel(
|
||||
ModelEncoding modelEncoding,
|
||||
string? modelName = null,
|
||||
IReadOnlyDictionary<string, int>? extraSpecialTokens = null,
|
||||
Normalizer? normalizer = null)
|
||||
{
|
||||
(Dictionary<string, int> SpecialTokens, Regex Regex, string VocabFile) tiktokenConfiguration = Tiktoken.GetTiktokenConfigurations(modelEncoding, modelName);
|
||||
|
||||
if (extraSpecialTokens is not null)
|
||||
{
|
||||
|
|
|
@ -89,7 +89,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// Encodes input text to tokens Ids.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <returns>The tokenization result includes the tokens list, tokens Ids, tokens offset mapping.</returns>
|
||||
/// <returns>The list of encoded Ids.</returns>
|
||||
public IReadOnlyList<int> EncodeToIds(string text)
|
||||
{
|
||||
if (text is null)
|
||||
|
@ -108,6 +108,48 @@ namespace Microsoft.ML.Tokenizers
|
|||
return idsList;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Encodes input text to tokens Ids up to maximum number of tokens.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
|
||||
/// <param name="processedText">If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text.</param>
|
||||
/// <param name="textLength">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||
/// <returns>The list of encoded Ids.</returns>
|
||||
public IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string processedText, out int textLength)
|
||||
{
|
||||
processedText = text;
|
||||
textLength = 0;
|
||||
|
||||
if (text is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(text));
|
||||
}
|
||||
|
||||
if (maxTokenCount <= 0)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than 0.");
|
||||
}
|
||||
|
||||
if (Normalizer is not null)
|
||||
{
|
||||
processedText = Normalizer.Normalize(text);
|
||||
}
|
||||
|
||||
List<int> idsList = new();
|
||||
|
||||
foreach (Split split in PreTokenizer.PreTokenize(processedText))
|
||||
{
|
||||
Model.EncodeToIds(split.TokenSpan, idsList, out int length, maxTokenCount - idsList.Count);
|
||||
if (length < split.Offset.Length || idsList.Count >= maxTokenCount)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return idsList;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
/// </summary>
|
||||
|
@ -324,6 +366,45 @@ namespace Microsoft.ML.Tokenizers
|
|||
public static Tokenizer CreateTiktokenForModel(string modelName, IReadOnlyDictionary<string, int>? extraSpecialTokens = null, Normalizer? normalizer = null)
|
||||
=> Tiktoken.CreateTokenizerForModel(modelName, extraSpecialTokens, normalizer);
|
||||
|
||||
/// <summary>
|
||||
/// Create tokenizer based on encoding name
|
||||
/// </summary>
|
||||
/// <param name="encodingName">Encoding name</param>
|
||||
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the encoding</param>
|
||||
/// <param name="normalizer">To normalize the text before tokenization</param>
|
||||
/// <returns>The tokenizer</returns>
|
||||
public static Tokenizer CreateTiktokenForEncoding(string encodingName, IReadOnlyDictionary<string, int>? extraSpecialTokens = null, Normalizer? normalizer = null)
|
||||
{
|
||||
if (string.IsNullOrEmpty(encodingName))
|
||||
{
|
||||
throw new ArgumentNullException(nameof(encodingName));
|
||||
}
|
||||
|
||||
Tiktoken.ModelEncoding modelEncoding;
|
||||
if (encodingName.Equals(Tiktoken.Cl100kBaseEncodingName, StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
modelEncoding = Tiktoken.ModelEncoding.Cl100kBase;
|
||||
}
|
||||
else if (encodingName.Equals(Tiktoken.P50kBaseEncodingName, StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
modelEncoding = Tiktoken.ModelEncoding.P50kBase;
|
||||
}
|
||||
else if (encodingName.Equals(Tiktoken.P50kEditEncodingName, StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
modelEncoding = Tiktoken.ModelEncoding.P50kEdit;
|
||||
}
|
||||
else if (encodingName.Equals(Tiktoken.R50kBaseEncodingName, StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
modelEncoding = Tiktoken.ModelEncoding.R50kBase;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new ArgumentException($"The encoding name '{encodingName}' is not supported. The only supported encoding names are: {Tiktoken.Cl100kBaseEncodingName}, {Tiktoken.P50kBaseEncodingName}, {Tiktoken.P50kEditEncodingName}, and {Tiktoken.R50kBaseEncodingName}.", nameof(encodingName));
|
||||
}
|
||||
|
||||
return Tiktoken.CreateTokenizerForModel(modelEncoding, modelName: null, extraSpecialTokens, normalizer);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a SentencePieceBpe tokenizer from the given model stream. The model stream should contain the SentencePiece Bpe model according to
|
||||
/// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto specification.
|
||||
|
|
|
@ -355,6 +355,49 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
Assert.NotNull(tokenizer.PreTokenizer);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData("r50k_base")]
|
||||
[InlineData("p50k_base")]
|
||||
[InlineData("p50k_edit")]
|
||||
[InlineData("cl100k_base")]
|
||||
public void TestAllSupportedEncodingNames(string encodingName)
|
||||
{
|
||||
Tokenizer tokenizer = Tokenizer.CreateTiktokenForEncoding(encodingName);
|
||||
Assert.NotNull(tokenizer.Model);
|
||||
Assert.NotNull(tokenizer.PreTokenizer);
|
||||
|
||||
string modelName = encodingName.ToLowerInvariant() switch
|
||||
{
|
||||
"r50k_base" => "text-davinci-001",
|
||||
"p50k_base" => "text-davinci-003",
|
||||
"p50k_edit" => "text-davinci-edit-001",
|
||||
"cl100k_base" => "gpt-4",
|
||||
_ => throw new ArgumentException("Invalid encoding name"),
|
||||
};
|
||||
|
||||
Tokenizer tokenizer1 = Tokenizer.CreateTiktokenForModel(modelName);
|
||||
|
||||
Tiktoken? model1 = tokenizer.Model as Tiktoken;
|
||||
Tiktoken? model2 = tokenizer1.Model as Tiktoken;
|
||||
Assert.NotNull(model1);
|
||||
Assert.NotNull(model2);
|
||||
|
||||
Assert.Equal(model2.Encoder, model1.Encoder);
|
||||
Assert.Equal(model2.Decoder, model1.Decoder);
|
||||
Assert.Equal(model2.SpecialTokens, model1.SpecialTokens);
|
||||
Assert.Equal(model2.Vocab, model1.Vocab);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestEncodingNamesNegativeCases()
|
||||
{
|
||||
Assert.Throws<ArgumentNullException>(() => Tokenizer.CreateTiktokenForEncoding(null!));
|
||||
Assert.Throws<ArgumentException>(() => Tokenizer.CreateTiktokenForEncoding("r50k_base_"));
|
||||
Assert.Throws<ArgumentException>(() => Tokenizer.CreateTiktokenForEncoding("p50k_base_"));
|
||||
Assert.Throws<ArgumentException>(() => Tokenizer.CreateTiktokenForEncoding("p50k_edit_"));
|
||||
Assert.Throws<ArgumentException>(() => Tokenizer.CreateTiktokenForEncoding("cl100k_base_"));
|
||||
}
|
||||
|
||||
[InlineData("gpt-4")]
|
||||
[InlineData("text-davinci-003")]
|
||||
[InlineData("text-curie-001")]
|
||||
|
|
|
@ -28,6 +28,12 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
{
|
||||
int index1 = tokenizer.IndexOfTokenCount(input, maxTokenCount: i, out string processedText1, out int tokenCount1);
|
||||
int index2 = tokenizer.LastIndexOfTokenCount(input, maxTokenCount: i, out string processedText2, out int tokenCount2);
|
||||
IReadOnlyList<int> partialIdsList = tokenizer.EncodeToIds(input, maxTokenCount: i, out string processedText, out int textLength);
|
||||
|
||||
Assert.True(textLength <= processedText.Length);
|
||||
Assert.True(tokenizer.Normalizer is not null || processedText == input);
|
||||
|
||||
Assert.Equal(fullIdsList.Take(partialIdsList.Count), partialIdsList);
|
||||
|
||||
IReadOnlyList<int>? prefixIds = null;
|
||||
IReadOnlyList<int>? suffixIds = null;
|
||||
|
|
Загрузка…
Ссылка в новой задаче