add minilm components
This commit is contained in:
Родитель
806d5fb12e
Коммит
11551c5937
|
@ -24,9 +24,14 @@ from s2s_ft.utils import (
|
||||||
batch_list_to_batch_tensors,
|
batch_list_to_batch_tensors,
|
||||||
)
|
)
|
||||||
from s2s_ft.modeling import BertForSequenceToSequence
|
from s2s_ft.modeling import BertForSequenceToSequence
|
||||||
|
from s2s_ft.modeling import MINILM_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
from s2s_ft.modeling import UNILM_PRETRAINED_MODEL_ARCHIVE_MAP
|
from s2s_ft.modeling import UNILM_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
|
||||||
|
from s2s_ft.tokenization_minilm import MinilmTokenizer
|
||||||
|
from s2s_ft.configuration_minilm import MinilmConfig, MINILM_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
from s2s_ft.tokenization_unilm import UnilmTokenizer
|
from s2s_ft.tokenization_unilm import UnilmTokenizer
|
||||||
from s2s_ft.configuration_unilm import UnilmConfig, UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from s2s_ft.configuration_unilm import UnilmConfig, UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
from s2s_ft.config import BertForSeq2SeqConfig
|
from s2s_ft.config import BertForSeq2SeqConfig
|
||||||
import s2s_ft.s2s_loader as seq2seq_loader
|
import s2s_ft.s2s_loader as seq2seq_loader
|
||||||
from s2s_ft.modeling_decoding import BertForSeq2SeqDecoder
|
from s2s_ft.modeling_decoding import BertForSeq2SeqDecoder
|
||||||
|
@ -42,14 +47,20 @@ MODEL_CLASS.update({k: BertForSequenceToSequence for k in SUPPORTED_ROBERTA_MODE
|
||||||
MODEL_CLASS.update(
|
MODEL_CLASS.update(
|
||||||
{k: BertForSequenceToSequence for k in UNILM_PRETRAINED_MODEL_ARCHIVE_MAP}
|
{k: BertForSequenceToSequence for k in UNILM_PRETRAINED_MODEL_ARCHIVE_MAP}
|
||||||
)
|
)
|
||||||
|
MODEL_CLASS.update(
|
||||||
|
{k: BertForSequenceToSequence for k in MINILM_PRETRAINED_MODEL_ARCHIVE_MAP}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
TOKENIZER_CLASS.update({k: UnilmTokenizer for k in UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP})
|
TOKENIZER_CLASS.update({k: UnilmTokenizer for k in UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP})
|
||||||
|
TOKENIZER_CLASS.update({k: MinilmTokenizer for k in MINILM_PRETRAINED_CONFIG_ARCHIVE_MAP})
|
||||||
|
|
||||||
CONFIG_CLASS = {}
|
CONFIG_CLASS = {}
|
||||||
CONFIG_CLASS.update({k: BertConfig for k in SUPPORTED_BERT_MODELS})
|
CONFIG_CLASS.update({k: BertConfig for k in SUPPORTED_BERT_MODELS})
|
||||||
CONFIG_CLASS.update({k: RobertaConfig for k in SUPPORTED_ROBERTA_MODELS})
|
CONFIG_CLASS.update({k: RobertaConfig for k in SUPPORTED_ROBERTA_MODELS})
|
||||||
CONFIG_CLASS.update({k: UnilmConfig for k in UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP})
|
CONFIG_CLASS.update({k: UnilmConfig for k in UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP})
|
||||||
|
CONFIG_CLASS.update({k: MinilmConfig for k in MINILM_PRETRAINED_CONFIG_ARCHIVE_MAP})
|
||||||
|
|
||||||
# XLM_ROBERTA is for multilingual and is WIP in s2s-ft.
|
# XLM_ROBERTA is for multilingual and is WIP in s2s-ft.
|
||||||
# We can add it when it's finished and validated
|
# We can add it when it's finished and validated
|
||||||
|
@ -68,6 +79,8 @@ def _get_model_type(model_name):
|
||||||
return "xlm-roberta"
|
return "xlm-roberta"
|
||||||
elif model_name.startswith("unilm"):
|
elif model_name.startswith("unilm"):
|
||||||
return "unilm"
|
return "unilm"
|
||||||
|
elif model_name.startswith("minilm"):
|
||||||
|
return "minilm"
|
||||||
else:
|
else:
|
||||||
return model_name.split("-")[0]
|
return model_name.split("-")[0]
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче