This commit is contained in:
saidbleik 2020-05-08 22:38:54 +00:00
Родитель 8f965d44e5
Коммит 78726ff471
3 изменённых файлов: 75 добавлений и 88 удалений

Просмотреть файл

@ -5,34 +5,29 @@
# This script reuses some code from https://github.com/huggingface/transformers/
# Add to noticefile
from collections import namedtuple
import logging
import os
import pickle
from tqdm import tqdm
from collections import namedtuple
import torch
from torch.utils.data import (
DataLoader,
SequentialSampler,
RandomSampler,
)
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from transformers import BertModel
from tqdm import tqdm
from transformers import AutoTokenizer, BertModel
from utils_nlp.common.pytorch_utils import (
compute_training_steps,
get_device,
get_amp,
get_device,
move_model_to_device,
parallelize_model,
)
from utils_nlp.eval import compute_rouge_python
from utils_nlp.models.transformers.common import TOKENIZER_CLASS, Transformer
from utils_nlp.models.transformers.bertsum import model_builder
from utils_nlp.models.transformers.bertsum.model_builder import AbsSummarizer
from utils_nlp.models.transformers.bertsum.predictor import build_predictor
from utils_nlp.models.transformers.common import Transformer
MODEL_CLASS = {"bert-base-uncased": BertModel}
@ -134,8 +129,11 @@ class BertSumAbsProcessor:
"""
self.model_name = model_name
self.tokenizer = TOKENIZER_CLASS[self.model_name].from_pretrained(
self.model_name, do_lower_case=to_lower, cache_dir=cache_dir
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
do_lower_case=to_lower,
cache_dir=cache_dir,
output_loading_info=False,
)
self.symbols = {
@ -156,7 +154,7 @@ class BertSumAbsProcessor:
@staticmethod
def list_supported_models():
return list(MODEL_CLASS.keys())
return list(MODEL_CLASS)
@property
def model_name(self):
@ -184,7 +182,7 @@ class BertSumAbsProcessor:
also contains the target ids and the number of tokens
in the target and target text.
device (torch.device): A PyTorch device.
model_name (bool, optional): Model name used to format the inputs.
model_name (bool): Model name used to format the inputs.
train_mode (bool, optional): Training mode flag.
Defaults to True.
@ -403,7 +401,8 @@ class BertSumAbs(Transformer):
check MODEL_CLASS for supported models. Defaults to "bert-base-uncased".
finetune_bert (bool, option): Whether the bert model in the encoder is
finetune or not. Defaults to True.
cache_dir (str, optional): Directory to cache the tokenizer. Defaults to ".".
cache_dir (str, optional): Directory to cache the tokenizer.
Defaults to ".".
label_smoothing (float, optional): The amount of label smoothing.
Value range is [0, 1]. Defaults to 0.1.
test (bool, optional): Whether the class is initiated for test or not.
@ -412,13 +411,11 @@ class BertSumAbs(Transformer):
max_pos_length (int, optional): maximum postional embedding length for the
input. Defaults to 768.
"""
super().__init__(
model_class=MODEL_CLASS,
model_name=model_name,
num_labels=0,
cache_dir=cache_dir,
model = MODEL_CLASS[model_name].from_pretrained(
model_name, cache_dir=cache_dir, num_labels=0, output_loading_info=False
)
super().__init__(model_name=model_name, model=model, cache_dir=cache_dir)
if model_name not in self.list_supported_models():
raise ValueError(
"Model name {} is not supported by BertSumAbs. "
@ -616,10 +613,7 @@ class BertSumAbs(Transformer):
)
train_dataloader = DataLoader(
train_dataset,
sampler=sampler,
batch_size=batch_size,
collate_fn=collate_fn,
train_dataset, sampler=sampler, batch_size=batch_size, collate_fn=collate_fn
)
# compute the max number of training steps

Просмотреть файл

@ -1,36 +1,34 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import json
import logging
from tqdm import tqdm
import os
import random
import torch
from torch.utils.data import DataLoader, SequentialSampler, Dataset
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers import BertConfig, RobertaConfig
from transformers import RobertaConfig, BertConfig
from utils_nlp.models.transformers.common import TOKENIZER_CLASS, Transformer
import s2s_ft
import s2s_ft.s2s_loader as seq2seq_loader
from s2s_ft.config import BertForSeq2SeqConfig
from s2s_ft.configuration_unilm import UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP, UnilmConfig
from s2s_ft.modeling import (
UNILM_PRETRAINED_MODEL_ARCHIVE_MAP,
BertForSequenceToSequence,
)
from s2s_ft.modeling_decoding import BertForSeq2SeqDecoder
from s2s_ft.tokenization_unilm import UnilmTokenizer
from s2s_ft.utils import Seq2seqDatasetForBert, batch_list_to_batch_tensors
from utils_nlp.common.pytorch_utils import (
get_device,
move_model_to_device,
parallelize_model,
)
import s2s_ft
from s2s_ft.utils import (
Seq2seqDatasetForBert,
batch_list_to_batch_tensors,
)
from s2s_ft.modeling import BertForSequenceToSequence
from s2s_ft.modeling import UNILM_PRETRAINED_MODEL_ARCHIVE_MAP
from s2s_ft.tokenization_unilm import UnilmTokenizer
from s2s_ft.configuration_unilm import UnilmConfig, UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP
from s2s_ft.config import BertForSeq2SeqConfig
import s2s_ft.s2s_loader as seq2seq_loader
from s2s_ft.modeling_decoding import BertForSeq2SeqDecoder
from utils_nlp.models.transformers.common import TOKENIZER_CLASS, Transformer
SUPPORTED_BERT_MODELS = ["bert-large-uncased", "bert-base-cased", "bert-large-cased"]
SUPPORTED_ROBERTA_MODELS = ["roberta-base", "roberta-large"]
@ -115,9 +113,7 @@ class S2SAbsSumProcessor:
Defaults to ".".
"""
def __init__(
self, model_name="unilm-base-cased", to_lower=False, cache_dir=".",
):
def __init__(self, model_name="unilm-base-cased", to_lower=False, cache_dir="."):
self.tokenizer = TOKENIZER_CLASS[model_name].from_pretrained(
model_name, do_lower_case=to_lower, cache_dir=cache_dir
@ -895,7 +891,7 @@ class S2SAbstractiveSummarizer(Transformer):
num_qkv=s2s_config.num_qkv,
seg_emb=s2s_config.seg_emb,
is_roberta=is_roberta,
no_segment_embedding=no_segment_embedding
no_segment_embedding=no_segment_embedding,
)
model = BertForSeq2SeqDecoder.from_pretrained(
@ -1035,7 +1031,7 @@ class S2SAbstractiveSummarizer(Transformer):
if fp16:
optim_to_save["amp"] = self.amp_state_dict
torch.save(
optim_to_save, os.path.join(output_dir, "optim.{}.bin".format(global_step)),
optim_to_save, os.path.join(output_dir, "optim.{}.bin".format(global_step))
)
@ -1087,7 +1083,7 @@ def load_and_cache_examples(
else:
source_tokens = tokenizer.tokenize(example["src"])
features.append(
{"source_ids": tokenizer.convert_tokens_to_ids(source_tokens),}
{"source_ids": tokenizer.convert_tokens_to_ids(source_tokens)}
)
if shuffle:

Просмотреть файл

@ -12,14 +12,20 @@ from multiprocessing import Pool, cpu_count
import numpy as np
import torch
from torch.utils.data import (
DataLoader,
SequentialSampler,
RandomSampler,
)
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from transformers import BertModel, DistilBertModel
from transformers import AutoTokenizer, BertModel, DistilBertModel
from utils_nlp.common.pytorch_utils import (
compute_training_steps,
get_device,
move_model_to_device,
parallelize_model,
)
from utils_nlp.dataset.sentence_selection import combination_selection, greedy_selection
from utils_nlp.models.transformers.abstractive_summarization_bertsum import (
fit_to_block_size,
)
from utils_nlp.models.transformers.bertsum import model_builder
from utils_nlp.models.transformers.bertsum.data_loader import (
@ -32,17 +38,7 @@ from utils_nlp.models.transformers.bertsum.dataset import (
ExtSumProcessedIterableDataset,
)
from utils_nlp.models.transformers.bertsum.model_builder import BertSumExt
from utils_nlp.common.pytorch_utils import (
compute_training_steps,
get_device,
move_model_to_device,
parallelize_model,
)
from utils_nlp.dataset.sentence_selection import combination_selection, greedy_selection
from utils_nlp.models.transformers.common import TOKENIZER_CLASS, Transformer
from utils_nlp.models.transformers.abstractive_summarization_bertsum import (
fit_to_block_size,
)
from utils_nlp.models.transformers.common import Transformer
MODEL_CLASS = {
"bert-base-uncased": BertModel,
@ -302,7 +298,7 @@ def parallel_preprocess(input_data, preprocess, num_pool=-1):
p = Pool(num_pool)
results = p.map(
preprocess, input_data, chunksize=min(1, int(len(input_data) / num_pool)),
preprocess, input_data, chunksize=min(1, int(len(input_data) / num_pool))
)
p.close()
p.join()
@ -347,8 +343,11 @@ class ExtSumProcessor:
"""
self.model_name = model_name
self.tokenizer = TOKENIZER_CLASS[self.model_name].from_pretrained(
self.model_name, do_lower_case=to_lower, cache_dir=cache_dir
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
do_lower_case=to_lower,
cache_dir=cache_dir,
output_loading_info=False,
)
self.sep_vid = self.tokenizer.vocab["[SEP]"]
self.cls_vid = self.tokenizer.vocab["[CLS]"]
@ -361,7 +360,7 @@ class ExtSumProcessor:
@staticmethod
def list_supported_models():
return list(TOKENIZER_CLASS.keys())
return list(MODEL_CLASS)
@property
def model_name(self):
@ -389,7 +388,7 @@ class ExtSumProcessor:
text. If train_model is True, it also contains the labels and target
text.
device (torch.device): A PyTorch device.
model_name (bool, optional): Model name used to format the inputs.
model_name (bool): Model name used to format the inputs.
train_mode (bool, optional): Training mode flag.
Defaults to True.
@ -500,7 +499,6 @@ class ExtSumProcessor:
if len(src) == 0:
raise ValueError("source doesn't have any sentences")
return None
original_src_txt = [" ".join(s) for s in src]
# no filtering for prediction
@ -588,12 +586,11 @@ class ExtractiveSummarizer(Transformer):
Defaults to ".".
"""
super().__init__(
model_class=MODEL_CLASS,
model_name=model_name,
num_labels=0,
cache_dir=cache_dir,
model = MODEL_CLASS[model_name].from_pretrained(
model_name, cache_dir=cache_dir, num_labels=0, output_loading_info=False
)
super().__init__(model_name=model_name, model=model, cache_dir=cache_dir)
if model_name not in self.list_supported_models():
raise ValueError(
"Model name {} is not supported by ExtractiveSummarizer. "
@ -621,7 +618,7 @@ class ExtractiveSummarizer(Transformer):
@staticmethod
def list_supported_models():
return list(MODEL_CLASS.keys())
return list(MODEL_CLASS)
def fit(
self,