MbartTokenizer: do not hardcode vocab size (#5998)
This commit is contained in:
Родитель
6e16195510
Коммит
9827d666eb
|
@ -58,6 +58,34 @@ class BartTokenizerFast(RobertaTokenizerFast):
|
|||
_all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"]
|
||||
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
|
||||
|
||||
FAIRSEQ_LANGUAGE_CODES = [
|
||||
"ar_AR",
|
||||
"cs_CZ",
|
||||
"de_DE",
|
||||
"en_XX",
|
||||
"es_XX",
|
||||
"et_EE",
|
||||
"fi_FI",
|
||||
"fr_XX",
|
||||
"gu_IN",
|
||||
"hi_IN",
|
||||
"it_IT",
|
||||
"ja_XX",
|
||||
"kk_KZ",
|
||||
"ko_KR",
|
||||
"lt_LT",
|
||||
"lv_LV",
|
||||
"my_MM",
|
||||
"ne_NP",
|
||||
"nl_XX",
|
||||
"ro_RO",
|
||||
"ru_RU",
|
||||
"si_LK",
|
||||
"tr_TR",
|
||||
"vi_VN",
|
||||
"zh_CN",
|
||||
]
|
||||
|
||||
|
||||
class MBartTokenizer(XLMRobertaTokenizer):
|
||||
"""
|
||||
|
@ -81,40 +109,20 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
|
||||
max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
|
||||
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
|
||||
lang_code_to_id = { # NOTE(SS): resize embeddings will break this
|
||||
"ar_AR": 250001,
|
||||
"cs_CZ": 250002,
|
||||
"de_DE": 250003,
|
||||
"en_XX": 250004,
|
||||
"es_XX": 250005,
|
||||
"et_EE": 250006,
|
||||
"fi_FI": 250007,
|
||||
"fr_XX": 250008,
|
||||
"gu_IN": 250009,
|
||||
"hi_IN": 250010,
|
||||
"it_IT": 250011,
|
||||
"ja_XX": 250012,
|
||||
"kk_KZ": 250013,
|
||||
"ko_KR": 250014,
|
||||
"lt_LT": 250015,
|
||||
"lv_LV": 250016,
|
||||
"my_MM": 250017,
|
||||
"ne_NP": 250018,
|
||||
"nl_XX": 250019,
|
||||
"ro_RO": 250020,
|
||||
"ru_RU": 250021,
|
||||
"si_LK": 250022,
|
||||
"tr_TR": 250023,
|
||||
"vi_VN": 250024,
|
||||
"zh_CN": 250025,
|
||||
}
|
||||
id_to_lang_code = {v: k for k, v in lang_code_to_id.items()}
|
||||
cur_lang_code = lang_code_to_id["en_XX"]
|
||||
|
||||
prefix_tokens: List[int] = []
|
||||
suffix_tokens: List[int] = []
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.sp_model_size = len(self.sp_model)
|
||||
self.lang_code_to_id = {
|
||||
code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
|
||||
}
|
||||
self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
|
||||
self.cur_lang_code = self.lang_code_to_id["en_XX"]
|
||||
|
||||
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
|
||||
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
||||
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
||||
|
|
|
@ -113,10 +113,15 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
|||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name)
|
||||
cls.tokenizer: MBartTokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name)
|
||||
cls.pad_token_id = 1
|
||||
return cls
|
||||
|
||||
def check_language_codes(self):
|
||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ar_AR"], 250001)
|
||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004)
|
||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020)
|
||||
|
||||
def test_enro_tokenizer_prepare_translation_batch(self):
|
||||
batch = self.tokenizer.prepare_translation_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
|
||||
|
|
Загрузка…
Ссылка в новой задаче