From 9827d666ebdf959aa9dfe3627ccb80592b378b77 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 23 Jul 2020 15:41:14 -0400 Subject: [PATCH] MbartTokenizer: do not hardcode vocab size (#5998) --- src/transformers/tokenization_bart.py | 66 +++++++++++++++------------ tests/test_tokenization_mbart.py | 7 ++- 2 files changed, 43 insertions(+), 30 deletions(-) diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py index e23abcb82..90353ddd3 100644 --- a/src/transformers/tokenization_bart.py +++ b/src/transformers/tokenization_bart.py @@ -58,6 +58,34 @@ class BartTokenizerFast(RobertaTokenizerFast): _all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"] SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model" +FAIRSEQ_LANGUAGE_CODES = [ + "ar_AR", + "cs_CZ", + "de_DE", + "en_XX", + "es_XX", + "et_EE", + "fi_FI", + "fr_XX", + "gu_IN", + "hi_IN", + "it_IT", + "ja_XX", + "kk_KZ", + "ko_KR", + "lt_LT", + "lv_LV", + "my_MM", + "ne_NP", + "nl_XX", + "ro_RO", + "ru_RU", + "si_LK", + "tr_TR", + "vi_VN", + "zh_CN", +] + class MBartTokenizer(XLMRobertaTokenizer): """ @@ -81,40 +109,20 @@ class MBartTokenizer(XLMRobertaTokenizer): vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"} max_model_input_sizes = {m: 1024 for m in _all_mbart_models} pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}} - lang_code_to_id = { # NOTE(SS): resize embeddings will break this - "ar_AR": 250001, - "cs_CZ": 250002, - "de_DE": 250003, - "en_XX": 250004, - "es_XX": 250005, - "et_EE": 250006, - "fi_FI": 250007, - "fr_XX": 250008, - "gu_IN": 250009, - "hi_IN": 250010, - "it_IT": 250011, - "ja_XX": 250012, - "kk_KZ": 250013, - "ko_KR": 250014, - "lt_LT": 250015, - "lv_LV": 250016, - "my_MM": 250017, - "ne_NP": 250018, - "nl_XX": 250019, - "ro_RO": 250020, - "ru_RU": 250021, - "si_LK": 250022, - "tr_TR": 250023, - "vi_VN": 250024, - "zh_CN": 250025, - } - id_to_lang_code = {v: k for k, v in lang_code_to_id.items()} - cur_lang_code = lang_code_to_id["en_XX"] + prefix_tokens: List[int] = [] suffix_tokens: List[int] = [] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + + self.sp_model_size = len(self.sp_model) + self.lang_code_to_id = { + code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES) + } + self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()} + self.cur_lang_code = self.lang_code_to_id["en_XX"] + self.fairseq_tokens_to_ids.update(self.lang_code_to_id) self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} self._additional_special_tokens = list(self.lang_code_to_id.keys()) diff --git a/tests/test_tokenization_mbart.py b/tests/test_tokenization_mbart.py index 84bc3ba8b..d45a5ee60 100644 --- a/tests/test_tokenization_mbart.py +++ b/tests/test_tokenization_mbart.py @@ -113,10 +113,15 @@ class MBartEnroIntegrationTest(unittest.TestCase): @classmethod def setUpClass(cls): - cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name) + cls.tokenizer: MBartTokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name) cls.pad_token_id = 1 return cls + def check_language_codes(self): + self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ar_AR"], 250001) + self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004) + self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020) + def test_enro_tokenizer_prepare_translation_batch(self): batch = self.tokenizer.prepare_translation_batch( self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),