This commit is contained in:
Daisy Deng 2020-03-20 19:58:53 +00:00
Родитель 18c9e661b2
Коммит dafc09cc63
1 изменённых файлов: 14 добавлений и 11 удалений

Просмотреть файл

@ -3,7 +3,6 @@
import argparse import argparse
import os import os
import pickle
import sys import sys
import time import time
import torch import torch
@ -11,12 +10,12 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
# torch.set_printoptions(threshold=5000) # torch.set_printoptions(threshold=5000)
from tempfile import TemporaryDirectory
nlp_path = os.path.abspath("../../") nlp_path = os.path.abspath("../../")
if nlp_path not in sys.path: if nlp_path not in sys.path:
sys.path.insert(0, nlp_path) sys.path.insert(0, nlp_path)
sys.path.insert(0, "./")
from utils_nlp.models.transformers.abstractive_summarization_bertsum import ( from utils_nlp.models.transformers.abstractive_summarization_bertsum import (
BertSumAbs, BertSumAbs,
@ -40,7 +39,7 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
"--dist_url", "--dist_url",
type=str, type=str,
default="tcp://127.0.0.1:29500", default="tcp://127.0.0.1:29507",
help="URL specifying how to initialize the process groupi.", help="URL specifying how to initialize the process groupi.",
) )
parser.add_argument( parser.add_argument(
@ -56,7 +55,7 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
"--data_dir", "--data_dir",
type=str, type=str,
default="./", default="./abstemp",
help="Directory to download the preprocessed data.", help="Directory to download the preprocessed data.",
) )
parser.add_argument( parser.add_argument(
@ -101,8 +100,8 @@ parser.add_argument(
"--max_steps", "--max_steps",
type=int, type=int,
default=5e4, default=5e4,
help="Maximum number of training steps run in training. \ help="""Maximum number of training steps run in training.
If quick_run is set, it's not used.", If quick_run is set, it's not used.""",
) )
parser.add_argument( parser.add_argument(
"--warmup_steps_bert", "--warmup_steps_bert",
@ -176,8 +175,11 @@ def main():
print("data_dir is {}".format(args.data_dir)) print("data_dir is {}".format(args.data_dir))
print("cache_dir is {}".format(args.cache_dir)) print("cache_dir is {}".format(args.cache_dir))
TOP_N = -1
if args.quick_run.lower() == "false":
TOP_N = 10
train_dataset, test_dataset = CNNDMSummarizationDataset( train_dataset, test_dataset = CNNDMSummarizationDataset(
top_n=-1, local_cache_path=args.data_dir, prepare_extractive=False top_n=TOP_N, local_cache_path=args.data_dir, prepare_extractive=False
) )
ngpus_per_node = torch.cuda.device_count() ngpus_per_node = torch.cuda.device_count()
@ -212,6 +214,7 @@ def main_worker(
checkpoint = os.path.join(args.cache_dir, args.checkpoint_filename) checkpoint = os.path.join(args.cache_dir, args.checkpoint_filename)
else: else:
checkpoint = None checkpoint = None
# train_sum_dataset, test_sum_dataset = load_processed_cnndm_abs(args.data_dir) # train_sum_dataset, test_sum_dataset = load_processed_cnndm_abs(args.data_dir)
def this_validate(class_obj): def this_validate(class_obj):
return validate(class_obj, test_dataset) return validate(class_obj, test_dataset)
@ -225,8 +228,8 @@ def main_worker(
fp16 = args.fp16.lower() == "true" fp16 = args.fp16.lower() == "true"
print("fp16 is {}".format(fp16)) print("fp16 is {}".format(fp16))
# total number of steps for training # total number of steps for training
MAX_STEPS = 50 MAX_STEPS = 10
SAVE_EVERY = 50 SAVE_EVERY = 10
REPORT_EVERY = 10 REPORT_EVERY = 10
# number of steps for warm up # number of steps for warm up
WARMUP_STEPS_BERT = MAX_STEPS WARMUP_STEPS_BERT = MAX_STEPS
@ -235,7 +238,7 @@ def main_worker(
MAX_STEPS = args.max_steps MAX_STEPS = args.max_steps
WARMUP_STEPS_BERT = args.warmup_steps_bert WARMUP_STEPS_BERT = args.warmup_steps_bert
WARMUP_STEPS_DEC = args.warmup_steps_dec WARMUP_STEPS_DEC = args.warmup_steps_dec
SAVE_EVERY = args.save_every SAVE_EVERY = save_every
REPORT_EVERY = args.report_every REPORT_EVERY = args.report_every
print("max steps is {}".format(MAX_STEPS)) print("max steps is {}".format(MAX_STEPS))
@ -266,7 +269,7 @@ def main_worker(
end = time.time() end = time.time()
print("rank {0}, duration {1:.6f}s".format(rank, end - start)) print("rank {0}, duration {1:.6f}s".format(rank, end - start))
if rank == 0 or local_rank == -1: if local_rank in [0, -1] and args.rank == 0:
saved_model_path = os.path.join( saved_model_path = os.path.join(
args.output_dir, "{}_step{}".format(args.model_filename, MAX_STEPS) args.output_dir, "{}_step{}".format(args.model_filename, MAX_STEPS)
) )