Introducing WordPiece and Bert tokenizers (#7275)
* Introducing WordPiece and Bert tokenizers * Fix corner case in WordPiece
This commit is contained in:
Родитель
32bac5e395
Коммит
81122c4c48
|
@ -152,6 +152,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
License notice for WordPiece and Bert tokenizers
|
||||
------------------------------------------------
|
||||
|
||||
https://github.com/huggingface/transformers/blob/8e3e145b427196e014f37aa42ba890b9bc94275e/src/transformers/models/bert/tokenization_bert.py#L2
|
||||
|
||||
Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
License notice for BitUtility
|
||||
------------------------------------------
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// Represent the token produced from the tokenization process containing the token substring,
|
||||
/// the id associated to the token substring, and the offset mapping to the original string.
|
||||
/// </summary>
|
||||
public readonly struct EncodedToken
|
||||
public readonly struct EncodedToken : IEquatable<EncodedToken>
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the Id value associated to the token.
|
||||
|
@ -39,5 +39,8 @@ namespace Microsoft.ML.Tokenizers
|
|||
Offset = offset;
|
||||
Value = value;
|
||||
}
|
||||
|
||||
/// inherited
|
||||
public bool Equals(EncodedToken other) => Id == other.Id && Value == other.Value && Offset.Equals(other.Offset);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -87,7 +87,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="vocabFile">The JSON file path containing the dictionary of string keys and their ids.</param>
|
||||
/// <param name="mergesFile">The file path containing the tokens's pairs list.</param>
|
||||
public static BpeTokenizer Create(string vocabFile, string? mergesFile)
|
||||
=> Create(vocabFile, mergesFile, preTokenizer: PreTokenizer.CreateWhiteSpace(), normalizer: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false);
|
||||
=> Create(vocabFile, mergesFile, preTokenizer: PreTokenizer.CreateWordOrNonWordPreTokenizer(), normalizer: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false);
|
||||
|
||||
/// <summary>
|
||||
/// Create a new Bpe tokenizer object to use for text encoding.
|
||||
|
@ -131,7 +131,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="vocabStream">The JSON stream containing the dictionary of string keys and their ids.</param>
|
||||
/// <param name="mergesStream">The stream containing the tokens's pairs list.</param>
|
||||
public static BpeTokenizer Create(Stream vocabStream, Stream? mergesStream)
|
||||
=> Create(vocabStream, mergesStream, preTokenizer: PreTokenizer.CreateWhiteSpace(), normalizer: null, addedTokens: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false);
|
||||
=> Create(vocabStream, mergesStream, preTokenizer: PreTokenizer.CreateWordOrNonWordPreTokenizer(), normalizer: null, addedTokens: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false);
|
||||
|
||||
/// <summary>
|
||||
/// Create a new Bpe tokenizer object to use for text encoding.
|
||||
|
@ -225,7 +225,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
FuseUnknownTokens = fuseUnknownTokens;
|
||||
ContinuingSubwordPrefix = continuingSubwordPrefix;
|
||||
EndOfWordSuffix = endOfWordSuffix;
|
||||
_preTokenizer = preTokenizer ?? PreTokenizer.CreateWhiteSpace(); // Default to WhiteSpace pre-tokenizer
|
||||
_preTokenizer = preTokenizer ?? PreTokenizer.CreateWordOrNonWordPreTokenizer(); // Default to WordOrNonWord pre-tokenizer
|
||||
_normalizer = normalizer;
|
||||
|
||||
_vocab = vocab ?? new Dictionary<StringSpanOrdinalKey, int>();
|
||||
|
|
|
@ -0,0 +1,729 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.ML.Tokenizers
|
||||
{
|
||||
/// <summary>
|
||||
/// Tokenizer for Bert model.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// The BertTokenizer is a based on the WordPieceTokenizer and is used to tokenize text for Bert models.
|
||||
/// The implementation of the BertTokenizer is based on the original Bert implementation in the Hugging Face Transformers library.
|
||||
/// https://huggingface.co/transformers/v3.0.2/model_doc/bert.html?highlight=berttokenizerfast#berttokenizer
|
||||
/// </remarks>
|
||||
public sealed partial class BertTokenizer : WordPieceTokenizer
|
||||
{
|
||||
internal BertTokenizer(
|
||||
Dictionary<StringSpanOrdinalKey, int> vocab,
|
||||
Dictionary<int, string> vocabReverse,
|
||||
PreTokenizer? preTokenizer,
|
||||
Normalizer? normalizer,
|
||||
IReadOnlyDictionary<string, int>? specialTokens,
|
||||
bool doLowerCase,
|
||||
bool doBasicTokenization,
|
||||
bool splitOnSpecialTokens,
|
||||
string unknownToken,
|
||||
string sepToken,
|
||||
string padToken,
|
||||
string clsToken,
|
||||
string maskToken,
|
||||
bool tokenizeChineseChars,
|
||||
bool stripAccents) : base(vocab, vocabReverse, preTokenizer, normalizer, specialTokens, unknownToken)
|
||||
{
|
||||
DoLowerCase = doLowerCase;
|
||||
DoBasicTokenization = doBasicTokenization;
|
||||
SplitOnSpecialTokens = splitOnSpecialTokens;
|
||||
|
||||
SepToken = sepToken;
|
||||
SepTokenId = vocab[new StringSpanOrdinalKey(sepToken)];
|
||||
|
||||
PadToken = padToken;
|
||||
PadTokenId = vocab[new StringSpanOrdinalKey(padToken)];
|
||||
|
||||
ClsToken = clsToken;
|
||||
ClsTokenId = vocab[new StringSpanOrdinalKey(clsToken)];
|
||||
|
||||
MaskToken = maskToken;
|
||||
MaskTokenId = vocab[new StringSpanOrdinalKey(maskToken)];
|
||||
|
||||
TokenizeChineseChars = tokenizeChineseChars;
|
||||
StripAccents = stripAccents;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether the tokenizer should lowercase the input text.
|
||||
/// </summary>
|
||||
public bool DoLowerCase { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether the tokenizer should do basic tokenization. Like clean text, normalize it, lowercasing, etc.
|
||||
/// </summary>
|
||||
public bool DoBasicTokenization { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether the tokenizer should split on the special tokens or treat special tokens as normal text.
|
||||
/// </summary>
|
||||
public bool SplitOnSpecialTokens { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for sequence classification or for a text and a question for question answering.
|
||||
/// It is also used as the last token of a sequence built with special tokens.
|
||||
/// </summary>
|
||||
public string SepToken { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the separator token Id
|
||||
/// </summary>
|
||||
public int SepTokenId { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the token used for padding, for example when batching sequences of different lengths
|
||||
/// </summary>
|
||||
public string PadToken { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets padding token Id
|
||||
/// </summary>
|
||||
public int PadTokenId { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the classifier token which is used when doing sequence classification (classification of the whole sequence instead of per-token classification).
|
||||
/// It is the first token of the sequence when built with special tokens.
|
||||
/// </summary>
|
||||
public string ClsToken { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the classifier token Id
|
||||
/// </summary>
|
||||
public int ClsTokenId { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the mask token used for masking values. This is the token used when training this model with masked language modeling.
|
||||
/// This is the token which the model will try to predict.
|
||||
/// </summary>
|
||||
public string MaskToken { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the mask token Id
|
||||
/// </summary>
|
||||
public int MaskTokenId { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether the tokenizer should split the Chinese characters into tokens.
|
||||
/// </summary>
|
||||
public bool TokenizeChineseChars { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether the tokenizer should strip accents characters.
|
||||
/// </summary>
|
||||
public bool StripAccents { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Encodes input text to token Ids.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||
/// <returns>The list of encoded Ids.</returns>
|
||||
public new IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true) =>
|
||||
EncodeToIds(text, ReadOnlySpan<char>.Empty, addSpecialTokens: true, considerPreTokenization, considerNormalization);
|
||||
|
||||
/// <summary>
|
||||
/// Encodes input text to token Ids.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||
/// <returns>The list of encoded Ids.</returns>
|
||||
public new IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true) =>
|
||||
EncodeToIds(null, text, addSpecialTokens: true, considerPreTokenization, considerNormalization);
|
||||
|
||||
/// <summary>
|
||||
/// Encodes input text to token Ids.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="addSpecialTokens">Indicate whether to add special tokens to the encoded Ids.</param>
|
||||
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||
/// <returns>The list of encoded Ids.</returns>
|
||||
public IReadOnlyList<int> EncodeToIds(string text, bool addSpecialTokens, bool considerPreTokenization = true, bool considerNormalization = true) =>
|
||||
EncodeToIds(text, ReadOnlySpan<char>.Empty, addSpecialTokens, considerPreTokenization, considerNormalization);
|
||||
|
||||
/// <summary>
|
||||
/// Encodes input text to token Ids.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="addSpecialTokens">Indicate whether to add special tokens to the encoded Ids.</param>
|
||||
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||
/// <returns>The list of encoded Ids.</returns>
|
||||
public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool addSpecialTokens, bool considerPreTokenization = true, bool considerNormalization = true) =>
|
||||
EncodeToIds(null, text, addSpecialTokens, considerPreTokenization, considerNormalization);
|
||||
|
||||
/// <summary>
|
||||
/// Encodes input text to token Ids.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="maxTokenCount">The maximum number of tokens to return.</param>
|
||||
/// <param name="normalizedText">The normalized text.</param>
|
||||
/// <param name="charsConsumed">The number of characters consumed from the input text.</param>
|
||||
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||
/// <returns>The list of encoded Ids.</returns>
|
||||
public new IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedText, out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true) =>
|
||||
EncodeToIds(text, ReadOnlySpan<char>.Empty, maxTokenCount, addSpecialTokens: true, out normalizedText, out charsConsumed, considerPreTokenization, considerNormalization);
|
||||
|
||||
/// <summary>
|
||||
/// Encodes input text to token Ids.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="maxTokenCount">The maximum number of tokens to return.</param>
|
||||
/// <param name="normalizedText">The normalized text.</param>
|
||||
/// <param name="charsConsumed">The number of characters consumed from the input text.</param>
|
||||
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||
/// <returns>The list of encoded Ids.</returns>
|
||||
public new IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedText, out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true) =>
|
||||
EncodeToIds(null, text, maxTokenCount, addSpecialTokens: true, out normalizedText, out charsConsumed, considerPreTokenization, considerNormalization);
|
||||
|
||||
/// <summary>
|
||||
/// Encodes input text to token Ids.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="maxTokenCount">The maximum number of tokens to return.</param>
|
||||
/// <param name="addSpecialTokens">Indicate whether to add special tokens to the encoded Ids.</param>
|
||||
/// <param name="normalizedText">The normalized text.</param>
|
||||
/// <param name="charsConsumed">The number of characters consumed from the input text.</param>
|
||||
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||
/// <returns>The list of encoded Ids.</returns>
|
||||
public IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, bool addSpecialTokens, out string? normalizedText, out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true) =>
|
||||
EncodeToIds(text, ReadOnlySpan<char>.Empty, maxTokenCount, addSpecialTokens, out normalizedText, out charsConsumed, considerPreTokenization, considerNormalization);
|
||||
|
||||
/// <summary>
|
||||
/// Encodes input text to token Ids.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="maxTokenCount">The maximum number of tokens to return.</param>
|
||||
/// <param name="addSpecialTokens">Indicate whether to add special tokens to the encoded Ids.</param>
|
||||
/// <param name="normalizedText">The normalized text.</param>
|
||||
/// <param name="charsConsumed">The number of characters consumed from the input text.</param>
|
||||
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
|
||||
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
|
||||
/// <returns>The list of encoded Ids.</returns>
|
||||
public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, bool addSpecialTokens, out string? normalizedText, out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true) =>
|
||||
EncodeToIds(null, text, maxTokenCount, addSpecialTokens, out normalizedText, out charsConsumed, considerPreTokenization, considerNormalization);
|
||||
|
||||
private IReadOnlyList<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, int maxTokenCount, bool addSpecialTokens, out string? normalizedText, out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||
{
|
||||
if (addSpecialTokens)
|
||||
{
|
||||
if (maxTokenCount < 2)
|
||||
{
|
||||
charsConsumed = 0;
|
||||
normalizedText = null;
|
||||
return Array.Empty<int>();
|
||||
}
|
||||
|
||||
IReadOnlyList<int> ids = text is null ?
|
||||
base.EncodeToIds(textSpan, maxTokenCount - 2, out normalizedText, out charsConsumed, considerPreTokenization, considerNormalization) :
|
||||
base.EncodeToIds(text, maxTokenCount - 2, out normalizedText, out charsConsumed, considerPreTokenization, considerNormalization);
|
||||
|
||||
if (ids is not List<int> list)
|
||||
{
|
||||
list = new List<int>(ids);
|
||||
}
|
||||
|
||||
list.Insert(0, ClsTokenId);
|
||||
list.Add(SepTokenId);
|
||||
|
||||
return list;
|
||||
}
|
||||
|
||||
return text is null ?
|
||||
base.EncodeToIds(textSpan, maxTokenCount, out normalizedText, out charsConsumed, considerPreTokenization, considerNormalization) :
|
||||
base.EncodeToIds(text, maxTokenCount, out normalizedText, out charsConsumed, considerPreTokenization, considerNormalization);
|
||||
}
|
||||
|
||||
private IReadOnlyList<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, bool addSpecialTokens, bool considerPreTokenization = true, bool considerNormalization = true)
|
||||
{
|
||||
IReadOnlyList<int> ids = text is null ? base.EncodeToIds(textSpan, considerPreTokenization, considerNormalization) : base.EncodeToIds(text, considerPreTokenization, considerNormalization);
|
||||
|
||||
if (addSpecialTokens)
|
||||
{
|
||||
if (ids is not List<int> list)
|
||||
{
|
||||
list = new List<int>(ids);
|
||||
}
|
||||
|
||||
list.Insert(0, ClsTokenId);
|
||||
list.Add(SepTokenId);
|
||||
|
||||
return list;
|
||||
}
|
||||
|
||||
return ids;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. A BERT sequence has the following format:
|
||||
/// - single sequence: `[CLS] tokenIds0 [SEP]`
|
||||
/// - pair of sequences: `[CLS] tokenIds0 [SEP] tokenIds1 [SEP]`
|
||||
/// </summary>
|
||||
/// <param name="tokenIds0">List of IDs to which the special tokens will be added.</param>
|
||||
/// <param name="tokenIds1">Optional second list of IDs for sequence pairs.</param>
|
||||
/// <returns>The list of IDs with special tokens added.</returns>
|
||||
/// <exception cref="ArgumentNullException">When <paramref name="tokenIds0"/> is null.</exception>
|
||||
public IReadOnlyList<int> BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds0, IEnumerable<int>? tokenIds1 = null)
|
||||
{
|
||||
if (tokenIds0 is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(tokenIds0));
|
||||
}
|
||||
|
||||
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
|
||||
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
|
||||
List<int> ids = new List<int>(capacity: capacity) { ClsTokenId };
|
||||
ids.AddRange(tokenIds0);
|
||||
ids.Add(SepTokenId);
|
||||
|
||||
if (tokenIds1 is not null)
|
||||
{
|
||||
ids.AddRange(tokenIds1);
|
||||
ids.Add(SepTokenId);
|
||||
}
|
||||
|
||||
return ids;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. A BERT sequence has the following format:
|
||||
/// - single sequence: `[CLS] tokenIds0 [SEP]`
|
||||
/// - pair of sequences: `[CLS] tokenIds0 [SEP] tokenIds1 [SEP]`
|
||||
/// </summary>
|
||||
/// <param name="tokenIds0">List of IDs to which the special tokens will be added.</param>
|
||||
/// <param name="buffer">The buffer to write the token IDs with special tokens added.</param>
|
||||
/// <param name="written">The number of elements written to the buffer.</param>
|
||||
/// <param name="tokenIds1">Optional second list of IDs for sequence pairs.</param>
|
||||
/// <returns>The status of the operation.</returns>
|
||||
/// <exception cref="ArgumentNullException">When <paramref name="tokenIds0"/> is null.</exception>
|
||||
public OperationStatus BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds0, Span<int> buffer, out int written, IEnumerable<int>? tokenIds1 = null)
|
||||
{
|
||||
if (tokenIds0 is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(tokenIds0));
|
||||
}
|
||||
|
||||
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
|
||||
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
|
||||
if (buffer.Length < capacity)
|
||||
{
|
||||
written = 0;
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
|
||||
written = 0;
|
||||
buffer[written++] = ClsTokenId;
|
||||
foreach (int id in tokenIds0)
|
||||
{
|
||||
buffer[written++] = id;
|
||||
}
|
||||
buffer[written++] = SepTokenId;
|
||||
|
||||
if (tokenIds1 is not null)
|
||||
{
|
||||
foreach (int id in tokenIds1)
|
||||
{
|
||||
buffer[written++] = id;
|
||||
}
|
||||
|
||||
buffer[written++] = SepTokenId;
|
||||
}
|
||||
|
||||
return OperationStatus.Done;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Retrieve sequence tokens mask from a IDs list.
|
||||
/// </summary>
|
||||
/// <param name="tokenIds0">List of IDs.</param>
|
||||
/// <param name="tokenIds1">Optional second list of IDs for sequence pairs.</param>
|
||||
/// <param name="alreadyHasSpecialTokens">Indicate whether or not the token list is already formatted with special tokens for the model.</param>
|
||||
/// <returns>A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.</returns>
|
||||
/// <exception cref="ArgumentNullException"></exception>
|
||||
public IReadOnlyList<int> GetSpecialTokensMask(IEnumerable<int> tokenIds0, IEnumerable<int>? tokenIds1 = null, bool alreadyHasSpecialTokens = false)
|
||||
{
|
||||
if (tokenIds0 is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(tokenIds0));
|
||||
}
|
||||
|
||||
int capacity = alreadyHasSpecialTokens ?
|
||||
tokenIds0.Count() + (tokenIds1?.Count() ?? 0) :
|
||||
tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : 1); // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
|
||||
|
||||
List<int> mask = new List<int>(capacity: capacity);
|
||||
|
||||
if (!alreadyHasSpecialTokens)
|
||||
{
|
||||
mask.Add(1); // CLS
|
||||
mask.AddRange(Enumerable.Repeat(0, tokenIds0.Count()));
|
||||
mask.Add(1); // SEP
|
||||
if (tokenIds1 is not null)
|
||||
{
|
||||
mask.AddRange(Enumerable.Repeat(0, tokenIds1.Count()));
|
||||
mask.Add(1); // SEP
|
||||
}
|
||||
|
||||
return mask;
|
||||
}
|
||||
|
||||
foreach (int id in tokenIds0)
|
||||
{
|
||||
mask.Add(id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0);
|
||||
}
|
||||
|
||||
if (tokenIds1 is not null)
|
||||
{
|
||||
foreach (int id in tokenIds1)
|
||||
{
|
||||
mask.Add(id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0);
|
||||
}
|
||||
}
|
||||
|
||||
return mask;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Retrieve sequence tokens mask from a IDs list.
|
||||
/// </summary>
|
||||
/// <param name="tokenIds0">List of IDs.</param>
|
||||
/// <param name="buffer">The buffer to write the mask. The integers written values are in the range [0, 1]: 1 for a special token, 0 for a sequence token.</param>
|
||||
/// <param name="written">The number of elements written to the buffer.</param>
|
||||
/// <param name="tokenIds1">Optional second list of IDs for sequence pairs.</param>
|
||||
/// <param name="alreadyHasSpecialTokens">Indicate whether or not the token list is already formatted with special tokens for the model.</param>
|
||||
/// <returns>The status of the operation.</returns>
|
||||
/// <exception cref="ArgumentNullException"></exception>
|
||||
public OperationStatus GetSpecialTokensMask(IEnumerable<int> tokenIds0, Span<int> buffer, out int written, IEnumerable<int>? tokenIds1 = null, bool alreadyHasSpecialTokens = false)
|
||||
{
|
||||
if (tokenIds0 is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(tokenIds0));
|
||||
}
|
||||
|
||||
int capacity = alreadyHasSpecialTokens ?
|
||||
tokenIds0.Count() + (tokenIds1?.Count() ?? 0) :
|
||||
tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1); // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
|
||||
|
||||
written = 0;
|
||||
if (buffer.Length < capacity)
|
||||
{
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
|
||||
if (!alreadyHasSpecialTokens)
|
||||
{
|
||||
buffer[written++] = 1; // CLS
|
||||
foreach (int id in tokenIds0)
|
||||
{
|
||||
buffer[written++] = 0;
|
||||
}
|
||||
buffer[written++] = 1; // SEP
|
||||
|
||||
if (tokenIds1 is not null)
|
||||
{
|
||||
foreach (int id in tokenIds1)
|
||||
{
|
||||
buffer[written++] = 0;
|
||||
}
|
||||
buffer[written++] = 1; // SEP
|
||||
}
|
||||
|
||||
return OperationStatus.Done;
|
||||
}
|
||||
|
||||
foreach (int id in tokenIds0)
|
||||
{
|
||||
buffer[written++] = id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0;
|
||||
}
|
||||
|
||||
if (tokenIds1 is not null)
|
||||
{
|
||||
foreach (int id in tokenIds1)
|
||||
{
|
||||
buffer[written++] = id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0;
|
||||
}
|
||||
}
|
||||
|
||||
return OperationStatus.Done;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence pair mask has the following format:
|
||||
/// 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
||||
/// | first sequence | second sequence |
|
||||
/// If <paramref name="tokenIds1"/> is null, this method only returns the first portion of the type ids (0s).
|
||||
/// </summary>
|
||||
/// <param name="tokenIds0">List of token IDs for the first sequence.</param>
|
||||
/// <param name="tokenIds1">Optional list of token IDs for the second sequence.</param>
|
||||
/// <returns>List of token type IDs according to the given sequence(s).</returns>
|
||||
/// <exception cref="ArgumentNullException">When <paramref name="tokenIds0"/> is null.</exception>
|
||||
public IReadOnlyList<int> CreateTokenTypeIdsFromSequences(IEnumerable<int> tokenIds0, IEnumerable<int>? tokenIds1 = null)
|
||||
{
|
||||
if (tokenIds0 is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(tokenIds0));
|
||||
}
|
||||
|
||||
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
|
||||
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
|
||||
|
||||
List<int> typeIds = new List<int>(capacity);
|
||||
for (int i = 0; i < tokenIds0.Count() + 2; i++) // Add 2 for [CLS] and [SEP] tokens.
|
||||
{
|
||||
typeIds.Add(0);
|
||||
}
|
||||
|
||||
if (tokenIds1 is not null)
|
||||
{
|
||||
for (int i = 0; i < tokenIds1.Count() + 1; i++) // Add 1 for [SEP] token.
|
||||
{
|
||||
typeIds.Add(1);
|
||||
}
|
||||
}
|
||||
|
||||
return typeIds;
|
||||
}
|
||||
|
||||
public OperationStatus CreateTokenTypeIdsFromSequences(IEnumerable<int> tokenIds0, Span<int> buffer, out int written, IEnumerable<int>? tokenIds1 = null)
|
||||
{
|
||||
if (tokenIds0 is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(tokenIds0));
|
||||
}
|
||||
|
||||
written = 0;
|
||||
|
||||
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
|
||||
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
|
||||
if (buffer.Length < capacity)
|
||||
{
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
|
||||
for (int i = 0; i < tokenIds0.Count() + 2; i++) // Add 2 for [CLS] and [SEP] tokens.
|
||||
{
|
||||
buffer[written++] = 0;
|
||||
}
|
||||
|
||||
if (tokenIds1 is not null)
|
||||
{
|
||||
for (int i = 0; i < tokenIds1.Count() + 1; i++) // Add 1 for [SEP] token.
|
||||
{
|
||||
buffer[written++] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
return OperationStatus.Done;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a new instance of the <see cref="BertTokenizer"/> class.
|
||||
/// </summary>
|
||||
/// <param name="vocabFilePath">The path to the vocabulary file.</param>
|
||||
/// <param name="doLowerCase">A value indicating whether the tokenizer should lowercase the input text.</param>
|
||||
/// <param name="doBasicTokenization">A value indicating whether the tokenizer should do basic tokenization. Like clean text, normalize it, lowercasing, etc.</param>
|
||||
/// <param name="splitOnSpecialTokens">A value indicating whether the tokenizer should split on special tokens.</param>
|
||||
/// <param name="unknownToken">The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.</param>
|
||||
/// <param name="sepToken">The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for sequence classification or for a text and a question for question answering. It is also used as the last token of a sequence built with special tokens.</param>
|
||||
/// <param name="padToken">The token used for padding, for example when batching sequences of different lengths.</param>
|
||||
/// <param name="clsToken">The classifier token which is used when doing sequence classification (classification of the whole sequence instead of per-token classification). It is the first token of the sequence when built with special tokens.</param>
|
||||
/// <param name="maskToken">The token used for masking values. This is the token used when training this model with masked language modeling. This is the token which the model will try to predict.</param>
|
||||
/// <param name="tokenizeChineseChars">A value indicating whether the tokenizer should split the Chinese characters into tokens.</param>
|
||||
/// <param name="stripAccents">A value indicating whether the tokenizer should strip accents characters.</param>
|
||||
/// <returns></returns>
|
||||
/// <exception cref="ArgumentNullException"></exception>
|
||||
public static BertTokenizer Create(
|
||||
string vocabFilePath,
|
||||
bool doLowerCase = true,
|
||||
bool doBasicTokenization = true,
|
||||
bool splitOnSpecialTokens = true,
|
||||
string unknownToken = "[UNK]",
|
||||
string sepToken = "[SEP]",
|
||||
string padToken = "[PAD]",
|
||||
string clsToken = "[CLS]",
|
||||
string maskToken = "[MASK]",
|
||||
bool tokenizeChineseChars = true,
|
||||
bool stripAccents = false) =>
|
||||
Create(
|
||||
string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath),
|
||||
doLowerCase, doBasicTokenization, splitOnSpecialTokens, unknownToken, sepToken, padToken, clsToken, maskToken, tokenizeChineseChars, stripAccents, disposeStream: true);
|
||||
|
||||
/// <summary>
|
||||
/// Create a new instance of the <see cref="BertTokenizer"/> class.
|
||||
/// </summary>
|
||||
/// <param name="vocabStream">The stream containing the vocabulary file.</param>
|
||||
/// <param name="doLowerCase">A value indicating whether the tokenizer should lowercase the input text.</param>
|
||||
/// <param name="doBasicTokenization">A value indicating whether the tokenizer should do basic tokenization. Like clean text, normalize it, lowercasing, etc.</param>
|
||||
/// <param name="splitOnSpecialTokens">A value indicating whether the tokenizer should split on special tokens.</param>
|
||||
/// <param name="unknownToken">The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.</param>
|
||||
/// <param name="sepToken">The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for sequence classification or for a text and a question for question answering. It is also used as the last token of a sequence built with special tokens.</param>
|
||||
/// <param name="padToken">The token used for padding, for example when batching sequences of different lengths.</param>
|
||||
/// <param name="clsToken">The classifier token which is used when doing sequence classification (classification of the whole sequence instead of per-token classification). It is the first token of the sequence when built with special tokens.</param>
|
||||
/// <param name="maskToken">The token used for masking values. This is the token used when training this model with masked language modeling. This is the token which the model will try to predict.</param>
|
||||
/// <param name="tokenizeChineseChars">A value indicating whether the tokenizer should split the Chinese characters into tokens.</param>
|
||||
/// <param name="stripAccents">A value indicating whether the tokenizer should strip accents characters.</param>
|
||||
/// <returns></returns>
|
||||
/// <exception cref="ArgumentNullException"></exception>
|
||||
public static BertTokenizer Create(
|
||||
Stream vocabStream,
|
||||
bool doLowerCase = true,
|
||||
bool doBasicTokenization = true,
|
||||
bool splitOnSpecialTokens = true,
|
||||
string unknownToken = "[UNK]",
|
||||
string sepToken = "[SEP]",
|
||||
string padToken = "[PAD]",
|
||||
string clsToken = "[CLS]",
|
||||
string maskToken = "[MASK]",
|
||||
bool tokenizeChineseChars = true,
|
||||
bool stripAccents = false) =>
|
||||
Create(vocabStream, doLowerCase, doBasicTokenization, splitOnSpecialTokens, unknownToken, sepToken, padToken, clsToken, maskToken, tokenizeChineseChars, stripAccents, disposeStream: false);
|
||||
|
||||
/// <summary>
|
||||
/// Create a new instance of the <see cref="BertTokenizer"/> class asynchronously.
|
||||
/// </summary>
|
||||
/// <param name="vocabStream">The stream containing the vocabulary file.</param>
|
||||
/// <param name="doLowerCase">A value indicating whether the tokenizer should lowercase the input text.</param>
|
||||
/// <param name="doBasicTokenization">A value indicating whether the tokenizer should do basic tokenization. Like clean text, normalize it, lowercasing, etc.</param>
|
||||
/// <param name="splitOnSpecialTokens">A value indicating whether the tokenizer should split on special tokens.</param>
|
||||
/// <param name="unknownToken">The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.</param>
|
||||
/// <param name="sepToken">The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for sequence classification or for a text and a question for question answering. It is also used as the last token of a sequence built with special tokens.</param>
|
||||
/// <param name="padToken">The token used for padding, for example when batching sequences of different lengths.</param>
|
||||
/// <param name="clsToken">The classifier token which is used when doing sequence classification (classification of the whole sequence instead of per-token classification). It is the first token of the sequence when built with special tokens.</param>
|
||||
/// <param name="maskToken">The token used for masking values. This is the token used when training this model with masked language modeling. This is the token which the model will try to predict.</param>
|
||||
/// <param name="tokenizeChineseChars">A value indicating whether the tokenizer should split the Chinese characters into tokens.</param>
|
||||
/// <param name="stripAccents">A value indicating whether the tokenizer should strip accents characters.</param>
|
||||
/// <returns></returns>
|
||||
/// <exception cref="ArgumentNullException"></exception>
|
||||
public static async Task<BertTokenizer> CreateAsync(
|
||||
Stream vocabStream,
|
||||
bool doLowerCase = true,
|
||||
bool doBasicTokenization = true,
|
||||
bool splitOnSpecialTokens = true,
|
||||
string unknownToken = "[UNK]",
|
||||
string sepToken = "[SEP]",
|
||||
string padToken = "[PAD]",
|
||||
string clsToken = "[CLS]",
|
||||
string maskToken = "[MASK]",
|
||||
bool tokenizeChineseChars = true,
|
||||
bool stripAccents = false)
|
||||
{
|
||||
if (vocabStream is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(vocabStream));
|
||||
}
|
||||
|
||||
(Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, string> vocabReverse) = await LoadVocabAsync(vocabStream, useAsync: true).ConfigureAwait(false);
|
||||
|
||||
return Create(vocab, vocabReverse, doLowerCase, doBasicTokenization, splitOnSpecialTokens, unknownToken, sepToken, padToken, clsToken, maskToken, tokenizeChineseChars, stripAccents);
|
||||
}
|
||||
|
||||
private static BertTokenizer Create(
|
||||
Stream vocabStream,
|
||||
bool doLowerCase,
|
||||
bool doBasicTokenization,
|
||||
bool splitOnSpecialTokens,
|
||||
string unknownToken,
|
||||
string sepToken,
|
||||
string padToken,
|
||||
string clsToken,
|
||||
string maskToken,
|
||||
bool tokenizeChineseChars,
|
||||
bool stripAccents,
|
||||
bool disposeStream)
|
||||
{
|
||||
if (vocabStream is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(vocabStream));
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
(Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, string> vocabReverse) = LoadVocabAsync(vocabStream, useAsync: false).GetAwaiter().GetResult();
|
||||
|
||||
return Create(vocab, vocabReverse, doLowerCase, doBasicTokenization, splitOnSpecialTokens, unknownToken, sepToken, padToken, clsToken, maskToken, tokenizeChineseChars, stripAccents);
|
||||
}
|
||||
finally
|
||||
{
|
||||
if (disposeStream)
|
||||
{
|
||||
vocabStream.Dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static BertTokenizer Create(
|
||||
Dictionary<StringSpanOrdinalKey, int> vocab,
|
||||
Dictionary<int, string> vocabReverse,
|
||||
bool doLowerCase,
|
||||
bool doBasicTokenization,
|
||||
bool splitOnSpecialTokens,
|
||||
string unknownToken,
|
||||
string sepToken,
|
||||
string padToken,
|
||||
string clsToken,
|
||||
string maskToken,
|
||||
bool tokenizeChineseChars,
|
||||
bool stripAccents)
|
||||
{
|
||||
Normalizer? normalizer = doBasicTokenization ? new BertNormalizer(doLowerCase, tokenizeChineseChars, stripAccents) : null;
|
||||
|
||||
Dictionary<string, int>? specialTokens = new();
|
||||
bool lowerCase = doBasicTokenization && doLowerCase && splitOnSpecialTokens;
|
||||
|
||||
AddSpecialToken(vocab, specialTokens, unknownToken, lowerCase);
|
||||
AddSpecialToken(vocab, specialTokens, sepToken, lowerCase);
|
||||
AddSpecialToken(vocab, specialTokens, padToken, lowerCase);
|
||||
AddSpecialToken(vocab, specialTokens, clsToken, lowerCase);
|
||||
AddSpecialToken(vocab, specialTokens, maskToken, lowerCase);
|
||||
|
||||
PreTokenizer? preTokenizer = doBasicTokenization ?
|
||||
PreTokenizer.CreateWhiteSpaceOrPunctuationPreTokenizer(splitOnSpecialTokens ? specialTokens : null) :
|
||||
PreTokenizer.CreateWhiteSpacePreTokenizer();
|
||||
|
||||
return new BertTokenizer(vocab, vocabReverse, preTokenizer, normalizer, specialTokens, doLowerCase, doBasicTokenization,
|
||||
splitOnSpecialTokens, unknownToken, sepToken, padToken, clsToken, maskToken, tokenizeChineseChars, stripAccents);
|
||||
}
|
||||
|
||||
private static void AddSpecialToken(Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<string, int> specialTokens, string token, bool lowerCase)
|
||||
{
|
||||
if (token is null || !vocab.TryGetValue(new StringSpanOrdinalKey(token), out int id))
|
||||
{
|
||||
throw new ArgumentException($"The special token '{token}' is not in the vocabulary.");
|
||||
}
|
||||
|
||||
string normalizedToken = token;
|
||||
if (lowerCase)
|
||||
{
|
||||
// Lowercase the special tokens to have the pre-tokenization can find them as we lowercase the input text.
|
||||
// we don't even need to do case-insensitive comparisons as we are lowercasing the input text.
|
||||
normalizedToken = token.ToLowerInvariant();
|
||||
|
||||
// Add lowercased special tokens to the vocab if they are not already there.
|
||||
// This will allow matching during the encoding process.
|
||||
vocab[new StringSpanOrdinalKey(normalizedToken)] = id;
|
||||
}
|
||||
|
||||
specialTokens[normalizedToken] = id;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,858 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Globalization;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.ML.Tokenizers
|
||||
{
|
||||
/// <summary>
|
||||
/// Represent the WordPiece tokenizer.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// The WordPiece tokenizer is a sub-word tokenizer that is used in BERT and other transformer models.
|
||||
/// The implementation is based on the Hugging Face WordPiece tokenizer https://huggingface.co/docs/tokenizers/api/models#tokenizers.models.WordPiece.
|
||||
/// </remarks>
|
||||
public partial class WordPieceTokenizer : Tokenizer
|
||||
{
|
||||
private readonly PreTokenizer? _preTokenizer;
|
||||
private readonly Normalizer? _normalizer;
|
||||
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;
|
||||
private readonly Dictionary<int, string> _vocabReverse;
|
||||
|
||||
internal const string DefaultContinuingSubwordPrefix = "##";
|
||||
internal const int DefaultMaxInputCharsPerWord = 100;
|
||||
|
||||
internal WordPieceTokenizer(
|
||||
Dictionary<StringSpanOrdinalKey, int> vocab,
|
||||
Dictionary<int, string> vocabReverse,
|
||||
PreTokenizer? preTokenizer,
|
||||
Normalizer? normalizer,
|
||||
IReadOnlyDictionary<string, int>? specialTokens,
|
||||
string unknownToken,
|
||||
string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
|
||||
int maxInputCharsPerWord = DefaultMaxInputCharsPerWord)
|
||||
{
|
||||
Debug.Assert(vocab is not null);
|
||||
Debug.Assert(vocabReverse is not null);
|
||||
_vocab = vocab!;
|
||||
_vocabReverse = vocabReverse!;
|
||||
SpecialTokens = specialTokens;
|
||||
SpecialTokensReverse = specialTokens is not null ? specialTokens.ToDictionary(kvp => kvp.Value, kvp => kvp.Key) : null;
|
||||
|
||||
if (unknownToken is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(unknownToken));
|
||||
}
|
||||
|
||||
if (continuingSubwordPrefix is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(continuingSubwordPrefix));
|
||||
}
|
||||
|
||||
if (maxInputCharsPerWord <= 0)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(maxInputCharsPerWord), "The maximum number of characters per word must be greater than zero.");
|
||||
}
|
||||
|
||||
if (!vocab!.TryGetValue(unknownToken, out int id))
|
||||
{
|
||||
throw new ArgumentException($"The unknown token '{unknownToken}' is not in the vocabulary.");
|
||||
}
|
||||
|
||||
UnknownToken = unknownToken;
|
||||
UnknownTokenId = id;
|
||||
ContinuingSubwordPrefix = continuingSubwordPrefix;
|
||||
MaxInputCharsPerWord = maxInputCharsPerWord;
|
||||
|
||||
_preTokenizer = preTokenizer ?? PreTokenizer.CreateWhiteSpacePreTokenizer(specialTokens);
|
||||
_normalizer = normalizer;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the unknown token ID.
|
||||
/// A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.
|
||||
/// </summary>
|
||||
public int UnknownTokenId { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the prefix to use for sub-words that are not the first part of a word.
|
||||
/// </summary>
|
||||
public string ContinuingSubwordPrefix { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the maximum number of characters to authorize in a single word.
|
||||
/// </summary>
|
||||
public int MaxInputCharsPerWord { get; }
|
||||
|
||||
internal static async ValueTask<(Dictionary<StringSpanOrdinalKey, int>, Dictionary<int, string>)> LoadVocabAsync(Stream vocabStream, bool useAsync, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (vocabStream is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(vocabStream));
|
||||
}
|
||||
|
||||
Dictionary<StringSpanOrdinalKey, int> vocab = new Dictionary<StringSpanOrdinalKey, int>();
|
||||
Dictionary<int, string> vocabReverse = new Dictionary<int, string>();
|
||||
|
||||
StreamReader reader = new StreamReader(vocabStream);
|
||||
string? line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine();
|
||||
int lineNumber = 0;
|
||||
|
||||
while (line is not null)
|
||||
{
|
||||
if (line.Length != 0)
|
||||
{
|
||||
vocab.Add(new StringSpanOrdinalKey(line), lineNumber);
|
||||
vocabReverse.Add(lineNumber, line);
|
||||
}
|
||||
|
||||
lineNumber++;
|
||||
line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine();
|
||||
}
|
||||
|
||||
return (vocab, vocabReverse);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a new instance of the <see cref="WordPieceTokenizer"/> class.
|
||||
/// </summary>
|
||||
/// <param name="vocabFilePath">The path to the WordPiece vocab file.</param>
|
||||
/// <param name="preTokenizer">The PreTokenizer to use.</param>
|
||||
/// <param name="normalizer">The Normalizer to use.</param>
|
||||
/// <param name="specialTokens">The dictionary containing the special tokens and their corresponding ids.</param>
|
||||
/// <param name="unknownToken">The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.</param>
|
||||
/// <param name="continuingSubwordPrefix">The prefix to use for sub-words that are not the first part of a word.</param>
|
||||
/// <param name="maxInputCharsPerWord">The maximum number of characters to authorize in a single word.</param>
|
||||
/// <returns>A new instance of the <see cref="WordPieceTokenizer"/> class.</returns>
|
||||
/// <remarks>
|
||||
/// If the <paramref name="preTokenizer"/> is null, the whitespace pre-tokenizer will be used.
|
||||
/// </remarks>
|
||||
public static WordPieceTokenizer Create(
|
||||
string vocabFilePath,
|
||||
PreTokenizer? preTokenizer = null,
|
||||
Normalizer? normalizer = null,
|
||||
IReadOnlyDictionary<string, int>? specialTokens = null,
|
||||
string unknownToken = "[UNK]",
|
||||
string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
|
||||
int maxInputCharsPerWord = DefaultMaxInputCharsPerWord) =>
|
||||
Create(string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath), preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord, disposeStream: true);
|
||||
|
||||
/// <summary>
|
||||
/// Create a new instance of the <see cref="WordPieceTokenizer"/> class.
|
||||
/// </summary>
|
||||
/// <param name="vocabStream">The path to the WordPiece vocab file.</param>
|
||||
/// <param name="preTokenizer">The PreTokenizer to use.</param>
|
||||
/// <param name="normalizer">The Normalizer to use.</param>
|
||||
/// <param name="specialTokens">The dictionary containing the special tokens and their corresponding ids.</param>
|
||||
/// <param name="unknownToken">The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.</param>
|
||||
/// <param name="continuingSubwordPrefix">The prefix to use for sub-words that are not the first part of a word.</param>
|
||||
/// <param name="maxInputCharsPerWord">The maximum number of characters to authorize in a single word.</param>
|
||||
/// <returns>A new instance of the <see cref="WordPieceTokenizer"/> class.</returns>
|
||||
/// <remarks>
|
||||
/// If the <paramref name="preTokenizer"/> is null, the whitespace pre-tokenizer will be used.
|
||||
/// </remarks>
|
||||
public static WordPieceTokenizer Create(
|
||||
Stream vocabStream,
|
||||
PreTokenizer? preTokenizer = null,
|
||||
Normalizer? normalizer = null,
|
||||
IReadOnlyDictionary<string, int>? specialTokens = null,
|
||||
string unknownToken = "[UNK]",
|
||||
string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
|
||||
int maxInputCharsPerWord = DefaultMaxInputCharsPerWord) => Create(vocabStream, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord, disposeStream: false);
|
||||
|
||||
private static WordPieceTokenizer Create(
|
||||
Stream vocabStream,
|
||||
PreTokenizer? preTokenizer,
|
||||
Normalizer? normalizer,
|
||||
IReadOnlyDictionary<string, int>? specialTokens,
|
||||
string unknownToken,
|
||||
string continuingSubwordPrefix,
|
||||
int maxInputCharsPerWord,
|
||||
bool disposeStream)
|
||||
{
|
||||
if (vocabStream is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(vocabStream));
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
(Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, string> vocabReverse) = LoadVocabAsync(vocabStream, useAsync: false).GetAwaiter().GetResult();
|
||||
|
||||
return new WordPieceTokenizer(vocab, vocabReverse, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord);
|
||||
}
|
||||
finally
|
||||
{
|
||||
if (disposeStream)
|
||||
{
|
||||
vocabStream.Dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a new instance of the <see cref="WordPieceTokenizer"/> class asynchronously.
|
||||
/// </summary>
|
||||
/// <param name="vocabFilePath">The path to the WordPiece vocab file.</param>
|
||||
/// <param name="preTokenizer">The PreTokenizer to use.</param>
|
||||
/// <param name="normalizer">The Normalizer to use.</param>
|
||||
/// <param name="specialTokens">The dictionary containing the special tokens and their corresponding ids.</param>
|
||||
/// <param name="unknownToken">The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.</param>
|
||||
/// <param name="continuingSubwordPrefix">The prefix to use for sub-words that are not the first part of a word.</param>
|
||||
/// <param name="maxInputCharsPerWord">The maximum number of characters to authorize in a single word.</param>
|
||||
/// <param name="cancellationToken">The cancellation token.</param>
|
||||
/// <returns>A new instance of the <see cref="WordPieceTokenizer"/> class.</returns>
|
||||
/// <remarks>
|
||||
/// If the <paramref name="preTokenizer"/> is null, the whitespace pre-tokenizer will be used.
|
||||
/// </remarks>
|
||||
public static async Task<WordPieceTokenizer> CreateAsync(
|
||||
string vocabFilePath,
|
||||
PreTokenizer? preTokenizer = null,
|
||||
Normalizer? normalizer = null,
|
||||
IReadOnlyDictionary<string, int>? specialTokens = null,
|
||||
string unknownToken = "[UNK]",
|
||||
string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
|
||||
int maxInputCharsPerWord = DefaultMaxInputCharsPerWord,
|
||||
CancellationToken cancellationToken = default) =>
|
||||
await CreateAsync(
|
||||
string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath),
|
||||
preTokenizer,
|
||||
normalizer,
|
||||
specialTokens,
|
||||
unknownToken,
|
||||
continuingSubwordPrefix,
|
||||
maxInputCharsPerWord,
|
||||
cancellationToken,
|
||||
disposeStream: true);
|
||||
|
||||
/// <summary>
|
||||
/// Create a new instance of the <see cref="WordPieceTokenizer"/> class asynchronously.
|
||||
/// </summary>
|
||||
/// <param name="vocabStream">The path to the WordPiece vocab file.</param>
|
||||
/// <param name="preTokenizer">The PreTokenizer to use.</param>
|
||||
/// <param name="normalizer">The Normalizer to use.</param>
|
||||
/// <param name="specialTokens">The dictionary containing the special tokens and their corresponding ids.</param>
|
||||
/// <param name="unknownToken">The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.</param>
|
||||
/// <param name="continuingSubwordPrefix">The prefix to use for sub-words that are not the first part of a word.</param>
|
||||
/// <param name="maxInputCharsPerWord">The maximum number of characters to authorize in a single word.</param>
|
||||
/// <param name="cancellationToken">The cancellation token.</param>
|
||||
/// <returns>A new instance of the <see cref="WordPieceTokenizer"/> class.</returns>
|
||||
/// <remarks>
|
||||
/// If the <paramref name="preTokenizer"/> is null, the whitespace pre-tokenizer will be used.
|
||||
/// </remarks>
|
||||
public static async Task<WordPieceTokenizer> CreateAsync(
|
||||
Stream vocabStream,
|
||||
PreTokenizer? preTokenizer = null,
|
||||
Normalizer? normalizer = null,
|
||||
IReadOnlyDictionary<string, int>? specialTokens = null,
|
||||
string unknownToken = "[UNK]",
|
||||
string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
|
||||
int maxInputCharsPerWord = DefaultMaxInputCharsPerWord,
|
||||
CancellationToken cancellationToken = default) =>
|
||||
await CreateAsync(vocabStream, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord, cancellationToken, disposeStream: false);
|
||||
|
||||
private static async Task<WordPieceTokenizer> CreateAsync(
|
||||
Stream vocabStream,
|
||||
PreTokenizer? preTokenizer,
|
||||
Normalizer? normalizer,
|
||||
IReadOnlyDictionary<string, int>? specialTokens,
|
||||
string unknownToken,
|
||||
string continuingSubwordPrefix,
|
||||
int maxInputCharsPerWord,
|
||||
CancellationToken cancellationToken,
|
||||
bool disposeStream)
|
||||
{
|
||||
if (vocabStream is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(vocabStream));
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
(Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, string> vocabReverse) = await LoadVocabAsync(vocabStream, useAsync: true, cancellationToken);
|
||||
|
||||
return new WordPieceTokenizer(vocab, vocabReverse, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord);
|
||||
}
|
||||
finally
|
||||
{
|
||||
if (disposeStream)
|
||||
{
|
||||
vocabStream.Dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the PreTokenizer used by the Tokenizer.
|
||||
/// </summary>
|
||||
public override PreTokenizer? PreTokenizer => _preTokenizer;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the Normalizer in use by the Tokenizer.
|
||||
/// </summary>
|
||||
public override Normalizer? Normalizer => _normalizer;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the unknown token.
|
||||
/// A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.
|
||||
/// </summary>
|
||||
public string UnknownToken { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the special tokens and their corresponding ids.
|
||||
/// </summary>
|
||||
public IReadOnlyDictionary<string, int>? SpecialTokens { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the Ids to tokens mapping for special tokens.
|
||||
/// </summary>
|
||||
internal IReadOnlyDictionary<int, string>? SpecialTokensReverse { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Encodes input text to a list of <see cref="EncodedToken" />s.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
|
||||
/// <param name="settings">The settings used to encode the text.</param>
|
||||
protected override EncodeResults<EncodedToken> EncodeToTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
|
||||
{
|
||||
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
|
||||
{
|
||||
return new EncodeResults<EncodedToken> { NormalizedText = null, Tokens = [], CharsConsumed = 0 };
|
||||
}
|
||||
|
||||
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
|
||||
text,
|
||||
textSpan,
|
||||
settings.ConsiderPreTokenization,
|
||||
settings.ConsiderNormalization,
|
||||
_normalizer,
|
||||
_preTokenizer,
|
||||
out string? normalizedString,
|
||||
out ReadOnlySpan<char> textSpanToEncode,
|
||||
out int charsConsumed);
|
||||
|
||||
List<EncodedToken> tokens = new();
|
||||
|
||||
if (splits is not null)
|
||||
{
|
||||
foreach ((int Offset, int Length) split in splits)
|
||||
{
|
||||
EncodeToTokens(textSpanToEncode.Slice(split.Offset, split.Length), tokens, split.Offset);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
EncodeToTokens(textSpanToEncode, tokens, 0);
|
||||
}
|
||||
|
||||
return new EncodeResults<EncodedToken> { NormalizedText = normalizedString, Tokens = tokens, CharsConsumed = charsConsumed };
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Encode text to a list of tokens.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="tokens">The list of tokens to populate.</param>
|
||||
/// <param name="offset">The offset to start encoding from.</param>
|
||||
private void EncodeToTokens(ReadOnlySpan<char> text, List<EncodedToken> tokens, int offset)
|
||||
{
|
||||
Debug.Assert(!text.IsEmpty);
|
||||
|
||||
if (text.Length > MaxInputCharsPerWord)
|
||||
{
|
||||
tokens.Add(new EncodedToken(UnknownTokenId, UnknownToken, new Range(offset, offset + text.Length)));
|
||||
return;
|
||||
}
|
||||
|
||||
int maxLength = MaxInputCharsPerWord + ContinuingSubwordPrefix.Length;
|
||||
char[]? arrayPool = maxLength <= 250 ? null : ArrayPool<char>.Shared.Rent(maxLength);
|
||||
Span<char> buffer = arrayPool is null ? stackalloc char[maxLength] : arrayPool;
|
||||
ContinuingSubwordPrefix.AsSpan().CopyTo(buffer);
|
||||
|
||||
int initialTokensCount = tokens.Count;
|
||||
int textLength = text.Length;
|
||||
bool isBad = false;
|
||||
|
||||
int start = 0;
|
||||
|
||||
while (start < textLength)
|
||||
{
|
||||
int end = textLength;
|
||||
EncodedToken curToken = default;
|
||||
|
||||
while (start < end)
|
||||
{
|
||||
scoped ReadOnlySpan<char> subStr = text.Slice(start, end - start);
|
||||
|
||||
if (start > 0)
|
||||
{
|
||||
subStr.CopyTo(buffer.Slice(ContinuingSubwordPrefix.Length));
|
||||
subStr = buffer.Slice(0, ContinuingSubwordPrefix.Length + subStr.Length);
|
||||
}
|
||||
|
||||
if (_vocab.TryGetValue(subStr, out int id))
|
||||
{
|
||||
Debug.Assert(_vocabReverse.ContainsKey(id));
|
||||
curToken = new EncodedToken(id, _vocabReverse[id], new Range(offset + start, offset + end));
|
||||
break;
|
||||
}
|
||||
|
||||
end -= 1;
|
||||
}
|
||||
|
||||
if (curToken.Value is null)
|
||||
{
|
||||
isBad = true;
|
||||
break;
|
||||
}
|
||||
|
||||
tokens.Add(curToken);
|
||||
start = end;
|
||||
}
|
||||
|
||||
if (isBad)
|
||||
{
|
||||
// remove previously added tokens and add the unknown token
|
||||
tokens.RemoveRange(initialTokensCount, tokens.Count - initialTokensCount);
|
||||
tokens.Add(new EncodedToken(UnknownTokenId, UnknownToken, new Range(offset, offset + textLength)));
|
||||
}
|
||||
|
||||
if (arrayPool is not null)
|
||||
{
|
||||
ArrayPool<char>.Shared.Return(arrayPool);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Encodes input text to token Ids.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
|
||||
/// <param name="settings">The settings used to encode the text.</param>
|
||||
/// <returns>The encoded results containing the list of encoded Ids.</returns>
|
||||
protected override EncodeResults<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
|
||||
{
|
||||
int maxTokenCount = settings.MaxTokenCount;
|
||||
if (maxTokenCount <= 0)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(settings.MaxTokenCount), "The maximum number of tokens must be greater than zero.");
|
||||
}
|
||||
|
||||
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
|
||||
{
|
||||
return new EncodeResults<int> { NormalizedText = null, Tokens = [], CharsConsumed = 0 };
|
||||
}
|
||||
|
||||
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
|
||||
text,
|
||||
textSpan,
|
||||
settings.ConsiderPreTokenization,
|
||||
settings.ConsiderNormalization,
|
||||
_normalizer,
|
||||
_preTokenizer,
|
||||
out string? normalizedString,
|
||||
out ReadOnlySpan<char> textSpanToEncode,
|
||||
out int charsConsumed);
|
||||
|
||||
List<int> ids = new();
|
||||
|
||||
if (splits is not null)
|
||||
{
|
||||
charsConsumed = 0;
|
||||
foreach ((int Offset, int Length) split in splits)
|
||||
{
|
||||
EncodeToIds(textSpanToEncode.Slice(split.Offset, split.Length), ids, out int length, maxTokenCount - ids.Count);
|
||||
|
||||
if (length < split.Length || ids.Count >= maxTokenCount)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
charsConsumed = split.Offset + length;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
EncodeToIds(textSpanToEncode, ids, out charsConsumed);
|
||||
}
|
||||
|
||||
return new EncodeResults<int> { NormalizedText = normalizedString, Tokens = ids, CharsConsumed = charsConsumed };
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Encode text to a list of Ids.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="accumulatedIds">The list of accumulated Ids.</param>
|
||||
/// <param name="charsConsumed">The length of the text that encompasses the maximum encoded tokens.</param>
|
||||
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
|
||||
/// <returns>The number of tokens that the input text will be encoded to.</returns>
|
||||
private int EncodeToIds(ReadOnlySpan<char> text, List<int>? accumulatedIds, out int charsConsumed, int maxTokenCount = int.MaxValue)
|
||||
{
|
||||
Debug.Assert(maxTokenCount > 0);
|
||||
|
||||
if (text.IsEmpty)
|
||||
{
|
||||
charsConsumed = 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (text.Length > MaxInputCharsPerWord)
|
||||
{
|
||||
accumulatedIds?.Add(UnknownTokenId);
|
||||
charsConsumed = text.Length;
|
||||
return 1;
|
||||
}
|
||||
|
||||
int maxLength = MaxInputCharsPerWord + ContinuingSubwordPrefix.Length;
|
||||
char[]? arrayPool = maxLength <= 250 ? null : ArrayPool<char>.Shared.Rent(maxLength);
|
||||
Span<char> buffer = arrayPool is null ? stackalloc char[maxLength] : arrayPool;
|
||||
ContinuingSubwordPrefix.AsSpan().CopyTo(buffer);
|
||||
|
||||
int addedIds = 0;
|
||||
int textLength = text.Length;
|
||||
bool isBad = false;
|
||||
|
||||
int start = 0;
|
||||
|
||||
while (start < textLength)
|
||||
{
|
||||
int end = textLength;
|
||||
int curId = 0;
|
||||
bool found = false;
|
||||
|
||||
while (start < end)
|
||||
{
|
||||
scoped ReadOnlySpan<char> subStr = text.Slice(start, end - start);
|
||||
|
||||
if (start > 0)
|
||||
{
|
||||
subStr.CopyTo(buffer.Slice(ContinuingSubwordPrefix.Length));
|
||||
subStr = buffer.Slice(0, ContinuingSubwordPrefix.Length + subStr.Length);
|
||||
}
|
||||
|
||||
if (_vocab.TryGetValue(subStr, out curId))
|
||||
{
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
|
||||
end -= 1;
|
||||
}
|
||||
|
||||
if (!found)
|
||||
{
|
||||
isBad = true;
|
||||
break;
|
||||
}
|
||||
|
||||
accumulatedIds?.Add(curId);
|
||||
addedIds++;
|
||||
start = end;
|
||||
}
|
||||
|
||||
charsConsumed = textLength;
|
||||
if (addedIds > maxTokenCount)
|
||||
{
|
||||
// not enough space to hold added ids. Remove previously added ids
|
||||
accumulatedIds?.RemoveRange(accumulatedIds.Count - addedIds, addedIds);
|
||||
addedIds = 0;
|
||||
charsConsumed = 0;
|
||||
}
|
||||
else if (isBad)
|
||||
{
|
||||
// remove previously added ids and add the unknown token id
|
||||
accumulatedIds?.RemoveRange(accumulatedIds.Count - addedIds, addedIds);
|
||||
accumulatedIds?.Add(UnknownTokenId);
|
||||
addedIds = 1;
|
||||
}
|
||||
|
||||
if (arrayPool is not null)
|
||||
{
|
||||
ArrayPool<char>.Shared.Return(arrayPool);
|
||||
}
|
||||
|
||||
return addedIds;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of tokens that the input text will be encoded to.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
|
||||
/// <param name="settings">The settings used to encode the text.</param>
|
||||
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
|
||||
protected override int CountTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
|
||||
{
|
||||
int maxTokenCount = settings.MaxTokenCount;
|
||||
if (maxTokenCount <= 0)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(settings.MaxTokenCount), "The maximum number of tokens must be greater than zero.");
|
||||
}
|
||||
|
||||
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
|
||||
text,
|
||||
textSpan,
|
||||
settings.ConsiderPreTokenization,
|
||||
settings.ConsiderNormalization,
|
||||
_normalizer,
|
||||
_preTokenizer,
|
||||
out string? normalizedString,
|
||||
out ReadOnlySpan<char> textSpanToEncode,
|
||||
out int charsConsumed);
|
||||
|
||||
int count = 0;
|
||||
if (splits is not null)
|
||||
{
|
||||
foreach ((int Offset, int Length) split in splits)
|
||||
{
|
||||
count += EncodeToIds(textSpanToEncode.Slice(split.Offset, split.Length), accumulatedIds: null, out int length, maxTokenCount - count);
|
||||
|
||||
if (length < split.Length || count >= maxTokenCount)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
count = EncodeToIds(textSpanToEncode, accumulatedIds: null, out charsConsumed, maxTokenCount);
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Find the index of the maximum encoding capacity without surpassing the token limit.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to encode.</param>
|
||||
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
|
||||
/// <param name="settings">The settings used to encode the text.</param>
|
||||
/// <param name="fromEnd">Indicate whether to find the index from the end of the text.</param>
|
||||
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="settings" /> has <see cref="EncodeSettings.ConsiderNormalization"/> is <see langword="false"/>, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to <see langword="null"/>.</param>
|
||||
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
|
||||
/// <returns>
|
||||
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
|
||||
/// If <paramRef name="fromEnd" /> is <see langword="false"/>, it represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
|
||||
/// if all tokens fit, the result will be length of the input text or the <paramref name="normalizedString"/> if the normalization is enabled.
|
||||
/// If <paramRef name="fromEnd" /> is <see langword="true"/>, it represents the index of the first character to be included. In cases where no tokens fit, the result will be the text length; conversely,
|
||||
/// if all tokens fit, the result will be zero.
|
||||
/// </returns>
|
||||
protected override int GetIndexByTokenCount(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings, bool fromEnd, out string? normalizedString, out int tokenCount)
|
||||
{
|
||||
if (settings.MaxTokenCount <= 0)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(settings.MaxTokenCount), "The max token count must be greater than 0.");
|
||||
}
|
||||
|
||||
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
|
||||
{
|
||||
normalizedString = null;
|
||||
tokenCount = 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
|
||||
text,
|
||||
textSpan,
|
||||
settings.ConsiderNormalization,
|
||||
settings.ConsiderNormalization,
|
||||
_normalizer,
|
||||
_preTokenizer,
|
||||
out normalizedString,
|
||||
out ReadOnlySpan<char> textSpanToEncode,
|
||||
out _);
|
||||
|
||||
int charsConsumed;
|
||||
|
||||
if (splits is null)
|
||||
{
|
||||
tokenCount = EncodeToIds(textSpanToEncode, accumulatedIds: null, out charsConsumed, settings.MaxTokenCount);
|
||||
if (charsConsumed != textSpanToEncode.Length)
|
||||
{
|
||||
tokenCount = 0;
|
||||
return fromEnd ? textSpanToEncode.Length : 0;
|
||||
}
|
||||
|
||||
return fromEnd ? 0 : textSpanToEncode.Length;
|
||||
}
|
||||
|
||||
if (fromEnd)
|
||||
{
|
||||
splits = splits.Reverse();
|
||||
}
|
||||
|
||||
tokenCount = 0;
|
||||
foreach ((int Offset, int Length) split in splits)
|
||||
{
|
||||
int count = EncodeToIds(textSpanToEncode.Slice(split.Offset, split.Length), accumulatedIds: null, out charsConsumed, settings.MaxTokenCount - tokenCount);
|
||||
if (charsConsumed != split.Length)
|
||||
{
|
||||
return fromEnd ? split.Offset + split.Length : split.Offset;
|
||||
}
|
||||
|
||||
tokenCount += count;
|
||||
|
||||
if (count >= settings.MaxTokenCount)
|
||||
{
|
||||
return fromEnd ? split.Offset : split.Offset + split.Length;
|
||||
}
|
||||
}
|
||||
|
||||
return fromEnd ? 0 : textSpanToEncode.Length;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Decode the given ids, back to a String.
|
||||
/// </summary>
|
||||
/// <param name="ids">The list of ids that we want to decode.</param>
|
||||
/// <returns>The decoded string.</returns>
|
||||
public override string Decode(IEnumerable<int> ids) => Decode(ids, skipSpecialTokens: false);
|
||||
|
||||
/// <summary>
|
||||
/// Decode the given ids, back to a String.
|
||||
/// </summary>
|
||||
/// <param name="ids">The list of ids that we want to decode.</param>
|
||||
/// <param name="skipSpecialTokens">Indicate whether to skip the special tokens during the decoding.</param>
|
||||
/// <returns>The decoded string.</returns>
|
||||
public string Decode(IEnumerable<int> ids, bool skipSpecialTokens)
|
||||
{
|
||||
ValueStringBuilder sb = new ValueStringBuilder();
|
||||
bool first = true;
|
||||
bool ignoreSpecialTokens = skipSpecialTokens && SpecialTokensReverse is not null;
|
||||
|
||||
foreach (int id in ids)
|
||||
{
|
||||
if (ignoreSpecialTokens && SpecialTokensReverse!.TryGetValue(id, out _))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (_vocabReverse.TryGetValue(id, out string? token))
|
||||
{
|
||||
if (token.StartsWith(ContinuingSubwordPrefix))
|
||||
{
|
||||
sb.Append(token.AsSpan().Slice(ContinuingSubwordPrefix.Length));
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!first && token[0] is not ('.' or ',' or '!' or '?' or '\''))
|
||||
{
|
||||
sb.Append(' ');
|
||||
}
|
||||
|
||||
sb.Append(token);
|
||||
}
|
||||
}
|
||||
|
||||
first = false;
|
||||
}
|
||||
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Decode the given ids back to text and store the result in the <paramref name="destination"/> span.
|
||||
/// </summary>
|
||||
/// <param name="ids">The list of ids that we want to decode.</param>
|
||||
/// <param name="destination">The span to store the decoded text.</param>
|
||||
/// <param name="idsConsumed">The number of ids consumed during the decoding.</param>
|
||||
/// <param name="charsWritten">The number of characters written to the destination span.</param>
|
||||
/// <returns>The operation status indicates whether all IDs were successfully decoded or if the <paramref name="destination"/> is too small to contain the entire decoded result.</returns>
|
||||
public override OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, out int idsConsumed, out int charsWritten) =>
|
||||
Decode(ids, destination, skipSpecialTokens: false, out idsConsumed, out charsWritten);
|
||||
|
||||
/// <summary>
|
||||
/// Decode the given ids back to text and store the result in the <paramref name="destination"/> span.
|
||||
/// </summary>
|
||||
/// <param name="ids">The list of ids that we want to decode.</param>
|
||||
/// <param name="destination">The span to store the decoded text.</param>
|
||||
/// <param name="skipSpecialTokens">Indicate whether to skip the special tokens during the decoding.</param>
|
||||
/// <param name="idsConsumed">The number of ids consumed during the decoding.</param>
|
||||
/// <param name="charsWritten">The number of characters written to the destination span.</param>
|
||||
/// <returns>The operation status indicates whether all IDs were successfully decoded or if the <paramref name="destination"/> is too small to contain the entire decoded result.</returns>
|
||||
public OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, bool skipSpecialTokens, out int idsConsumed, out int charsWritten)
|
||||
{
|
||||
charsWritten = 0;
|
||||
idsConsumed = 0;
|
||||
Span<char> buffer = destination;
|
||||
|
||||
bool first = true;
|
||||
bool ignoreSpecialTokens = SpecialTokensReverse is not null && skipSpecialTokens;
|
||||
|
||||
foreach (int id in ids)
|
||||
{
|
||||
if (ignoreSpecialTokens && SpecialTokensReverse!.TryGetValue(id, out _))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (_vocabReverse.TryGetValue(id, out string? token))
|
||||
{
|
||||
if (token.StartsWith(ContinuingSubwordPrefix, StringComparison.Ordinal))
|
||||
{
|
||||
if (token.Length - ContinuingSubwordPrefix.Length > buffer.Length)
|
||||
{
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
token.AsSpan().Slice(ContinuingSubwordPrefix.Length).CopyTo(buffer);
|
||||
buffer = buffer.Slice(token.Length - ContinuingSubwordPrefix.Length);
|
||||
charsWritten += token.Length - ContinuingSubwordPrefix.Length;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!first)
|
||||
{
|
||||
if (token.Length + 1 > buffer.Length)
|
||||
{
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
|
||||
buffer[0] = ' ';
|
||||
token.AsSpan().CopyTo(buffer.Slice(1));
|
||||
buffer = buffer.Slice(token.Length + 1);
|
||||
charsWritten += token.Length + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (token.Length > buffer.Length)
|
||||
{
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
|
||||
token.AsSpan().CopyTo(buffer);
|
||||
buffer = buffer.Slice(token.Length);
|
||||
charsWritten += token.Length;
|
||||
}
|
||||
}
|
||||
|
||||
first = false;
|
||||
|
||||
idsConsumed++;
|
||||
}
|
||||
else
|
||||
{
|
||||
return OperationStatus.InvalidData;
|
||||
}
|
||||
}
|
||||
|
||||
return OperationStatus.Done;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,200 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.Diagnostics;
|
||||
using System.Globalization;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Text;
|
||||
|
||||
namespace Microsoft.ML.Tokenizers
|
||||
{
|
||||
/// <summary>
|
||||
/// Normalizer that performs the Bert model normalization.
|
||||
/// </summary>
|
||||
internal sealed class BertNormalizer : Normalizer
|
||||
{
|
||||
private readonly bool _doLowerCase;
|
||||
private readonly bool _tokenizeChineseChars;
|
||||
private readonly bool _stripAccents;
|
||||
|
||||
/// <summary>
|
||||
/// Normalize the input string.
|
||||
/// </summary>
|
||||
/// <param name="original">The input string to normalize.</param>
|
||||
/// <returns>The normalized string.</returns>
|
||||
public override string Normalize(string original)
|
||||
{
|
||||
if (string.IsNullOrEmpty(original))
|
||||
{
|
||||
return string.Empty;
|
||||
}
|
||||
|
||||
if (_stripAccents)
|
||||
{
|
||||
original = original.Normalize(NormalizationForm.FormD);
|
||||
}
|
||||
|
||||
Span<char> casingBuffer = stackalloc char[10];
|
||||
char[] buffer = ArrayPool<char>.Shared.Rent(original.Length);
|
||||
int index = 0;
|
||||
|
||||
for (int i = 0; i < original.Length; i++)
|
||||
{
|
||||
char c = original[i];
|
||||
|
||||
if (c == '\u0000' || c == '\uFFFD')
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
int inc = 0;
|
||||
int codePoint = (int)c;
|
||||
if (char.IsHighSurrogate(c) && i + 1 < original.Length && char.IsLowSurrogate(original[i + 1]))
|
||||
{
|
||||
codePoint = char.ConvertToUtf32(c, original[i + 1]);
|
||||
inc = 1;
|
||||
}
|
||||
|
||||
UnicodeCategory category = CharUnicodeInfo.GetUnicodeCategory(original, i);
|
||||
|
||||
if (category == UnicodeCategory.Control)
|
||||
{
|
||||
i += inc;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (category == UnicodeCategory.SpaceSeparator)
|
||||
{
|
||||
InsertChar(ref buffer, ref index, ' ');
|
||||
i += inc;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (_stripAccents && category is UnicodeCategory.NonSpacingMark or UnicodeCategory.SpacingCombiningMark)
|
||||
{
|
||||
i += inc;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (_doLowerCase && category == UnicodeCategory.UppercaseLetter)
|
||||
{
|
||||
int length = original.AsSpan().Slice(i, inc + 1).ToLowerInvariant(casingBuffer);
|
||||
Debug.Assert(length > 0);
|
||||
|
||||
InsertSpan(ref buffer, ref index, casingBuffer.Slice(0, length));
|
||||
|
||||
i += inc;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (_tokenizeChineseChars && IsChineseChar(codePoint))
|
||||
{
|
||||
InsertChar(ref buffer, ref index, ' ');
|
||||
InsertChar(ref buffer, ref index, c);
|
||||
if (inc > 0)
|
||||
{
|
||||
InsertChar(ref buffer, ref index, original[i + 1]);
|
||||
}
|
||||
InsertChar(ref buffer, ref index, ' ');
|
||||
|
||||
i += inc;
|
||||
continue;
|
||||
}
|
||||
|
||||
InsertChar(ref buffer, ref index, c);
|
||||
if (inc > 0)
|
||||
{
|
||||
InsertChar(ref buffer, ref index, original[i + 1]);
|
||||
}
|
||||
i += inc;
|
||||
}
|
||||
|
||||
string result = index == 0 ? string.Empty : new string(buffer, 0, index).Normalize(NormalizationForm.FormC);
|
||||
ArrayPool<char>.Shared.Return(buffer);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Normalize the input character span.
|
||||
/// </summary>
|
||||
/// <param name="original">The input character span to normalize.</param>
|
||||
/// <returns>The normalized string.</returns>
|
||||
public override string Normalize(ReadOnlySpan<char> original)
|
||||
{
|
||||
if (original.IsEmpty)
|
||||
{
|
||||
return string.Empty;
|
||||
}
|
||||
|
||||
return Normalize(original.ToString());
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="BertNormalizer"/> class.
|
||||
/// </summary>
|
||||
/// <param name="doLowerCase">Whether to lowercase the input.</param>
|
||||
/// <param name="tokenizeChineseChars">Whether to tokenize Chinese characters.</param>
|
||||
/// <param name="stripAccents">Whether to strip accents from the input.</param>
|
||||
public BertNormalizer(bool doLowerCase, bool tokenizeChineseChars, bool stripAccents)
|
||||
{
|
||||
_doLowerCase = doLowerCase;
|
||||
_tokenizeChineseChars = tokenizeChineseChars;
|
||||
_stripAccents = stripAccents;
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private static void InsertChar(ref char[] buffer, ref int index, char c)
|
||||
{
|
||||
if (index >= buffer.Length)
|
||||
{
|
||||
Helpers.ArrayPoolGrow(ref buffer, index + 40);
|
||||
}
|
||||
|
||||
buffer[index++] = c;
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private static void InsertSpan(ref char[] buffer, ref int index, Span<char> chars)
|
||||
{
|
||||
if (index + buffer.Length >= buffer.Length)
|
||||
{
|
||||
Helpers.ArrayPoolGrow(ref buffer, index + buffer.Length + 10);
|
||||
}
|
||||
|
||||
chars.CopyTo(buffer.AsSpan(index));
|
||||
index += chars.Length;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Checks whether CP is the codepoint of a CJK character.
|
||||
/// This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
/// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
/// </summary>
|
||||
/// <param name="codePoint">The codepoint to check.</param>
|
||||
/// <remarks>
|
||||
/// The CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
/// despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
/// as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
/// space-separated words, so they are not treated specially and handled
|
||||
/// like the all of the other languages.
|
||||
/// </remarks>
|
||||
/// <returns>True if the codepoint is a CJK character, false otherwise.</returns>
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private static bool IsChineseChar(int codePoint)
|
||||
{
|
||||
return (codePoint > 0x3400) && // Quick check to exit early if the codepoint is outside of the CJK range
|
||||
(((uint)(codePoint - 0x3400) <= (uint)(0x4DBF - 0x3400)) ||
|
||||
((uint)(codePoint - 0xF900) <= (uint)(0xFAFF - 0xF900)) ||
|
||||
((uint)(codePoint - 0x4E00) <= (uint)(0x9FFF - 0x4E00)) ||
|
||||
((uint)(codePoint - 0x20000) <= (uint)(0x2A6DF - 0x20000)) ||
|
||||
((uint)(codePoint - 0x2A700) <= (uint)(0x2B73F - 0x2A700)) ||
|
||||
((uint)(codePoint - 0x2B740) <= (uint)(0x2B81F - 0x2B740)) ||
|
||||
((uint)(codePoint - 0x2B820) <= (uint)(0x2CEAF - 0x2B820)) ||
|
||||
((uint)(codePoint - 0x2F800) <= (uint)(0x2FA1F - 0x2F800)));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -40,8 +40,61 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
}
|
||||
|
||||
private const string WhiteSpacePattern = /*lang=regex*/ @"\w+|[^\w\s]+";
|
||||
private const string WhiteSpaceOrPunctuationPattern = @"\w+|[\p{P}]";
|
||||
private static PreTokenizer? _whiteSpaceOrPunctuationPreTokenizer;
|
||||
#if NET7_0_OR_GREATER
|
||||
[GeneratedRegex(WhiteSpaceOrPunctuationPattern)]
|
||||
private static partial Regex WhiteSpaceOrPunctuationRegex();
|
||||
#else
|
||||
private static Regex WhiteSpaceOrPunctuationRegex() => new Regex(WhiteSpaceOrPunctuationPattern, RegexOptions.Compiled);
|
||||
#endif
|
||||
|
||||
/// <summary>
|
||||
/// Create a new instance of the <see cref="PreTokenizer"/> class which split the text at the whitespace or punctuation characters.
|
||||
/// </summary>
|
||||
/// <param name="specialTokensEncoder">The dictionary containing the special tokens and their corresponding ids.</param>
|
||||
/// <returns>The pre-tokenizer that splits the text at the whitespace or punctuation characters.</returns>
|
||||
public static PreTokenizer CreateWhiteSpaceOrPunctuationPreTokenizer(IReadOnlyDictionary<string, int>? specialTokensEncoder = null)
|
||||
{
|
||||
if (specialTokensEncoder is null)
|
||||
{
|
||||
// return a singleton instance of the WhiteSpace pre-tokenizer
|
||||
return _whiteSpaceOrPunctuationPreTokenizer ??= new RegexPreTokenizer(WhiteSpaceOrPunctuationRegex(), null);
|
||||
}
|
||||
|
||||
return new RegexPreTokenizer(WhiteSpaceOrPunctuationRegex(), specialTokensEncoder);
|
||||
}
|
||||
|
||||
private const string WordOrNonWordPattern = /*lang=regex*/ @"\w+|[^\w\s]+";
|
||||
private static PreTokenizer? _wordOrNonWordPreTokenizer;
|
||||
|
||||
#if NET7_0_OR_GREATER
|
||||
[GeneratedRegex(WordOrNonWordPattern)]
|
||||
private static partial Regex WordOrNonWordRegex();
|
||||
#else
|
||||
private static Regex WordOrNonWordRegex() => new Regex(WordOrNonWordPattern, RegexOptions.Compiled);
|
||||
#endif
|
||||
|
||||
/// <summary>
|
||||
/// Create a new instance of the <see cref="PreTokenizer"/> class which split the text at the word or non-word boundary.
|
||||
/// The word is a set of alphabet, numeric, and underscore characters.
|
||||
/// </summary>
|
||||
/// <param name="specialTokensEncoder">The dictionary containing the special tokens and their corresponding ids.</param>
|
||||
/// <returns>The pre-tokenizer that splits the text at the word boundary.</returns>
|
||||
public static PreTokenizer CreateWordOrNonWordPreTokenizer(IReadOnlyDictionary<string, int>? specialTokensEncoder = null)
|
||||
{
|
||||
if (specialTokensEncoder is null)
|
||||
{
|
||||
// return a singleton instance of the WhiteSpace pre-tokenizer
|
||||
return _wordOrNonWordPreTokenizer ??= new RegexPreTokenizer(WordOrNonWordRegex(), null);
|
||||
}
|
||||
|
||||
return new RegexPreTokenizer(WordOrNonWordRegex(), specialTokensEncoder);
|
||||
}
|
||||
|
||||
private const string WhiteSpacePattern = @"\S+";
|
||||
private static PreTokenizer? _whiteSpacePreTokenizer;
|
||||
|
||||
#if NET7_0_OR_GREATER
|
||||
[GeneratedRegex(WhiteSpacePattern)]
|
||||
private static partial Regex WhiteSpaceRegex();
|
||||
|
@ -50,12 +103,11 @@ namespace Microsoft.ML.Tokenizers
|
|||
#endif
|
||||
|
||||
/// <summary>
|
||||
/// Create a new instance of the <see cref="PreTokenizer"/> class which split the text at the word boundary.
|
||||
/// The word is a set of alphabet, numeric, and underscore characters.
|
||||
/// Create a new instance of the <see cref="PreTokenizer"/> class which split the text at the white spaces.
|
||||
/// </summary>
|
||||
/// <param name="specialTokensEncoder">The dictionary containing the special tokens and their corresponding ids.</param>
|
||||
/// <returns>The pre-tokenizer that splits the text at the word boundary.</returns>
|
||||
public static PreTokenizer CreateWhiteSpace(IReadOnlyDictionary<string, int>? specialTokensEncoder = null)
|
||||
/// <returns>The pre-tokenizer that splits the text at the white spaces.</returns>
|
||||
public static PreTokenizer CreateWhiteSpacePreTokenizer(IReadOnlyDictionary<string, int>? specialTokensEncoder = null)
|
||||
{
|
||||
if (specialTokensEncoder is null)
|
||||
{
|
||||
|
|
|
@ -0,0 +1,513 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics.Tracing;
|
||||
using System.IO;
|
||||
using System.Threading.Tasks;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.ML.Tokenizers.Tests
|
||||
{
|
||||
public class BertTokenizerTests
|
||||
{
|
||||
[Fact]
|
||||
public void TestWithLowerCasing()
|
||||
{
|
||||
// Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12
|
||||
string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you"];
|
||||
|
||||
string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens);
|
||||
|
||||
try
|
||||
{
|
||||
using Stream vocabStream = File.OpenRead(vocabFile);
|
||||
BertTokenizer[] bertTokenizers = [BertTokenizer.Create(vocabFile), BertTokenizer.Create(vocabStream)];
|
||||
|
||||
foreach (var tokenizer in bertTokenizers)
|
||||
{
|
||||
Assert.NotNull(tokenizer.PreTokenizer);
|
||||
Assert.Equal("[UNK]", tokenizer.UnknownToken);
|
||||
Assert.Equal(1, tokenizer.UnknownTokenId);
|
||||
Assert.NotNull(tokenizer.Normalizer);
|
||||
Assert.NotNull(tokenizer.PreTokenizer);
|
||||
|
||||
string text = "Hello, How are you?";
|
||||
var tokens = tokenizer.EncodeToTokens(text, out string? normalizedText);
|
||||
Assert.Equal("hello, how are you?", normalizedText);
|
||||
|
||||
Assert.Equal(
|
||||
[
|
||||
new EncodedToken(8, "hello", new Range(0, 5)),
|
||||
new EncodedToken(6, ",", new Range(5, 6)),
|
||||
new EncodedToken(10, "how", new Range(7, 10)),
|
||||
new EncodedToken(11, "are", new Range(11, 14)),
|
||||
new EncodedToken(12, "you", new Range(15, 18)),
|
||||
new EncodedToken(7, "?", new Range(18, 19))
|
||||
],
|
||||
tokens);
|
||||
|
||||
var ids = tokenizer.EncodeToIds(text);
|
||||
Assert.Equal([tokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, tokenizer.SepTokenId], ids);
|
||||
|
||||
Assert.Equal("[CLS] hello, how are you? [SEP]", tokenizer.Decode(ids));
|
||||
Assert.Equal("hello, how are you?", tokenizer.Decode(ids, skipSpecialTokens: true));
|
||||
|
||||
tokens = tokenizer.EncodeToTokens(tokenizer.Decode(ids), out normalizedText);
|
||||
Assert.Equal("[cls] hello, how are you? [sep]", normalizedText);
|
||||
Assert.Equal(
|
||||
[
|
||||
new EncodedToken(2, "[CLS]", new Range(0, 5)),
|
||||
new EncodedToken(8, "hello", new Range(6, 11)),
|
||||
new EncodedToken(6, ",", new Range(11, 12)),
|
||||
new EncodedToken(10, "how", new Range(13, 16)),
|
||||
new EncodedToken(11, "are", new Range(17, 20)),
|
||||
new EncodedToken(12, "you", new Range(21, 24)),
|
||||
new EncodedToken(7, "?", new Range(24, 25)),
|
||||
new EncodedToken(3, "[SEP]", new Range(26, 31))
|
||||
],
|
||||
tokens);
|
||||
|
||||
ids = tokenizer.EncodeToIds(normalizedText!);
|
||||
Assert.Equal([tokenizer.ClsTokenId, tokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, tokenizer.SepTokenId, tokenizer.SepTokenId], ids);
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
File.Delete(vocabFile);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestWithNoLowerCasing()
|
||||
{
|
||||
// Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12
|
||||
string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you"];
|
||||
|
||||
string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens);
|
||||
|
||||
try
|
||||
{
|
||||
using Stream vocabStream = File.OpenRead(vocabFile);
|
||||
BertTokenizer[] bertTokenizers = [BertTokenizer.Create(vocabFile, doLowerCase: false), BertTokenizer.Create(vocabStream, doLowerCase: false)];
|
||||
|
||||
foreach (var tokenizer in bertTokenizers)
|
||||
{
|
||||
Assert.NotNull(tokenizer.PreTokenizer);
|
||||
Assert.Equal("[UNK]", tokenizer.UnknownToken);
|
||||
Assert.Equal(1, tokenizer.UnknownTokenId);
|
||||
Assert.NotNull(tokenizer.Normalizer);
|
||||
Assert.NotNull(tokenizer.PreTokenizer);
|
||||
|
||||
string text = "Hello, How are you?";
|
||||
var tokens = tokenizer.EncodeToTokens(text, out string? normalizedText);
|
||||
Assert.Equal("Hello, How are you?", normalizedText);
|
||||
|
||||
Assert.Equal(
|
||||
[
|
||||
new EncodedToken(1, "[UNK]", new Range(0, 5)),
|
||||
new EncodedToken(6, ",", new Range(5, 6)),
|
||||
new EncodedToken(1, "[UNK]", new Range(7, 10)),
|
||||
new EncodedToken(11, "are", new Range(11, 14)),
|
||||
new EncodedToken(12, "you", new Range(15, 18)),
|
||||
new EncodedToken(7, "?", new Range(18, 19))
|
||||
],
|
||||
tokens);
|
||||
|
||||
var ids = tokenizer.EncodeToIds(text);
|
||||
Assert.Equal([tokenizer.ClsTokenId, 1, 6, 1, 11, 12, 7, tokenizer.SepTokenId], ids);
|
||||
|
||||
Assert.Equal("[CLS] [UNK], [UNK] are you? [SEP]", tokenizer.Decode(ids));
|
||||
Assert.Equal(", are you?", tokenizer.Decode(ids, skipSpecialTokens: true));
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
File.Delete(vocabFile);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TestWithAccentMarks()
|
||||
{
|
||||
// Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
|
||||
string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "Café", "cafe", "café", "Über", "über", "uber", "Ångström", "ångström", "angstrom", "Résumé", "résumé", "resume",
|
||||
// Ids: 20 21 22 23
|
||||
"Cafe", "Uber", "Angstrom", "Resume"];
|
||||
string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens);
|
||||
|
||||
try
|
||||
{
|
||||
using Stream vocabStream = File.OpenRead(vocabFile);
|
||||
BertTokenizer bertTokenizer = await BertTokenizer.CreateAsync(vocabStream); // lowercasing and no accent stripping
|
||||
|
||||
string text = "Café Über Ångström Résumé!";
|
||||
var tokens = bertTokenizer.EncodeToTokens(text, out string? normalizedText);
|
||||
Assert.Equal(
|
||||
[
|
||||
new EncodedToken(10, "café", new Range(0, 4)),
|
||||
new EncodedToken(12, "über", new Range(5, 9)),
|
||||
new EncodedToken(15, "ångström", new Range(10, 18)),
|
||||
new EncodedToken(18, "résumé", new Range(19, 25)),
|
||||
new EncodedToken(5, "!", new Range(25, 26)),
|
||||
],
|
||||
tokens);
|
||||
|
||||
Assert.Equal("café über ångström résumé!", normalizedText);
|
||||
|
||||
vocabStream.Position = 0;
|
||||
bertTokenizer = await BertTokenizer.CreateAsync(vocabStream, doLowerCase: false); // no lowercasing and no accent stripping
|
||||
tokens = bertTokenizer.EncodeToTokens(text, out normalizedText);
|
||||
Assert.Equal(
|
||||
[
|
||||
new EncodedToken(8, "Café", new Range(0, 4)),
|
||||
new EncodedToken(11, "Über", new Range(5, 9)),
|
||||
new EncodedToken(14, "Ångström", new Range(10, 18)),
|
||||
new EncodedToken(17, "Résumé", new Range(19, 25)),
|
||||
new EncodedToken(5, "!", new Range(25, 26)),
|
||||
],
|
||||
tokens);
|
||||
|
||||
Assert.Equal("Café Über Ångström Résumé!", normalizedText);
|
||||
|
||||
vocabStream.Position = 0;
|
||||
bertTokenizer = await BertTokenizer.CreateAsync(vocabStream, stripAccents: true); // lowercasing and accent stripping
|
||||
tokens = bertTokenizer.EncodeToTokens(text, out normalizedText);
|
||||
Assert.Equal("cafe uber angstrom resume!", normalizedText);
|
||||
Assert.Equal(
|
||||
[
|
||||
new EncodedToken(9, "cafe", new Range(0, 4)),
|
||||
new EncodedToken(13, "uber", new Range(5, 9)),
|
||||
new EncodedToken(16, "angstrom", new Range(10, 18)),
|
||||
new EncodedToken(19, "resume", new Range(19, 25)),
|
||||
new EncodedToken(5, "!", new Range(25, 26)),
|
||||
],
|
||||
tokens);
|
||||
|
||||
vocabStream.Position = 0;
|
||||
bertTokenizer = await BertTokenizer.CreateAsync(vocabStream, doLowerCase: false, stripAccents: true); // no lowercasing and accent stripping
|
||||
tokens = bertTokenizer.EncodeToTokens(text, out normalizedText);
|
||||
Assert.Equal("Cafe Uber Angstrom Resume!", normalizedText);
|
||||
Assert.Equal(
|
||||
[
|
||||
new EncodedToken(20, "Cafe", new Range(0, 4)),
|
||||
new EncodedToken(21, "Uber", new Range(5, 9)),
|
||||
new EncodedToken(22, "Angstrom", new Range(10, 18)),
|
||||
new EncodedToken(23, "Resume", new Range(19, 25)),
|
||||
new EncodedToken(5, "!", new Range(25, 26)),
|
||||
],
|
||||
tokens);
|
||||
}
|
||||
finally
|
||||
{
|
||||
File.Delete(vocabFile);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TestChineseCharacters()
|
||||
{
|
||||
// Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12
|
||||
string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", "##驷", "##驸", "受", "叟", "叢", "驷", "驸"];
|
||||
string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens);
|
||||
|
||||
try
|
||||
{
|
||||
using Stream vocabStream = File.OpenRead(vocabFile);
|
||||
BertTokenizer bertTokenizer = await BertTokenizer.CreateAsync(vocabStream); // tokenize Chinese characters
|
||||
string text = "叟驷 叢驸!";
|
||||
|
||||
var tokens = bertTokenizer.EncodeToTokens(text, out string? normalizedText);
|
||||
Assert.Equal(" 叟 驷 叢 驸 !", normalizedText);
|
||||
Assert.Equal(
|
||||
[
|
||||
new EncodedToken(9, "叟", new Range(1, 2)),
|
||||
new EncodedToken(11, "驷", new Range(4, 5)),
|
||||
new EncodedToken(10, "叢", new Range(8, 9)),
|
||||
new EncodedToken(12, "驸", new Range(11, 12)),
|
||||
new EncodedToken(5, "!", new Range(13, 14))
|
||||
],
|
||||
tokens);
|
||||
IReadOnlyList<int> ids = bertTokenizer.EncodeToIds(text);
|
||||
Assert.Equal("[CLS] 叟 驷 叢 驸! [SEP]", bertTokenizer.Decode(bertTokenizer.EncodeToIds(text)));
|
||||
Assert.Equal("叟 驷 叢 驸!", bertTokenizer.Decode(bertTokenizer.EncodeToIds(text), skipSpecialTokens: true));
|
||||
|
||||
vocabStream.Position = 0;
|
||||
bertTokenizer = await BertTokenizer.CreateAsync(vocabStream, tokenizeChineseChars: false); // do not tokenize Chinese characters
|
||||
tokens = bertTokenizer.EncodeToTokens(text, out normalizedText);
|
||||
Assert.Equal("叟驷 叢驸!", normalizedText);
|
||||
|
||||
Assert.Equal(
|
||||
[
|
||||
new EncodedToken(9, "叟", new Range(0, 1)),
|
||||
new EncodedToken(6, "##驷", new Range(1, 2)),
|
||||
new EncodedToken(10, "叢", new Range(3, 4)),
|
||||
new EncodedToken(7, "##驸", new Range(4, 5)),
|
||||
new EncodedToken(5, "!", new Range(5, 6))
|
||||
],
|
||||
tokens);
|
||||
ids = bertTokenizer.EncodeToIds(text);
|
||||
Assert.Equal("[CLS] 叟驷 叢驸! [SEP]", bertTokenizer.Decode(bertTokenizer.EncodeToIds(text)));
|
||||
Assert.Equal("叟驷 叢驸!", bertTokenizer.Decode(bertTokenizer.EncodeToIds(text), skipSpecialTokens: true));
|
||||
}
|
||||
finally
|
||||
{
|
||||
File.Delete(vocabFile);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestBuildInputsWithSpecialTokens()
|
||||
{
|
||||
// Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
|
||||
string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you", "i", "am", "fine"];
|
||||
|
||||
string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens);
|
||||
|
||||
try
|
||||
{
|
||||
using Stream vocabStream = File.OpenRead(vocabFile);
|
||||
BertTokenizer bertTokenizer = BertTokenizer.Create(vocabFile);
|
||||
|
||||
string text1 = "Hello, How are you?";
|
||||
string text2 = "I am fine!";
|
||||
|
||||
var ids1 = bertTokenizer.EncodeToIds(text1);
|
||||
Assert.Equal([bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId], ids1);
|
||||
|
||||
var ids2 = bertTokenizer.EncodeToIds(text2);
|
||||
Assert.Equal([bertTokenizer.ClsTokenId, 13, 14, 15, 5, bertTokenizer.SepTokenId], ids2);
|
||||
|
||||
Assert.Equal(
|
||||
[bertTokenizer.ClsTokenId, bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId, bertTokenizer.SepTokenId],
|
||||
bertTokenizer.BuildInputsWithSpecialTokens(ids1));
|
||||
|
||||
Span<int> ids1Span = stackalloc int[1];
|
||||
OperationStatus status = bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids1Span, out int written);
|
||||
Assert.Equal(OperationStatus.DestinationTooSmall, status);
|
||||
Assert.Equal(0, written);
|
||||
|
||||
ids1Span = stackalloc int[ids1.Count + 2];
|
||||
status = bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids1Span, out written);
|
||||
Assert.Equal(OperationStatus.Done, status);
|
||||
Assert.Equal(ids1.Count + 2, written);
|
||||
Assert.Equal(new int[] { bertTokenizer.ClsTokenId, bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId, bertTokenizer.SepTokenId }, ids1Span.ToArray());
|
||||
|
||||
Assert.Equal(
|
||||
[bertTokenizer.ClsTokenId, bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId, bertTokenizer.SepTokenId, bertTokenizer.ClsTokenId, 13, 14, 15, 5, bertTokenizer.SepTokenId, bertTokenizer.SepTokenId],
|
||||
bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids2));
|
||||
|
||||
ids1Span = stackalloc int[1];
|
||||
status = bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids1Span, out written, ids2);
|
||||
Assert.Equal(OperationStatus.DestinationTooSmall, status);
|
||||
Assert.Equal(0, written);
|
||||
|
||||
ids1Span = stackalloc int[ids1.Count + ids2.Count + 3];
|
||||
status = bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids1Span, out written, ids2);
|
||||
Assert.Equal(OperationStatus.Done, status);
|
||||
Assert.Equal(ids1Span.Length, written);
|
||||
Assert.Equal(
|
||||
new int[] { bertTokenizer.ClsTokenId, bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId, bertTokenizer.SepTokenId, bertTokenizer.ClsTokenId, 13, 14, 15, 5, bertTokenizer.SepTokenId, bertTokenizer.SepTokenId },
|
||||
ids1Span.ToArray());
|
||||
|
||||
ids1 = bertTokenizer.EncodeToIds(text1, addSpecialTokens: false);
|
||||
Assert.Equal([8, 6, 10, 11, 12, 7], ids1);
|
||||
|
||||
ids2 = bertTokenizer.EncodeToIds(text2, addSpecialTokens: false);
|
||||
Assert.Equal([13, 14, 15, 5], ids2);
|
||||
|
||||
Assert.Equal(
|
||||
[bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId],
|
||||
bertTokenizer.BuildInputsWithSpecialTokens(ids1));
|
||||
|
||||
ids1Span = stackalloc int[1];
|
||||
status = bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids1Span, out written);
|
||||
Assert.Equal(OperationStatus.DestinationTooSmall, status);
|
||||
Assert.Equal(0, written);
|
||||
|
||||
ids1Span = stackalloc int[ids1.Count + 2];
|
||||
status = bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids1Span, out written);
|
||||
Assert.Equal(OperationStatus.Done, status);
|
||||
Assert.Equal(ids1Span.Length, written);
|
||||
Assert.Equal(
|
||||
new int[] { bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId },
|
||||
ids1Span.ToArray());
|
||||
|
||||
Assert.Equal(
|
||||
[bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId, 13, 14, 15, 5, bertTokenizer.SepTokenId],
|
||||
bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids2));
|
||||
|
||||
ids1Span = stackalloc int[1];
|
||||
status = bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids1Span, out written, ids2);
|
||||
Assert.Equal(OperationStatus.DestinationTooSmall, status);
|
||||
Assert.Equal(0, written);
|
||||
|
||||
ids1Span = stackalloc int[ids1.Count + ids2.Count + 3];
|
||||
status = bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids1Span, out written, ids2);
|
||||
Assert.Equal(OperationStatus.Done, status);
|
||||
Assert.Equal(ids1Span.Length, written);
|
||||
Assert.Equal(
|
||||
new int[] { bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId, 13, 14, 15, 5, bertTokenizer.SepTokenId },
|
||||
ids1Span.ToArray());
|
||||
}
|
||||
finally
|
||||
{
|
||||
File.Delete(vocabFile);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestGetSpecialTokensMask()
|
||||
{
|
||||
// Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
|
||||
string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you", "i", "am", "fine"];
|
||||
|
||||
string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens);
|
||||
|
||||
try
|
||||
{
|
||||
using Stream vocabStream = File.OpenRead(vocabFile);
|
||||
BertTokenizer bertTokenizer = BertTokenizer.Create(vocabFile);
|
||||
|
||||
string text1 = "Hello, How are you?";
|
||||
string text2 = "I am fine!";
|
||||
|
||||
var ids1 = bertTokenizer.EncodeToIds(text1);
|
||||
Assert.Equal([bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId], ids1);
|
||||
|
||||
var ids2 = bertTokenizer.EncodeToIds(text2);
|
||||
Assert.Equal([bertTokenizer.ClsTokenId, 13, 14, 15, 5, bertTokenizer.SepTokenId], ids2);
|
||||
|
||||
Assert.Equal(
|
||||
[1, 0, 0, 0, 0, 0, 0, 1],
|
||||
bertTokenizer.GetSpecialTokensMask(ids1, tokenIds1: null, alreadyHasSpecialTokens: true));
|
||||
|
||||
Span<int> ids1Span = stackalloc int[1];
|
||||
OperationStatus status = bertTokenizer.GetSpecialTokensMask(ids1, ids1Span, out int written, alreadyHasSpecialTokens: true);
|
||||
Assert.Equal(OperationStatus.DestinationTooSmall, status);
|
||||
Assert.Equal(0, written);
|
||||
|
||||
ids1Span = stackalloc int[ids1.Count];
|
||||
status = bertTokenizer.GetSpecialTokensMask(ids1, ids1Span, out written, alreadyHasSpecialTokens: true);
|
||||
Assert.Equal(OperationStatus.Done, status);
|
||||
Assert.Equal(ids1.Count, written);
|
||||
Assert.Equal(new int[] { 1, 0, 0, 0, 0, 0, 0, 1 }, ids1Span.ToArray());
|
||||
|
||||
Assert.Equal(
|
||||
[1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1],
|
||||
bertTokenizer.GetSpecialTokensMask(ids1, tokenIds1: ids2, alreadyHasSpecialTokens: true));
|
||||
|
||||
ids1Span = stackalloc int[1];
|
||||
status = bertTokenizer.GetSpecialTokensMask(ids1, ids1Span, out written, ids2, alreadyHasSpecialTokens: true);
|
||||
Assert.Equal(OperationStatus.DestinationTooSmall, status);
|
||||
Assert.Equal(0, written);
|
||||
|
||||
ids1Span = stackalloc int[ids1.Count + ids2.Count];
|
||||
status = bertTokenizer.GetSpecialTokensMask(ids1, ids1Span, out written, ids2, alreadyHasSpecialTokens: true);
|
||||
Assert.Equal(OperationStatus.Done, status);
|
||||
Assert.Equal(ids1.Count + ids2.Count, written);
|
||||
Assert.Equal(new int[] { 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1 }, ids1Span.ToArray());
|
||||
|
||||
ids1 = bertTokenizer.EncodeToIds(text1, addSpecialTokens: false);
|
||||
Assert.Equal([8, 6, 10, 11, 12, 7], ids1);
|
||||
|
||||
ids2 = bertTokenizer.EncodeToIds(text2, addSpecialTokens: false);
|
||||
Assert.Equal([13, 14, 15, 5], ids2);
|
||||
Assert.Equal(
|
||||
[1, 0, 0, 0, 0, 0, 0, 1],
|
||||
bertTokenizer.GetSpecialTokensMask(ids1, tokenIds1: null, alreadyHasSpecialTokens: false));
|
||||
|
||||
ids1Span = stackalloc int[1];
|
||||
status = bertTokenizer.GetSpecialTokensMask(ids1, ids1Span, out written, alreadyHasSpecialTokens: false);
|
||||
Assert.Equal(OperationStatus.DestinationTooSmall, status);
|
||||
Assert.Equal(0, written);
|
||||
|
||||
ids1Span = stackalloc int[ids1.Count + 2];
|
||||
status = bertTokenizer.GetSpecialTokensMask(ids1, ids1Span, out written, alreadyHasSpecialTokens: false);
|
||||
Assert.Equal(OperationStatus.Done, status);
|
||||
Assert.Equal(ids1.Count + 2, written);
|
||||
Assert.Equal(new int[] { 1, 0, 0, 0, 0, 0, 0, 1 }, ids1Span.ToArray());
|
||||
|
||||
Assert.Equal(
|
||||
[1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1],
|
||||
bertTokenizer.GetSpecialTokensMask(ids1, tokenIds1: ids2, alreadyHasSpecialTokens: false));
|
||||
|
||||
ids1Span = stackalloc int[1];
|
||||
status = bertTokenizer.GetSpecialTokensMask(ids1, ids1Span, out written, ids2, alreadyHasSpecialTokens: false);
|
||||
Assert.Equal(OperationStatus.DestinationTooSmall, status);
|
||||
Assert.Equal(0, written);
|
||||
|
||||
ids1Span = stackalloc int[ids1.Count + ids2.Count + 3];
|
||||
status = bertTokenizer.GetSpecialTokensMask(ids1, ids1Span, out written, ids2, alreadyHasSpecialTokens: false);
|
||||
Assert.Equal(OperationStatus.Done, status);
|
||||
Assert.Equal(ids1.Count + ids2.Count + 3, written);
|
||||
Assert.Equal(new int[] { 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1 }, ids1Span.ToArray());
|
||||
}
|
||||
finally
|
||||
{
|
||||
File.Delete(vocabFile);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestCreateTokenTypeIdsFromSequences()
|
||||
{
|
||||
// Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
|
||||
string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you", "i", "am", "fine"];
|
||||
|
||||
string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens);
|
||||
|
||||
try
|
||||
{
|
||||
using Stream vocabStream = File.OpenRead(vocabFile);
|
||||
BertTokenizer bertTokenizer = BertTokenizer.Create(vocabFile);
|
||||
|
||||
string text1 = "Hello, How are you?";
|
||||
string text2 = "I am fine!";
|
||||
|
||||
var ids1 = bertTokenizer.EncodeToIds(text1, addSpecialTokens: false);
|
||||
Assert.Equal([8, 6, 10, 11, 12, 7], ids1);
|
||||
|
||||
var ids2 = bertTokenizer.EncodeToIds(text2, addSpecialTokens: false);
|
||||
Assert.Equal([13, 14, 15, 5], ids2);
|
||||
|
||||
Assert.Equal(
|
||||
[0, 0, 0, 0, 0, 0, 0, 0],
|
||||
bertTokenizer.CreateTokenTypeIdsFromSequences(ids1));
|
||||
|
||||
Span<int> ids1Span = stackalloc int[1];
|
||||
OperationStatus status = bertTokenizer.CreateTokenTypeIdsFromSequences(ids1, ids1Span, out int written);
|
||||
Assert.Equal(OperationStatus.DestinationTooSmall, status);
|
||||
Assert.Equal(0, written);
|
||||
|
||||
ids1Span = stackalloc int[ids1.Count + 2];
|
||||
status = bertTokenizer.CreateTokenTypeIdsFromSequences(ids1, ids1Span, out written);
|
||||
Assert.Equal(OperationStatus.Done, status);
|
||||
Assert.Equal(ids1.Count + 2, written);
|
||||
Assert.Equal(new int[] { 0, 0, 0, 0, 0, 0, 0, 0 }, ids1Span.ToArray());
|
||||
|
||||
Assert.Equal(
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
|
||||
bertTokenizer.CreateTokenTypeIdsFromSequences(ids1, ids2));
|
||||
|
||||
ids1Span = stackalloc int[1];
|
||||
status = bertTokenizer.CreateTokenTypeIdsFromSequences(ids1, ids1Span, out written, ids2);
|
||||
Assert.Equal(OperationStatus.DestinationTooSmall, status);
|
||||
Assert.Equal(0, written);
|
||||
|
||||
ids1Span = stackalloc int[ids1.Count + ids2.Count + 3];
|
||||
status = bertTokenizer.CreateTokenTypeIdsFromSequences(ids1, ids1Span, out written, ids2);
|
||||
Assert.Equal(OperationStatus.Done, status);
|
||||
Assert.Equal(ids1Span.Length, written);
|
||||
Assert.Equal(new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1 }, ids1Span.ToArray());
|
||||
}
|
||||
finally
|
||||
{
|
||||
File.Delete(vocabFile);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -251,7 +251,7 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
|
||||
try
|
||||
{
|
||||
BpeTokenizer bpe = BpeTokenizer.Create(vocabFile: vocabFile, mergesFile: mergesFile, preTokenizer: PreTokenizer.CreateWhiteSpace(), normalizer: null, unknownToken: unknownToken,
|
||||
BpeTokenizer bpe = BpeTokenizer.Create(vocabFile: vocabFile, mergesFile: mergesFile, preTokenizer: PreTokenizer.CreateWordOrNonWordPreTokenizer(), normalizer: null, unknownToken: unknownToken,
|
||||
continuingSubwordPrefix: continuingSubwordPrefix, endOfWordSuffix: endOfWordSuffix, fuseUnknownTokens: fuseUnknownToken);
|
||||
Tokenizer tokenizer = bpe;
|
||||
IReadOnlyList<EncodedToken> encoding = tokenizer.EncodeToTokens(sentence, out _);
|
||||
|
@ -500,7 +500,7 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
using Stream vocabStream = File.OpenRead(Path.Combine(@"Gpt-2", "vocab.json"));
|
||||
using Stream mergesStream = File.OpenRead(Path.Combine(@"Gpt-2", "merges.txt"));
|
||||
|
||||
var bpeTokenizer = BpeTokenizer.Create(vocabStream, mergesStream, PreTokenizer.CreateWhiteSpace(addedTokens), normalizer: null, addedTokens: addedTokens, unknownToken: "<|endoftext|>");
|
||||
var bpeTokenizer = BpeTokenizer.Create(vocabStream, mergesStream, PreTokenizer.CreateWordOrNonWordPreTokenizer(addedTokens), normalizer: null, addedTokens: addedTokens, unknownToken: "<|endoftext|>");
|
||||
|
||||
string input = "Hello, y'all! <issue_comment>How are you 😁 ?<|endoftext|>";
|
||||
|
||||
|
@ -556,7 +556,7 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
emptyVocabStream.Position = 0;
|
||||
|
||||
return BpeTokenizer.Create(
|
||||
vocabStream: emptyVocabStream, mergesStream: null, preTokenizer: preTokenizer ?? PreTokenizer.CreateWhiteSpace(), normalizer: normalizer, unknownToken: "Ukn");
|
||||
vocabStream: emptyVocabStream, mergesStream: null, preTokenizer: preTokenizer ?? PreTokenizer.CreateWordOrNonWordPreTokenizer(), normalizer: normalizer, unknownToken: "Ukn");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,18 +18,25 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
{
|
||||
yield return new object[]
|
||||
{
|
||||
PreTokenizer.CreateWhiteSpace(),
|
||||
PreTokenizer.CreateWordOrNonWordPreTokenizer(),
|
||||
"How are you doing?",
|
||||
new (int Offset, int Length)[] { (0, 3), (4, 3), (8, 3), (12, 5), (17, 1), }
|
||||
};
|
||||
|
||||
yield return new object[]
|
||||
{
|
||||
PreTokenizer.CreateWhiteSpace(),
|
||||
PreTokenizer.CreateWordOrNonWordPreTokenizer(),
|
||||
"I_am_Just_Fine!",
|
||||
new (int Offset, int Length)[] { (0, 14), (14, 1) }
|
||||
};
|
||||
|
||||
yield return new object[]
|
||||
{
|
||||
PreTokenizer.CreateWhiteSpacePreTokenizer(),
|
||||
"Hello, how are you doing?!",
|
||||
new (int Offset, int Length)[] { (0, 6), (7, 3), (11, 3), (15, 3), (19, 7) }
|
||||
};
|
||||
|
||||
yield return new object[]
|
||||
{
|
||||
new SpacePreTokenizer(),
|
||||
|
@ -61,9 +68,9 @@ namespace Microsoft.ML.Tokenizers.Tests
|
|||
}
|
||||
|
||||
[Fact]
|
||||
public void TestWhiteSpacePreTokenizer()
|
||||
public void TestWordOrNonWordPreTokenizer()
|
||||
{
|
||||
Assert.Empty(PreTokenizer.CreateWhiteSpace().PreTokenize((string)null!));
|
||||
Assert.Empty(PreTokenizer.CreateWordOrNonWordPreTokenizer().PreTokenize((string)null!));
|
||||
}
|
||||
|
||||
public class SpacePreTokenizer : PreTokenizer
|
||||
|
|
|
@ -0,0 +1,221 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.ML.Tokenizers.Tests
|
||||
{
|
||||
public class WordPieceTests
|
||||
{
|
||||
static string[] _vocabTokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"];
|
||||
|
||||
internal static string CreateVocabFile(string[] vocabTokens)
|
||||
{
|
||||
string vocabFile = Path.GetTempFileName();
|
||||
File.WriteAllLines(vocabFile, vocabTokens);
|
||||
return vocabFile;
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestCreation()
|
||||
{
|
||||
string vocabFile = CreateVocabFile(_vocabTokens);
|
||||
|
||||
try
|
||||
{
|
||||
using Stream vocabStream = File.OpenRead(vocabFile);
|
||||
WordPieceTokenizer[] wordPieceTokenizers = [WordPieceTokenizer.Create(vocabFile), WordPieceTokenizer.Create(vocabStream)];
|
||||
|
||||
foreach (var tokenizer in wordPieceTokenizers)
|
||||
{
|
||||
Assert.NotNull(tokenizer.PreTokenizer);
|
||||
Assert.Equal("[UNK]", tokenizer.UnknownToken);
|
||||
Assert.Equal(0, tokenizer.UnknownTokenId);
|
||||
Assert.Null(tokenizer.Normalizer);
|
||||
Assert.Equal(100, tokenizer.MaxInputCharsPerWord);
|
||||
Assert.Equal("##", tokenizer.ContinuingSubwordPrefix);
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
File.Delete(vocabFile);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestTokenization()
|
||||
{
|
||||
string vocabFile = CreateVocabFile(_vocabTokens);
|
||||
|
||||
try
|
||||
{
|
||||
WordPieceTokenizer tokenizer = WordPieceTokenizer.Create(vocabFile);
|
||||
|
||||
Assert.Null(tokenizer.SpecialTokens);
|
||||
|
||||
IReadOnlyList<EncodedToken> tokens = tokenizer.EncodeToTokens("", out _);
|
||||
Assert.Empty(tokens);
|
||||
Assert.Equal(0, tokenizer.CountTokens(""));
|
||||
IReadOnlyList<int> ids = tokenizer.EncodeToIds("");
|
||||
Assert.Empty(ids);
|
||||
int index = tokenizer.GetIndexByTokenCount("", maxTokenCount: 10, normalizedString: out _, tokenCount: out int tokenCount);
|
||||
Assert.Equal(0, index);
|
||||
Assert.Equal(0, tokenCount);
|
||||
index = tokenizer.GetIndexByTokenCountFromEnd("", maxTokenCount: 10, normalizedString: out _, tokenCount: out tokenCount);
|
||||
Assert.Equal(0, index);
|
||||
Assert.Equal(0, tokenCount);
|
||||
|
||||
string text = "unwanted running";
|
||||
tokens = tokenizer.EncodeToTokens(text, out _);
|
||||
Assert.Equal(
|
||||
[
|
||||
new EncodedToken(7, "un", new Range(0, 2)),
|
||||
new EncodedToken(4, "##want", new Range(2, 6)),
|
||||
new EncodedToken(5, "##ed", new Range(6, 8)),
|
||||
new EncodedToken(8, "runn", new Range(9, 13)),
|
||||
new EncodedToken(9, "##ing", new Range(13, 16))
|
||||
],
|
||||
tokens
|
||||
);
|
||||
|
||||
ids = tokenizer.EncodeToIds(text);
|
||||
Assert.Equal([7, 4, 5, 8, 9], ids);
|
||||
|
||||
int[] expectedTokenCount = [0, 0, 3, 3, 5];
|
||||
for (int i = 1; i <= 5; i++)
|
||||
{
|
||||
Assert.Equal(ids.Take(expectedTokenCount[i - 1]).ToArray(), tokenizer.EncodeToIds(text, maxTokenCount: i, normalizedText: out _, out tokenCount));
|
||||
}
|
||||
|
||||
Assert.Equal(text, tokenizer.Decode(ids));
|
||||
|
||||
Span<char> buffer = stackalloc char[text.Length];
|
||||
for (int i = 0; i < text.Length - 1; i++)
|
||||
{
|
||||
Span<char> bufferSlice = buffer.Slice(0, i);
|
||||
OperationStatus result = tokenizer.Decode(ids, bufferSlice, out int idsConsumed, out int charsWritten);
|
||||
Assert.Equal(OperationStatus.DestinationTooSmall, result);
|
||||
|
||||
int j = 0;
|
||||
|
||||
while (i >= tokens[j].Offset.End.Value)
|
||||
{
|
||||
j++;
|
||||
}
|
||||
|
||||
Assert.Equal(j, idsConsumed);
|
||||
Assert.Equal(j == 0 ? 0 : tokens[j - 1].Offset.End.Value, charsWritten);
|
||||
Assert.Equal(j == 0 ? "" : text.Substring(0, tokens[j - 1].Offset.End.Value), bufferSlice.Slice(0, charsWritten).ToString());
|
||||
}
|
||||
|
||||
Assert.Equal(5, tokenizer.CountTokens(text));
|
||||
|
||||
int[] expectedIndexes = [0, 0, 8, 9, 16];
|
||||
expectedTokenCount = [0, 0, 3, 3, 5];
|
||||
|
||||
for (int i = 1; i <= 5; i++)
|
||||
{
|
||||
index = tokenizer.GetIndexByTokenCount(text, maxTokenCount: i, normalizedString: out _, out tokenCount);
|
||||
Assert.Equal(expectedTokenCount[i - 1], tokenCount);
|
||||
Assert.Equal(expectedIndexes[i - 1], index);
|
||||
}
|
||||
|
||||
expectedIndexes = [16, 9, 8, 8, 0];
|
||||
expectedTokenCount = [0, 2, 2, 2, 5];
|
||||
|
||||
for (int i = 1; i <= 5; i++)
|
||||
{
|
||||
index = tokenizer.GetIndexByTokenCountFromEnd(text, maxTokenCount: i, normalizedString: out _, out tokenCount);
|
||||
Assert.Equal(expectedTokenCount[i - 1], tokenCount);
|
||||
Assert.Equal(expectedIndexes[i - 1], index);
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
File.Delete(vocabFile);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestTokenizationWithUnknownTokens()
|
||||
{
|
||||
string vocabFile = CreateVocabFile(_vocabTokens);
|
||||
|
||||
try
|
||||
{
|
||||
WordPieceTokenizer tokenizer = WordPieceTokenizer.Create(vocabFile);
|
||||
|
||||
string text = "unwantedX running";
|
||||
|
||||
IReadOnlyList<EncodedToken> tokens = tokenizer.EncodeToTokens(text, out _);
|
||||
Assert.Equal(
|
||||
[
|
||||
new EncodedToken(0, "[UNK]", new Range(0, 9)),
|
||||
new EncodedToken(8, "runn", new Range(10, 14)),
|
||||
new EncodedToken(9, "##ing", new Range(14, 17))
|
||||
],
|
||||
tokens
|
||||
);
|
||||
|
||||
IReadOnlyList<int> ids = tokenizer.EncodeToIds(text);
|
||||
Assert.Equal([0, 8, 9], ids);
|
||||
|
||||
Assert.Equal("[UNK] running", tokenizer.Decode(ids));
|
||||
}
|
||||
finally
|
||||
{
|
||||
File.Delete(vocabFile);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestTokenizationWithSpecialTokens()
|
||||
{
|
||||
string vocabFile = CreateVocabFile(_vocabTokens);
|
||||
|
||||
try
|
||||
{
|
||||
Dictionary<string, int> specialTokens = new Dictionary<string, int>
|
||||
{
|
||||
{ "[UNK]", 0 }, { "[CLS]", 1 }, { "[SEP]", 2 }
|
||||
};
|
||||
WordPieceTokenizer tokenizer = WordPieceTokenizer.Create(vocabFile, specialTokens: specialTokens);
|
||||
|
||||
Assert.Equal(specialTokens, tokenizer.SpecialTokens);
|
||||
|
||||
string text = "[UNK] unwanted [SEP][CLS] running [CLS]";
|
||||
|
||||
IReadOnlyList<EncodedToken> tokens = tokenizer.EncodeToTokens(text, out _);
|
||||
Assert.Equal(
|
||||
[
|
||||
new EncodedToken(0, "[UNK]", new Range(0, 5)),
|
||||
new EncodedToken(7, "un", new Range(6, 8)),
|
||||
new EncodedToken(4, "##want", new Range(8, 12)),
|
||||
new EncodedToken(5, "##ed", new Range(12, 14)),
|
||||
new EncodedToken(2, "[SEP]", new Range(15, 20)),
|
||||
new EncodedToken(1, "[CLS]", new Range(20, 25)),
|
||||
new EncodedToken(8, "runn", new Range(26, 30)),
|
||||
new EncodedToken(9, "##ing", new Range(30, 33)),
|
||||
new EncodedToken(1, "[CLS]", new Range(34, 39)),
|
||||
],
|
||||
tokens
|
||||
);
|
||||
|
||||
IReadOnlyList<int> ids = tokenizer.EncodeToIds(text);
|
||||
Assert.Equal([0, 7, 4, 5, 2, 1, 8, 9, 1], ids);
|
||||
|
||||
Assert.Equal("[UNK] unwanted [SEP] [CLS] running [CLS]", tokenizer.Decode(ids));
|
||||
}
|
||||
finally
|
||||
{
|
||||
File.Delete(vocabFile);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Загрузка…
Ссылка в новой задаче