Address the feedback regarding Bert tokenizer (#7280)
* Address the feedback regarding Bert tokenizer * Small fix
This commit is contained in:
Родитель
a7a6d88b05
Коммит
a9b4212eb3
|
@ -290,9 +290,25 @@ namespace Microsoft.ML.Tokenizers
|
|||
throw new ArgumentNullException(nameof(tokenIds0));
|
||||
}
|
||||
|
||||
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
|
||||
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
|
||||
List<int> ids = new List<int>(capacity: capacity) { ClsTokenId };
|
||||
List<int> ids;
|
||||
|
||||
if (tokenIds0 is ICollection<int> c1)
|
||||
{
|
||||
int capacity = c1.Count + 2; // Add 2 for [CLS] and two [SEP] tokens.
|
||||
|
||||
if (tokenIds1 is not null)
|
||||
{
|
||||
capacity += tokenIds1 is ICollection<int> c2 ? c2.Count + 1 : c1.Count + 1;
|
||||
}
|
||||
|
||||
ids = new(capacity) { ClsTokenId };
|
||||
}
|
||||
else
|
||||
{
|
||||
// slow path
|
||||
ids = new List<int>(10) { ClsTokenId };
|
||||
}
|
||||
|
||||
ids.AddRange(tokenIds0);
|
||||
ids.Add(SepTokenId);
|
||||
|
||||
|
@ -323,29 +339,48 @@ namespace Microsoft.ML.Tokenizers
|
|||
throw new ArgumentNullException(nameof(tokenIds0));
|
||||
}
|
||||
|
||||
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
|
||||
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
|
||||
if (buffer.Length < capacity)
|
||||
written = 0;
|
||||
if (buffer.Length < 1)
|
||||
{
|
||||
written = 0;
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
|
||||
written = 0;
|
||||
buffer[written++] = ClsTokenId;
|
||||
foreach (int id in tokenIds0)
|
||||
{
|
||||
if (buffer.Length <= written)
|
||||
{
|
||||
written = 0;
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
|
||||
buffer[written++] = id;
|
||||
}
|
||||
|
||||
if (buffer.Length <= written)
|
||||
{
|
||||
written = 0;
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
buffer[written++] = SepTokenId;
|
||||
|
||||
if (tokenIds1 is not null)
|
||||
{
|
||||
foreach (int id in tokenIds1)
|
||||
{
|
||||
if (buffer.Length <= written)
|
||||
{
|
||||
written = 0;
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
buffer[written++] = id;
|
||||
}
|
||||
|
||||
if (buffer.Length <= written)
|
||||
{
|
||||
written = 0;
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
buffer[written++] = SepTokenId;
|
||||
}
|
||||
|
||||
|
@ -367,11 +402,22 @@ namespace Microsoft.ML.Tokenizers
|
|||
throw new ArgumentNullException(nameof(tokenIds0));
|
||||
}
|
||||
|
||||
int capacity = alreadyHasSpecialTokens ?
|
||||
tokenIds0.Count() + (tokenIds1?.Count() ?? 0) :
|
||||
tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : 1); // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
|
||||
List<int> mask;
|
||||
if (tokenIds0 is ICollection<int> c1)
|
||||
{
|
||||
int capcity = c1.Count + 2;
|
||||
|
||||
List<int> mask = new List<int>(capacity: capacity);
|
||||
if (tokenIds1 is not null)
|
||||
{
|
||||
capcity += tokenIds1 is ICollection<int> c2 ? c2.Count + 1 : c1.Count + 1;
|
||||
}
|
||||
|
||||
mask = new List<int>(capcity);
|
||||
}
|
||||
else
|
||||
{
|
||||
mask = new List<int>(10);
|
||||
}
|
||||
|
||||
if (!alreadyHasSpecialTokens)
|
||||
{
|
||||
|
@ -420,31 +466,49 @@ namespace Microsoft.ML.Tokenizers
|
|||
throw new ArgumentNullException(nameof(tokenIds0));
|
||||
}
|
||||
|
||||
int capacity = alreadyHasSpecialTokens ?
|
||||
tokenIds0.Count() + (tokenIds1?.Count() ?? 0) :
|
||||
tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1); // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
|
||||
|
||||
written = 0;
|
||||
if (buffer.Length < capacity)
|
||||
{
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
|
||||
if (!alreadyHasSpecialTokens)
|
||||
{
|
||||
if (buffer.Length < 1)
|
||||
{
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
buffer[written++] = 1; // CLS
|
||||
|
||||
foreach (int id in tokenIds0)
|
||||
{
|
||||
if (buffer.Length <= written)
|
||||
{
|
||||
written = 0;
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
buffer[written++] = 0;
|
||||
}
|
||||
|
||||
if (buffer.Length <= written)
|
||||
{
|
||||
written = 0;
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
buffer[written++] = 1; // SEP
|
||||
|
||||
if (tokenIds1 is not null)
|
||||
{
|
||||
foreach (int id in tokenIds1)
|
||||
{
|
||||
if (buffer.Length <= written)
|
||||
{
|
||||
written = 0;
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
buffer[written++] = 0;
|
||||
}
|
||||
|
||||
if (buffer.Length <= written)
|
||||
{
|
||||
written = 0;
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
buffer[written++] = 1; // SEP
|
||||
}
|
||||
|
||||
|
@ -453,6 +517,11 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
foreach (int id in tokenIds0)
|
||||
{
|
||||
if (buffer.Length <= written)
|
||||
{
|
||||
written = 0;
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
buffer[written++] = id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0;
|
||||
}
|
||||
|
||||
|
@ -460,6 +529,11 @@ namespace Microsoft.ML.Tokenizers
|
|||
{
|
||||
foreach (int id in tokenIds1)
|
||||
{
|
||||
if (buffer.Length <= written)
|
||||
{
|
||||
written = 0;
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
buffer[written++] = id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0;
|
||||
}
|
||||
}
|
||||
|
@ -484,21 +558,38 @@ namespace Microsoft.ML.Tokenizers
|
|||
throw new ArgumentNullException(nameof(tokenIds0));
|
||||
}
|
||||
|
||||
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
|
||||
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
|
||||
List<int> typeIds;
|
||||
if (tokenIds0 is ICollection<int> c1)
|
||||
{
|
||||
int capacity = c1.Count + 2; // Add 2 for [CLS] and [SEP] tokens.
|
||||
|
||||
List<int> typeIds = new List<int>(capacity);
|
||||
for (int i = 0; i < tokenIds0.Count() + 2; i++) // Add 2 for [CLS] and [SEP] tokens.
|
||||
if (tokenIds1 is not null)
|
||||
{
|
||||
capacity += tokenIds1 is ICollection<int> c2 ? c2.Count + 1 : c1.Count + 1;
|
||||
}
|
||||
|
||||
typeIds = new List<int>(capacity);
|
||||
}
|
||||
else
|
||||
{
|
||||
typeIds = new List<int>(10);
|
||||
}
|
||||
|
||||
foreach (var id in tokenIds0)
|
||||
{
|
||||
typeIds.Add(0);
|
||||
}
|
||||
typeIds.Add(0); // [CLS]
|
||||
typeIds.Add(0); // [SEP]
|
||||
|
||||
if (tokenIds1 is not null)
|
||||
{
|
||||
for (int i = 0; i < tokenIds1.Count() + 1; i++) // Add 1 for [SEP] token.
|
||||
foreach (int id in tokenIds1)
|
||||
{
|
||||
typeIds.Add(1);
|
||||
}
|
||||
|
||||
typeIds.Add(1); // [SEP]
|
||||
}
|
||||
|
||||
return typeIds;
|
||||
|
@ -515,22 +606,40 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
|
||||
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
|
||||
if (buffer.Length < capacity)
|
||||
if (buffer.Length < 2)
|
||||
{
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
buffer[written++] = 0; // [CLS]
|
||||
buffer[written++] = 0; // [SEP]
|
||||
|
||||
for (int i = 0; i < tokenIds0.Count() + 2; i++) // Add 2 for [CLS] and [SEP] tokens.
|
||||
foreach (int id in tokenIds0)
|
||||
{
|
||||
if (buffer.Length <= written)
|
||||
{
|
||||
written = 0;
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
buffer[written++] = 0;
|
||||
}
|
||||
|
||||
if (tokenIds1 is not null)
|
||||
{
|
||||
for (int i = 0; i < tokenIds1.Count() + 1; i++) // Add 1 for [SEP] token.
|
||||
foreach (int id in tokenIds1)
|
||||
{
|
||||
if (buffer.Length <= written)
|
||||
{
|
||||
written = 0;
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
buffer[written++] = 1;
|
||||
}
|
||||
|
||||
if (buffer.Length < written)
|
||||
{
|
||||
return OperationStatus.DestinationTooSmall;
|
||||
}
|
||||
buffer[written++] = 1; // [SEP]
|
||||
}
|
||||
|
||||
return OperationStatus.Done;
|
||||
|
|
|
@ -233,7 +233,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
continuingSubwordPrefix,
|
||||
maxInputCharsPerWord,
|
||||
cancellationToken,
|
||||
disposeStream: true);
|
||||
disposeStream: true).ConfigureAwait(false);
|
||||
|
||||
/// <summary>
|
||||
/// Create a new instance of the <see cref="WordPieceTokenizer"/> class asynchronously.
|
||||
|
@ -259,7 +259,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
|
||||
int maxInputCharsPerWord = DefaultMaxInputCharsPerWord,
|
||||
CancellationToken cancellationToken = default) =>
|
||||
await CreateAsync(vocabStream, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord, cancellationToken, disposeStream: false);
|
||||
await CreateAsync(vocabStream, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord, cancellationToken, disposeStream: false).ConfigureAwait(false);
|
||||
|
||||
private static async Task<WordPieceTokenizer> CreateAsync(
|
||||
Stream vocabStream,
|
||||
|
|
|
@ -69,7 +69,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
if (category == UnicodeCategory.SpaceSeparator)
|
||||
{
|
||||
InsertChar(ref buffer, ref index, ' ');
|
||||
AddChar(ref buffer, ref index, ' ');
|
||||
i += inc;
|
||||
continue;
|
||||
}
|
||||
|
@ -85,7 +85,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
int length = original.AsSpan().Slice(i, inc + 1).ToLowerInvariant(casingBuffer);
|
||||
Debug.Assert(length > 0);
|
||||
|
||||
InsertSpan(ref buffer, ref index, casingBuffer.Slice(0, length));
|
||||
AddSpan(ref buffer, ref index, casingBuffer.Slice(0, length));
|
||||
|
||||
i += inc;
|
||||
continue;
|
||||
|
@ -93,22 +93,22 @@ namespace Microsoft.ML.Tokenizers
|
|||
|
||||
if (_tokenizeChineseChars && IsChineseChar(codePoint))
|
||||
{
|
||||
InsertChar(ref buffer, ref index, ' ');
|
||||
InsertChar(ref buffer, ref index, c);
|
||||
AddChar(ref buffer, ref index, ' ');
|
||||
AddChar(ref buffer, ref index, c);
|
||||
if (inc > 0)
|
||||
{
|
||||
InsertChar(ref buffer, ref index, original[i + 1]);
|
||||
AddChar(ref buffer, ref index, original[i + 1]);
|
||||
}
|
||||
InsertChar(ref buffer, ref index, ' ');
|
||||
AddChar(ref buffer, ref index, ' ');
|
||||
|
||||
i += inc;
|
||||
continue;
|
||||
}
|
||||
|
||||
InsertChar(ref buffer, ref index, c);
|
||||
AddChar(ref buffer, ref index, c);
|
||||
if (inc > 0)
|
||||
{
|
||||
InsertChar(ref buffer, ref index, original[i + 1]);
|
||||
AddChar(ref buffer, ref index, original[i + 1]);
|
||||
}
|
||||
i += inc;
|
||||
}
|
||||
|
@ -147,7 +147,7 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private static void InsertChar(ref char[] buffer, ref int index, char c)
|
||||
private static void AddChar(ref char[] buffer, ref int index, char c)
|
||||
{
|
||||
if (index >= buffer.Length)
|
||||
{
|
||||
|
@ -158,9 +158,9 @@ namespace Microsoft.ML.Tokenizers
|
|||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private static void InsertSpan(ref char[] buffer, ref int index, Span<char> chars)
|
||||
private static void AddSpan(ref char[] buffer, ref int index, Span<char> chars)
|
||||
{
|
||||
if (index + buffer.Length >= buffer.Length)
|
||||
if (index + chars.Length >= buffer.Length)
|
||||
{
|
||||
Helpers.ArrayPoolGrow(ref buffer, index + buffer.Length + 10);
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче