remove all unnecessary labels
This commit is contained in:
Родитель
2c4cc44839
Коммит
0c79d68381
|
@ -183,14 +183,13 @@ def train(config, data_folder, learning_rate=0.0001):
|
|||
"""Using Horovod"""
|
||||
# Horovod: scale learning rate by the number of GPUs.
|
||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate * hvd.size())
|
||||
# optimizer = optim.SGD(model.parameters(), lr=args.lr * hvd.size(), momentum=args.momentum)
|
||||
|
||||
# Horovod: broadcast parameters & optimizer state.
|
||||
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
|
||||
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
|
||||
|
||||
# Horovod: (optional) compression algorithm.
|
||||
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none
|
||||
compression = hvd.Compression.fp16
|
||||
|
||||
# Horovod: wrap optimizer with DistributedOptimizer.
|
||||
optimizer = hvd.DistributedOptimizer(optimizer,
|
||||
|
@ -322,8 +321,6 @@ def train(config, data_folder, learning_rate=0.0001):
|
|||
torch.nn.utils.clip_grad_norm(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
|
||||
|
||||
|
||||
end = time.time()
|
||||
mbatch_times.append(end - start)
|
||||
|
||||
|
@ -416,7 +413,8 @@ def train(config, data_folder, learning_rate=0.0001):
|
|||
if break_flag == 1:
|
||||
end_training = time.time()
|
||||
logging.info("##### Training stopped at ##### %f" % min_val_loss)
|
||||
logging.info("##### Training Time ##### %f seconds" % (end_training-start_training))
|
||||
logging.info(
|
||||
"##### Training Time ##### %f seconds" % (end_training - start_training))
|
||||
break
|
||||
logging.info("Evaluating on NLI")
|
||||
n_correct = 0.0
|
||||
|
@ -492,10 +490,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument('--data_folder', type=str, help='data folder')
|
||||
# Add learning rate to tune model.
|
||||
parser.add_argument('--learning_rate', type=float, default=0.0001, help='learning rate')
|
||||
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
|
||||
help='SGD momentum (default: 0.5)')
|
||||
parser.add_argument('--fp16-allreduce', action='store_true', default=False,
|
||||
help='use fp16 compression during allreduce')
|
||||
|
||||
args = parser.parse_args()
|
||||
data_folder = args.data_folder
|
||||
learning_rate = args.learning_rate
|
||||
|
@ -503,4 +498,4 @@ if __name__ == "__main__":
|
|||
|
||||
config_file_path = args.config
|
||||
config = read_config(config_file_path)
|
||||
train(config, data_folder, learning_rate)
|
||||
train(config, data_folder, learning_rate)
|
||||
|
|
Загрузка…
Ссылка в новой задаче