MBartTokenizer:add language codes (#3776)

This commit is contained in:
Sam Shleifer 2020-06-11 13:02:33 -04:00 коммит произвёл GitHub
Родитель 20451195f0
Коммит 08b59d10e5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 161 добавлений и 38 удалений

Просмотреть файл

@ -14,8 +14,10 @@
# limitations under the License.
import logging
from typing import List, Optional
from .tokenization_roberta import RobertaTokenizer
from .tokenization_utils import BatchEncoding
from .tokenization_xlm_roberta import XLMRobertaTokenizer
@ -47,6 +49,104 @@ SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-la
class MBartTokenizer(XLMRobertaTokenizer):
"""
This inherits from XLMRobertaTokenizer. ``prepare_translation_batch`` should be used to encode inputs.
Other tokenizer methods like encode do not work properly.
The tokenization method is <tokens> <eos> <language code>. There is no BOS token.
Examples::
from transformers import MBartTokenizer
tokenizer = MBartTokenizer.from_pretrained('mbart-large-en-ro')
example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
batch: dict = tokenizer.prepare_translation_batch(
example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian
)
"""
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
lang_code_to_id = { # NOTE(SS): resize embeddings will break this
"ar_AR": 250001,
"cs_CZ": 250002,
"de_DE": 250003,
"en_XX": 250004,
"es_XX": 250005,
"et_EE": 250006,
"fi_FI": 250007,
"fr_XX": 250008,
"gu_IN": 250009,
"hi_IN": 250010,
"it_IT": 250011,
"ja_XX": 250012,
"kk_KZ": 250013,
"ko_KR": 250014,
"lt_LT": 250015,
"lv_LV": 250016,
"my_MM": 250017,
"ne_NP": 250018,
"nl_XX": 250019,
"ro_RO": 250020,
"ru_RU": 250021,
"si_LK": 250022,
"tr_TR": 250023,
"vi_VN": 250024,
"zh_CN": 250025,
}
cur_lang_code = lang_code_to_id["en_XX"]
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
"""Build model inputs from a sequence by appending eos_token_id."""
special_tokens = [self.eos_token_id, self.cur_lang_code]
if token_ids_1 is None:
return token_ids_0 + special_tokens
# We don't expect to process pairs, but leave the pair logic for API consistency
return token_ids_0 + token_ids_1 + special_tokens
def prepare_translation_batch(
self,
src_texts: List[str],
src_lang: str = "en_XX",
tgt_texts: Optional[List[str]] = None,
tgt_lang: str = "ro_RO",
max_length: Optional[int] = None,
pad_to_max_length: bool = True,
return_tensors: str = "pt",
) -> BatchEncoding:
"""
Arguments:
src_texts: list of src language texts
src_lang: default en_XX (english)
tgt_texts: list of tgt language texts
tgt_lang: default ro_RO (romanian)
max_length: (None) defer to config (1024 for mbart-large-en-ro)
pad_to_max_length: (bool)
Returns:
dict with keys input_ids, attention_mask, decoder_input_ids, each value is a torch.Tensor.
"""
if max_length is None:
max_length = self.max_len
self.cur_lang_code = self.lang_code_to_id[src_lang]
model_inputs: BatchEncoding = self.batch_encode_plus(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
pad_to_max_length=pad_to_max_length,
)
if tgt_texts is None:
return model_inputs
self.cur_lang_code = self.lang_code_to_id[tgt_lang]
decoder_inputs: BatchEncoding = self.batch_encode_plus(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
pad_to_max_length=pad_to_max_length,
)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v
self.cur_lang_code = self.lang_code_to_id[src_lang]
return model_inputs

Просмотреть файл

@ -19,6 +19,7 @@ import unittest
import timeout_decorator # noqa
from transformers import is_torch_available
from transformers.file_utils import cached_property
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
@ -37,6 +38,7 @@ if is_torch_available():
BartConfig,
BartTokenizer,
MBartTokenizer,
BatchEncoding,
)
from transformers.modeling_bart import (
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
@ -197,15 +199,37 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
tiny(**inputs_dict)
EN_CODE = 250004
@require_torch
class BartTranslationTests(unittest.TestCase):
_model = None
class MBartIntegrationTests(unittest.TestCase):
src_text = [
" UN Chief Says There Is No Military Solution in Syria",
" I ate lunch twice yesterday",
]
tgt_text = ["Şeful ONU declară că nu există o soluţie militară în Siria", "to be padded"]
expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE]
@classmethod
def setUpClass(cls):
checkpoint_name = "facebook/mbart-large-en-ro"
cls.tokenizer = MBartTokenizer.from_pretrained(checkpoint_name)
cls.pad_token_id = 1
return cls
@cached_property
def model(self):
"""Only load the model if needed."""
model = BartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
if "cuda" in torch_device:
model = model.half()
return model
@slow
def test_enro_forward(self):
model = self.model
net_input = {
"input_ids": _long_tensor(
[
@ -221,24 +245,9 @@ class BartTranslationTests(unittest.TestCase):
),
"generation_mode": False,
}
net_input["attention_mask"] = net_input["input_ids"].ne(cls.pad_token_id)
cls.net_input = net_input
return cls
@property
def model(self):
"""Only load the model if needed."""
if self._model is None:
model = BartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
self._model = model.to(torch_device)
return self._model
@slow
def test_enro_forward(self):
model = self.model
net_input["attention_mask"] = net_input["input_ids"].ne(self.pad_token_id)
with torch.no_grad():
logits, *other_stuff = model(**self.net_input)
logits, *other_stuff = model(**net_input)
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=torch_device)
result_slice = logits[0][0][:3]
@ -246,19 +255,10 @@ class BartTranslationTests(unittest.TestCase):
@slow
def test_enro_generate(self):
model = self.model
# example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
# inputs: dict = tokenizer.batch_encode_plus([example_english_phrase], return_tensors="pt",)
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
inputs = {
"input_ids": torch.LongTensor(
[[8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]] # 250004
)
}
translated_tokens = model.generate(input_ids=inputs["input_ids"].to(torch_device), num_beams=5,)
inputs: dict = self.tokenizer.prepare_translation_batch([self.src_text[0]]).to(torch_device)
translated_tokens = self.model.generate(input_ids=inputs["input_ids"].to(torch_device))
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
self.assertEqual(expected_translation_romanian, decoded[0])
self.assertEqual(self.tgt_text[0], decoded[0])
def test_mbart_enro_config(self):
mbart_models = ["facebook/mbart-large-en-ro"]
@ -273,13 +273,6 @@ class BartTranslationTests(unittest.TestCase):
e.args += (name, k)
raise
def test_enro_tokenizer(self):
raw = "UN Chief Says There Is No Military Solution in Syria"
ids = self.tokenizer.batch_encode_plus([raw])["input_ids"][0]
expected_result = [0, 8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]
# TODO(SS): should be [8274, ..., 2, 250020]
self.assertListEqual(expected_result, ids)
def test_mbart_fast_forward(self):
config = BartConfig(
vocab_size=99,
@ -301,6 +294,36 @@ class BartTranslationTests(unittest.TestCase):
self.assertEqual(logits.shape, expected_shape)
@require_torch
class MBartTokenizerTests(MBartIntegrationTests):
def test_enro_tokenizer_prepare_translation_batch(self):
batch = self.tokenizer.prepare_translation_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 14), batch.input_ids.shape)
self.assertEqual((2, 14), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(self.expected_src_tokens, result)
self.assertEqual(2, batch.decoder_input_ids[0, -2]) # EOS
def test_enro_tokenizer_batch_encode_plus(self):
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
self.assertListEqual(self.expected_src_tokens, ids)
def test_enro_tokenizer_truncation(self):
src_text = ["this is gunna be a long sentence " * 20]
assert isinstance(src_text[0], str)
desired_max_length = 10
ids = self.tokenizer.prepare_translation_batch(
src_text, return_tensors=None, max_length=desired_max_length
).input_ids[0]
self.assertEqual(ids[-2], 2)
self.assertEqual(ids[-1], EN_CODE)
self.assertEqual(len(ids), desired_max_length)
@require_torch
class BartHeadTests(unittest.TestCase):
vocab_size = 99