minor changes
This commit is contained in:
Родитель
18c9e661b2
Коммит
dafc09cc63
|
@ -3,7 +3,6 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
|
@ -11,12 +10,12 @@ import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
# torch.set_printoptions(threshold=5000)
|
# torch.set_printoptions(threshold=5000)
|
||||||
from tempfile import TemporaryDirectory
|
|
||||||
|
|
||||||
nlp_path = os.path.abspath("../../")
|
nlp_path = os.path.abspath("../../")
|
||||||
if nlp_path not in sys.path:
|
if nlp_path not in sys.path:
|
||||||
sys.path.insert(0, nlp_path)
|
sys.path.insert(0, nlp_path)
|
||||||
|
|
||||||
|
sys.path.insert(0, "./")
|
||||||
|
|
||||||
from utils_nlp.models.transformers.abstractive_summarization_bertsum import (
|
from utils_nlp.models.transformers.abstractive_summarization_bertsum import (
|
||||||
BertSumAbs,
|
BertSumAbs,
|
||||||
|
@ -40,7 +39,7 @@ parser.add_argument(
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dist_url",
|
"--dist_url",
|
||||||
type=str,
|
type=str,
|
||||||
default="tcp://127.0.0.1:29500",
|
default="tcp://127.0.0.1:29507",
|
||||||
help="URL specifying how to initialize the process groupi.",
|
help="URL specifying how to initialize the process groupi.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -56,7 +55,7 @@ parser.add_argument(
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data_dir",
|
"--data_dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="./",
|
default="./abstemp",
|
||||||
help="Directory to download the preprocessed data.",
|
help="Directory to download the preprocessed data.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -101,8 +100,8 @@ parser.add_argument(
|
||||||
"--max_steps",
|
"--max_steps",
|
||||||
type=int,
|
type=int,
|
||||||
default=5e4,
|
default=5e4,
|
||||||
help="Maximum number of training steps run in training. \
|
help="""Maximum number of training steps run in training.
|
||||||
If quick_run is set, it's not used.",
|
If quick_run is set, it's not used.""",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--warmup_steps_bert",
|
"--warmup_steps_bert",
|
||||||
|
@ -176,8 +175,11 @@ def main():
|
||||||
print("data_dir is {}".format(args.data_dir))
|
print("data_dir is {}".format(args.data_dir))
|
||||||
print("cache_dir is {}".format(args.cache_dir))
|
print("cache_dir is {}".format(args.cache_dir))
|
||||||
|
|
||||||
|
TOP_N = -1
|
||||||
|
if args.quick_run.lower() == "false":
|
||||||
|
TOP_N = 10
|
||||||
train_dataset, test_dataset = CNNDMSummarizationDataset(
|
train_dataset, test_dataset = CNNDMSummarizationDataset(
|
||||||
top_n=-1, local_cache_path=args.data_dir, prepare_extractive=False
|
top_n=TOP_N, local_cache_path=args.data_dir, prepare_extractive=False
|
||||||
)
|
)
|
||||||
|
|
||||||
ngpus_per_node = torch.cuda.device_count()
|
ngpus_per_node = torch.cuda.device_count()
|
||||||
|
@ -212,6 +214,7 @@ def main_worker(
|
||||||
checkpoint = os.path.join(args.cache_dir, args.checkpoint_filename)
|
checkpoint = os.path.join(args.cache_dir, args.checkpoint_filename)
|
||||||
else:
|
else:
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
|
|
||||||
# train_sum_dataset, test_sum_dataset = load_processed_cnndm_abs(args.data_dir)
|
# train_sum_dataset, test_sum_dataset = load_processed_cnndm_abs(args.data_dir)
|
||||||
def this_validate(class_obj):
|
def this_validate(class_obj):
|
||||||
return validate(class_obj, test_dataset)
|
return validate(class_obj, test_dataset)
|
||||||
|
@ -225,8 +228,8 @@ def main_worker(
|
||||||
fp16 = args.fp16.lower() == "true"
|
fp16 = args.fp16.lower() == "true"
|
||||||
print("fp16 is {}".format(fp16))
|
print("fp16 is {}".format(fp16))
|
||||||
# total number of steps for training
|
# total number of steps for training
|
||||||
MAX_STEPS = 50
|
MAX_STEPS = 10
|
||||||
SAVE_EVERY = 50
|
SAVE_EVERY = 10
|
||||||
REPORT_EVERY = 10
|
REPORT_EVERY = 10
|
||||||
# number of steps for warm up
|
# number of steps for warm up
|
||||||
WARMUP_STEPS_BERT = MAX_STEPS
|
WARMUP_STEPS_BERT = MAX_STEPS
|
||||||
|
@ -235,7 +238,7 @@ def main_worker(
|
||||||
MAX_STEPS = args.max_steps
|
MAX_STEPS = args.max_steps
|
||||||
WARMUP_STEPS_BERT = args.warmup_steps_bert
|
WARMUP_STEPS_BERT = args.warmup_steps_bert
|
||||||
WARMUP_STEPS_DEC = args.warmup_steps_dec
|
WARMUP_STEPS_DEC = args.warmup_steps_dec
|
||||||
SAVE_EVERY = args.save_every
|
SAVE_EVERY = save_every
|
||||||
REPORT_EVERY = args.report_every
|
REPORT_EVERY = args.report_every
|
||||||
|
|
||||||
print("max steps is {}".format(MAX_STEPS))
|
print("max steps is {}".format(MAX_STEPS))
|
||||||
|
@ -266,7 +269,7 @@ def main_worker(
|
||||||
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print("rank {0}, duration {1:.6f}s".format(rank, end - start))
|
print("rank {0}, duration {1:.6f}s".format(rank, end - start))
|
||||||
if rank == 0 or local_rank == -1:
|
if local_rank in [0, -1] and args.rank == 0:
|
||||||
saved_model_path = os.path.join(
|
saved_model_path = os.path.join(
|
||||||
args.output_dir, "{}_step{}".format(args.model_filename, MAX_STEPS)
|
args.output_dir, "{}_step{}".format(args.model_filename, MAX_STEPS)
|
||||||
)
|
)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче