diff --git a/utils_nlp/models/transformers/abstractive_summarization_seq2seq.py b/utils_nlp/models/transformers/abstractive_summarization_seq2seq.py index 4640451..c63a8c0 100644 --- a/utils_nlp/models/transformers/abstractive_summarization_seq2seq.py +++ b/utils_nlp/models/transformers/abstractive_summarization_seq2seq.py @@ -24,9 +24,14 @@ from s2s_ft.utils import ( batch_list_to_batch_tensors, ) from s2s_ft.modeling import BertForSequenceToSequence +from s2s_ft.modeling import MINILM_PRETRAINED_MODEL_ARCHIVE_MAP from s2s_ft.modeling import UNILM_PRETRAINED_MODEL_ARCHIVE_MAP + +from s2s_ft.tokenization_minilm import MinilmTokenizer +from s2s_ft.configuration_minilm import MinilmConfig, MINILM_PRETRAINED_CONFIG_ARCHIVE_MAP from s2s_ft.tokenization_unilm import UnilmTokenizer from s2s_ft.configuration_unilm import UnilmConfig, UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP + from s2s_ft.config import BertForSeq2SeqConfig import s2s_ft.s2s_loader as seq2seq_loader from s2s_ft.modeling_decoding import BertForSeq2SeqDecoder @@ -42,14 +47,20 @@ MODEL_CLASS.update({k: BertForSequenceToSequence for k in SUPPORTED_ROBERTA_MODE MODEL_CLASS.update( {k: BertForSequenceToSequence for k in UNILM_PRETRAINED_MODEL_ARCHIVE_MAP} ) +MODEL_CLASS.update( + {k: BertForSequenceToSequence for k in MINILM_PRETRAINED_MODEL_ARCHIVE_MAP} +) + TOKENIZER_CLASS.update({k: UnilmTokenizer for k in UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP}) +TOKENIZER_CLASS.update({k: MinilmTokenizer for k in MINILM_PRETRAINED_CONFIG_ARCHIVE_MAP}) CONFIG_CLASS = {} CONFIG_CLASS.update({k: BertConfig for k in SUPPORTED_BERT_MODELS}) CONFIG_CLASS.update({k: RobertaConfig for k in SUPPORTED_ROBERTA_MODELS}) CONFIG_CLASS.update({k: UnilmConfig for k in UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP}) +CONFIG_CLASS.update({k: MinilmConfig for k in MINILM_PRETRAINED_CONFIG_ARCHIVE_MAP}) # XLM_ROBERTA is for multilingual and is WIP in s2s-ft. # We can add it when it's finished and validated @@ -68,6 +79,8 @@ def _get_model_type(model_name): return "xlm-roberta" elif model_name.startswith("unilm"): return "unilm" + elif model_name.startswith("minilm"): + return "minilm" else: return model_name.split("-")[0]