This commit is contained in:
Liqun Shao 2019-06-05 17:30:06 -04:00
Родитель 2c4cc44839
Коммит 0c79d68381
1 изменённых файлов: 5 добавлений и 10 удалений

Просмотреть файл

@ -183,14 +183,13 @@ def train(config, data_folder, learning_rate=0.0001):
"""Using Horovod"""
# Horovod: scale learning rate by the number of GPUs.
optimizer = optim.Adam(model.parameters(), lr=learning_rate * hvd.size())
# optimizer = optim.SGD(model.parameters(), lr=args.lr * hvd.size(), momentum=args.momentum)
# Horovod: broadcast parameters & optimizer state.
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
# Horovod: (optional) compression algorithm.
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none
compression = hvd.Compression.fp16
# Horovod: wrap optimizer with DistributedOptimizer.
optimizer = hvd.DistributedOptimizer(optimizer,
@ -322,8 +321,6 @@ def train(config, data_folder, learning_rate=0.0001):
torch.nn.utils.clip_grad_norm(model.parameters(), 1.0)
optimizer.step()
end = time.time()
mbatch_times.append(end - start)
@ -416,7 +413,8 @@ def train(config, data_folder, learning_rate=0.0001):
if break_flag == 1:
end_training = time.time()
logging.info("##### Training stopped at ##### %f" % min_val_loss)
logging.info("##### Training Time ##### %f seconds" % (end_training-start_training))
logging.info(
"##### Training Time ##### %f seconds" % (end_training - start_training))
break
logging.info("Evaluating on NLI")
n_correct = 0.0
@ -492,10 +490,7 @@ if __name__ == "__main__":
parser.add_argument('--data_folder', type=str, help='data folder')
# Add learning rate to tune model.
parser.add_argument('--learning_rate', type=float, default=0.0001, help='learning rate')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--fp16-allreduce', action='store_true', default=False,
help='use fp16 compression during allreduce')
args = parser.parse_args()
data_folder = args.data_folder
learning_rate = args.learning_rate
@ -503,4 +498,4 @@ if __name__ == "__main__":
config_file_path = args.config
config = read_config(config_file_path)
train(config, data_folder, learning_rate)
train(config, data_folder, learning_rate)