MBartTokenizer:add language codes (#3776)
This commit is contained in:
Родитель
20451195f0
Коммит
08b59d10e5
|
@ -14,8 +14,10 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from .tokenization_roberta import RobertaTokenizer
|
||||
from .tokenization_utils import BatchEncoding
|
||||
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
|
||||
|
||||
|
@ -47,6 +49,104 @@ SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-la
|
|||
|
||||
|
||||
class MBartTokenizer(XLMRobertaTokenizer):
|
||||
"""
|
||||
This inherits from XLMRobertaTokenizer. ``prepare_translation_batch`` should be used to encode inputs.
|
||||
Other tokenizer methods like encode do not work properly.
|
||||
The tokenization method is <tokens> <eos> <language code>. There is no BOS token.
|
||||
|
||||
Examples::
|
||||
from transformers import MBartTokenizer
|
||||
tokenizer = MBartTokenizer.from_pretrained('mbart-large-en-ro')
|
||||
example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
batch: dict = tokenizer.prepare_translation_batch(
|
||||
example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian
|
||||
)
|
||||
"""
|
||||
|
||||
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
|
||||
max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
|
||||
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
|
||||
lang_code_to_id = { # NOTE(SS): resize embeddings will break this
|
||||
"ar_AR": 250001,
|
||||
"cs_CZ": 250002,
|
||||
"de_DE": 250003,
|
||||
"en_XX": 250004,
|
||||
"es_XX": 250005,
|
||||
"et_EE": 250006,
|
||||
"fi_FI": 250007,
|
||||
"fr_XX": 250008,
|
||||
"gu_IN": 250009,
|
||||
"hi_IN": 250010,
|
||||
"it_IT": 250011,
|
||||
"ja_XX": 250012,
|
||||
"kk_KZ": 250013,
|
||||
"ko_KR": 250014,
|
||||
"lt_LT": 250015,
|
||||
"lv_LV": 250016,
|
||||
"my_MM": 250017,
|
||||
"ne_NP": 250018,
|
||||
"nl_XX": 250019,
|
||||
"ro_RO": 250020,
|
||||
"ru_RU": 250021,
|
||||
"si_LK": 250022,
|
||||
"tr_TR": 250023,
|
||||
"vi_VN": 250024,
|
||||
"zh_CN": 250025,
|
||||
}
|
||||
cur_lang_code = lang_code_to_id["en_XX"]
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
||||
"""Build model inputs from a sequence by appending eos_token_id."""
|
||||
special_tokens = [self.eos_token_id, self.cur_lang_code]
|
||||
if token_ids_1 is None:
|
||||
return token_ids_0 + special_tokens
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return token_ids_0 + token_ids_1 + special_tokens
|
||||
|
||||
def prepare_translation_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
src_lang: str = "en_XX",
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
tgt_lang: str = "ro_RO",
|
||||
max_length: Optional[int] = None,
|
||||
pad_to_max_length: bool = True,
|
||||
return_tensors: str = "pt",
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Arguments:
|
||||
src_texts: list of src language texts
|
||||
src_lang: default en_XX (english)
|
||||
tgt_texts: list of tgt language texts
|
||||
tgt_lang: default ro_RO (romanian)
|
||||
max_length: (None) defer to config (1024 for mbart-large-en-ro)
|
||||
pad_to_max_length: (bool)
|
||||
|
||||
Returns:
|
||||
dict with keys input_ids, attention_mask, decoder_input_ids, each value is a torch.Tensor.
|
||||
"""
|
||||
if max_length is None:
|
||||
max_length = self.max_len
|
||||
self.cur_lang_code = self.lang_code_to_id[src_lang]
|
||||
model_inputs: BatchEncoding = self.batch_encode_plus(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
pad_to_max_length=pad_to_max_length,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
self.cur_lang_code = self.lang_code_to_id[tgt_lang]
|
||||
decoder_inputs: BatchEncoding = self.batch_encode_plus(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
pad_to_max_length=pad_to_max_length,
|
||||
)
|
||||
for k, v in decoder_inputs.items():
|
||||
model_inputs[f"decoder_{k}"] = v
|
||||
self.cur_lang_code = self.lang_code_to_id[src_lang]
|
||||
return model_inputs
|
||||
|
|
|
@ -19,6 +19,7 @@ import unittest
|
|||
import timeout_decorator # noqa
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
@ -37,6 +38,7 @@ if is_torch_available():
|
|||
BartConfig,
|
||||
BartTokenizer,
|
||||
MBartTokenizer,
|
||||
BatchEncoding,
|
||||
)
|
||||
from transformers.modeling_bart import (
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
|
@ -197,15 +199,37 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
tiny(**inputs_dict)
|
||||
|
||||
|
||||
EN_CODE = 250004
|
||||
|
||||
|
||||
@require_torch
|
||||
class BartTranslationTests(unittest.TestCase):
|
||||
_model = None
|
||||
class MBartIntegrationTests(unittest.TestCase):
|
||||
src_text = [
|
||||
" UN Chief Says There Is No Military Solution in Syria",
|
||||
" I ate lunch twice yesterday",
|
||||
]
|
||||
tgt_text = ["Şeful ONU declară că nu există o soluţie militară în Siria", "to be padded"]
|
||||
expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
checkpoint_name = "facebook/mbart-large-en-ro"
|
||||
cls.tokenizer = MBartTokenizer.from_pretrained(checkpoint_name)
|
||||
cls.pad_token_id = 1
|
||||
return cls
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
"""Only load the model if needed."""
|
||||
|
||||
model = BartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
|
||||
if "cuda" in torch_device:
|
||||
model = model.half()
|
||||
return model
|
||||
|
||||
@slow
|
||||
def test_enro_forward(self):
|
||||
model = self.model
|
||||
net_input = {
|
||||
"input_ids": _long_tensor(
|
||||
[
|
||||
|
@ -221,24 +245,9 @@ class BartTranslationTests(unittest.TestCase):
|
|||
),
|
||||
"generation_mode": False,
|
||||
}
|
||||
net_input["attention_mask"] = net_input["input_ids"].ne(cls.pad_token_id)
|
||||
cls.net_input = net_input
|
||||
|
||||
return cls
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""Only load the model if needed."""
|
||||
if self._model is None:
|
||||
model = BartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
|
||||
self._model = model.to(torch_device)
|
||||
return self._model
|
||||
|
||||
@slow
|
||||
def test_enro_forward(self):
|
||||
model = self.model
|
||||
net_input["attention_mask"] = net_input["input_ids"].ne(self.pad_token_id)
|
||||
with torch.no_grad():
|
||||
logits, *other_stuff = model(**self.net_input)
|
||||
logits, *other_stuff = model(**net_input)
|
||||
|
||||
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=torch_device)
|
||||
result_slice = logits[0][0][:3]
|
||||
|
@ -246,19 +255,10 @@ class BartTranslationTests(unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_enro_generate(self):
|
||||
model = self.model
|
||||
# example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||
# inputs: dict = tokenizer.batch_encode_plus([example_english_phrase], return_tensors="pt",)
|
||||
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
|
||||
inputs = {
|
||||
"input_ids": torch.LongTensor(
|
||||
[[8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]] # 250004
|
||||
)
|
||||
}
|
||||
translated_tokens = model.generate(input_ids=inputs["input_ids"].to(torch_device), num_beams=5,)
|
||||
inputs: dict = self.tokenizer.prepare_translation_batch([self.src_text[0]]).to(torch_device)
|
||||
translated_tokens = self.model.generate(input_ids=inputs["input_ids"].to(torch_device))
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
self.assertEqual(expected_translation_romanian, decoded[0])
|
||||
self.assertEqual(self.tgt_text[0], decoded[0])
|
||||
|
||||
def test_mbart_enro_config(self):
|
||||
mbart_models = ["facebook/mbart-large-en-ro"]
|
||||
|
@ -273,13 +273,6 @@ class BartTranslationTests(unittest.TestCase):
|
|||
e.args += (name, k)
|
||||
raise
|
||||
|
||||
def test_enro_tokenizer(self):
|
||||
raw = "UN Chief Says There Is No Military Solution in Syria"
|
||||
ids = self.tokenizer.batch_encode_plus([raw])["input_ids"][0]
|
||||
expected_result = [0, 8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]
|
||||
# TODO(SS): should be [8274, ..., 2, 250020]
|
||||
self.assertListEqual(expected_result, ids)
|
||||
|
||||
def test_mbart_fast_forward(self):
|
||||
config = BartConfig(
|
||||
vocab_size=99,
|
||||
|
@ -301,6 +294,36 @@ class BartTranslationTests(unittest.TestCase):
|
|||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
|
||||
@require_torch
|
||||
class MBartTokenizerTests(MBartIntegrationTests):
|
||||
def test_enro_tokenizer_prepare_translation_batch(self):
|
||||
batch = self.tokenizer.prepare_translation_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
|
||||
)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
||||
self.assertEqual((2, 14), batch.input_ids.shape)
|
||||
self.assertEqual((2, 14), batch.attention_mask.shape)
|
||||
result = batch.input_ids.tolist()[0]
|
||||
self.assertListEqual(self.expected_src_tokens, result)
|
||||
self.assertEqual(2, batch.decoder_input_ids[0, -2]) # EOS
|
||||
|
||||
def test_enro_tokenizer_batch_encode_plus(self):
|
||||
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
|
||||
self.assertListEqual(self.expected_src_tokens, ids)
|
||||
|
||||
def test_enro_tokenizer_truncation(self):
|
||||
src_text = ["this is gunna be a long sentence " * 20]
|
||||
assert isinstance(src_text[0], str)
|
||||
desired_max_length = 10
|
||||
ids = self.tokenizer.prepare_translation_batch(
|
||||
src_text, return_tensors=None, max_length=desired_max_length
|
||||
).input_ids[0]
|
||||
self.assertEqual(ids[-2], 2)
|
||||
self.assertEqual(ids[-1], EN_CODE)
|
||||
self.assertEqual(len(ids), desired_max_length)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BartHeadTests(unittest.TestCase):
|
||||
vocab_size = 99
|
||||
|
|
Загрузка…
Ссылка в новой задаче