Introducing WordPiece and Bert tokenizers (#7275)

* Introducing WordPiece and Bert tokenizers

* Fix corner case in WordPiece
This commit is contained in:
Tarek Mahmoud Sayed 2024-10-22 12:33:58 -07:00 коммит произвёл GitHub
Родитель 32bac5e395
Коммит 81122c4c48
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
11 изменённых файлов: 2618 добавлений и 16 удалений

Просмотреть файл

@ -152,6 +152,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
License notice for WordPiece and Bert tokenizers
------------------------------------------------
https://github.com/huggingface/transformers/blob/8e3e145b427196e014f37aa42ba890b9bc94275e/src/transformers/models/bert/tokenization_bert.py#L2
Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
License notice for BitUtility
------------------------------------------

Просмотреть файл

@ -10,7 +10,7 @@ namespace Microsoft.ML.Tokenizers
/// Represent the token produced from the tokenization process containing the token substring,
/// the id associated to the token substring, and the offset mapping to the original string.
/// </summary>
public readonly struct EncodedToken
public readonly struct EncodedToken : IEquatable<EncodedToken>
{
/// <summary>
/// Gets the Id value associated to the token.
@ -39,5 +39,8 @@ namespace Microsoft.ML.Tokenizers
Offset = offset;
Value = value;
}
/// inherited
public bool Equals(EncodedToken other) => Id == other.Id && Value == other.Value && Offset.Equals(other.Offset);
}
}

Просмотреть файл

@ -87,7 +87,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="vocabFile">The JSON file path containing the dictionary of string keys and their ids.</param>
/// <param name="mergesFile">The file path containing the tokens's pairs list.</param>
public static BpeTokenizer Create(string vocabFile, string? mergesFile)
=> Create(vocabFile, mergesFile, preTokenizer: PreTokenizer.CreateWhiteSpace(), normalizer: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false);
=> Create(vocabFile, mergesFile, preTokenizer: PreTokenizer.CreateWordOrNonWordPreTokenizer(), normalizer: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false);
/// <summary>
/// Create a new Bpe tokenizer object to use for text encoding.
@ -131,7 +131,7 @@ namespace Microsoft.ML.Tokenizers
/// <param name="vocabStream">The JSON stream containing the dictionary of string keys and their ids.</param>
/// <param name="mergesStream">The stream containing the tokens's pairs list.</param>
public static BpeTokenizer Create(Stream vocabStream, Stream? mergesStream)
=> Create(vocabStream, mergesStream, preTokenizer: PreTokenizer.CreateWhiteSpace(), normalizer: null, addedTokens: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false);
=> Create(vocabStream, mergesStream, preTokenizer: PreTokenizer.CreateWordOrNonWordPreTokenizer(), normalizer: null, addedTokens: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false);
/// <summary>
/// Create a new Bpe tokenizer object to use for text encoding.
@ -225,7 +225,7 @@ namespace Microsoft.ML.Tokenizers
FuseUnknownTokens = fuseUnknownTokens;
ContinuingSubwordPrefix = continuingSubwordPrefix;
EndOfWordSuffix = endOfWordSuffix;
_preTokenizer = preTokenizer ?? PreTokenizer.CreateWhiteSpace(); // Default to WhiteSpace pre-tokenizer
_preTokenizer = preTokenizer ?? PreTokenizer.CreateWordOrNonWordPreTokenizer(); // Default to WordOrNonWord pre-tokenizer
_normalizer = normalizer;
_vocab = vocab ?? new Dictionary<StringSpanOrdinalKey, int>();

Просмотреть файл

@ -0,0 +1,729 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Microsoft.ML.Tokenizers
{
/// <summary>
/// Tokenizer for Bert model.
/// </summary>
/// <remarks>
/// The BertTokenizer is a based on the WordPieceTokenizer and is used to tokenize text for Bert models.
/// The implementation of the BertTokenizer is based on the original Bert implementation in the Hugging Face Transformers library.
/// https://huggingface.co/transformers/v3.0.2/model_doc/bert.html?highlight=berttokenizerfast#berttokenizer
/// </remarks>
public sealed partial class BertTokenizer : WordPieceTokenizer
{
internal BertTokenizer(
Dictionary<StringSpanOrdinalKey, int> vocab,
Dictionary<int, string> vocabReverse,
PreTokenizer? preTokenizer,
Normalizer? normalizer,
IReadOnlyDictionary<string, int>? specialTokens,
bool doLowerCase,
bool doBasicTokenization,
bool splitOnSpecialTokens,
string unknownToken,
string sepToken,
string padToken,
string clsToken,
string maskToken,
bool tokenizeChineseChars,
bool stripAccents) : base(vocab, vocabReverse, preTokenizer, normalizer, specialTokens, unknownToken)
{
DoLowerCase = doLowerCase;
DoBasicTokenization = doBasicTokenization;
SplitOnSpecialTokens = splitOnSpecialTokens;
SepToken = sepToken;
SepTokenId = vocab[new StringSpanOrdinalKey(sepToken)];
PadToken = padToken;
PadTokenId = vocab[new StringSpanOrdinalKey(padToken)];
ClsToken = clsToken;
ClsTokenId = vocab[new StringSpanOrdinalKey(clsToken)];
MaskToken = maskToken;
MaskTokenId = vocab[new StringSpanOrdinalKey(maskToken)];
TokenizeChineseChars = tokenizeChineseChars;
StripAccents = stripAccents;
}
/// <summary>
/// Gets a value indicating whether the tokenizer should lowercase the input text.
/// </summary>
public bool DoLowerCase { get; }
/// <summary>
/// Gets a value indicating whether the tokenizer should do basic tokenization. Like clean text, normalize it, lowercasing, etc.
/// </summary>
public bool DoBasicTokenization { get; }
/// <summary>
/// Gets a value indicating whether the tokenizer should split on the special tokens or treat special tokens as normal text.
/// </summary>
public bool SplitOnSpecialTokens { get; }
/// <summary>
/// Gets the separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for sequence classification or for a text and a question for question answering.
/// It is also used as the last token of a sequence built with special tokens.
/// </summary>
public string SepToken { get; }
/// <summary>
/// Gets the separator token Id
/// </summary>
public int SepTokenId { get; }
/// <summary>
/// Gets the token used for padding, for example when batching sequences of different lengths
/// </summary>
public string PadToken { get; }
/// <summary>
/// Gets padding token Id
/// </summary>
public int PadTokenId { get; }
/// <summary>
/// Gets the classifier token which is used when doing sequence classification (classification of the whole sequence instead of per-token classification).
/// It is the first token of the sequence when built with special tokens.
/// </summary>
public string ClsToken { get; }
/// <summary>
/// Gets the classifier token Id
/// </summary>
public int ClsTokenId { get; }
/// <summary>
/// Gets the mask token used for masking values. This is the token used when training this model with masked language modeling.
/// This is the token which the model will try to predict.
/// </summary>
public string MaskToken { get; }
/// <summary>
/// Gets the mask token Id
/// </summary>
public int MaskTokenId { get; }
/// <summary>
/// Gets a value indicating whether the tokenizer should split the Chinese characters into tokens.
/// </summary>
public bool TokenizeChineseChars { get; }
/// <summary>
/// Gets a value indicating whether the tokenizer should strip accents characters.
/// </summary>
public bool StripAccents { get; }
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public new IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true) =>
EncodeToIds(text, ReadOnlySpan<char>.Empty, addSpecialTokens: true, considerPreTokenization, considerNormalization);
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public new IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true) =>
EncodeToIds(null, text, addSpecialTokens: true, considerPreTokenization, considerNormalization);
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="addSpecialTokens">Indicate whether to add special tokens to the encoded Ids.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public IReadOnlyList<int> EncodeToIds(string text, bool addSpecialTokens, bool considerPreTokenization = true, bool considerNormalization = true) =>
EncodeToIds(text, ReadOnlySpan<char>.Empty, addSpecialTokens, considerPreTokenization, considerNormalization);
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="addSpecialTokens">Indicate whether to add special tokens to the encoded Ids.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool addSpecialTokens, bool considerPreTokenization = true, bool considerNormalization = true) =>
EncodeToIds(null, text, addSpecialTokens, considerPreTokenization, considerNormalization);
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to return.</param>
/// <param name="normalizedText">The normalized text.</param>
/// <param name="charsConsumed">The number of characters consumed from the input text.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public new IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedText, out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true) =>
EncodeToIds(text, ReadOnlySpan<char>.Empty, maxTokenCount, addSpecialTokens: true, out normalizedText, out charsConsumed, considerPreTokenization, considerNormalization);
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to return.</param>
/// <param name="normalizedText">The normalized text.</param>
/// <param name="charsConsumed">The number of characters consumed from the input text.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public new IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedText, out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true) =>
EncodeToIds(null, text, maxTokenCount, addSpecialTokens: true, out normalizedText, out charsConsumed, considerPreTokenization, considerNormalization);
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to return.</param>
/// <param name="addSpecialTokens">Indicate whether to add special tokens to the encoded Ids.</param>
/// <param name="normalizedText">The normalized text.</param>
/// <param name="charsConsumed">The number of characters consumed from the input text.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, bool addSpecialTokens, out string? normalizedText, out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true) =>
EncodeToIds(text, ReadOnlySpan<char>.Empty, maxTokenCount, addSpecialTokens, out normalizedText, out charsConsumed, considerPreTokenization, considerNormalization);
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum number of tokens to return.</param>
/// <param name="addSpecialTokens">Indicate whether to add special tokens to the encoded Ids.</param>
/// <param name="normalizedText">The normalized text.</param>
/// <param name="charsConsumed">The number of characters consumed from the input text.</param>
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
/// <returns>The list of encoded Ids.</returns>
public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, bool addSpecialTokens, out string? normalizedText, out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true) =>
EncodeToIds(null, text, maxTokenCount, addSpecialTokens, out normalizedText, out charsConsumed, considerPreTokenization, considerNormalization);
private IReadOnlyList<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, int maxTokenCount, bool addSpecialTokens, out string? normalizedText, out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true)
{
if (addSpecialTokens)
{
if (maxTokenCount < 2)
{
charsConsumed = 0;
normalizedText = null;
return Array.Empty<int>();
}
IReadOnlyList<int> ids = text is null ?
base.EncodeToIds(textSpan, maxTokenCount - 2, out normalizedText, out charsConsumed, considerPreTokenization, considerNormalization) :
base.EncodeToIds(text, maxTokenCount - 2, out normalizedText, out charsConsumed, considerPreTokenization, considerNormalization);
if (ids is not List<int> list)
{
list = new List<int>(ids);
}
list.Insert(0, ClsTokenId);
list.Add(SepTokenId);
return list;
}
return text is null ?
base.EncodeToIds(textSpan, maxTokenCount, out normalizedText, out charsConsumed, considerPreTokenization, considerNormalization) :
base.EncodeToIds(text, maxTokenCount, out normalizedText, out charsConsumed, considerPreTokenization, considerNormalization);
}
private IReadOnlyList<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, bool addSpecialTokens, bool considerPreTokenization = true, bool considerNormalization = true)
{
IReadOnlyList<int> ids = text is null ? base.EncodeToIds(textSpan, considerPreTokenization, considerNormalization) : base.EncodeToIds(text, considerPreTokenization, considerNormalization);
if (addSpecialTokens)
{
if (ids is not List<int> list)
{
list = new List<int>(ids);
}
list.Insert(0, ClsTokenId);
list.Add(SepTokenId);
return list;
}
return ids;
}
/// <summary>
/// Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. A BERT sequence has the following format:
/// - single sequence: `[CLS] tokenIds0 [SEP]`
/// - pair of sequences: `[CLS] tokenIds0 [SEP] tokenIds1 [SEP]`
/// </summary>
/// <param name="tokenIds0">List of IDs to which the special tokens will be added.</param>
/// <param name="tokenIds1">Optional second list of IDs for sequence pairs.</param>
/// <returns>The list of IDs with special tokens added.</returns>
/// <exception cref="ArgumentNullException">When <paramref name="tokenIds0"/> is null.</exception>
public IReadOnlyList<int> BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds0, IEnumerable<int>? tokenIds1 = null)
{
if (tokenIds0 is null)
{
throw new ArgumentNullException(nameof(tokenIds0));
}
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
List<int> ids = new List<int>(capacity: capacity) { ClsTokenId };
ids.AddRange(tokenIds0);
ids.Add(SepTokenId);
if (tokenIds1 is not null)
{
ids.AddRange(tokenIds1);
ids.Add(SepTokenId);
}
return ids;
}
/// <summary>
/// Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. A BERT sequence has the following format:
/// - single sequence: `[CLS] tokenIds0 [SEP]`
/// - pair of sequences: `[CLS] tokenIds0 [SEP] tokenIds1 [SEP]`
/// </summary>
/// <param name="tokenIds0">List of IDs to which the special tokens will be added.</param>
/// <param name="buffer">The buffer to write the token IDs with special tokens added.</param>
/// <param name="written">The number of elements written to the buffer.</param>
/// <param name="tokenIds1">Optional second list of IDs for sequence pairs.</param>
/// <returns>The status of the operation.</returns>
/// <exception cref="ArgumentNullException">When <paramref name="tokenIds0"/> is null.</exception>
public OperationStatus BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds0, Span<int> buffer, out int written, IEnumerable<int>? tokenIds1 = null)
{
if (tokenIds0 is null)
{
throw new ArgumentNullException(nameof(tokenIds0));
}
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
if (buffer.Length < capacity)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
written = 0;
buffer[written++] = ClsTokenId;
foreach (int id in tokenIds0)
{
buffer[written++] = id;
}
buffer[written++] = SepTokenId;
if (tokenIds1 is not null)
{
foreach (int id in tokenIds1)
{
buffer[written++] = id;
}
buffer[written++] = SepTokenId;
}
return OperationStatus.Done;
}
/// <summary>
/// Retrieve sequence tokens mask from a IDs list.
/// </summary>
/// <param name="tokenIds0">List of IDs.</param>
/// <param name="tokenIds1">Optional second list of IDs for sequence pairs.</param>
/// <param name="alreadyHasSpecialTokens">Indicate whether or not the token list is already formatted with special tokens for the model.</param>
/// <returns>A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.</returns>
/// <exception cref="ArgumentNullException"></exception>
public IReadOnlyList<int> GetSpecialTokensMask(IEnumerable<int> tokenIds0, IEnumerable<int>? tokenIds1 = null, bool alreadyHasSpecialTokens = false)
{
if (tokenIds0 is null)
{
throw new ArgumentNullException(nameof(tokenIds0));
}
int capacity = alreadyHasSpecialTokens ?
tokenIds0.Count() + (tokenIds1?.Count() ?? 0) :
tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : 1); // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
List<int> mask = new List<int>(capacity: capacity);
if (!alreadyHasSpecialTokens)
{
mask.Add(1); // CLS
mask.AddRange(Enumerable.Repeat(0, tokenIds0.Count()));
mask.Add(1); // SEP
if (tokenIds1 is not null)
{
mask.AddRange(Enumerable.Repeat(0, tokenIds1.Count()));
mask.Add(1); // SEP
}
return mask;
}
foreach (int id in tokenIds0)
{
mask.Add(id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0);
}
if (tokenIds1 is not null)
{
foreach (int id in tokenIds1)
{
mask.Add(id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0);
}
}
return mask;
}
/// <summary>
/// Retrieve sequence tokens mask from a IDs list.
/// </summary>
/// <param name="tokenIds0">List of IDs.</param>
/// <param name="buffer">The buffer to write the mask. The integers written values are in the range [0, 1]: 1 for a special token, 0 for a sequence token.</param>
/// <param name="written">The number of elements written to the buffer.</param>
/// <param name="tokenIds1">Optional second list of IDs for sequence pairs.</param>
/// <param name="alreadyHasSpecialTokens">Indicate whether or not the token list is already formatted with special tokens for the model.</param>
/// <returns>The status of the operation.</returns>
/// <exception cref="ArgumentNullException"></exception>
public OperationStatus GetSpecialTokensMask(IEnumerable<int> tokenIds0, Span<int> buffer, out int written, IEnumerable<int>? tokenIds1 = null, bool alreadyHasSpecialTokens = false)
{
if (tokenIds0 is null)
{
throw new ArgumentNullException(nameof(tokenIds0));
}
int capacity = alreadyHasSpecialTokens ?
tokenIds0.Count() + (tokenIds1?.Count() ?? 0) :
tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1); // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
written = 0;
if (buffer.Length < capacity)
{
return OperationStatus.DestinationTooSmall;
}
if (!alreadyHasSpecialTokens)
{
buffer[written++] = 1; // CLS
foreach (int id in tokenIds0)
{
buffer[written++] = 0;
}
buffer[written++] = 1; // SEP
if (tokenIds1 is not null)
{
foreach (int id in tokenIds1)
{
buffer[written++] = 0;
}
buffer[written++] = 1; // SEP
}
return OperationStatus.Done;
}
foreach (int id in tokenIds0)
{
buffer[written++] = id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0;
}
if (tokenIds1 is not null)
{
foreach (int id in tokenIds1)
{
buffer[written++] = id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0;
}
}
return OperationStatus.Done;
}
/// <summary>
/// Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence pair mask has the following format:
/// 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
/// | first sequence | second sequence |
/// If <paramref name="tokenIds1"/> is null, this method only returns the first portion of the type ids (0s).
/// </summary>
/// <param name="tokenIds0">List of token IDs for the first sequence.</param>
/// <param name="tokenIds1">Optional list of token IDs for the second sequence.</param>
/// <returns>List of token type IDs according to the given sequence(s).</returns>
/// <exception cref="ArgumentNullException">When <paramref name="tokenIds0"/> is null.</exception>
public IReadOnlyList<int> CreateTokenTypeIdsFromSequences(IEnumerable<int> tokenIds0, IEnumerable<int>? tokenIds1 = null)
{
if (tokenIds0 is null)
{
throw new ArgumentNullException(nameof(tokenIds0));
}
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
List<int> typeIds = new List<int>(capacity);
for (int i = 0; i < tokenIds0.Count() + 2; i++) // Add 2 for [CLS] and [SEP] tokens.
{
typeIds.Add(0);
}
if (tokenIds1 is not null)
{
for (int i = 0; i < tokenIds1.Count() + 1; i++) // Add 1 for [SEP] token.
{
typeIds.Add(1);
}
}
return typeIds;
}
public OperationStatus CreateTokenTypeIdsFromSequences(IEnumerable<int> tokenIds0, Span<int> buffer, out int written, IEnumerable<int>? tokenIds1 = null)
{
if (tokenIds0 is null)
{
throw new ArgumentNullException(nameof(tokenIds0));
}
written = 0;
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
if (buffer.Length < capacity)
{
return OperationStatus.DestinationTooSmall;
}
for (int i = 0; i < tokenIds0.Count() + 2; i++) // Add 2 for [CLS] and [SEP] tokens.
{
buffer[written++] = 0;
}
if (tokenIds1 is not null)
{
for (int i = 0; i < tokenIds1.Count() + 1; i++) // Add 1 for [SEP] token.
{
buffer[written++] = 1;
}
}
return OperationStatus.Done;
}
/// <summary>
/// Create a new instance of the <see cref="BertTokenizer"/> class.
/// </summary>
/// <param name="vocabFilePath">The path to the vocabulary file.</param>
/// <param name="doLowerCase">A value indicating whether the tokenizer should lowercase the input text.</param>
/// <param name="doBasicTokenization">A value indicating whether the tokenizer should do basic tokenization. Like clean text, normalize it, lowercasing, etc.</param>
/// <param name="splitOnSpecialTokens">A value indicating whether the tokenizer should split on special tokens.</param>
/// <param name="unknownToken">The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.</param>
/// <param name="sepToken">The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for sequence classification or for a text and a question for question answering. It is also used as the last token of a sequence built with special tokens.</param>
/// <param name="padToken">The token used for padding, for example when batching sequences of different lengths.</param>
/// <param name="clsToken">The classifier token which is used when doing sequence classification (classification of the whole sequence instead of per-token classification). It is the first token of the sequence when built with special tokens.</param>
/// <param name="maskToken">The token used for masking values. This is the token used when training this model with masked language modeling. This is the token which the model will try to predict.</param>
/// <param name="tokenizeChineseChars">A value indicating whether the tokenizer should split the Chinese characters into tokens.</param>
/// <param name="stripAccents">A value indicating whether the tokenizer should strip accents characters.</param>
/// <returns></returns>
/// <exception cref="ArgumentNullException"></exception>
public static BertTokenizer Create(
string vocabFilePath,
bool doLowerCase = true,
bool doBasicTokenization = true,
bool splitOnSpecialTokens = true,
string unknownToken = "[UNK]",
string sepToken = "[SEP]",
string padToken = "[PAD]",
string clsToken = "[CLS]",
string maskToken = "[MASK]",
bool tokenizeChineseChars = true,
bool stripAccents = false) =>
Create(
string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath),
doLowerCase, doBasicTokenization, splitOnSpecialTokens, unknownToken, sepToken, padToken, clsToken, maskToken, tokenizeChineseChars, stripAccents, disposeStream: true);
/// <summary>
/// Create a new instance of the <see cref="BertTokenizer"/> class.
/// </summary>
/// <param name="vocabStream">The stream containing the vocabulary file.</param>
/// <param name="doLowerCase">A value indicating whether the tokenizer should lowercase the input text.</param>
/// <param name="doBasicTokenization">A value indicating whether the tokenizer should do basic tokenization. Like clean text, normalize it, lowercasing, etc.</param>
/// <param name="splitOnSpecialTokens">A value indicating whether the tokenizer should split on special tokens.</param>
/// <param name="unknownToken">The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.</param>
/// <param name="sepToken">The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for sequence classification or for a text and a question for question answering. It is also used as the last token of a sequence built with special tokens.</param>
/// <param name="padToken">The token used for padding, for example when batching sequences of different lengths.</param>
/// <param name="clsToken">The classifier token which is used when doing sequence classification (classification of the whole sequence instead of per-token classification). It is the first token of the sequence when built with special tokens.</param>
/// <param name="maskToken">The token used for masking values. This is the token used when training this model with masked language modeling. This is the token which the model will try to predict.</param>
/// <param name="tokenizeChineseChars">A value indicating whether the tokenizer should split the Chinese characters into tokens.</param>
/// <param name="stripAccents">A value indicating whether the tokenizer should strip accents characters.</param>
/// <returns></returns>
/// <exception cref="ArgumentNullException"></exception>
public static BertTokenizer Create(
Stream vocabStream,
bool doLowerCase = true,
bool doBasicTokenization = true,
bool splitOnSpecialTokens = true,
string unknownToken = "[UNK]",
string sepToken = "[SEP]",
string padToken = "[PAD]",
string clsToken = "[CLS]",
string maskToken = "[MASK]",
bool tokenizeChineseChars = true,
bool stripAccents = false) =>
Create(vocabStream, doLowerCase, doBasicTokenization, splitOnSpecialTokens, unknownToken, sepToken, padToken, clsToken, maskToken, tokenizeChineseChars, stripAccents, disposeStream: false);
/// <summary>
/// Create a new instance of the <see cref="BertTokenizer"/> class asynchronously.
/// </summary>
/// <param name="vocabStream">The stream containing the vocabulary file.</param>
/// <param name="doLowerCase">A value indicating whether the tokenizer should lowercase the input text.</param>
/// <param name="doBasicTokenization">A value indicating whether the tokenizer should do basic tokenization. Like clean text, normalize it, lowercasing, etc.</param>
/// <param name="splitOnSpecialTokens">A value indicating whether the tokenizer should split on special tokens.</param>
/// <param name="unknownToken">The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.</param>
/// <param name="sepToken">The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for sequence classification or for a text and a question for question answering. It is also used as the last token of a sequence built with special tokens.</param>
/// <param name="padToken">The token used for padding, for example when batching sequences of different lengths.</param>
/// <param name="clsToken">The classifier token which is used when doing sequence classification (classification of the whole sequence instead of per-token classification). It is the first token of the sequence when built with special tokens.</param>
/// <param name="maskToken">The token used for masking values. This is the token used when training this model with masked language modeling. This is the token which the model will try to predict.</param>
/// <param name="tokenizeChineseChars">A value indicating whether the tokenizer should split the Chinese characters into tokens.</param>
/// <param name="stripAccents">A value indicating whether the tokenizer should strip accents characters.</param>
/// <returns></returns>
/// <exception cref="ArgumentNullException"></exception>
public static async Task<BertTokenizer> CreateAsync(
Stream vocabStream,
bool doLowerCase = true,
bool doBasicTokenization = true,
bool splitOnSpecialTokens = true,
string unknownToken = "[UNK]",
string sepToken = "[SEP]",
string padToken = "[PAD]",
string clsToken = "[CLS]",
string maskToken = "[MASK]",
bool tokenizeChineseChars = true,
bool stripAccents = false)
{
if (vocabStream is null)
{
throw new ArgumentNullException(nameof(vocabStream));
}
(Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, string> vocabReverse) = await LoadVocabAsync(vocabStream, useAsync: true).ConfigureAwait(false);
return Create(vocab, vocabReverse, doLowerCase, doBasicTokenization, splitOnSpecialTokens, unknownToken, sepToken, padToken, clsToken, maskToken, tokenizeChineseChars, stripAccents);
}
private static BertTokenizer Create(
Stream vocabStream,
bool doLowerCase,
bool doBasicTokenization,
bool splitOnSpecialTokens,
string unknownToken,
string sepToken,
string padToken,
string clsToken,
string maskToken,
bool tokenizeChineseChars,
bool stripAccents,
bool disposeStream)
{
if (vocabStream is null)
{
throw new ArgumentNullException(nameof(vocabStream));
}
try
{
(Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, string> vocabReverse) = LoadVocabAsync(vocabStream, useAsync: false).GetAwaiter().GetResult();
return Create(vocab, vocabReverse, doLowerCase, doBasicTokenization, splitOnSpecialTokens, unknownToken, sepToken, padToken, clsToken, maskToken, tokenizeChineseChars, stripAccents);
}
finally
{
if (disposeStream)
{
vocabStream.Dispose();
}
}
}
private static BertTokenizer Create(
Dictionary<StringSpanOrdinalKey, int> vocab,
Dictionary<int, string> vocabReverse,
bool doLowerCase,
bool doBasicTokenization,
bool splitOnSpecialTokens,
string unknownToken,
string sepToken,
string padToken,
string clsToken,
string maskToken,
bool tokenizeChineseChars,
bool stripAccents)
{
Normalizer? normalizer = doBasicTokenization ? new BertNormalizer(doLowerCase, tokenizeChineseChars, stripAccents) : null;
Dictionary<string, int>? specialTokens = new();
bool lowerCase = doBasicTokenization && doLowerCase && splitOnSpecialTokens;
AddSpecialToken(vocab, specialTokens, unknownToken, lowerCase);
AddSpecialToken(vocab, specialTokens, sepToken, lowerCase);
AddSpecialToken(vocab, specialTokens, padToken, lowerCase);
AddSpecialToken(vocab, specialTokens, clsToken, lowerCase);
AddSpecialToken(vocab, specialTokens, maskToken, lowerCase);
PreTokenizer? preTokenizer = doBasicTokenization ?
PreTokenizer.CreateWhiteSpaceOrPunctuationPreTokenizer(splitOnSpecialTokens ? specialTokens : null) :
PreTokenizer.CreateWhiteSpacePreTokenizer();
return new BertTokenizer(vocab, vocabReverse, preTokenizer, normalizer, specialTokens, doLowerCase, doBasicTokenization,
splitOnSpecialTokens, unknownToken, sepToken, padToken, clsToken, maskToken, tokenizeChineseChars, stripAccents);
}
private static void AddSpecialToken(Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<string, int> specialTokens, string token, bool lowerCase)
{
if (token is null || !vocab.TryGetValue(new StringSpanOrdinalKey(token), out int id))
{
throw new ArgumentException($"The special token '{token}' is not in the vocabulary.");
}
string normalizedToken = token;
if (lowerCase)
{
// Lowercase the special tokens to have the pre-tokenization can find them as we lowercase the input text.
// we don't even need to do case-insensitive comparisons as we are lowercasing the input text.
normalizedToken = token.ToLowerInvariant();
// Add lowercased special tokens to the vocab if they are not already there.
// This will allow matching during the encoding process.
vocab[new StringSpanOrdinalKey(normalizedToken)] = id;
}
specialTokens[normalizedToken] = id;
}
}
}

Просмотреть файл

@ -0,0 +1,858 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.ML.Tokenizers
{
/// <summary>
/// Represent the WordPiece tokenizer.
/// </summary>
/// <remarks>
/// The WordPiece tokenizer is a sub-word tokenizer that is used in BERT and other transformer models.
/// The implementation is based on the Hugging Face WordPiece tokenizer https://huggingface.co/docs/tokenizers/api/models#tokenizers.models.WordPiece.
/// </remarks>
public partial class WordPieceTokenizer : Tokenizer
{
private readonly PreTokenizer? _preTokenizer;
private readonly Normalizer? _normalizer;
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;
private readonly Dictionary<int, string> _vocabReverse;
internal const string DefaultContinuingSubwordPrefix = "##";
internal const int DefaultMaxInputCharsPerWord = 100;
internal WordPieceTokenizer(
Dictionary<StringSpanOrdinalKey, int> vocab,
Dictionary<int, string> vocabReverse,
PreTokenizer? preTokenizer,
Normalizer? normalizer,
IReadOnlyDictionary<string, int>? specialTokens,
string unknownToken,
string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
int maxInputCharsPerWord = DefaultMaxInputCharsPerWord)
{
Debug.Assert(vocab is not null);
Debug.Assert(vocabReverse is not null);
_vocab = vocab!;
_vocabReverse = vocabReverse!;
SpecialTokens = specialTokens;
SpecialTokensReverse = specialTokens is not null ? specialTokens.ToDictionary(kvp => kvp.Value, kvp => kvp.Key) : null;
if (unknownToken is null)
{
throw new ArgumentNullException(nameof(unknownToken));
}
if (continuingSubwordPrefix is null)
{
throw new ArgumentNullException(nameof(continuingSubwordPrefix));
}
if (maxInputCharsPerWord <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxInputCharsPerWord), "The maximum number of characters per word must be greater than zero.");
}
if (!vocab!.TryGetValue(unknownToken, out int id))
{
throw new ArgumentException($"The unknown token '{unknownToken}' is not in the vocabulary.");
}
UnknownToken = unknownToken;
UnknownTokenId = id;
ContinuingSubwordPrefix = continuingSubwordPrefix;
MaxInputCharsPerWord = maxInputCharsPerWord;
_preTokenizer = preTokenizer ?? PreTokenizer.CreateWhiteSpacePreTokenizer(specialTokens);
_normalizer = normalizer;
}
/// <summary>
/// Gets the unknown token ID.
/// A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.
/// </summary>
public int UnknownTokenId { get; }
/// <summary>
/// Gets the prefix to use for sub-words that are not the first part of a word.
/// </summary>
public string ContinuingSubwordPrefix { get; }
/// <summary>
/// Gets the maximum number of characters to authorize in a single word.
/// </summary>
public int MaxInputCharsPerWord { get; }
internal static async ValueTask<(Dictionary<StringSpanOrdinalKey, int>, Dictionary<int, string>)> LoadVocabAsync(Stream vocabStream, bool useAsync, CancellationToken cancellationToken = default)
{
if (vocabStream is null)
{
throw new ArgumentNullException(nameof(vocabStream));
}
Dictionary<StringSpanOrdinalKey, int> vocab = new Dictionary<StringSpanOrdinalKey, int>();
Dictionary<int, string> vocabReverse = new Dictionary<int, string>();
StreamReader reader = new StreamReader(vocabStream);
string? line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine();
int lineNumber = 0;
while (line is not null)
{
if (line.Length != 0)
{
vocab.Add(new StringSpanOrdinalKey(line), lineNumber);
vocabReverse.Add(lineNumber, line);
}
lineNumber++;
line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine();
}
return (vocab, vocabReverse);
}
/// <summary>
/// Create a new instance of the <see cref="WordPieceTokenizer"/> class.
/// </summary>
/// <param name="vocabFilePath">The path to the WordPiece vocab file.</param>
/// <param name="preTokenizer">The PreTokenizer to use.</param>
/// <param name="normalizer">The Normalizer to use.</param>
/// <param name="specialTokens">The dictionary containing the special tokens and their corresponding ids.</param>
/// <param name="unknownToken">The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.</param>
/// <param name="continuingSubwordPrefix">The prefix to use for sub-words that are not the first part of a word.</param>
/// <param name="maxInputCharsPerWord">The maximum number of characters to authorize in a single word.</param>
/// <returns>A new instance of the <see cref="WordPieceTokenizer"/> class.</returns>
/// <remarks>
/// If the <paramref name="preTokenizer"/> is null, the whitespace pre-tokenizer will be used.
/// </remarks>
public static WordPieceTokenizer Create(
string vocabFilePath,
PreTokenizer? preTokenizer = null,
Normalizer? normalizer = null,
IReadOnlyDictionary<string, int>? specialTokens = null,
string unknownToken = "[UNK]",
string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
int maxInputCharsPerWord = DefaultMaxInputCharsPerWord) =>
Create(string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath), preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord, disposeStream: true);
/// <summary>
/// Create a new instance of the <see cref="WordPieceTokenizer"/> class.
/// </summary>
/// <param name="vocabStream">The path to the WordPiece vocab file.</param>
/// <param name="preTokenizer">The PreTokenizer to use.</param>
/// <param name="normalizer">The Normalizer to use.</param>
/// <param name="specialTokens">The dictionary containing the special tokens and their corresponding ids.</param>
/// <param name="unknownToken">The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.</param>
/// <param name="continuingSubwordPrefix">The prefix to use for sub-words that are not the first part of a word.</param>
/// <param name="maxInputCharsPerWord">The maximum number of characters to authorize in a single word.</param>
/// <returns>A new instance of the <see cref="WordPieceTokenizer"/> class.</returns>
/// <remarks>
/// If the <paramref name="preTokenizer"/> is null, the whitespace pre-tokenizer will be used.
/// </remarks>
public static WordPieceTokenizer Create(
Stream vocabStream,
PreTokenizer? preTokenizer = null,
Normalizer? normalizer = null,
IReadOnlyDictionary<string, int>? specialTokens = null,
string unknownToken = "[UNK]",
string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
int maxInputCharsPerWord = DefaultMaxInputCharsPerWord) => Create(vocabStream, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord, disposeStream: false);
private static WordPieceTokenizer Create(
Stream vocabStream,
PreTokenizer? preTokenizer,
Normalizer? normalizer,
IReadOnlyDictionary<string, int>? specialTokens,
string unknownToken,
string continuingSubwordPrefix,
int maxInputCharsPerWord,
bool disposeStream)
{
if (vocabStream is null)
{
throw new ArgumentNullException(nameof(vocabStream));
}
try
{
(Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, string> vocabReverse) = LoadVocabAsync(vocabStream, useAsync: false).GetAwaiter().GetResult();
return new WordPieceTokenizer(vocab, vocabReverse, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord);
}
finally
{
if (disposeStream)
{
vocabStream.Dispose();
}
}
}
/// <summary>
/// Create a new instance of the <see cref="WordPieceTokenizer"/> class asynchronously.
/// </summary>
/// <param name="vocabFilePath">The path to the WordPiece vocab file.</param>
/// <param name="preTokenizer">The PreTokenizer to use.</param>
/// <param name="normalizer">The Normalizer to use.</param>
/// <param name="specialTokens">The dictionary containing the special tokens and their corresponding ids.</param>
/// <param name="unknownToken">The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.</param>
/// <param name="continuingSubwordPrefix">The prefix to use for sub-words that are not the first part of a word.</param>
/// <param name="maxInputCharsPerWord">The maximum number of characters to authorize in a single word.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A new instance of the <see cref="WordPieceTokenizer"/> class.</returns>
/// <remarks>
/// If the <paramref name="preTokenizer"/> is null, the whitespace pre-tokenizer will be used.
/// </remarks>
public static async Task<WordPieceTokenizer> CreateAsync(
string vocabFilePath,
PreTokenizer? preTokenizer = null,
Normalizer? normalizer = null,
IReadOnlyDictionary<string, int>? specialTokens = null,
string unknownToken = "[UNK]",
string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
int maxInputCharsPerWord = DefaultMaxInputCharsPerWord,
CancellationToken cancellationToken = default) =>
await CreateAsync(
string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath),
preTokenizer,
normalizer,
specialTokens,
unknownToken,
continuingSubwordPrefix,
maxInputCharsPerWord,
cancellationToken,
disposeStream: true);
/// <summary>
/// Create a new instance of the <see cref="WordPieceTokenizer"/> class asynchronously.
/// </summary>
/// <param name="vocabStream">The path to the WordPiece vocab file.</param>
/// <param name="preTokenizer">The PreTokenizer to use.</param>
/// <param name="normalizer">The Normalizer to use.</param>
/// <param name="specialTokens">The dictionary containing the special tokens and their corresponding ids.</param>
/// <param name="unknownToken">The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.</param>
/// <param name="continuingSubwordPrefix">The prefix to use for sub-words that are not the first part of a word.</param>
/// <param name="maxInputCharsPerWord">The maximum number of characters to authorize in a single word.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A new instance of the <see cref="WordPieceTokenizer"/> class.</returns>
/// <remarks>
/// If the <paramref name="preTokenizer"/> is null, the whitespace pre-tokenizer will be used.
/// </remarks>
public static async Task<WordPieceTokenizer> CreateAsync(
Stream vocabStream,
PreTokenizer? preTokenizer = null,
Normalizer? normalizer = null,
IReadOnlyDictionary<string, int>? specialTokens = null,
string unknownToken = "[UNK]",
string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
int maxInputCharsPerWord = DefaultMaxInputCharsPerWord,
CancellationToken cancellationToken = default) =>
await CreateAsync(vocabStream, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord, cancellationToken, disposeStream: false);
private static async Task<WordPieceTokenizer> CreateAsync(
Stream vocabStream,
PreTokenizer? preTokenizer,
Normalizer? normalizer,
IReadOnlyDictionary<string, int>? specialTokens,
string unknownToken,
string continuingSubwordPrefix,
int maxInputCharsPerWord,
CancellationToken cancellationToken,
bool disposeStream)
{
if (vocabStream is null)
{
throw new ArgumentNullException(nameof(vocabStream));
}
try
{
(Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, string> vocabReverse) = await LoadVocabAsync(vocabStream, useAsync: true, cancellationToken);
return new WordPieceTokenizer(vocab, vocabReverse, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord);
}
finally
{
if (disposeStream)
{
vocabStream.Dispose();
}
}
}
/// <summary>
/// Gets the PreTokenizer used by the Tokenizer.
/// </summary>
public override PreTokenizer? PreTokenizer => _preTokenizer;
/// <summary>
/// Gets the Normalizer in use by the Tokenizer.
/// </summary>
public override Normalizer? Normalizer => _normalizer;
/// <summary>
/// Gets the unknown token.
/// A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.
/// </summary>
public string UnknownToken { get; }
/// <summary>
/// Gets the special tokens and their corresponding ids.
/// </summary>
public IReadOnlyDictionary<string, int>? SpecialTokens { get; }
/// <summary>
/// Gets the Ids to tokens mapping for special tokens.
/// </summary>
internal IReadOnlyDictionary<int, string>? SpecialTokensReverse { get; }
/// <summary>
/// Encodes input text to a list of <see cref="EncodedToken" />s.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
protected override EncodeResults<EncodedToken> EncodeToTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
{
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
return new EncodeResults<EncodedToken> { NormalizedText = null, Tokens = [], CharsConsumed = 0 };
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
text,
textSpan,
settings.ConsiderPreTokenization,
settings.ConsiderNormalization,
_normalizer,
_preTokenizer,
out string? normalizedString,
out ReadOnlySpan<char> textSpanToEncode,
out int charsConsumed);
List<EncodedToken> tokens = new();
if (splits is not null)
{
foreach ((int Offset, int Length) split in splits)
{
EncodeToTokens(textSpanToEncode.Slice(split.Offset, split.Length), tokens, split.Offset);
}
}
else
{
EncodeToTokens(textSpanToEncode, tokens, 0);
}
return new EncodeResults<EncodedToken> { NormalizedText = normalizedString, Tokens = tokens, CharsConsumed = charsConsumed };
}
/// <summary>
/// Encode text to a list of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="tokens">The list of tokens to populate.</param>
/// <param name="offset">The offset to start encoding from.</param>
private void EncodeToTokens(ReadOnlySpan<char> text, List<EncodedToken> tokens, int offset)
{
Debug.Assert(!text.IsEmpty);
if (text.Length > MaxInputCharsPerWord)
{
tokens.Add(new EncodedToken(UnknownTokenId, UnknownToken, new Range(offset, offset + text.Length)));
return;
}
int maxLength = MaxInputCharsPerWord + ContinuingSubwordPrefix.Length;
char[]? arrayPool = maxLength <= 250 ? null : ArrayPool<char>.Shared.Rent(maxLength);
Span<char> buffer = arrayPool is null ? stackalloc char[maxLength] : arrayPool;
ContinuingSubwordPrefix.AsSpan().CopyTo(buffer);
int initialTokensCount = tokens.Count;
int textLength = text.Length;
bool isBad = false;
int start = 0;
while (start < textLength)
{
int end = textLength;
EncodedToken curToken = default;
while (start < end)
{
scoped ReadOnlySpan<char> subStr = text.Slice(start, end - start);
if (start > 0)
{
subStr.CopyTo(buffer.Slice(ContinuingSubwordPrefix.Length));
subStr = buffer.Slice(0, ContinuingSubwordPrefix.Length + subStr.Length);
}
if (_vocab.TryGetValue(subStr, out int id))
{
Debug.Assert(_vocabReverse.ContainsKey(id));
curToken = new EncodedToken(id, _vocabReverse[id], new Range(offset + start, offset + end));
break;
}
end -= 1;
}
if (curToken.Value is null)
{
isBad = true;
break;
}
tokens.Add(curToken);
start = end;
}
if (isBad)
{
// remove previously added tokens and add the unknown token
tokens.RemoveRange(initialTokensCount, tokens.Count - initialTokensCount);
tokens.Add(new EncodedToken(UnknownTokenId, UnknownToken, new Range(offset, offset + textLength)));
}
if (arrayPool is not null)
{
ArrayPool<char>.Shared.Return(arrayPool);
}
}
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
/// <returns>The encoded results containing the list of encoded Ids.</returns>
protected override EncodeResults<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
{
int maxTokenCount = settings.MaxTokenCount;
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(settings.MaxTokenCount), "The maximum number of tokens must be greater than zero.");
}
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
return new EncodeResults<int> { NormalizedText = null, Tokens = [], CharsConsumed = 0 };
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
text,
textSpan,
settings.ConsiderPreTokenization,
settings.ConsiderNormalization,
_normalizer,
_preTokenizer,
out string? normalizedString,
out ReadOnlySpan<char> textSpanToEncode,
out int charsConsumed);
List<int> ids = new();
if (splits is not null)
{
charsConsumed = 0;
foreach ((int Offset, int Length) split in splits)
{
EncodeToIds(textSpanToEncode.Slice(split.Offset, split.Length), ids, out int length, maxTokenCount - ids.Count);
if (length < split.Length || ids.Count >= maxTokenCount)
{
break;
}
charsConsumed = split.Offset + length;
}
}
else
{
EncodeToIds(textSpanToEncode, ids, out charsConsumed);
}
return new EncodeResults<int> { NormalizedText = normalizedString, Tokens = ids, CharsConsumed = charsConsumed };
}
/// <summary>
/// Encode text to a list of Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="accumulatedIds">The list of accumulated Ids.</param>
/// <param name="charsConsumed">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
private int EncodeToIds(ReadOnlySpan<char> text, List<int>? accumulatedIds, out int charsConsumed, int maxTokenCount = int.MaxValue)
{
Debug.Assert(maxTokenCount > 0);
if (text.IsEmpty)
{
charsConsumed = 0;
return 0;
}
if (text.Length > MaxInputCharsPerWord)
{
accumulatedIds?.Add(UnknownTokenId);
charsConsumed = text.Length;
return 1;
}
int maxLength = MaxInputCharsPerWord + ContinuingSubwordPrefix.Length;
char[]? arrayPool = maxLength <= 250 ? null : ArrayPool<char>.Shared.Rent(maxLength);
Span<char> buffer = arrayPool is null ? stackalloc char[maxLength] : arrayPool;
ContinuingSubwordPrefix.AsSpan().CopyTo(buffer);
int addedIds = 0;
int textLength = text.Length;
bool isBad = false;
int start = 0;
while (start < textLength)
{
int end = textLength;
int curId = 0;
bool found = false;
while (start < end)
{
scoped ReadOnlySpan<char> subStr = text.Slice(start, end - start);
if (start > 0)
{
subStr.CopyTo(buffer.Slice(ContinuingSubwordPrefix.Length));
subStr = buffer.Slice(0, ContinuingSubwordPrefix.Length + subStr.Length);
}
if (_vocab.TryGetValue(subStr, out curId))
{
found = true;
break;
}
end -= 1;
}
if (!found)
{
isBad = true;
break;
}
accumulatedIds?.Add(curId);
addedIds++;
start = end;
}
charsConsumed = textLength;
if (addedIds > maxTokenCount)
{
// not enough space to hold added ids. Remove previously added ids
accumulatedIds?.RemoveRange(accumulatedIds.Count - addedIds, addedIds);
addedIds = 0;
charsConsumed = 0;
}
else if (isBad)
{
// remove previously added ids and add the unknown token id
accumulatedIds?.RemoveRange(accumulatedIds.Count - addedIds, addedIds);
accumulatedIds?.Add(UnknownTokenId);
addedIds = 1;
}
if (arrayPool is not null)
{
ArrayPool<char>.Shared.Return(arrayPool);
}
return addedIds;
}
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
protected override int CountTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
{
int maxTokenCount = settings.MaxTokenCount;
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(settings.MaxTokenCount), "The maximum number of tokens must be greater than zero.");
}
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
return 0;
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
text,
textSpan,
settings.ConsiderPreTokenization,
settings.ConsiderNormalization,
_normalizer,
_preTokenizer,
out string? normalizedString,
out ReadOnlySpan<char> textSpanToEncode,
out int charsConsumed);
int count = 0;
if (splits is not null)
{
foreach ((int Offset, int Length) split in splits)
{
count += EncodeToIds(textSpanToEncode.Slice(split.Offset, split.Length), accumulatedIds: null, out int length, maxTokenCount - count);
if (length < split.Length || count >= maxTokenCount)
{
break;
}
}
}
else
{
count = EncodeToIds(textSpanToEncode, accumulatedIds: null, out charsConsumed, maxTokenCount);
}
return count;
}
/// <summary>
/// Find the index of the maximum encoding capacity without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
/// <param name="fromEnd">Indicate whether to find the index from the end of the text.</param>
/// <param name="normalizedString">If the tokenizer's normalization is enabled or <paramRef name="settings" /> has <see cref="EncodeSettings.ConsiderNormalization"/> is <see langword="false"/>, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to <see langword="null"/>.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// If <paramRef name="fromEnd" /> is <see langword="false"/>, it represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the input text or the <paramref name="normalizedString"/> if the normalization is enabled.
/// If <paramRef name="fromEnd" /> is <see langword="true"/>, it represents the index of the first character to be included. In cases where no tokens fit, the result will be the text length; conversely,
/// if all tokens fit, the result will be zero.
/// </returns>
protected override int GetIndexByTokenCount(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings, bool fromEnd, out string? normalizedString, out int tokenCount)
{
if (settings.MaxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(settings.MaxTokenCount), "The max token count must be greater than 0.");
}
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
normalizedString = null;
tokenCount = 0;
return 0;
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
text,
textSpan,
settings.ConsiderNormalization,
settings.ConsiderNormalization,
_normalizer,
_preTokenizer,
out normalizedString,
out ReadOnlySpan<char> textSpanToEncode,
out _);
int charsConsumed;
if (splits is null)
{
tokenCount = EncodeToIds(textSpanToEncode, accumulatedIds: null, out charsConsumed, settings.MaxTokenCount);
if (charsConsumed != textSpanToEncode.Length)
{
tokenCount = 0;
return fromEnd ? textSpanToEncode.Length : 0;
}
return fromEnd ? 0 : textSpanToEncode.Length;
}
if (fromEnd)
{
splits = splits.Reverse();
}
tokenCount = 0;
foreach ((int Offset, int Length) split in splits)
{
int count = EncodeToIds(textSpanToEncode.Slice(split.Offset, split.Length), accumulatedIds: null, out charsConsumed, settings.MaxTokenCount - tokenCount);
if (charsConsumed != split.Length)
{
return fromEnd ? split.Offset + split.Length : split.Offset;
}
tokenCount += count;
if (count >= settings.MaxTokenCount)
{
return fromEnd ? split.Offset : split.Offset + split.Length;
}
}
return fromEnd ? 0 : textSpanToEncode.Length;
}
/// <summary>
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <returns>The decoded string.</returns>
public override string Decode(IEnumerable<int> ids) => Decode(ids, skipSpecialTokens: false);
/// <summary>
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="skipSpecialTokens">Indicate whether to skip the special tokens during the decoding.</param>
/// <returns>The decoded string.</returns>
public string Decode(IEnumerable<int> ids, bool skipSpecialTokens)
{
ValueStringBuilder sb = new ValueStringBuilder();
bool first = true;
bool ignoreSpecialTokens = skipSpecialTokens && SpecialTokensReverse is not null;
foreach (int id in ids)
{
if (ignoreSpecialTokens && SpecialTokensReverse!.TryGetValue(id, out _))
{
continue;
}
if (_vocabReverse.TryGetValue(id, out string? token))
{
if (token.StartsWith(ContinuingSubwordPrefix))
{
sb.Append(token.AsSpan().Slice(ContinuingSubwordPrefix.Length));
}
else
{
if (!first && token[0] is not ('.' or ',' or '!' or '?' or '\''))
{
sb.Append(' ');
}
sb.Append(token);
}
}
first = false;
}
return sb.ToString();
}
/// <summary>
/// Decode the given ids back to text and store the result in the <paramref name="destination"/> span.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="destination">The span to store the decoded text.</param>
/// <param name="idsConsumed">The number of ids consumed during the decoding.</param>
/// <param name="charsWritten">The number of characters written to the destination span.</param>
/// <returns>The operation status indicates whether all IDs were successfully decoded or if the <paramref name="destination"/> is too small to contain the entire decoded result.</returns>
public override OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, out int idsConsumed, out int charsWritten) =>
Decode(ids, destination, skipSpecialTokens: false, out idsConsumed, out charsWritten);
/// <summary>
/// Decode the given ids back to text and store the result in the <paramref name="destination"/> span.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="destination">The span to store the decoded text.</param>
/// <param name="skipSpecialTokens">Indicate whether to skip the special tokens during the decoding.</param>
/// <param name="idsConsumed">The number of ids consumed during the decoding.</param>
/// <param name="charsWritten">The number of characters written to the destination span.</param>
/// <returns>The operation status indicates whether all IDs were successfully decoded or if the <paramref name="destination"/> is too small to contain the entire decoded result.</returns>
public OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, bool skipSpecialTokens, out int idsConsumed, out int charsWritten)
{
charsWritten = 0;
idsConsumed = 0;
Span<char> buffer = destination;
bool first = true;
bool ignoreSpecialTokens = SpecialTokensReverse is not null && skipSpecialTokens;
foreach (int id in ids)
{
if (ignoreSpecialTokens && SpecialTokensReverse!.TryGetValue(id, out _))
{
continue;
}
if (_vocabReverse.TryGetValue(id, out string? token))
{
if (token.StartsWith(ContinuingSubwordPrefix, StringComparison.Ordinal))
{
if (token.Length - ContinuingSubwordPrefix.Length > buffer.Length)
{
return OperationStatus.DestinationTooSmall;
}
token.AsSpan().Slice(ContinuingSubwordPrefix.Length).CopyTo(buffer);
buffer = buffer.Slice(token.Length - ContinuingSubwordPrefix.Length);
charsWritten += token.Length - ContinuingSubwordPrefix.Length;
}
else
{
if (!first)
{
if (token.Length + 1 > buffer.Length)
{
return OperationStatus.DestinationTooSmall;
}
buffer[0] = ' ';
token.AsSpan().CopyTo(buffer.Slice(1));
buffer = buffer.Slice(token.Length + 1);
charsWritten += token.Length + 1;
}
else
{
if (token.Length > buffer.Length)
{
return OperationStatus.DestinationTooSmall;
}
token.AsSpan().CopyTo(buffer);
buffer = buffer.Slice(token.Length);
charsWritten += token.Length;
}
}
first = false;
idsConsumed++;
}
else
{
return OperationStatus.InvalidData;
}
}
return OperationStatus.Done;
}
}
}

Просмотреть файл

@ -0,0 +1,200 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Buffers;
using System.Diagnostics;
using System.Globalization;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
namespace Microsoft.ML.Tokenizers
{
/// <summary>
/// Normalizer that performs the Bert model normalization.
/// </summary>
internal sealed class BertNormalizer : Normalizer
{
private readonly bool _doLowerCase;
private readonly bool _tokenizeChineseChars;
private readonly bool _stripAccents;
/// <summary>
/// Normalize the input string.
/// </summary>
/// <param name="original">The input string to normalize.</param>
/// <returns>The normalized string.</returns>
public override string Normalize(string original)
{
if (string.IsNullOrEmpty(original))
{
return string.Empty;
}
if (_stripAccents)
{
original = original.Normalize(NormalizationForm.FormD);
}
Span<char> casingBuffer = stackalloc char[10];
char[] buffer = ArrayPool<char>.Shared.Rent(original.Length);
int index = 0;
for (int i = 0; i < original.Length; i++)
{
char c = original[i];
if (c == '\u0000' || c == '\uFFFD')
{
continue;
}
int inc = 0;
int codePoint = (int)c;
if (char.IsHighSurrogate(c) && i + 1 < original.Length && char.IsLowSurrogate(original[i + 1]))
{
codePoint = char.ConvertToUtf32(c, original[i + 1]);
inc = 1;
}
UnicodeCategory category = CharUnicodeInfo.GetUnicodeCategory(original, i);
if (category == UnicodeCategory.Control)
{
i += inc;
continue;
}
if (category == UnicodeCategory.SpaceSeparator)
{
InsertChar(ref buffer, ref index, ' ');
i += inc;
continue;
}
if (_stripAccents && category is UnicodeCategory.NonSpacingMark or UnicodeCategory.SpacingCombiningMark)
{
i += inc;
continue;
}
if (_doLowerCase && category == UnicodeCategory.UppercaseLetter)
{
int length = original.AsSpan().Slice(i, inc + 1).ToLowerInvariant(casingBuffer);
Debug.Assert(length > 0);
InsertSpan(ref buffer, ref index, casingBuffer.Slice(0, length));
i += inc;
continue;
}
if (_tokenizeChineseChars && IsChineseChar(codePoint))
{
InsertChar(ref buffer, ref index, ' ');
InsertChar(ref buffer, ref index, c);
if (inc > 0)
{
InsertChar(ref buffer, ref index, original[i + 1]);
}
InsertChar(ref buffer, ref index, ' ');
i += inc;
continue;
}
InsertChar(ref buffer, ref index, c);
if (inc > 0)
{
InsertChar(ref buffer, ref index, original[i + 1]);
}
i += inc;
}
string result = index == 0 ? string.Empty : new string(buffer, 0, index).Normalize(NormalizationForm.FormC);
ArrayPool<char>.Shared.Return(buffer);
return result;
}
/// <summary>
/// Normalize the input character span.
/// </summary>
/// <param name="original">The input character span to normalize.</param>
/// <returns>The normalized string.</returns>
public override string Normalize(ReadOnlySpan<char> original)
{
if (original.IsEmpty)
{
return string.Empty;
}
return Normalize(original.ToString());
}
/// <summary>
/// Initializes a new instance of the <see cref="BertNormalizer"/> class.
/// </summary>
/// <param name="doLowerCase">Whether to lowercase the input.</param>
/// <param name="tokenizeChineseChars">Whether to tokenize Chinese characters.</param>
/// <param name="stripAccents">Whether to strip accents from the input.</param>
public BertNormalizer(bool doLowerCase, bool tokenizeChineseChars, bool stripAccents)
{
_doLowerCase = doLowerCase;
_tokenizeChineseChars = tokenizeChineseChars;
_stripAccents = stripAccents;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void InsertChar(ref char[] buffer, ref int index, char c)
{
if (index >= buffer.Length)
{
Helpers.ArrayPoolGrow(ref buffer, index + 40);
}
buffer[index++] = c;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void InsertSpan(ref char[] buffer, ref int index, Span<char> chars)
{
if (index + buffer.Length >= buffer.Length)
{
Helpers.ArrayPoolGrow(ref buffer, index + buffer.Length + 10);
}
chars.CopyTo(buffer.AsSpan(index));
index += chars.Length;
}
/// <summary>
/// Checks whether CP is the codepoint of a CJK character.
/// This defines a "chinese character" as anything in the CJK Unicode block:
/// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
/// </summary>
/// <param name="codePoint">The codepoint to check.</param>
/// <remarks>
/// The CJK Unicode block is NOT all Japanese and Korean characters,
/// despite its name. The modern Korean Hangul alphabet is a different block,
/// as is Japanese Hiragana and Katakana. Those alphabets are used to write
/// space-separated words, so they are not treated specially and handled
/// like the all of the other languages.
/// </remarks>
/// <returns>True if the codepoint is a CJK character, false otherwise.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool IsChineseChar(int codePoint)
{
return (codePoint > 0x3400) && // Quick check to exit early if the codepoint is outside of the CJK range
(((uint)(codePoint - 0x3400) <= (uint)(0x4DBF - 0x3400)) ||
((uint)(codePoint - 0xF900) <= (uint)(0xFAFF - 0xF900)) ||
((uint)(codePoint - 0x4E00) <= (uint)(0x9FFF - 0x4E00)) ||
((uint)(codePoint - 0x20000) <= (uint)(0x2A6DF - 0x20000)) ||
((uint)(codePoint - 0x2A700) <= (uint)(0x2B73F - 0x2A700)) ||
((uint)(codePoint - 0x2B740) <= (uint)(0x2B81F - 0x2B740)) ||
((uint)(codePoint - 0x2B820) <= (uint)(0x2CEAF - 0x2B820)) ||
((uint)(codePoint - 0x2F800) <= (uint)(0x2FA1F - 0x2F800)));
}
}
}

Просмотреть файл

@ -40,8 +40,61 @@ namespace Microsoft.ML.Tokenizers
}
}
private const string WhiteSpacePattern = /*lang=regex*/ @"\w+|[^\w\s]+";
private const string WhiteSpaceOrPunctuationPattern = @"\w+|[\p{P}]";
private static PreTokenizer? _whiteSpaceOrPunctuationPreTokenizer;
#if NET7_0_OR_GREATER
[GeneratedRegex(WhiteSpaceOrPunctuationPattern)]
private static partial Regex WhiteSpaceOrPunctuationRegex();
#else
private static Regex WhiteSpaceOrPunctuationRegex() => new Regex(WhiteSpaceOrPunctuationPattern, RegexOptions.Compiled);
#endif
/// <summary>
/// Create a new instance of the <see cref="PreTokenizer"/> class which split the text at the whitespace or punctuation characters.
/// </summary>
/// <param name="specialTokensEncoder">The dictionary containing the special tokens and their corresponding ids.</param>
/// <returns>The pre-tokenizer that splits the text at the whitespace or punctuation characters.</returns>
public static PreTokenizer CreateWhiteSpaceOrPunctuationPreTokenizer(IReadOnlyDictionary<string, int>? specialTokensEncoder = null)
{
if (specialTokensEncoder is null)
{
// return a singleton instance of the WhiteSpace pre-tokenizer
return _whiteSpaceOrPunctuationPreTokenizer ??= new RegexPreTokenizer(WhiteSpaceOrPunctuationRegex(), null);
}
return new RegexPreTokenizer(WhiteSpaceOrPunctuationRegex(), specialTokensEncoder);
}
private const string WordOrNonWordPattern = /*lang=regex*/ @"\w+|[^\w\s]+";
private static PreTokenizer? _wordOrNonWordPreTokenizer;
#if NET7_0_OR_GREATER
[GeneratedRegex(WordOrNonWordPattern)]
private static partial Regex WordOrNonWordRegex();
#else
private static Regex WordOrNonWordRegex() => new Regex(WordOrNonWordPattern, RegexOptions.Compiled);
#endif
/// <summary>
/// Create a new instance of the <see cref="PreTokenizer"/> class which split the text at the word or non-word boundary.
/// The word is a set of alphabet, numeric, and underscore characters.
/// </summary>
/// <param name="specialTokensEncoder">The dictionary containing the special tokens and their corresponding ids.</param>
/// <returns>The pre-tokenizer that splits the text at the word boundary.</returns>
public static PreTokenizer CreateWordOrNonWordPreTokenizer(IReadOnlyDictionary<string, int>? specialTokensEncoder = null)
{
if (specialTokensEncoder is null)
{
// return a singleton instance of the WhiteSpace pre-tokenizer
return _wordOrNonWordPreTokenizer ??= new RegexPreTokenizer(WordOrNonWordRegex(), null);
}
return new RegexPreTokenizer(WordOrNonWordRegex(), specialTokensEncoder);
}
private const string WhiteSpacePattern = @"\S+";
private static PreTokenizer? _whiteSpacePreTokenizer;
#if NET7_0_OR_GREATER
[GeneratedRegex(WhiteSpacePattern)]
private static partial Regex WhiteSpaceRegex();
@ -50,12 +103,11 @@ namespace Microsoft.ML.Tokenizers
#endif
/// <summary>
/// Create a new instance of the <see cref="PreTokenizer"/> class which split the text at the word boundary.
/// The word is a set of alphabet, numeric, and underscore characters.
/// Create a new instance of the <see cref="PreTokenizer"/> class which split the text at the white spaces.
/// </summary>
/// <param name="specialTokensEncoder">The dictionary containing the special tokens and their corresponding ids.</param>
/// <returns>The pre-tokenizer that splits the text at the word boundary.</returns>
public static PreTokenizer CreateWhiteSpace(IReadOnlyDictionary<string, int>? specialTokensEncoder = null)
/// <returns>The pre-tokenizer that splits the text at the white spaces.</returns>
public static PreTokenizer CreateWhiteSpacePreTokenizer(IReadOnlyDictionary<string, int>? specialTokensEncoder = null)
{
if (specialTokensEncoder is null)
{

Просмотреть файл

@ -0,0 +1,513 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics.Tracing;
using System.IO;
using System.Threading.Tasks;
using Xunit;
namespace Microsoft.ML.Tokenizers.Tests
{
public class BertTokenizerTests
{
[Fact]
public void TestWithLowerCasing()
{
// Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12
string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you"];
string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens);
try
{
using Stream vocabStream = File.OpenRead(vocabFile);
BertTokenizer[] bertTokenizers = [BertTokenizer.Create(vocabFile), BertTokenizer.Create(vocabStream)];
foreach (var tokenizer in bertTokenizers)
{
Assert.NotNull(tokenizer.PreTokenizer);
Assert.Equal("[UNK]", tokenizer.UnknownToken);
Assert.Equal(1, tokenizer.UnknownTokenId);
Assert.NotNull(tokenizer.Normalizer);
Assert.NotNull(tokenizer.PreTokenizer);
string text = "Hello, How are you?";
var tokens = tokenizer.EncodeToTokens(text, out string? normalizedText);
Assert.Equal("hello, how are you?", normalizedText);
Assert.Equal(
[
new EncodedToken(8, "hello", new Range(0, 5)),
new EncodedToken(6, ",", new Range(5, 6)),
new EncodedToken(10, "how", new Range(7, 10)),
new EncodedToken(11, "are", new Range(11, 14)),
new EncodedToken(12, "you", new Range(15, 18)),
new EncodedToken(7, "?", new Range(18, 19))
],
tokens);
var ids = tokenizer.EncodeToIds(text);
Assert.Equal([tokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, tokenizer.SepTokenId], ids);
Assert.Equal("[CLS] hello, how are you? [SEP]", tokenizer.Decode(ids));
Assert.Equal("hello, how are you?", tokenizer.Decode(ids, skipSpecialTokens: true));
tokens = tokenizer.EncodeToTokens(tokenizer.Decode(ids), out normalizedText);
Assert.Equal("[cls] hello, how are you? [sep]", normalizedText);
Assert.Equal(
[
new EncodedToken(2, "[CLS]", new Range(0, 5)),
new EncodedToken(8, "hello", new Range(6, 11)),
new EncodedToken(6, ",", new Range(11, 12)),
new EncodedToken(10, "how", new Range(13, 16)),
new EncodedToken(11, "are", new Range(17, 20)),
new EncodedToken(12, "you", new Range(21, 24)),
new EncodedToken(7, "?", new Range(24, 25)),
new EncodedToken(3, "[SEP]", new Range(26, 31))
],
tokens);
ids = tokenizer.EncodeToIds(normalizedText!);
Assert.Equal([tokenizer.ClsTokenId, tokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, tokenizer.SepTokenId, tokenizer.SepTokenId], ids);
}
}
finally
{
File.Delete(vocabFile);
}
}
[Fact]
public void TestWithNoLowerCasing()
{
// Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12
string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you"];
string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens);
try
{
using Stream vocabStream = File.OpenRead(vocabFile);
BertTokenizer[] bertTokenizers = [BertTokenizer.Create(vocabFile, doLowerCase: false), BertTokenizer.Create(vocabStream, doLowerCase: false)];
foreach (var tokenizer in bertTokenizers)
{
Assert.NotNull(tokenizer.PreTokenizer);
Assert.Equal("[UNK]", tokenizer.UnknownToken);
Assert.Equal(1, tokenizer.UnknownTokenId);
Assert.NotNull(tokenizer.Normalizer);
Assert.NotNull(tokenizer.PreTokenizer);
string text = "Hello, How are you?";
var tokens = tokenizer.EncodeToTokens(text, out string? normalizedText);
Assert.Equal("Hello, How are you?", normalizedText);
Assert.Equal(
[
new EncodedToken(1, "[UNK]", new Range(0, 5)),
new EncodedToken(6, ",", new Range(5, 6)),
new EncodedToken(1, "[UNK]", new Range(7, 10)),
new EncodedToken(11, "are", new Range(11, 14)),
new EncodedToken(12, "you", new Range(15, 18)),
new EncodedToken(7, "?", new Range(18, 19))
],
tokens);
var ids = tokenizer.EncodeToIds(text);
Assert.Equal([tokenizer.ClsTokenId, 1, 6, 1, 11, 12, 7, tokenizer.SepTokenId], ids);
Assert.Equal("[CLS] [UNK], [UNK] are you? [SEP]", tokenizer.Decode(ids));
Assert.Equal(", are you?", tokenizer.Decode(ids, skipSpecialTokens: true));
}
}
finally
{
File.Delete(vocabFile);
}
}
[Fact]
public async Task TestWithAccentMarks()
{
// Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "Café", "cafe", "café", "Über", "über", "uber", "Ångström", "ångström", "angstrom", "Résumé", "résumé", "resume",
// Ids: 20 21 22 23
"Cafe", "Uber", "Angstrom", "Resume"];
string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens);
try
{
using Stream vocabStream = File.OpenRead(vocabFile);
BertTokenizer bertTokenizer = await BertTokenizer.CreateAsync(vocabStream); // lowercasing and no accent stripping
string text = "Café Über Ångström Résumé!";
var tokens = bertTokenizer.EncodeToTokens(text, out string? normalizedText);
Assert.Equal(
[
new EncodedToken(10, "café", new Range(0, 4)),
new EncodedToken(12, "über", new Range(5, 9)),
new EncodedToken(15, "ångström", new Range(10, 18)),
new EncodedToken(18, "résumé", new Range(19, 25)),
new EncodedToken(5, "!", new Range(25, 26)),
],
tokens);
Assert.Equal("café über ångström résumé!", normalizedText);
vocabStream.Position = 0;
bertTokenizer = await BertTokenizer.CreateAsync(vocabStream, doLowerCase: false); // no lowercasing and no accent stripping
tokens = bertTokenizer.EncodeToTokens(text, out normalizedText);
Assert.Equal(
[
new EncodedToken(8, "Café", new Range(0, 4)),
new EncodedToken(11, "Über", new Range(5, 9)),
new EncodedToken(14, "Ångström", new Range(10, 18)),
new EncodedToken(17, "Résumé", new Range(19, 25)),
new EncodedToken(5, "!", new Range(25, 26)),
],
tokens);
Assert.Equal("Café Über Ångström Résumé!", normalizedText);
vocabStream.Position = 0;
bertTokenizer = await BertTokenizer.CreateAsync(vocabStream, stripAccents: true); // lowercasing and accent stripping
tokens = bertTokenizer.EncodeToTokens(text, out normalizedText);
Assert.Equal("cafe uber angstrom resume!", normalizedText);
Assert.Equal(
[
new EncodedToken(9, "cafe", new Range(0, 4)),
new EncodedToken(13, "uber", new Range(5, 9)),
new EncodedToken(16, "angstrom", new Range(10, 18)),
new EncodedToken(19, "resume", new Range(19, 25)),
new EncodedToken(5, "!", new Range(25, 26)),
],
tokens);
vocabStream.Position = 0;
bertTokenizer = await BertTokenizer.CreateAsync(vocabStream, doLowerCase: false, stripAccents: true); // no lowercasing and accent stripping
tokens = bertTokenizer.EncodeToTokens(text, out normalizedText);
Assert.Equal("Cafe Uber Angstrom Resume!", normalizedText);
Assert.Equal(
[
new EncodedToken(20, "Cafe", new Range(0, 4)),
new EncodedToken(21, "Uber", new Range(5, 9)),
new EncodedToken(22, "Angstrom", new Range(10, 18)),
new EncodedToken(23, "Resume", new Range(19, 25)),
new EncodedToken(5, "!", new Range(25, 26)),
],
tokens);
}
finally
{
File.Delete(vocabFile);
}
}
[Fact]
public async Task TestChineseCharacters()
{
// Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12
string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", "##驷", "##驸", "受", "叟", "叢", "驷", "驸"];
string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens);
try
{
using Stream vocabStream = File.OpenRead(vocabFile);
BertTokenizer bertTokenizer = await BertTokenizer.CreateAsync(vocabStream); // tokenize Chinese characters
string text = "叟驷 叢驸!";
var tokens = bertTokenizer.EncodeToTokens(text, out string? normalizedText);
Assert.Equal(" 叟 驷 叢 驸 !", normalizedText);
Assert.Equal(
[
new EncodedToken(9, "叟", new Range(1, 2)),
new EncodedToken(11, "驷", new Range(4, 5)),
new EncodedToken(10, "叢", new Range(8, 9)),
new EncodedToken(12, "驸", new Range(11, 12)),
new EncodedToken(5, "!", new Range(13, 14))
],
tokens);
IReadOnlyList<int> ids = bertTokenizer.EncodeToIds(text);
Assert.Equal("[CLS] 叟 驷 叢 驸! [SEP]", bertTokenizer.Decode(bertTokenizer.EncodeToIds(text)));
Assert.Equal("叟 驷 叢 驸!", bertTokenizer.Decode(bertTokenizer.EncodeToIds(text), skipSpecialTokens: true));
vocabStream.Position = 0;
bertTokenizer = await BertTokenizer.CreateAsync(vocabStream, tokenizeChineseChars: false); // do not tokenize Chinese characters
tokens = bertTokenizer.EncodeToTokens(text, out normalizedText);
Assert.Equal("叟驷 叢驸!", normalizedText);
Assert.Equal(
[
new EncodedToken(9, "叟", new Range(0, 1)),
new EncodedToken(6, "##驷", new Range(1, 2)),
new EncodedToken(10, "叢", new Range(3, 4)),
new EncodedToken(7, "##驸", new Range(4, 5)),
new EncodedToken(5, "!", new Range(5, 6))
],
tokens);
ids = bertTokenizer.EncodeToIds(text);
Assert.Equal("[CLS] 叟驷 叢驸! [SEP]", bertTokenizer.Decode(bertTokenizer.EncodeToIds(text)));
Assert.Equal("叟驷 叢驸!", bertTokenizer.Decode(bertTokenizer.EncodeToIds(text), skipSpecialTokens: true));
}
finally
{
File.Delete(vocabFile);
}
}
[Fact]
public void TestBuildInputsWithSpecialTokens()
{
// Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you", "i", "am", "fine"];
string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens);
try
{
using Stream vocabStream = File.OpenRead(vocabFile);
BertTokenizer bertTokenizer = BertTokenizer.Create(vocabFile);
string text1 = "Hello, How are you?";
string text2 = "I am fine!";
var ids1 = bertTokenizer.EncodeToIds(text1);
Assert.Equal([bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId], ids1);
var ids2 = bertTokenizer.EncodeToIds(text2);
Assert.Equal([bertTokenizer.ClsTokenId, 13, 14, 15, 5, bertTokenizer.SepTokenId], ids2);
Assert.Equal(
[bertTokenizer.ClsTokenId, bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId, bertTokenizer.SepTokenId],
bertTokenizer.BuildInputsWithSpecialTokens(ids1));
Span<int> ids1Span = stackalloc int[1];
OperationStatus status = bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids1Span, out int written);
Assert.Equal(OperationStatus.DestinationTooSmall, status);
Assert.Equal(0, written);
ids1Span = stackalloc int[ids1.Count + 2];
status = bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids1Span, out written);
Assert.Equal(OperationStatus.Done, status);
Assert.Equal(ids1.Count + 2, written);
Assert.Equal(new int[] { bertTokenizer.ClsTokenId, bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId, bertTokenizer.SepTokenId }, ids1Span.ToArray());
Assert.Equal(
[bertTokenizer.ClsTokenId, bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId, bertTokenizer.SepTokenId, bertTokenizer.ClsTokenId, 13, 14, 15, 5, bertTokenizer.SepTokenId, bertTokenizer.SepTokenId],
bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids2));
ids1Span = stackalloc int[1];
status = bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids1Span, out written, ids2);
Assert.Equal(OperationStatus.DestinationTooSmall, status);
Assert.Equal(0, written);
ids1Span = stackalloc int[ids1.Count + ids2.Count + 3];
status = bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids1Span, out written, ids2);
Assert.Equal(OperationStatus.Done, status);
Assert.Equal(ids1Span.Length, written);
Assert.Equal(
new int[] { bertTokenizer.ClsTokenId, bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId, bertTokenizer.SepTokenId, bertTokenizer.ClsTokenId, 13, 14, 15, 5, bertTokenizer.SepTokenId, bertTokenizer.SepTokenId },
ids1Span.ToArray());
ids1 = bertTokenizer.EncodeToIds(text1, addSpecialTokens: false);
Assert.Equal([8, 6, 10, 11, 12, 7], ids1);
ids2 = bertTokenizer.EncodeToIds(text2, addSpecialTokens: false);
Assert.Equal([13, 14, 15, 5], ids2);
Assert.Equal(
[bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId],
bertTokenizer.BuildInputsWithSpecialTokens(ids1));
ids1Span = stackalloc int[1];
status = bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids1Span, out written);
Assert.Equal(OperationStatus.DestinationTooSmall, status);
Assert.Equal(0, written);
ids1Span = stackalloc int[ids1.Count + 2];
status = bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids1Span, out written);
Assert.Equal(OperationStatus.Done, status);
Assert.Equal(ids1Span.Length, written);
Assert.Equal(
new int[] { bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId },
ids1Span.ToArray());
Assert.Equal(
[bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId, 13, 14, 15, 5, bertTokenizer.SepTokenId],
bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids2));
ids1Span = stackalloc int[1];
status = bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids1Span, out written, ids2);
Assert.Equal(OperationStatus.DestinationTooSmall, status);
Assert.Equal(0, written);
ids1Span = stackalloc int[ids1.Count + ids2.Count + 3];
status = bertTokenizer.BuildInputsWithSpecialTokens(ids1, ids1Span, out written, ids2);
Assert.Equal(OperationStatus.Done, status);
Assert.Equal(ids1Span.Length, written);
Assert.Equal(
new int[] { bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId, 13, 14, 15, 5, bertTokenizer.SepTokenId },
ids1Span.ToArray());
}
finally
{
File.Delete(vocabFile);
}
}
[Fact]
public void TestGetSpecialTokensMask()
{
// Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you", "i", "am", "fine"];
string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens);
try
{
using Stream vocabStream = File.OpenRead(vocabFile);
BertTokenizer bertTokenizer = BertTokenizer.Create(vocabFile);
string text1 = "Hello, How are you?";
string text2 = "I am fine!";
var ids1 = bertTokenizer.EncodeToIds(text1);
Assert.Equal([bertTokenizer.ClsTokenId, 8, 6, 10, 11, 12, 7, bertTokenizer.SepTokenId], ids1);
var ids2 = bertTokenizer.EncodeToIds(text2);
Assert.Equal([bertTokenizer.ClsTokenId, 13, 14, 15, 5, bertTokenizer.SepTokenId], ids2);
Assert.Equal(
[1, 0, 0, 0, 0, 0, 0, 1],
bertTokenizer.GetSpecialTokensMask(ids1, tokenIds1: null, alreadyHasSpecialTokens: true));
Span<int> ids1Span = stackalloc int[1];
OperationStatus status = bertTokenizer.GetSpecialTokensMask(ids1, ids1Span, out int written, alreadyHasSpecialTokens: true);
Assert.Equal(OperationStatus.DestinationTooSmall, status);
Assert.Equal(0, written);
ids1Span = stackalloc int[ids1.Count];
status = bertTokenizer.GetSpecialTokensMask(ids1, ids1Span, out written, alreadyHasSpecialTokens: true);
Assert.Equal(OperationStatus.Done, status);
Assert.Equal(ids1.Count, written);
Assert.Equal(new int[] { 1, 0, 0, 0, 0, 0, 0, 1 }, ids1Span.ToArray());
Assert.Equal(
[1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1],
bertTokenizer.GetSpecialTokensMask(ids1, tokenIds1: ids2, alreadyHasSpecialTokens: true));
ids1Span = stackalloc int[1];
status = bertTokenizer.GetSpecialTokensMask(ids1, ids1Span, out written, ids2, alreadyHasSpecialTokens: true);
Assert.Equal(OperationStatus.DestinationTooSmall, status);
Assert.Equal(0, written);
ids1Span = stackalloc int[ids1.Count + ids2.Count];
status = bertTokenizer.GetSpecialTokensMask(ids1, ids1Span, out written, ids2, alreadyHasSpecialTokens: true);
Assert.Equal(OperationStatus.Done, status);
Assert.Equal(ids1.Count + ids2.Count, written);
Assert.Equal(new int[] { 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1 }, ids1Span.ToArray());
ids1 = bertTokenizer.EncodeToIds(text1, addSpecialTokens: false);
Assert.Equal([8, 6, 10, 11, 12, 7], ids1);
ids2 = bertTokenizer.EncodeToIds(text2, addSpecialTokens: false);
Assert.Equal([13, 14, 15, 5], ids2);
Assert.Equal(
[1, 0, 0, 0, 0, 0, 0, 1],
bertTokenizer.GetSpecialTokensMask(ids1, tokenIds1: null, alreadyHasSpecialTokens: false));
ids1Span = stackalloc int[1];
status = bertTokenizer.GetSpecialTokensMask(ids1, ids1Span, out written, alreadyHasSpecialTokens: false);
Assert.Equal(OperationStatus.DestinationTooSmall, status);
Assert.Equal(0, written);
ids1Span = stackalloc int[ids1.Count + 2];
status = bertTokenizer.GetSpecialTokensMask(ids1, ids1Span, out written, alreadyHasSpecialTokens: false);
Assert.Equal(OperationStatus.Done, status);
Assert.Equal(ids1.Count + 2, written);
Assert.Equal(new int[] { 1, 0, 0, 0, 0, 0, 0, 1 }, ids1Span.ToArray());
Assert.Equal(
[1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1],
bertTokenizer.GetSpecialTokensMask(ids1, tokenIds1: ids2, alreadyHasSpecialTokens: false));
ids1Span = stackalloc int[1];
status = bertTokenizer.GetSpecialTokensMask(ids1, ids1Span, out written, ids2, alreadyHasSpecialTokens: false);
Assert.Equal(OperationStatus.DestinationTooSmall, status);
Assert.Equal(0, written);
ids1Span = stackalloc int[ids1.Count + ids2.Count + 3];
status = bertTokenizer.GetSpecialTokensMask(ids1, ids1Span, out written, ids2, alreadyHasSpecialTokens: false);
Assert.Equal(OperationStatus.Done, status);
Assert.Equal(ids1.Count + ids2.Count + 3, written);
Assert.Equal(new int[] { 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1 }, ids1Span.ToArray());
}
finally
{
File.Delete(vocabFile);
}
}
[Fact]
public void TestCreateTokenTypeIdsFromSequences()
{
// Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you", "i", "am", "fine"];
string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens);
try
{
using Stream vocabStream = File.OpenRead(vocabFile);
BertTokenizer bertTokenizer = BertTokenizer.Create(vocabFile);
string text1 = "Hello, How are you?";
string text2 = "I am fine!";
var ids1 = bertTokenizer.EncodeToIds(text1, addSpecialTokens: false);
Assert.Equal([8, 6, 10, 11, 12, 7], ids1);
var ids2 = bertTokenizer.EncodeToIds(text2, addSpecialTokens: false);
Assert.Equal([13, 14, 15, 5], ids2);
Assert.Equal(
[0, 0, 0, 0, 0, 0, 0, 0],
bertTokenizer.CreateTokenTypeIdsFromSequences(ids1));
Span<int> ids1Span = stackalloc int[1];
OperationStatus status = bertTokenizer.CreateTokenTypeIdsFromSequences(ids1, ids1Span, out int written);
Assert.Equal(OperationStatus.DestinationTooSmall, status);
Assert.Equal(0, written);
ids1Span = stackalloc int[ids1.Count + 2];
status = bertTokenizer.CreateTokenTypeIdsFromSequences(ids1, ids1Span, out written);
Assert.Equal(OperationStatus.Done, status);
Assert.Equal(ids1.Count + 2, written);
Assert.Equal(new int[] { 0, 0, 0, 0, 0, 0, 0, 0 }, ids1Span.ToArray());
Assert.Equal(
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
bertTokenizer.CreateTokenTypeIdsFromSequences(ids1, ids2));
ids1Span = stackalloc int[1];
status = bertTokenizer.CreateTokenTypeIdsFromSequences(ids1, ids1Span, out written, ids2);
Assert.Equal(OperationStatus.DestinationTooSmall, status);
Assert.Equal(0, written);
ids1Span = stackalloc int[ids1.Count + ids2.Count + 3];
status = bertTokenizer.CreateTokenTypeIdsFromSequences(ids1, ids1Span, out written, ids2);
Assert.Equal(OperationStatus.Done, status);
Assert.Equal(ids1Span.Length, written);
Assert.Equal(new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1 }, ids1Span.ToArray());
}
finally
{
File.Delete(vocabFile);
}
}
}
}

Просмотреть файл

@ -251,7 +251,7 @@ namespace Microsoft.ML.Tokenizers.Tests
try
{
BpeTokenizer bpe = BpeTokenizer.Create(vocabFile: vocabFile, mergesFile: mergesFile, preTokenizer: PreTokenizer.CreateWhiteSpace(), normalizer: null, unknownToken: unknownToken,
BpeTokenizer bpe = BpeTokenizer.Create(vocabFile: vocabFile, mergesFile: mergesFile, preTokenizer: PreTokenizer.CreateWordOrNonWordPreTokenizer(), normalizer: null, unknownToken: unknownToken,
continuingSubwordPrefix: continuingSubwordPrefix, endOfWordSuffix: endOfWordSuffix, fuseUnknownTokens: fuseUnknownToken);
Tokenizer tokenizer = bpe;
IReadOnlyList<EncodedToken> encoding = tokenizer.EncodeToTokens(sentence, out _);
@ -500,7 +500,7 @@ namespace Microsoft.ML.Tokenizers.Tests
using Stream vocabStream = File.OpenRead(Path.Combine(@"Gpt-2", "vocab.json"));
using Stream mergesStream = File.OpenRead(Path.Combine(@"Gpt-2", "merges.txt"));
var bpeTokenizer = BpeTokenizer.Create(vocabStream, mergesStream, PreTokenizer.CreateWhiteSpace(addedTokens), normalizer: null, addedTokens: addedTokens, unknownToken: "<|endoftext|>");
var bpeTokenizer = BpeTokenizer.Create(vocabStream, mergesStream, PreTokenizer.CreateWordOrNonWordPreTokenizer(addedTokens), normalizer: null, addedTokens: addedTokens, unknownToken: "<|endoftext|>");
string input = "Hello, y'all! <issue_comment>How are you 😁 ?<|endoftext|>";
@ -556,7 +556,7 @@ namespace Microsoft.ML.Tokenizers.Tests
emptyVocabStream.Position = 0;
return BpeTokenizer.Create(
vocabStream: emptyVocabStream, mergesStream: null, preTokenizer: preTokenizer ?? PreTokenizer.CreateWhiteSpace(), normalizer: normalizer, unknownToken: "Ukn");
vocabStream: emptyVocabStream, mergesStream: null, preTokenizer: preTokenizer ?? PreTokenizer.CreateWordOrNonWordPreTokenizer(), normalizer: normalizer, unknownToken: "Ukn");
}
}
}

Просмотреть файл

@ -18,18 +18,25 @@ namespace Microsoft.ML.Tokenizers.Tests
{
yield return new object[]
{
PreTokenizer.CreateWhiteSpace(),
PreTokenizer.CreateWordOrNonWordPreTokenizer(),
"How are you doing?",
new (int Offset, int Length)[] { (0, 3), (4, 3), (8, 3), (12, 5), (17, 1), }
};
yield return new object[]
{
PreTokenizer.CreateWhiteSpace(),
PreTokenizer.CreateWordOrNonWordPreTokenizer(),
"I_am_Just_Fine!",
new (int Offset, int Length)[] { (0, 14), (14, 1) }
};
yield return new object[]
{
PreTokenizer.CreateWhiteSpacePreTokenizer(),
"Hello, how are you doing?!",
new (int Offset, int Length)[] { (0, 6), (7, 3), (11, 3), (15, 3), (19, 7) }
};
yield return new object[]
{
new SpacePreTokenizer(),
@ -61,9 +68,9 @@ namespace Microsoft.ML.Tokenizers.Tests
}
[Fact]
public void TestWhiteSpacePreTokenizer()
public void TestWordOrNonWordPreTokenizer()
{
Assert.Empty(PreTokenizer.CreateWhiteSpace().PreTokenize((string)null!));
Assert.Empty(PreTokenizer.CreateWordOrNonWordPreTokenizer().PreTokenize((string)null!));
}
public class SpacePreTokenizer : PreTokenizer

Просмотреть файл

@ -0,0 +1,221 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Xunit;
namespace Microsoft.ML.Tokenizers.Tests
{
public class WordPieceTests
{
static string[] _vocabTokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"];
internal static string CreateVocabFile(string[] vocabTokens)
{
string vocabFile = Path.GetTempFileName();
File.WriteAllLines(vocabFile, vocabTokens);
return vocabFile;
}
[Fact]
public void TestCreation()
{
string vocabFile = CreateVocabFile(_vocabTokens);
try
{
using Stream vocabStream = File.OpenRead(vocabFile);
WordPieceTokenizer[] wordPieceTokenizers = [WordPieceTokenizer.Create(vocabFile), WordPieceTokenizer.Create(vocabStream)];
foreach (var tokenizer in wordPieceTokenizers)
{
Assert.NotNull(tokenizer.PreTokenizer);
Assert.Equal("[UNK]", tokenizer.UnknownToken);
Assert.Equal(0, tokenizer.UnknownTokenId);
Assert.Null(tokenizer.Normalizer);
Assert.Equal(100, tokenizer.MaxInputCharsPerWord);
Assert.Equal("##", tokenizer.ContinuingSubwordPrefix);
}
}
finally
{
File.Delete(vocabFile);
}
}
[Fact]
public void TestTokenization()
{
string vocabFile = CreateVocabFile(_vocabTokens);
try
{
WordPieceTokenizer tokenizer = WordPieceTokenizer.Create(vocabFile);
Assert.Null(tokenizer.SpecialTokens);
IReadOnlyList<EncodedToken> tokens = tokenizer.EncodeToTokens("", out _);
Assert.Empty(tokens);
Assert.Equal(0, tokenizer.CountTokens(""));
IReadOnlyList<int> ids = tokenizer.EncodeToIds("");
Assert.Empty(ids);
int index = tokenizer.GetIndexByTokenCount("", maxTokenCount: 10, normalizedString: out _, tokenCount: out int tokenCount);
Assert.Equal(0, index);
Assert.Equal(0, tokenCount);
index = tokenizer.GetIndexByTokenCountFromEnd("", maxTokenCount: 10, normalizedString: out _, tokenCount: out tokenCount);
Assert.Equal(0, index);
Assert.Equal(0, tokenCount);
string text = "unwanted running";
tokens = tokenizer.EncodeToTokens(text, out _);
Assert.Equal(
[
new EncodedToken(7, "un", new Range(0, 2)),
new EncodedToken(4, "##want", new Range(2, 6)),
new EncodedToken(5, "##ed", new Range(6, 8)),
new EncodedToken(8, "runn", new Range(9, 13)),
new EncodedToken(9, "##ing", new Range(13, 16))
],
tokens
);
ids = tokenizer.EncodeToIds(text);
Assert.Equal([7, 4, 5, 8, 9], ids);
int[] expectedTokenCount = [0, 0, 3, 3, 5];
for (int i = 1; i <= 5; i++)
{
Assert.Equal(ids.Take(expectedTokenCount[i - 1]).ToArray(), tokenizer.EncodeToIds(text, maxTokenCount: i, normalizedText: out _, out tokenCount));
}
Assert.Equal(text, tokenizer.Decode(ids));
Span<char> buffer = stackalloc char[text.Length];
for (int i = 0; i < text.Length - 1; i++)
{
Span<char> bufferSlice = buffer.Slice(0, i);
OperationStatus result = tokenizer.Decode(ids, bufferSlice, out int idsConsumed, out int charsWritten);
Assert.Equal(OperationStatus.DestinationTooSmall, result);
int j = 0;
while (i >= tokens[j].Offset.End.Value)
{
j++;
}
Assert.Equal(j, idsConsumed);
Assert.Equal(j == 0 ? 0 : tokens[j - 1].Offset.End.Value, charsWritten);
Assert.Equal(j == 0 ? "" : text.Substring(0, tokens[j - 1].Offset.End.Value), bufferSlice.Slice(0, charsWritten).ToString());
}
Assert.Equal(5, tokenizer.CountTokens(text));
int[] expectedIndexes = [0, 0, 8, 9, 16];
expectedTokenCount = [0, 0, 3, 3, 5];
for (int i = 1; i <= 5; i++)
{
index = tokenizer.GetIndexByTokenCount(text, maxTokenCount: i, normalizedString: out _, out tokenCount);
Assert.Equal(expectedTokenCount[i - 1], tokenCount);
Assert.Equal(expectedIndexes[i - 1], index);
}
expectedIndexes = [16, 9, 8, 8, 0];
expectedTokenCount = [0, 2, 2, 2, 5];
for (int i = 1; i <= 5; i++)
{
index = tokenizer.GetIndexByTokenCountFromEnd(text, maxTokenCount: i, normalizedString: out _, out tokenCount);
Assert.Equal(expectedTokenCount[i - 1], tokenCount);
Assert.Equal(expectedIndexes[i - 1], index);
}
}
finally
{
File.Delete(vocabFile);
}
}
[Fact]
public void TestTokenizationWithUnknownTokens()
{
string vocabFile = CreateVocabFile(_vocabTokens);
try
{
WordPieceTokenizer tokenizer = WordPieceTokenizer.Create(vocabFile);
string text = "unwantedX running";
IReadOnlyList<EncodedToken> tokens = tokenizer.EncodeToTokens(text, out _);
Assert.Equal(
[
new EncodedToken(0, "[UNK]", new Range(0, 9)),
new EncodedToken(8, "runn", new Range(10, 14)),
new EncodedToken(9, "##ing", new Range(14, 17))
],
tokens
);
IReadOnlyList<int> ids = tokenizer.EncodeToIds(text);
Assert.Equal([0, 8, 9], ids);
Assert.Equal("[UNK] running", tokenizer.Decode(ids));
}
finally
{
File.Delete(vocabFile);
}
}
[Fact]
public void TestTokenizationWithSpecialTokens()
{
string vocabFile = CreateVocabFile(_vocabTokens);
try
{
Dictionary<string, int> specialTokens = new Dictionary<string, int>
{
{ "[UNK]", 0 }, { "[CLS]", 1 }, { "[SEP]", 2 }
};
WordPieceTokenizer tokenizer = WordPieceTokenizer.Create(vocabFile, specialTokens: specialTokens);
Assert.Equal(specialTokens, tokenizer.SpecialTokens);
string text = "[UNK] unwanted [SEP][CLS] running [CLS]";
IReadOnlyList<EncodedToken> tokens = tokenizer.EncodeToTokens(text, out _);
Assert.Equal(
[
new EncodedToken(0, "[UNK]", new Range(0, 5)),
new EncodedToken(7, "un", new Range(6, 8)),
new EncodedToken(4, "##want", new Range(8, 12)),
new EncodedToken(5, "##ed", new Range(12, 14)),
new EncodedToken(2, "[SEP]", new Range(15, 20)),
new EncodedToken(1, "[CLS]", new Range(20, 25)),
new EncodedToken(8, "runn", new Range(26, 30)),
new EncodedToken(9, "##ing", new Range(30, 33)),
new EncodedToken(1, "[CLS]", new Range(34, 39)),
],
tokens
);
IReadOnlyList<int> ids = tokenizer.EncodeToIds(text);
Assert.Equal([0, 7, 4, 5, 2, 1, 8, 9, 1], ids);
Assert.Equal("[UNK] unwanted [SEP] [CLS] running [CLS]", tokenizer.Decode(ids));
}
finally
{
File.Delete(vocabFile);
}
}
}
}