Address the feedback regarding Bert tokenizer (#7280)

* Address the feedback regarding Bert tokenizer

* Small fix
This commit is contained in:
Tarek Mahmoud Sayed 2024-10-26 16:07:20 -07:00 коммит произвёл GitHub
Родитель a7a6d88b05
Коммит a9b4212eb3
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
3 изменённых файлов: 151 добавлений и 42 удалений

Просмотреть файл

@ -290,9 +290,25 @@ namespace Microsoft.ML.Tokenizers
throw new ArgumentNullException(nameof(tokenIds0));
}
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
List<int> ids = new List<int>(capacity: capacity) { ClsTokenId };
List<int> ids;
if (tokenIds0 is ICollection<int> c1)
{
int capacity = c1.Count + 2; // Add 2 for [CLS] and two [SEP] tokens.
if (tokenIds1 is not null)
{
capacity += tokenIds1 is ICollection<int> c2 ? c2.Count + 1 : c1.Count + 1;
}
ids = new(capacity) { ClsTokenId };
}
else
{
// slow path
ids = new List<int>(10) { ClsTokenId };
}
ids.AddRange(tokenIds0);
ids.Add(SepTokenId);
@ -323,29 +339,48 @@ namespace Microsoft.ML.Tokenizers
throw new ArgumentNullException(nameof(tokenIds0));
}
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
if (buffer.Length < capacity)
written = 0;
if (buffer.Length < 1)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
written = 0;
buffer[written++] = ClsTokenId;
foreach (int id in tokenIds0)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = id;
}
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = SepTokenId;
if (tokenIds1 is not null)
{
foreach (int id in tokenIds1)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = id;
}
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = SepTokenId;
}
@ -367,11 +402,22 @@ namespace Microsoft.ML.Tokenizers
throw new ArgumentNullException(nameof(tokenIds0));
}
int capacity = alreadyHasSpecialTokens ?
tokenIds0.Count() + (tokenIds1?.Count() ?? 0) :
tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : 1); // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
List<int> mask;
if (tokenIds0 is ICollection<int> c1)
{
int capcity = c1.Count + 2;
List<int> mask = new List<int>(capacity: capacity);
if (tokenIds1 is not null)
{
capcity += tokenIds1 is ICollection<int> c2 ? c2.Count + 1 : c1.Count + 1;
}
mask = new List<int>(capcity);
}
else
{
mask = new List<int>(10);
}
if (!alreadyHasSpecialTokens)
{
@ -420,31 +466,49 @@ namespace Microsoft.ML.Tokenizers
throw new ArgumentNullException(nameof(tokenIds0));
}
int capacity = alreadyHasSpecialTokens ?
tokenIds0.Count() + (tokenIds1?.Count() ?? 0) :
tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1); // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
written = 0;
if (buffer.Length < capacity)
{
return OperationStatus.DestinationTooSmall;
}
if (!alreadyHasSpecialTokens)
{
if (buffer.Length < 1)
{
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 1; // CLS
foreach (int id in tokenIds0)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 0;
}
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 1; // SEP
if (tokenIds1 is not null)
{
foreach (int id in tokenIds1)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 0;
}
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 1; // SEP
}
@ -453,6 +517,11 @@ namespace Microsoft.ML.Tokenizers
foreach (int id in tokenIds0)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0;
}
@ -460,6 +529,11 @@ namespace Microsoft.ML.Tokenizers
{
foreach (int id in tokenIds1)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0;
}
}
@ -484,21 +558,38 @@ namespace Microsoft.ML.Tokenizers
throw new ArgumentNullException(nameof(tokenIds0));
}
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
List<int> typeIds;
if (tokenIds0 is ICollection<int> c1)
{
int capacity = c1.Count + 2; // Add 2 for [CLS] and [SEP] tokens.
List<int> typeIds = new List<int>(capacity);
for (int i = 0; i < tokenIds0.Count() + 2; i++) // Add 2 for [CLS] and [SEP] tokens.
if (tokenIds1 is not null)
{
capacity += tokenIds1 is ICollection<int> c2 ? c2.Count + 1 : c1.Count + 1;
}
typeIds = new List<int>(capacity);
}
else
{
typeIds = new List<int>(10);
}
foreach (var id in tokenIds0)
{
typeIds.Add(0);
}
typeIds.Add(0); // [CLS]
typeIds.Add(0); // [SEP]
if (tokenIds1 is not null)
{
for (int i = 0; i < tokenIds1.Count() + 1; i++) // Add 1 for [SEP] token.
foreach (int id in tokenIds1)
{
typeIds.Add(1);
}
typeIds.Add(1); // [SEP]
}
return typeIds;
@ -515,22 +606,40 @@ namespace Microsoft.ML.Tokenizers
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
if (buffer.Length < capacity)
if (buffer.Length < 2)
{
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 0; // [CLS]
buffer[written++] = 0; // [SEP]
for (int i = 0; i < tokenIds0.Count() + 2; i++) // Add 2 for [CLS] and [SEP] tokens.
foreach (int id in tokenIds0)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 0;
}
if (tokenIds1 is not null)
{
for (int i = 0; i < tokenIds1.Count() + 1; i++) // Add 1 for [SEP] token.
foreach (int id in tokenIds1)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 1;
}
if (buffer.Length < written)
{
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 1; // [SEP]
}
return OperationStatus.Done;

Просмотреть файл

@ -233,7 +233,7 @@ namespace Microsoft.ML.Tokenizers
continuingSubwordPrefix,
maxInputCharsPerWord,
cancellationToken,
disposeStream: true);
disposeStream: true).ConfigureAwait(false);
/// <summary>
/// Create a new instance of the <see cref="WordPieceTokenizer"/> class asynchronously.
@ -259,7 +259,7 @@ namespace Microsoft.ML.Tokenizers
string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
int maxInputCharsPerWord = DefaultMaxInputCharsPerWord,
CancellationToken cancellationToken = default) =>
await CreateAsync(vocabStream, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord, cancellationToken, disposeStream: false);
await CreateAsync(vocabStream, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord, cancellationToken, disposeStream: false).ConfigureAwait(false);
private static async Task<WordPieceTokenizer> CreateAsync(
Stream vocabStream,

Просмотреть файл

@ -69,7 +69,7 @@ namespace Microsoft.ML.Tokenizers
if (category == UnicodeCategory.SpaceSeparator)
{
InsertChar(ref buffer, ref index, ' ');
AddChar(ref buffer, ref index, ' ');
i += inc;
continue;
}
@ -85,7 +85,7 @@ namespace Microsoft.ML.Tokenizers
int length = original.AsSpan().Slice(i, inc + 1).ToLowerInvariant(casingBuffer);
Debug.Assert(length > 0);
InsertSpan(ref buffer, ref index, casingBuffer.Slice(0, length));
AddSpan(ref buffer, ref index, casingBuffer.Slice(0, length));
i += inc;
continue;
@ -93,22 +93,22 @@ namespace Microsoft.ML.Tokenizers
if (_tokenizeChineseChars && IsChineseChar(codePoint))
{
InsertChar(ref buffer, ref index, ' ');
InsertChar(ref buffer, ref index, c);
AddChar(ref buffer, ref index, ' ');
AddChar(ref buffer, ref index, c);
if (inc > 0)
{
InsertChar(ref buffer, ref index, original[i + 1]);
AddChar(ref buffer, ref index, original[i + 1]);
}
InsertChar(ref buffer, ref index, ' ');
AddChar(ref buffer, ref index, ' ');
i += inc;
continue;
}
InsertChar(ref buffer, ref index, c);
AddChar(ref buffer, ref index, c);
if (inc > 0)
{
InsertChar(ref buffer, ref index, original[i + 1]);
AddChar(ref buffer, ref index, original[i + 1]);
}
i += inc;
}
@ -147,7 +147,7 @@ namespace Microsoft.ML.Tokenizers
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void InsertChar(ref char[] buffer, ref int index, char c)
private static void AddChar(ref char[] buffer, ref int index, char c)
{
if (index >= buffer.Length)
{
@ -158,9 +158,9 @@ namespace Microsoft.ML.Tokenizers
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void InsertSpan(ref char[] buffer, ref int index, Span<char> chars)
private static void AddSpan(ref char[] buffer, ref int index, Span<char> chars)
{
if (index + buffer.Length >= buffer.Length)
if (index + chars.Length >= buffer.Length)
{
Helpers.ArrayPoolGrow(ref buffer, index + buffer.Length + 10);
}