[Tokenization] Fix #5181 - make #5155 more explicit - move back the default logging level in tests to WARNING (#5252)

* fix-5181

Padding to max sequence length while truncation to another length was wrong on slow tokenizers

* clean up and fix #5155

* fix XLM test

* Fix tests for Transfo-XL

* logging only above WARNING in tests

* switch slow tokenizers tests in @slow

* fix Marian truncation tokenization test

* style and quality

* make the test a lot faster by limiting the sequence length used in tests
This commit is contained in:
Thomas Wolf 2020-06-25 17:24:28 +02:00 коммит произвёл GitHub
Родитель e008d520bb
Коммит 27cf1d97f0
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
9 изменённых файлов: 134 добавлений и 75 удалений

Просмотреть файл

@ -576,18 +576,6 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
Args:
batch_ids_pairs: list of tokenized input ids or input ids pairs
"""
if padding_strategy == PaddingStrategy.LONGEST:
# For simplicity we keep the single sentnce path here
def total_sequence_length(input_pairs):
first_ids, second_ids = input_pairs
return len(first_ids) + (
self.num_special_tokens_to_add()
if second_ids is None
else (len(second_ids) + self.num_special_tokens_to_add(pair=True))
)
max_length = max([total_sequence_length(input_pairs) for input_pairs in batch_ids_pairs])
padding_strategy = PaddingStrategy.MAX_LENGTH
batch_outputs = {}
for first_ids, second_ids in batch_ids_pairs:
@ -595,16 +583,16 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
first_ids,
second_ids,
add_special_tokens=add_special_tokens,
padding_strategy=padding_strategy,
padding_strategy=PaddingStrategy.DO_NOT_PAD, # we pad in batch afterward
truncation_strategy=truncation_strategy,
max_length=max_length,
stride=stride,
return_attention_mask=return_attention_mask,
return_attention_mask=False, # we pad in batch afterward
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_length=return_length,
return_tensors=None, # We will convert the whole batch to tensors at the end
return_tensors=None, # We convert the whole batch to tensors at the end
prepend_batch_axis=False,
verbose=verbose,
)
@ -614,6 +602,13 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
batch_outputs[key] = []
batch_outputs[key].append(value)
batch_outputs = self.pad(
batch_outputs,
padding=padding_strategy.value,
max_length=max_length,
return_attention_mask=return_attention_mask,
)
batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
return batch_outputs
@ -700,12 +695,13 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
)
# Padding
encoded_inputs = self.pad(
encoded_inputs,
max_length=max_length,
padding=padding_strategy.value,
return_attention_mask=return_attention_mask,
)
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
encoded_inputs = self.pad(
encoded_inputs,
max_length=max_length,
padding=padding_strategy.value,
return_attention_mask=return_attention_mask,
)
if return_length:
encoded_inputs["length"] = len(encoded_inputs["input_ids"])
@ -768,15 +764,29 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
else:
pair_ids = pair_ids[:-1]
elif truncation_strategy == TruncationStrategy.ONLY_FIRST:
assert len(ids) > num_tokens_to_remove
window_len = min(len(ids), stride + num_tokens_to_remove)
overflowing_tokens = ids[-window_len:]
ids = ids[:-num_tokens_to_remove]
elif truncation_strategy == TruncationStrategy.ONLY_SECOND:
assert pair_ids is not None and len(pair_ids) > num_tokens_to_remove
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
overflowing_tokens = pair_ids[-window_len:]
pair_ids = pair_ids[:-num_tokens_to_remove]
if len(ids) > num_tokens_to_remove:
window_len = min(len(ids), stride + num_tokens_to_remove)
overflowing_tokens = ids[-window_len:]
ids = ids[:-num_tokens_to_remove]
else:
logger.error(
f"We need to remove {num_tokens_to_remove} to truncate the input"
f"but the first sequence has a length {len(ids)}. "
f"Please select another truncation strategy than {truncation_strategy}, "
f"for instance 'longest_first' or 'only_second'."
)
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
if len(pair_ids) > num_tokens_to_remove:
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
overflowing_tokens = pair_ids[-window_len:]
pair_ids = pair_ids[:-num_tokens_to_remove]
else:
logger.error(
f"We need to remove {num_tokens_to_remove} to truncate the input"
f"but the second sequence has a length {len(pair_ids)}. "
f"Please select another truncation strategy than {truncation_strategy}, "
f"for instance 'longest_first' or 'only_first'."
)
return (ids, pair_ids, overflowing_tokens)

Просмотреть файл

@ -1890,7 +1890,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
if return_attention_mask is None:
return_attention_mask = "attention_mask" in self.model_input_names
if padding_strategy == PaddingStrategy.LONGEST and max_length is None:
if padding_strategy == PaddingStrategy.LONGEST:
max_length = len(encoded_inputs["input_ids"])
needs_to_be_padded = (

Просмотреть файл

@ -14,7 +14,6 @@
# limitations under the License.
import logging
import unittest
from transformers import is_torch_available
@ -67,7 +66,6 @@ if is_torch_available():
class AutoModelTest(unittest.TestCase):
@slow
def test_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
@ -82,7 +80,6 @@ class AutoModelTest(unittest.TestCase):
@slow
def test_model_for_pretraining_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
@ -98,7 +95,6 @@ class AutoModelTest(unittest.TestCase):
@slow
def test_lmhead_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
@ -111,7 +107,6 @@ class AutoModelTest(unittest.TestCase):
@slow
def test_model_for_causal_lm(self):
logging.basicConfig(level=logging.INFO)
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
@ -124,7 +119,6 @@ class AutoModelTest(unittest.TestCase):
@slow
def test_model_for_masked_lm(self):
logging.basicConfig(level=logging.INFO)
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
@ -137,7 +131,6 @@ class AutoModelTest(unittest.TestCase):
@slow
def test_model_for_encoder_decoder_lm(self):
logging.basicConfig(level=logging.INFO)
for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
@ -150,7 +143,6 @@ class AutoModelTest(unittest.TestCase):
@slow
def test_sequence_classification_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
@ -165,7 +157,6 @@ class AutoModelTest(unittest.TestCase):
@slow
def test_question_answering_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
@ -178,7 +169,6 @@ class AutoModelTest(unittest.TestCase):
@slow
def test_token_classification_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
@ -190,14 +180,12 @@ class AutoModelTest(unittest.TestCase):
self.assertIsInstance(model, BertForTokenClassification)
def test_from_pretrained_identifier(self):
logging.basicConfig(level=logging.INFO)
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, BertForMaskedLM)
self.assertEqual(model.num_parameters(), 14830)
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
def test_from_identifier_from_model_type(self):
logging.basicConfig(level=logging.INFO)
model = AutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
self.assertIsInstance(model, RobertaForMaskedLM)
self.assertEqual(model.num_parameters(), 14830)

Просмотреть файл

@ -14,7 +14,6 @@
# limitations under the License.
import copy
import logging
import os.path
import random
import tempfile
@ -855,7 +854,6 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
class ModelUtilsTest(unittest.TestCase):
@slow
def test_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
config = BertConfig.from_pretrained(model_name)
self.assertIsNotNone(config)

Просмотреть файл

@ -14,7 +14,6 @@
# limitations under the License.
import logging
import unittest
from transformers import is_tf_available
@ -48,7 +47,6 @@ class TFAutoModelTest(unittest.TestCase):
self.assertTrue(h5py.version.hdf5_version.startswith("1.10"))
logging.basicConfig(level=logging.INFO)
# for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
for model_name in ["bert-base-uncased"]:
config = AutoConfig.from_pretrained(model_name)
@ -65,7 +63,6 @@ class TFAutoModelTest(unittest.TestCase):
self.assertTrue(h5py.version.hdf5_version.startswith("1.10"))
logging.basicConfig(level=logging.INFO)
# for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
for model_name in ["bert-base-uncased"]:
config = AutoConfig.from_pretrained(model_name)
@ -78,7 +75,6 @@ class TFAutoModelTest(unittest.TestCase):
@slow
def test_lmhead_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
# for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
for model_name in ["bert-base-uncased"]:
config = AutoConfig.from_pretrained(model_name)
@ -91,7 +87,6 @@ class TFAutoModelTest(unittest.TestCase):
@slow
def test_sequence_classification_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
# for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
for model_name in ["bert-base-uncased"]:
config = AutoConfig.from_pretrained(model_name)
@ -104,7 +99,6 @@ class TFAutoModelTest(unittest.TestCase):
@slow
def test_question_answering_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
# for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
for model_name in ["bert-base-uncased"]:
config = AutoConfig.from_pretrained(model_name)
@ -116,14 +110,12 @@ class TFAutoModelTest(unittest.TestCase):
self.assertIsInstance(model, TFBertForQuestionAnswering)
def test_from_pretrained_identifier(self):
logging.basicConfig(level=logging.INFO)
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, TFBertForMaskedLM)
self.assertEqual(model.num_parameters(), 14830)
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
def test_from_identifier_from_model_type(self):
logging.basicConfig(level=logging.INFO)
model = TFAutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
self.assertIsInstance(model, TFRobertaForMaskedLM)
self.assertEqual(model.num_parameters(), 14830)

Просмотреть файл

@ -14,7 +14,6 @@
# limitations under the License.
import logging
import unittest
from transformers import (
@ -36,7 +35,6 @@ from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, slow # noq
class AutoTokenizerTest(unittest.TestCase):
# @slow
def test_tokenizer_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in (x for x in BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys() if "japanese" not in x):
tokenizer = AutoTokenizer.from_pretrained(model_name)
self.assertIsNotNone(tokenizer)
@ -50,19 +48,16 @@ class AutoTokenizerTest(unittest.TestCase):
self.assertGreater(len(tokenizer), 0)
def test_tokenizer_from_pretrained_identifier(self):
logging.basicConfig(level=logging.INFO)
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast))
self.assertEqual(tokenizer.vocab_size, 12)
def test_tokenizer_from_model_type(self):
logging.basicConfig(level=logging.INFO)
tokenizer = AutoTokenizer.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
self.assertIsInstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
self.assertEqual(tokenizer.vocab_size, 20)
def test_tokenizer_identifier_with_correct_config(self):
logging.basicConfig(level=logging.INFO)
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
tokenizer = tokenizer_class.from_pretrained("wietsedv/bert-base-dutch-cased")
self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast))
@ -75,7 +70,6 @@ class AutoTokenizerTest(unittest.TestCase):
self.assertEqual(tokenizer.max_len, 512)
def test_tokenizer_identifier_non_existent(self):
logging.basicConfig(level=logging.INFO)
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
with self.assertRaises(EnvironmentError):
_ = tokenizer_class.from_pretrained("julien-c/herlolip-not-exists")

Просмотреть файл

@ -22,7 +22,7 @@ import tempfile
from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from tests.utils import require_tf, require_torch
from tests.utils import require_tf, require_torch, slow
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
@ -71,7 +71,7 @@ class TokenizerTesterMixin:
input_txt = self.get_clean_sequence(tokenizer)[0]
return input_txt, input_txt
def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=None) -> Tuple[str, list]:
def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20) -> Tuple[str, list]:
toks = [(i, tokenizer.decode([i], clean_up_tokenization_spaces=False)) for i in range(len(tokenizer))]
toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks))
toks = list(filter(lambda t: [t[0]] == tokenizer.encode(t[1], add_special_tokens=False), toks))
@ -436,17 +436,51 @@ class TokenizerTesterMixin:
)
def test_maximum_encoding_length_single_input(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
seq_0, ids = self.get_clean_sequence(tokenizer)
stride = 2
seq_0, ids = self.get_clean_sequence(tokenizer, max_length=20)
sequence = tokenizer.encode(seq_0, add_special_tokens=False)
# self.assertEqual(sequence, ids)
total_length = len(sequence)
information = tokenizer.encode_plus(
assert total_length > 1, "Issue with the testing sequence, please update it it's too short"
# Test with max model input length
model_max_length = tokenizer.model_max_length
self.assertEqual(model_max_length, 100)
seq_1 = seq_0 * model_max_length
sequence1 = tokenizer(seq_1, add_special_tokens=False)
total_length1 = len(sequence1["input_ids"])
assert (
total_length1 > model_max_length
), "Issue with the testing sequence, please update it it's too short"
# Simple
padding_strategies = (
[False, True, "longest"] if tokenizer.pad_token and tokenizer.pad_token_id >= 0 else [False]
)
for padding_state in padding_strategies:
with self.subTest(f"Padding: {padding_state}"):
for truncation_state in [True, "longest_first", "only_first"]:
with self.subTest(f"Truncation: {truncation_state}"):
output = tokenizer(seq_1, padding=padding_state, truncation=truncation_state)
self.assertEqual(len(output["input_ids"]), model_max_length)
output = tokenizer([seq_1], padding=padding_state, truncation=truncation_state)
self.assertEqual(len(output["input_ids"][0]), model_max_length)
# Simple with no truncation
output = tokenizer(seq_1, padding=padding_state, truncation=False)
self.assertNotEqual(len(output["input_ids"]), model_max_length)
output = tokenizer([seq_1], padding=padding_state, truncation=False)
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
# Overflowing tokens
stride = 2
information = tokenizer(
seq_0,
max_length=total_length - 2,
add_special_tokens=False,
@ -479,22 +513,22 @@ class TokenizerTesterMixin:
) # No overflowing tokens when using 'longest' in python tokenizers
def test_maximum_encoding_length_pair_input(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
# Build a sequence from our model's vocabulary
stride = 2
seq_0, ids = self.get_clean_sequence(tokenizer)
seq_0, ids = self.get_clean_sequence(tokenizer, max_length=20)
if len(ids) <= 2 + stride:
seq_0 = [s for s in seq_0 for _ in range(2 + stride)]
ids = [i for i in ids for _ in range(2 + stride)]
seq_0 = (seq_0 + " ") * (2 + stride)
ids = None
seq0_tokens = tokenizer.encode(seq_0, add_special_tokens=False)
assert len(seq0_tokens) > 2 + stride
seq_1 = "This is another sentence to be encoded."
seq1_tokens = tokenizer.encode(seq_1, add_special_tokens=False)
if len(seq0_tokens) == len(seq1_tokens):
if abs(len(seq0_tokens) - len(seq1_tokens)) <= 2:
seq1_tokens = seq1_tokens + seq1_tokens
seq_1 = tokenizer.decode(seq1_tokens, clean_up_tokenization_spaces=False)
seq1_tokens = tokenizer.encode(seq_1, add_special_tokens=False)
@ -506,6 +540,49 @@ class TokenizerTesterMixin:
# We are not using the special tokens - a bit too hard to test all the tokenizers with this
# TODO try this again later
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=False) # , add_prefix_space=False)
# Test with max model input length
model_max_length = tokenizer.model_max_length
self.assertEqual(model_max_length, 100)
seq_2 = seq_0 * model_max_length
sequence1 = tokenizer(seq_1, add_special_tokens=False)
total_length1 = len(sequence1["input_ids"])
sequence2 = tokenizer(seq_2, seq_1, add_special_tokens=False)
total_length2 = len(sequence2["input_ids"])
assert total_length1 < model_max_length - 10, "Issue with the testing sequence, please update it."
assert total_length2 > model_max_length, "Issue with the testing sequence, please update it."
# Simple
padding_strategies = (
[False, True, "longest"] if tokenizer.pad_token and tokenizer.pad_token_id >= 0 else [False]
)
for padding_state in padding_strategies:
with self.subTest(f"Padding: {padding_state}"):
for truncation_state in [True, "longest_first", "only_first"]:
with self.subTest(f"Truncation: {truncation_state}"):
output = tokenizer(seq_2, seq_1, padding=padding_state, truncation=truncation_state)
self.assertEqual(len(output["input_ids"]), model_max_length)
output = tokenizer(
[seq_2], [seq_1], padding=padding_state, truncation=truncation_state
)
self.assertEqual(len(output["input_ids"][0]), model_max_length)
# Simple
output = tokenizer(seq_1, seq_2, padding=padding_state, truncation="only_second")
self.assertEqual(len(output["input_ids"]), model_max_length)
output = tokenizer([seq_1], [seq_2], padding=padding_state, truncation="only_second")
self.assertEqual(len(output["input_ids"][0]), model_max_length)
# Simple with no truncation
output = tokenizer(seq_1, seq_2, padding=padding_state, truncation=False)
self.assertNotEqual(len(output["input_ids"]), model_max_length)
output = tokenizer([seq_1], [seq_2], padding=padding_state, truncation=False)
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
truncated_first_sequence = tokenizer.encode(seq_0, add_special_tokens=False)[:-2] + tokenizer.encode(
seq_1, add_special_tokens=False
)
@ -1229,6 +1306,7 @@ class TokenizerTesterMixin:
# add pad_token_id to pass subsequent tests
tokenizer.add_special_tokens({"pad_token": "<PAD>"})
@slow
@require_torch
def test_torch_encode_plus_sent_to_model(self):
import torch
@ -1278,6 +1356,7 @@ class TokenizerTesterMixin:
# model(**encoded_sequence_fast)
# model(**batch_encoded_sequence_fast)
@slow
@require_tf
def test_tf_encode_plus_sent_to_model(self):
from transformers import TF_MODEL_MAPPING, TOKENIZER_MAPPING
@ -1312,6 +1391,7 @@ class TokenizerTesterMixin:
model(batch_encoded_sequence)
# TODO: Check if require_torch is the best to test for numpy here ... Maybe move to require_flax when available
@slow
@require_torch
def test_np_encode_plus_sent_to_model(self):
from transformers import MODEL_MAPPING, TOKENIZER_MAPPING

Просмотреть файл

@ -22,8 +22,6 @@ from transformers.tokenization_roberta import RobertaTokenizerFast
from transformers.tokenization_transfo_xl import TransfoXLTokenizerFast
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilingual"]

Просмотреть файл

@ -62,7 +62,6 @@ if __name__ == "__main__":
parser = HfArgumentParser((TrainingArguments,))
training_args = parser.parse_args_into_dataclasses(sys.argv + ["--output_dir", "./examples"])[0]
logging.basicConfig(level=logging.INFO)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s",
training_args.local_rank,