change the stopping condition to when the validation loss is small
This commit is contained in:
Родитель
fde487ea89
Коммит
786f8de629
|
@ -5,7 +5,7 @@
|
|||
"lrate": 0.0001,
|
||||
"batch_size": 48,
|
||||
"n_gpus": 1,
|
||||
"stop_patience": 10000
|
||||
"stop_patience": 5
|
||||
},
|
||||
"management": {
|
||||
"monitor_loss": 9600,
|
||||
|
|
|
@ -381,20 +381,7 @@ def train(config, data_folder, learning_rate=0.0001):
|
|||
mbatch_times = []
|
||||
nli_losses = []
|
||||
|
||||
# if (
|
||||
# updates % config["management"]["checkpoint_freq"] == 0
|
||||
# and updates != 0
|
||||
# ):
|
||||
# logging.info("Saving model ...")
|
||||
#
|
||||
# torch.save(
|
||||
# model.state_dict(),
|
||||
# open(os.path.join(save_dir, "best_model.model"), "wb"),
|
||||
# )
|
||||
# # Let the training end.
|
||||
# break
|
||||
|
||||
if updates % config["management"]["eval_freq"] == 0:
|
||||
if updates % (len(nli_iterator.train_lines)/(batch_size*n_gpus)) == 0:
|
||||
logging.info("############################")
|
||||
logging.info("##### Evaluating model #####")
|
||||
logging.info("############################")
|
||||
|
@ -414,6 +401,24 @@ def train(config, data_folder, learning_rate=0.0001):
|
|||
)
|
||||
# log the best val accuracy to AML run
|
||||
run.log("best_val_loss", np.float(validation_loss))
|
||||
|
||||
# If the validation loss is small enough, and it starts to go up.
|
||||
# Should stop training.
|
||||
# Small is defined by the number of epochs it lasts.
|
||||
if validation_loss < min_val_loss:
|
||||
min_val_loss = validation_loss
|
||||
min_val_loss_epoch = nli_epoch
|
||||
|
||||
if nli_epoch - min_val_loss_epoch > config['training']['stop_patience']:
|
||||
print(nli_epoch, min_val_loss_epoch, min_val_loss)
|
||||
logging.info("Saving model ...")
|
||||
|
||||
torch.save(
|
||||
model.state_dict(),
|
||||
open(os.path.join(save_dir, "best_model.model"), "wb"),
|
||||
)
|
||||
# Let the training end.
|
||||
break
|
||||
logging.info("Evaluating on NLI")
|
||||
n_correct = 0.0
|
||||
n_wrong = 0.0
|
||||
|
@ -464,22 +469,6 @@ def train(config, data_folder, learning_rate=0.0001):
|
|||
logging.info(
|
||||
"******************************************************"
|
||||
)
|
||||
# If the validation loss is small enough, and it starts to go up. Should stop training.
|
||||
# Small is defined by the number of epochs it lasts.
|
||||
if validation_loss < min_val_loss:
|
||||
min_val_loss = validation_loss
|
||||
min_val_loss_epoch = updates
|
||||
|
||||
if updates - min_val_loss_epoch > config['training']['stop_patience']:
|
||||
print(updates, min_val_loss_epoch, min_val_loss)
|
||||
logging.info("Saving model ...")
|
||||
|
||||
torch.save(
|
||||
model.state_dict(),
|
||||
open(os.path.join(save_dir, "best_model.model"), "wb"),
|
||||
)
|
||||
# Let the training end.
|
||||
break
|
||||
|
||||
updates += batch_size * n_gpus
|
||||
nli_ctr += 1
|
||||
|
|
Загрузка…
Ссылка в новой задаче