166 строки
6.3 KiB
Python
166 строки
6.3 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
import time
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
|
|
nlp_path = os.path.abspath("../../")
|
|
if nlp_path not in sys.path:
|
|
sys.path.insert(0, nlp_path)
|
|
|
|
|
|
from utils_nlp.dataset.cnndm import CNNDMSummarizationDataset
|
|
from utils_nlp.models.transformers.extractive_summarization import (
|
|
ExtractiveSummarizer,
|
|
ExtSumProcessedData,
|
|
ExtSumProcessor,
|
|
)
|
|
|
|
# os.environ["NCCL_BLOCKING_WAIT"] = "1"
|
|
|
|
os.environ["NCCL_IB_DISABLE"] = "0"
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--rank", type=int, default=0,
|
|
help="The rank of the current node in the cluster")
|
|
parser.add_argument("--dist_url", type=str, default="tcp://127.0.0.1:29500",
|
|
help="URL specifying how to initialize the process groupi.")
|
|
parser.add_argument("--node_count", type=int, default=1,
|
|
help="Number of nodes in the cluster.")
|
|
parser.add_argument("--cache_dir", type=str, default="./",
|
|
help="Directory to cache the tokenizer.")
|
|
parser.add_argument("--data_dir", type=str, default="./",
|
|
help="Directory to download the preprocessed data.")
|
|
parser.add_argument("--output_dir", type=str, default="./",
|
|
help="Directory to save the output model and prediction results.")
|
|
parser.add_argument("--quick_run", type=str.lower, default='false', choices=['true', 'false'],
|
|
help="Whether to have a quick run")
|
|
parser.add_argument("--model_name", type=str, default="distilbert-base-uncased",
|
|
help="Transformer model used in the extractive summarization, only \
|
|
\"bert-uncased\" and \"distilbert-base-uncased\" are supported.")
|
|
parser.add_argument("--encoder", type=str.lower, default='transformer',
|
|
choices=['baseline', 'classifier', 'transformer', 'rnn'],
|
|
help="Encoder types in the extractive summarizer.")
|
|
parser.add_argument("--learning_rate", type=float, default=1e-3,
|
|
help="Learning rate.")
|
|
parser.add_argument("--batch_size", type=int, default=3000,
|
|
help="batch size in terms of input token numbers in training")
|
|
parser.add_argument("--max_steps", type=int, default=1e4,
|
|
help="Maximum number of training steps run in training. If quick_run is set,\
|
|
it's not used.")
|
|
parser.add_argument("--warmup_steps", type=int, default=5e3,
|
|
help="Warm-up number of training steps run in training. If quick_run is set,\
|
|
it's not used.")
|
|
parser.add_argument("--top_n", type=int, default=3,
|
|
help="Number of sentences selected in prediction for evaluation.")
|
|
parser.add_argument("--summary_filename", type=str, default="generated_summaries.txt",
|
|
help="Summary file name generated by prediction for evaluation.")
|
|
parser.add_argument("--model_filename", type=str, default="dist_extsum_model.pt",
|
|
help="model file name saved for evaluation.")
|
|
|
|
|
|
|
|
def cleanup():
|
|
dist.destroy_process_group()
|
|
|
|
# How often the statistics reports show up in training, unit is step.
|
|
REPORT_EVERY = 100
|
|
SAVE_EVERY = 1000
|
|
|
|
def main():
|
|
print("NCCL_IB_DISABLE: {}".format(os.getenv("NCCL_IB_DISABLE")))
|
|
args = parser.parse_args()
|
|
print("quick_run is {}".format(args.quick_run))
|
|
print("output_dir is {}".format(args.output_dir))
|
|
print("data_dir is {}".format(args.data_dir))
|
|
print("cache_dir is {}".format(args.cache_dir))
|
|
|
|
#shutil.rmtree(args.output_dir)
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
os.makedirs(args.cache_dir, exist_ok=True)
|
|
|
|
ngpus_per_node = torch.cuda.device_count()
|
|
|
|
summarizer = ExtractiveSummarizer(args.model_name, args.encoder, args.cache_dir)
|
|
|
|
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, summarizer, args))
|
|
|
|
|
|
def main_worker(local_rank, ngpus_per_node, summarizer, args):
|
|
rank = args.rank * ngpus_per_node + local_rank
|
|
world_size = args.node_count * ngpus_per_node
|
|
|
|
print("init_method: {}".format(args.dist_url))
|
|
print("ngpus_per_node: {}".format(ngpus_per_node))
|
|
print("rank: {}".format(rank))
|
|
print("local_rank: {}".format(local_rank))
|
|
print("world_size: {}".format(world_size))
|
|
|
|
torch.distributed.init_process_group(
|
|
backend="nccl",
|
|
init_method=args.dist_url,
|
|
world_size=world_size,
|
|
rank=rank,
|
|
)
|
|
|
|
train_dataset, test_dataset = ExtSumProcessedData().splits(root=args.data_dir)
|
|
# total number of steps for training
|
|
MAX_STEPS = 1e3
|
|
# number of steps for warm up
|
|
WARMUP_STEPS = 5e2
|
|
if args.quick_run.lower() == "false":
|
|
MAX_STEPS = args.max_steps
|
|
WARMUP_STEPS = args.warmup_steps
|
|
|
|
print("max steps is {}".format(MAX_STEPS))
|
|
print("warmup steps is {}".format(WARMUP_STEPS))
|
|
start = time.time()
|
|
|
|
if rank not in [-1, 0]:
|
|
save_every = -1
|
|
else:
|
|
save_every = SAVE_EVERY
|
|
summarizer.fit(
|
|
train_dataset,
|
|
num_gpus=world_size,
|
|
batch_size=args.batch_size,
|
|
gradient_accumulation_steps=2,
|
|
max_steps=MAX_STEPS / world_size,
|
|
learning_rate=args.learning_rate,
|
|
warmup_steps=WARMUP_STEPS,
|
|
verbose=True,
|
|
report_every=REPORT_EVERY,
|
|
clip_grad_norm=False,
|
|
local_rank=rank,
|
|
save_every=save_every,
|
|
world_size=world_size
|
|
)
|
|
|
|
end = time.time()
|
|
print("rank {0}, duration {1:.6f}s".format(rank, end - start))
|
|
if rank in [-1, 0]:
|
|
summarizer.save_model(os.path.join(args.output_dir, args.model_filename))
|
|
prediction = summarizer.predict(test_dataset, num_gpus=ngpus_per_node, batch_size=128)
|
|
|
|
def _write_list_to_file(list_items, filename):
|
|
with open(filename, "w") as filehandle:
|
|
# for cnt, line in enumerate(filehandle):
|
|
for item in list_items:
|
|
filehandle.write("%s\n" % item)
|
|
|
|
print("writing generated summaries")
|
|
_write_list_to_file(prediction, os.path.join(args.output_dir, args.summary_filename))
|
|
# only use the following line when you use your own cluster.
|
|
# AML distributed training run cleanup for you.
|
|
# cleanup()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|