This commit is contained in:
Daisy Deng 2020-04-10 21:51:05 +00:00
Родитель 806d5fb12e
Коммит 11551c5937
1 изменённых файлов: 13 добавлений и 0 удалений

Просмотреть файл

@ -24,9 +24,14 @@ from s2s_ft.utils import (
batch_list_to_batch_tensors, batch_list_to_batch_tensors,
) )
from s2s_ft.modeling import BertForSequenceToSequence from s2s_ft.modeling import BertForSequenceToSequence
from s2s_ft.modeling import MINILM_PRETRAINED_MODEL_ARCHIVE_MAP
from s2s_ft.modeling import UNILM_PRETRAINED_MODEL_ARCHIVE_MAP from s2s_ft.modeling import UNILM_PRETRAINED_MODEL_ARCHIVE_MAP
from s2s_ft.tokenization_minilm import MinilmTokenizer
from s2s_ft.configuration_minilm import MinilmConfig, MINILM_PRETRAINED_CONFIG_ARCHIVE_MAP
from s2s_ft.tokenization_unilm import UnilmTokenizer from s2s_ft.tokenization_unilm import UnilmTokenizer
from s2s_ft.configuration_unilm import UnilmConfig, UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP from s2s_ft.configuration_unilm import UnilmConfig, UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP
from s2s_ft.config import BertForSeq2SeqConfig from s2s_ft.config import BertForSeq2SeqConfig
import s2s_ft.s2s_loader as seq2seq_loader import s2s_ft.s2s_loader as seq2seq_loader
from s2s_ft.modeling_decoding import BertForSeq2SeqDecoder from s2s_ft.modeling_decoding import BertForSeq2SeqDecoder
@ -42,14 +47,20 @@ MODEL_CLASS.update({k: BertForSequenceToSequence for k in SUPPORTED_ROBERTA_MODE
MODEL_CLASS.update( MODEL_CLASS.update(
{k: BertForSequenceToSequence for k in UNILM_PRETRAINED_MODEL_ARCHIVE_MAP} {k: BertForSequenceToSequence for k in UNILM_PRETRAINED_MODEL_ARCHIVE_MAP}
) )
MODEL_CLASS.update(
{k: BertForSequenceToSequence for k in MINILM_PRETRAINED_MODEL_ARCHIVE_MAP}
)
TOKENIZER_CLASS.update({k: UnilmTokenizer for k in UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP}) TOKENIZER_CLASS.update({k: UnilmTokenizer for k in UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP})
TOKENIZER_CLASS.update({k: MinilmTokenizer for k in MINILM_PRETRAINED_CONFIG_ARCHIVE_MAP})
CONFIG_CLASS = {} CONFIG_CLASS = {}
CONFIG_CLASS.update({k: BertConfig for k in SUPPORTED_BERT_MODELS}) CONFIG_CLASS.update({k: BertConfig for k in SUPPORTED_BERT_MODELS})
CONFIG_CLASS.update({k: RobertaConfig for k in SUPPORTED_ROBERTA_MODELS}) CONFIG_CLASS.update({k: RobertaConfig for k in SUPPORTED_ROBERTA_MODELS})
CONFIG_CLASS.update({k: UnilmConfig for k in UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP}) CONFIG_CLASS.update({k: UnilmConfig for k in UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP})
CONFIG_CLASS.update({k: MinilmConfig for k in MINILM_PRETRAINED_CONFIG_ARCHIVE_MAP})
# XLM_ROBERTA is for multilingual and is WIP in s2s-ft. # XLM_ROBERTA is for multilingual and is WIP in s2s-ft.
# We can add it when it's finished and validated # We can add it when it's finished and validated
@ -68,6 +79,8 @@ def _get_model_type(model_name):
return "xlm-roberta" return "xlm-roberta"
elif model_name.startswith("unilm"): elif model_name.startswith("unilm"):
return "unilm" return "unilm"
elif model_name.startswith("minilm"):
return "minilm"
else: else:
return model_name.split("-")[0] return model_name.split("-")[0]