update summarization models
This commit is contained in:
Родитель
8f965d44e5
Коммит
78726ff471
|
@ -5,34 +5,29 @@
|
|||
# This script reuses some code from https://github.com/huggingface/transformers/
|
||||
# Add to noticefile
|
||||
|
||||
from collections import namedtuple
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from torch.utils.data import (
|
||||
DataLoader,
|
||||
SequentialSampler,
|
||||
RandomSampler,
|
||||
)
|
||||
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers import BertModel
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, BertModel
|
||||
|
||||
from utils_nlp.common.pytorch_utils import (
|
||||
compute_training_steps,
|
||||
get_device,
|
||||
get_amp,
|
||||
get_device,
|
||||
move_model_to_device,
|
||||
parallelize_model,
|
||||
)
|
||||
from utils_nlp.eval import compute_rouge_python
|
||||
from utils_nlp.models.transformers.common import TOKENIZER_CLASS, Transformer
|
||||
from utils_nlp.models.transformers.bertsum import model_builder
|
||||
from utils_nlp.models.transformers.bertsum.model_builder import AbsSummarizer
|
||||
from utils_nlp.models.transformers.bertsum.predictor import build_predictor
|
||||
from utils_nlp.models.transformers.common import Transformer
|
||||
|
||||
MODEL_CLASS = {"bert-base-uncased": BertModel}
|
||||
|
||||
|
@ -134,8 +129,11 @@ class BertSumAbsProcessor:
|
|||
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.tokenizer = TOKENIZER_CLASS[self.model_name].from_pretrained(
|
||||
self.model_name, do_lower_case=to_lower, cache_dir=cache_dir
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
do_lower_case=to_lower,
|
||||
cache_dir=cache_dir,
|
||||
output_loading_info=False,
|
||||
)
|
||||
|
||||
self.symbols = {
|
||||
|
@ -156,7 +154,7 @@ class BertSumAbsProcessor:
|
|||
|
||||
@staticmethod
|
||||
def list_supported_models():
|
||||
return list(MODEL_CLASS.keys())
|
||||
return list(MODEL_CLASS)
|
||||
|
||||
@property
|
||||
def model_name(self):
|
||||
|
@ -184,7 +182,7 @@ class BertSumAbsProcessor:
|
|||
also contains the target ids and the number of tokens
|
||||
in the target and target text.
|
||||
device (torch.device): A PyTorch device.
|
||||
model_name (bool, optional): Model name used to format the inputs.
|
||||
model_name (bool): Model name used to format the inputs.
|
||||
train_mode (bool, optional): Training mode flag.
|
||||
Defaults to True.
|
||||
|
||||
|
@ -403,7 +401,8 @@ class BertSumAbs(Transformer):
|
|||
check MODEL_CLASS for supported models. Defaults to "bert-base-uncased".
|
||||
finetune_bert (bool, option): Whether the bert model in the encoder is
|
||||
finetune or not. Defaults to True.
|
||||
cache_dir (str, optional): Directory to cache the tokenizer. Defaults to ".".
|
||||
cache_dir (str, optional): Directory to cache the tokenizer.
|
||||
Defaults to ".".
|
||||
label_smoothing (float, optional): The amount of label smoothing.
|
||||
Value range is [0, 1]. Defaults to 0.1.
|
||||
test (bool, optional): Whether the class is initiated for test or not.
|
||||
|
@ -412,13 +411,11 @@ class BertSumAbs(Transformer):
|
|||
max_pos_length (int, optional): maximum postional embedding length for the
|
||||
input. Defaults to 768.
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
model_class=MODEL_CLASS,
|
||||
model_name=model_name,
|
||||
num_labels=0,
|
||||
cache_dir=cache_dir,
|
||||
model = MODEL_CLASS[model_name].from_pretrained(
|
||||
model_name, cache_dir=cache_dir, num_labels=0, output_loading_info=False
|
||||
)
|
||||
super().__init__(model_name=model_name, model=model, cache_dir=cache_dir)
|
||||
|
||||
if model_name not in self.list_supported_models():
|
||||
raise ValueError(
|
||||
"Model name {} is not supported by BertSumAbs. "
|
||||
|
@ -616,10 +613,7 @@ class BertSumAbs(Transformer):
|
|||
)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
collate_fn=collate_fn,
|
||||
train_dataset, sampler=sampler, batch_size=batch_size, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
# compute the max number of training steps
|
||||
|
|
|
@ -1,36 +1,34 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, SequentialSampler, Dataset
|
||||
from torch.utils.data import DataLoader, Dataset, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
from transformers import BertConfig, RobertaConfig
|
||||
|
||||
from transformers import RobertaConfig, BertConfig
|
||||
|
||||
from utils_nlp.models.transformers.common import TOKENIZER_CLASS, Transformer
|
||||
import s2s_ft
|
||||
import s2s_ft.s2s_loader as seq2seq_loader
|
||||
from s2s_ft.config import BertForSeq2SeqConfig
|
||||
from s2s_ft.configuration_unilm import UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP, UnilmConfig
|
||||
from s2s_ft.modeling import (
|
||||
UNILM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BertForSequenceToSequence,
|
||||
)
|
||||
from s2s_ft.modeling_decoding import BertForSeq2SeqDecoder
|
||||
from s2s_ft.tokenization_unilm import UnilmTokenizer
|
||||
from s2s_ft.utils import Seq2seqDatasetForBert, batch_list_to_batch_tensors
|
||||
from utils_nlp.common.pytorch_utils import (
|
||||
get_device,
|
||||
move_model_to_device,
|
||||
parallelize_model,
|
||||
)
|
||||
import s2s_ft
|
||||
from s2s_ft.utils import (
|
||||
Seq2seqDatasetForBert,
|
||||
batch_list_to_batch_tensors,
|
||||
)
|
||||
from s2s_ft.modeling import BertForSequenceToSequence
|
||||
from s2s_ft.modeling import UNILM_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from s2s_ft.tokenization_unilm import UnilmTokenizer
|
||||
from s2s_ft.configuration_unilm import UnilmConfig, UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
from s2s_ft.config import BertForSeq2SeqConfig
|
||||
import s2s_ft.s2s_loader as seq2seq_loader
|
||||
from s2s_ft.modeling_decoding import BertForSeq2SeqDecoder
|
||||
from utils_nlp.models.transformers.common import TOKENIZER_CLASS, Transformer
|
||||
|
||||
SUPPORTED_BERT_MODELS = ["bert-large-uncased", "bert-base-cased", "bert-large-cased"]
|
||||
SUPPORTED_ROBERTA_MODELS = ["roberta-base", "roberta-large"]
|
||||
|
@ -115,9 +113,7 @@ class S2SAbsSumProcessor:
|
|||
Defaults to ".".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, model_name="unilm-base-cased", to_lower=False, cache_dir=".",
|
||||
):
|
||||
def __init__(self, model_name="unilm-base-cased", to_lower=False, cache_dir="."):
|
||||
|
||||
self.tokenizer = TOKENIZER_CLASS[model_name].from_pretrained(
|
||||
model_name, do_lower_case=to_lower, cache_dir=cache_dir
|
||||
|
@ -895,7 +891,7 @@ class S2SAbstractiveSummarizer(Transformer):
|
|||
num_qkv=s2s_config.num_qkv,
|
||||
seg_emb=s2s_config.seg_emb,
|
||||
is_roberta=is_roberta,
|
||||
no_segment_embedding=no_segment_embedding
|
||||
no_segment_embedding=no_segment_embedding,
|
||||
)
|
||||
|
||||
model = BertForSeq2SeqDecoder.from_pretrained(
|
||||
|
@ -1035,7 +1031,7 @@ class S2SAbstractiveSummarizer(Transformer):
|
|||
if fp16:
|
||||
optim_to_save["amp"] = self.amp_state_dict
|
||||
torch.save(
|
||||
optim_to_save, os.path.join(output_dir, "optim.{}.bin".format(global_step)),
|
||||
optim_to_save, os.path.join(output_dir, "optim.{}.bin".format(global_step))
|
||||
)
|
||||
|
||||
|
||||
|
@ -1087,7 +1083,7 @@ def load_and_cache_examples(
|
|||
else:
|
||||
source_tokens = tokenizer.tokenize(example["src"])
|
||||
features.append(
|
||||
{"source_ids": tokenizer.convert_tokens_to_ids(source_tokens),}
|
||||
{"source_ids": tokenizer.convert_tokens_to_ids(source_tokens)}
|
||||
)
|
||||
|
||||
if shuffle:
|
||||
|
|
|
@ -12,14 +12,20 @@ from multiprocessing import Pool, cpu_count
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import (
|
||||
DataLoader,
|
||||
SequentialSampler,
|
||||
RandomSampler,
|
||||
)
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers import BertModel, DistilBertModel
|
||||
from transformers import AutoTokenizer, BertModel, DistilBertModel
|
||||
|
||||
from utils_nlp.common.pytorch_utils import (
|
||||
compute_training_steps,
|
||||
get_device,
|
||||
move_model_to_device,
|
||||
parallelize_model,
|
||||
)
|
||||
from utils_nlp.dataset.sentence_selection import combination_selection, greedy_selection
|
||||
from utils_nlp.models.transformers.abstractive_summarization_bertsum import (
|
||||
fit_to_block_size,
|
||||
)
|
||||
|
||||
from utils_nlp.models.transformers.bertsum import model_builder
|
||||
from utils_nlp.models.transformers.bertsum.data_loader import (
|
||||
|
@ -32,17 +38,7 @@ from utils_nlp.models.transformers.bertsum.dataset import (
|
|||
ExtSumProcessedIterableDataset,
|
||||
)
|
||||
from utils_nlp.models.transformers.bertsum.model_builder import BertSumExt
|
||||
from utils_nlp.common.pytorch_utils import (
|
||||
compute_training_steps,
|
||||
get_device,
|
||||
move_model_to_device,
|
||||
parallelize_model,
|
||||
)
|
||||
from utils_nlp.dataset.sentence_selection import combination_selection, greedy_selection
|
||||
from utils_nlp.models.transformers.common import TOKENIZER_CLASS, Transformer
|
||||
from utils_nlp.models.transformers.abstractive_summarization_bertsum import (
|
||||
fit_to_block_size,
|
||||
)
|
||||
from utils_nlp.models.transformers.common import Transformer
|
||||
|
||||
MODEL_CLASS = {
|
||||
"bert-base-uncased": BertModel,
|
||||
|
@ -302,7 +298,7 @@ def parallel_preprocess(input_data, preprocess, num_pool=-1):
|
|||
p = Pool(num_pool)
|
||||
|
||||
results = p.map(
|
||||
preprocess, input_data, chunksize=min(1, int(len(input_data) / num_pool)),
|
||||
preprocess, input_data, chunksize=min(1, int(len(input_data) / num_pool))
|
||||
)
|
||||
p.close()
|
||||
p.join()
|
||||
|
@ -347,8 +343,11 @@ class ExtSumProcessor:
|
|||
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.tokenizer = TOKENIZER_CLASS[self.model_name].from_pretrained(
|
||||
self.model_name, do_lower_case=to_lower, cache_dir=cache_dir
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
do_lower_case=to_lower,
|
||||
cache_dir=cache_dir,
|
||||
output_loading_info=False,
|
||||
)
|
||||
self.sep_vid = self.tokenizer.vocab["[SEP]"]
|
||||
self.cls_vid = self.tokenizer.vocab["[CLS]"]
|
||||
|
@ -361,7 +360,7 @@ class ExtSumProcessor:
|
|||
|
||||
@staticmethod
|
||||
def list_supported_models():
|
||||
return list(TOKENIZER_CLASS.keys())
|
||||
return list(MODEL_CLASS)
|
||||
|
||||
@property
|
||||
def model_name(self):
|
||||
|
@ -389,7 +388,7 @@ class ExtSumProcessor:
|
|||
text. If train_model is True, it also contains the labels and target
|
||||
text.
|
||||
device (torch.device): A PyTorch device.
|
||||
model_name (bool, optional): Model name used to format the inputs.
|
||||
model_name (bool): Model name used to format the inputs.
|
||||
train_mode (bool, optional): Training mode flag.
|
||||
Defaults to True.
|
||||
|
||||
|
@ -500,7 +499,6 @@ class ExtSumProcessor:
|
|||
|
||||
if len(src) == 0:
|
||||
raise ValueError("source doesn't have any sentences")
|
||||
return None
|
||||
|
||||
original_src_txt = [" ".join(s) for s in src]
|
||||
# no filtering for prediction
|
||||
|
@ -588,12 +586,11 @@ class ExtractiveSummarizer(Transformer):
|
|||
Defaults to ".".
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
model_class=MODEL_CLASS,
|
||||
model_name=model_name,
|
||||
num_labels=0,
|
||||
cache_dir=cache_dir,
|
||||
model = MODEL_CLASS[model_name].from_pretrained(
|
||||
model_name, cache_dir=cache_dir, num_labels=0, output_loading_info=False
|
||||
)
|
||||
super().__init__(model_name=model_name, model=model, cache_dir=cache_dir)
|
||||
|
||||
if model_name not in self.list_supported_models():
|
||||
raise ValueError(
|
||||
"Model name {} is not supported by ExtractiveSummarizer. "
|
||||
|
@ -621,7 +618,7 @@ class ExtractiveSummarizer(Transformer):
|
|||
|
||||
@staticmethod
|
||||
def list_supported_models():
|
||||
return list(MODEL_CLASS.keys())
|
||||
return list(MODEL_CLASS)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
|
|
Загрузка…
Ссылка в новой задаче