add minilm components
This commit is contained in:
Родитель
806d5fb12e
Коммит
11551c5937
|
@ -24,9 +24,14 @@ from s2s_ft.utils import (
|
|||
batch_list_to_batch_tensors,
|
||||
)
|
||||
from s2s_ft.modeling import BertForSequenceToSequence
|
||||
from s2s_ft.modeling import MINILM_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from s2s_ft.modeling import UNILM_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from s2s_ft.tokenization_minilm import MinilmTokenizer
|
||||
from s2s_ft.configuration_minilm import MinilmConfig, MINILM_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
from s2s_ft.tokenization_unilm import UnilmTokenizer
|
||||
from s2s_ft.configuration_unilm import UnilmConfig, UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
from s2s_ft.config import BertForSeq2SeqConfig
|
||||
import s2s_ft.s2s_loader as seq2seq_loader
|
||||
from s2s_ft.modeling_decoding import BertForSeq2SeqDecoder
|
||||
|
@ -42,14 +47,20 @@ MODEL_CLASS.update({k: BertForSequenceToSequence for k in SUPPORTED_ROBERTA_MODE
|
|||
MODEL_CLASS.update(
|
||||
{k: BertForSequenceToSequence for k in UNILM_PRETRAINED_MODEL_ARCHIVE_MAP}
|
||||
)
|
||||
MODEL_CLASS.update(
|
||||
{k: BertForSequenceToSequence for k in MINILM_PRETRAINED_MODEL_ARCHIVE_MAP}
|
||||
)
|
||||
|
||||
|
||||
|
||||
TOKENIZER_CLASS.update({k: UnilmTokenizer for k in UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP})
|
||||
TOKENIZER_CLASS.update({k: MinilmTokenizer for k in MINILM_PRETRAINED_CONFIG_ARCHIVE_MAP})
|
||||
|
||||
CONFIG_CLASS = {}
|
||||
CONFIG_CLASS.update({k: BertConfig for k in SUPPORTED_BERT_MODELS})
|
||||
CONFIG_CLASS.update({k: RobertaConfig for k in SUPPORTED_ROBERTA_MODELS})
|
||||
CONFIG_CLASS.update({k: UnilmConfig for k in UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP})
|
||||
CONFIG_CLASS.update({k: MinilmConfig for k in MINILM_PRETRAINED_CONFIG_ARCHIVE_MAP})
|
||||
|
||||
# XLM_ROBERTA is for multilingual and is WIP in s2s-ft.
|
||||
# We can add it when it's finished and validated
|
||||
|
@ -68,6 +79,8 @@ def _get_model_type(model_name):
|
|||
return "xlm-roberta"
|
||||
elif model_name.startswith("unilm"):
|
||||
return "unilm"
|
||||
elif model_name.startswith("minilm"):
|
||||
return "minilm"
|
||||
else:
|
||||
return model_name.split("-")[0]
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче