diff --git a/examples/text_summarization/extractive_summarization_cnndm_distributed_train.py b/examples/text_summarization/extractive_summarization_cnndm_distributed_train.py index 3eeac37..b7d7e7a 100644 --- a/examples/text_summarization/extractive_summarization_cnndm_distributed_train.py +++ b/examples/text_summarization/extractive_summarization_cnndm_distributed_train.py @@ -86,8 +86,10 @@ parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning parser.add_argument( "--batch_size", type=int, - default=3000, - help="batch size in terms of input token numbers in training", + default=5, + help="batch size in terms of the number of samples in training", + # default=3000, + # help="batch size in terms of input token numbers in training", ) parser.add_argument( "--max_steps", @@ -208,7 +210,7 @@ def main_worker(local_rank, ngpus_per_node, summarizer, args): summarizer.fit( ext_sum_train, num_gpus=world_size, - batch_size=5, # args.batch_size, + batch_size=args.batch_size, gradient_accumulation_steps=1, max_steps=MAX_STEPS / world_size, learning_rate=args.learning_rate,