This commit is contained in:
Daisy Deng 2020-03-20 19:58:53 +00:00
Родитель 18c9e661b2
Коммит dafc09cc63
1 изменённых файлов: 14 добавлений и 11 удалений

Просмотреть файл

@ -3,7 +3,6 @@
import argparse
import os
import pickle
import sys
import time
import torch
@ -11,12 +10,12 @@ import torch.distributed as dist
import torch.multiprocessing as mp
# torch.set_printoptions(threshold=5000)
from tempfile import TemporaryDirectory
nlp_path = os.path.abspath("../../")
if nlp_path not in sys.path:
sys.path.insert(0, nlp_path)
sys.path.insert(0, "./")
from utils_nlp.models.transformers.abstractive_summarization_bertsum import (
BertSumAbs,
@ -40,7 +39,7 @@ parser.add_argument(
parser.add_argument(
"--dist_url",
type=str,
default="tcp://127.0.0.1:29500",
default="tcp://127.0.0.1:29507",
help="URL specifying how to initialize the process groupi.",
)
parser.add_argument(
@ -56,7 +55,7 @@ parser.add_argument(
parser.add_argument(
"--data_dir",
type=str,
default="./",
default="./abstemp",
help="Directory to download the preprocessed data.",
)
parser.add_argument(
@ -101,8 +100,8 @@ parser.add_argument(
"--max_steps",
type=int,
default=5e4,
help="Maximum number of training steps run in training. \
If quick_run is set, it's not used.",
help="""Maximum number of training steps run in training.
If quick_run is set, it's not used.""",
)
parser.add_argument(
"--warmup_steps_bert",
@ -176,8 +175,11 @@ def main():
print("data_dir is {}".format(args.data_dir))
print("cache_dir is {}".format(args.cache_dir))
TOP_N = -1
if args.quick_run.lower() == "false":
TOP_N = 10
train_dataset, test_dataset = CNNDMSummarizationDataset(
top_n=-1, local_cache_path=args.data_dir, prepare_extractive=False
top_n=TOP_N, local_cache_path=args.data_dir, prepare_extractive=False
)
ngpus_per_node = torch.cuda.device_count()
@ -212,6 +214,7 @@ def main_worker(
checkpoint = os.path.join(args.cache_dir, args.checkpoint_filename)
else:
checkpoint = None
# train_sum_dataset, test_sum_dataset = load_processed_cnndm_abs(args.data_dir)
def this_validate(class_obj):
return validate(class_obj, test_dataset)
@ -225,8 +228,8 @@ def main_worker(
fp16 = args.fp16.lower() == "true"
print("fp16 is {}".format(fp16))
# total number of steps for training
MAX_STEPS = 50
SAVE_EVERY = 50
MAX_STEPS = 10
SAVE_EVERY = 10
REPORT_EVERY = 10
# number of steps for warm up
WARMUP_STEPS_BERT = MAX_STEPS
@ -235,7 +238,7 @@ def main_worker(
MAX_STEPS = args.max_steps
WARMUP_STEPS_BERT = args.warmup_steps_bert
WARMUP_STEPS_DEC = args.warmup_steps_dec
SAVE_EVERY = args.save_every
SAVE_EVERY = save_every
REPORT_EVERY = args.report_every
print("max steps is {}".format(MAX_STEPS))
@ -266,7 +269,7 @@ def main_worker(
end = time.time()
print("rank {0}, duration {1:.6f}s".format(rank, end - start))
if rank == 0 or local_rank == -1:
if local_rank in [0, -1] and args.rank == 0:
saved_model_path = os.path.join(
args.output_dir, "{}_step{}".format(args.model_filename, MAX_STEPS)
)