change the stopping condition to when the validation loss is small

This commit is contained in:
Liqun Shao 2019-05-31 17:22:11 -04:00 коммит произвёл Liqun Shao
Родитель fde487ea89
Коммит 786f8de629
2 изменённых файлов: 20 добавлений и 31 удалений

Просмотреть файл

@ -5,7 +5,7 @@
"lrate": 0.0001, "lrate": 0.0001,
"batch_size": 48, "batch_size": 48,
"n_gpus": 1, "n_gpus": 1,
"stop_patience": 10000 "stop_patience": 5
}, },
"management": { "management": {
"monitor_loss": 9600, "monitor_loss": 9600,

Просмотреть файл

@ -381,20 +381,7 @@ def train(config, data_folder, learning_rate=0.0001):
mbatch_times = [] mbatch_times = []
nli_losses = [] nli_losses = []
# if ( if updates % (len(nli_iterator.train_lines)/(batch_size*n_gpus)) == 0:
# updates % config["management"]["checkpoint_freq"] == 0
# and updates != 0
# ):
# logging.info("Saving model ...")
#
# torch.save(
# model.state_dict(),
# open(os.path.join(save_dir, "best_model.model"), "wb"),
# )
# # Let the training end.
# break
if updates % config["management"]["eval_freq"] == 0:
logging.info("############################") logging.info("############################")
logging.info("##### Evaluating model #####") logging.info("##### Evaluating model #####")
logging.info("############################") logging.info("############################")
@ -414,6 +401,24 @@ def train(config, data_folder, learning_rate=0.0001):
) )
# log the best val accuracy to AML run # log the best val accuracy to AML run
run.log("best_val_loss", np.float(validation_loss)) run.log("best_val_loss", np.float(validation_loss))
# If the validation loss is small enough, and it starts to go up.
# Should stop training.
# Small is defined by the number of epochs it lasts.
if validation_loss < min_val_loss:
min_val_loss = validation_loss
min_val_loss_epoch = nli_epoch
if nli_epoch - min_val_loss_epoch > config['training']['stop_patience']:
print(nli_epoch, min_val_loss_epoch, min_val_loss)
logging.info("Saving model ...")
torch.save(
model.state_dict(),
open(os.path.join(save_dir, "best_model.model"), "wb"),
)
# Let the training end.
break
logging.info("Evaluating on NLI") logging.info("Evaluating on NLI")
n_correct = 0.0 n_correct = 0.0
n_wrong = 0.0 n_wrong = 0.0
@ -464,22 +469,6 @@ def train(config, data_folder, learning_rate=0.0001):
logging.info( logging.info(
"******************************************************" "******************************************************"
) )
# If the validation loss is small enough, and it starts to go up. Should stop training.
# Small is defined by the number of epochs it lasts.
if validation_loss < min_val_loss:
min_val_loss = validation_loss
min_val_loss_epoch = updates
if updates - min_val_loss_epoch > config['training']['stop_patience']:
print(updates, min_val_loss_epoch, min_val_loss)
logging.info("Saving model ...")
torch.save(
model.state_dict(),
open(os.path.join(save_dir, "best_model.model"), "wb"),
)
# Let the training end.
break
updates += batch_size * n_gpus updates += batch_size * n_gpus
nli_ctr += 1 nli_ctr += 1