nlp-recipes/examples/text_summarization/extractive_summarization_cn...

166 строки
6.3 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import os
import sys
import time
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
nlp_path = os.path.abspath("../../")
if nlp_path not in sys.path:
sys.path.insert(0, nlp_path)
from utils_nlp.dataset.cnndm import CNNDMSummarizationDataset
from utils_nlp.models.transformers.extractive_summarization import (
ExtractiveSummarizer,
ExtSumProcessedData,
ExtSumProcessor,
)
# os.environ["NCCL_BLOCKING_WAIT"] = "1"
os.environ["NCCL_IB_DISABLE"] = "0"
parser = argparse.ArgumentParser()
parser.add_argument("--rank", type=int, default=0,
help="The rank of the current node in the cluster")
parser.add_argument("--dist_url", type=str, default="tcp://127.0.0.1:29500",
help="URL specifying how to initialize the process groupi.")
parser.add_argument("--node_count", type=int, default=1,
help="Number of nodes in the cluster.")
parser.add_argument("--cache_dir", type=str, default="./",
help="Directory to cache the tokenizer.")
parser.add_argument("--data_dir", type=str, default="./",
help="Directory to download the preprocessed data.")
parser.add_argument("--output_dir", type=str, default="./",
help="Directory to save the output model and prediction results.")
parser.add_argument("--quick_run", type=str.lower, default='false', choices=['true', 'false'],
help="Whether to have a quick run")
parser.add_argument("--model_name", type=str, default="distilbert-base-uncased",
help="Transformer model used in the extractive summarization, only \
\"bert-uncased\" and \"distilbert-base-uncased\" are supported.")
parser.add_argument("--encoder", type=str.lower, default='transformer',
choices=['baseline', 'classifier', 'transformer', 'rnn'],
help="Encoder types in the extractive summarizer.")
parser.add_argument("--learning_rate", type=float, default=1e-3,
help="Learning rate.")
parser.add_argument("--batch_size", type=int, default=3000,
help="batch size in terms of input token numbers in training")
parser.add_argument("--max_steps", type=int, default=1e4,
help="Maximum number of training steps run in training. If quick_run is set,\
it's not used.")
parser.add_argument("--warmup_steps", type=int, default=5e3,
help="Warm-up number of training steps run in training. If quick_run is set,\
it's not used.")
parser.add_argument("--top_n", type=int, default=3,
help="Number of sentences selected in prediction for evaluation.")
parser.add_argument("--summary_filename", type=str, default="generated_summaries.txt",
help="Summary file name generated by prediction for evaluation.")
parser.add_argument("--model_filename", type=str, default="dist_extsum_model.pt",
help="model file name saved for evaluation.")
def cleanup():
dist.destroy_process_group()
# How often the statistics reports show up in training, unit is step.
REPORT_EVERY = 100
SAVE_EVERY = 1000
def main():
print("NCCL_IB_DISABLE: {}".format(os.getenv("NCCL_IB_DISABLE")))
args = parser.parse_args()
print("quick_run is {}".format(args.quick_run))
print("output_dir is {}".format(args.output_dir))
print("data_dir is {}".format(args.data_dir))
print("cache_dir is {}".format(args.cache_dir))
#shutil.rmtree(args.output_dir)
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(args.cache_dir, exist_ok=True)
ngpus_per_node = torch.cuda.device_count()
summarizer = ExtractiveSummarizer(args.model_name, args.encoder, args.cache_dir)
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, summarizer, args))
def main_worker(local_rank, ngpus_per_node, summarizer, args):
rank = args.rank * ngpus_per_node + local_rank
world_size = args.node_count * ngpus_per_node
print("init_method: {}".format(args.dist_url))
print("ngpus_per_node: {}".format(ngpus_per_node))
print("rank: {}".format(rank))
print("local_rank: {}".format(local_rank))
print("world_size: {}".format(world_size))
torch.distributed.init_process_group(
backend="nccl",
init_method=args.dist_url,
world_size=world_size,
rank=rank,
)
train_dataset, test_dataset = ExtSumProcessedData().splits(root=args.data_dir)
# total number of steps for training
MAX_STEPS = 1e3
# number of steps for warm up
WARMUP_STEPS = 5e2
if args.quick_run.lower() == "false":
MAX_STEPS = args.max_steps
WARMUP_STEPS = args.warmup_steps
print("max steps is {}".format(MAX_STEPS))
print("warmup steps is {}".format(WARMUP_STEPS))
start = time.time()
if rank not in [-1, 0]:
save_every = -1
else:
save_every = SAVE_EVERY
summarizer.fit(
train_dataset,
num_gpus=world_size,
batch_size=args.batch_size,
gradient_accumulation_steps=2,
max_steps=MAX_STEPS / world_size,
learning_rate=args.learning_rate,
warmup_steps=WARMUP_STEPS,
verbose=True,
report_every=REPORT_EVERY,
clip_grad_norm=False,
local_rank=rank,
save_every=save_every,
world_size=world_size
)
end = time.time()
print("rank {0}, duration {1:.6f}s".format(rank, end - start))
if rank in [-1, 0]:
summarizer.save_model(os.path.join(args.output_dir, args.model_filename))
prediction = summarizer.predict(test_dataset, num_gpus=ngpus_per_node, batch_size=128)
def _write_list_to_file(list_items, filename):
with open(filename, "w") as filehandle:
# for cnt, line in enumerate(filehandle):
for item in list_items:
filehandle.write("%s\n" % item)
print("writing generated summaries")
_write_list_to_file(prediction, os.path.join(args.output_dir, args.summary_filename))
# only use the following line when you use your own cluster.
# AML distributed training run cleanup for you.
# cleanup()
if __name__ == "__main__":
main()