minor changes
This commit is contained in:
Родитель
18c9e661b2
Коммит
dafc09cc63
|
@ -3,7 +3,6 @@
|
|||
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import time
|
||||
import torch
|
||||
|
@ -11,12 +10,12 @@ import torch.distributed as dist
|
|||
import torch.multiprocessing as mp
|
||||
|
||||
# torch.set_printoptions(threshold=5000)
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
nlp_path = os.path.abspath("../../")
|
||||
if nlp_path not in sys.path:
|
||||
sys.path.insert(0, nlp_path)
|
||||
|
||||
sys.path.insert(0, "./")
|
||||
|
||||
from utils_nlp.models.transformers.abstractive_summarization_bertsum import (
|
||||
BertSumAbs,
|
||||
|
@ -40,7 +39,7 @@ parser.add_argument(
|
|||
parser.add_argument(
|
||||
"--dist_url",
|
||||
type=str,
|
||||
default="tcp://127.0.0.1:29500",
|
||||
default="tcp://127.0.0.1:29507",
|
||||
help="URL specifying how to initialize the process groupi.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -56,7 +55,7 @@ parser.add_argument(
|
|||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
default="./",
|
||||
default="./abstemp",
|
||||
help="Directory to download the preprocessed data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -101,8 +100,8 @@ parser.add_argument(
|
|||
"--max_steps",
|
||||
type=int,
|
||||
default=5e4,
|
||||
help="Maximum number of training steps run in training. \
|
||||
If quick_run is set, it's not used.",
|
||||
help="""Maximum number of training steps run in training.
|
||||
If quick_run is set, it's not used.""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup_steps_bert",
|
||||
|
@ -176,8 +175,11 @@ def main():
|
|||
print("data_dir is {}".format(args.data_dir))
|
||||
print("cache_dir is {}".format(args.cache_dir))
|
||||
|
||||
TOP_N = -1
|
||||
if args.quick_run.lower() == "false":
|
||||
TOP_N = 10
|
||||
train_dataset, test_dataset = CNNDMSummarizationDataset(
|
||||
top_n=-1, local_cache_path=args.data_dir, prepare_extractive=False
|
||||
top_n=TOP_N, local_cache_path=args.data_dir, prepare_extractive=False
|
||||
)
|
||||
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
|
@ -212,6 +214,7 @@ def main_worker(
|
|||
checkpoint = os.path.join(args.cache_dir, args.checkpoint_filename)
|
||||
else:
|
||||
checkpoint = None
|
||||
|
||||
# train_sum_dataset, test_sum_dataset = load_processed_cnndm_abs(args.data_dir)
|
||||
def this_validate(class_obj):
|
||||
return validate(class_obj, test_dataset)
|
||||
|
@ -225,8 +228,8 @@ def main_worker(
|
|||
fp16 = args.fp16.lower() == "true"
|
||||
print("fp16 is {}".format(fp16))
|
||||
# total number of steps for training
|
||||
MAX_STEPS = 50
|
||||
SAVE_EVERY = 50
|
||||
MAX_STEPS = 10
|
||||
SAVE_EVERY = 10
|
||||
REPORT_EVERY = 10
|
||||
# number of steps for warm up
|
||||
WARMUP_STEPS_BERT = MAX_STEPS
|
||||
|
@ -235,7 +238,7 @@ def main_worker(
|
|||
MAX_STEPS = args.max_steps
|
||||
WARMUP_STEPS_BERT = args.warmup_steps_bert
|
||||
WARMUP_STEPS_DEC = args.warmup_steps_dec
|
||||
SAVE_EVERY = args.save_every
|
||||
SAVE_EVERY = save_every
|
||||
REPORT_EVERY = args.report_every
|
||||
|
||||
print("max steps is {}".format(MAX_STEPS))
|
||||
|
@ -266,7 +269,7 @@ def main_worker(
|
|||
|
||||
end = time.time()
|
||||
print("rank {0}, duration {1:.6f}s".format(rank, end - start))
|
||||
if rank == 0 or local_rank == -1:
|
||||
if local_rank in [0, -1] and args.rank == 0:
|
||||
saved_model_path = os.path.join(
|
||||
args.output_dir, "{}_step{}".format(args.model_filename, MAX_STEPS)
|
||||
)
|
||||
|
|
Загрузка…
Ссылка в новой задаче