change the stopping condition to when the validation loss is small
This commit is contained in:
Родитель
fde487ea89
Коммит
786f8de629
|
@ -5,7 +5,7 @@
|
||||||
"lrate": 0.0001,
|
"lrate": 0.0001,
|
||||||
"batch_size": 48,
|
"batch_size": 48,
|
||||||
"n_gpus": 1,
|
"n_gpus": 1,
|
||||||
"stop_patience": 10000
|
"stop_patience": 5
|
||||||
},
|
},
|
||||||
"management": {
|
"management": {
|
||||||
"monitor_loss": 9600,
|
"monitor_loss": 9600,
|
||||||
|
|
|
@ -381,20 +381,7 @@ def train(config, data_folder, learning_rate=0.0001):
|
||||||
mbatch_times = []
|
mbatch_times = []
|
||||||
nli_losses = []
|
nli_losses = []
|
||||||
|
|
||||||
# if (
|
if updates % (len(nli_iterator.train_lines)/(batch_size*n_gpus)) == 0:
|
||||||
# updates % config["management"]["checkpoint_freq"] == 0
|
|
||||||
# and updates != 0
|
|
||||||
# ):
|
|
||||||
# logging.info("Saving model ...")
|
|
||||||
#
|
|
||||||
# torch.save(
|
|
||||||
# model.state_dict(),
|
|
||||||
# open(os.path.join(save_dir, "best_model.model"), "wb"),
|
|
||||||
# )
|
|
||||||
# # Let the training end.
|
|
||||||
# break
|
|
||||||
|
|
||||||
if updates % config["management"]["eval_freq"] == 0:
|
|
||||||
logging.info("############################")
|
logging.info("############################")
|
||||||
logging.info("##### Evaluating model #####")
|
logging.info("##### Evaluating model #####")
|
||||||
logging.info("############################")
|
logging.info("############################")
|
||||||
|
@ -414,6 +401,24 @@ def train(config, data_folder, learning_rate=0.0001):
|
||||||
)
|
)
|
||||||
# log the best val accuracy to AML run
|
# log the best val accuracy to AML run
|
||||||
run.log("best_val_loss", np.float(validation_loss))
|
run.log("best_val_loss", np.float(validation_loss))
|
||||||
|
|
||||||
|
# If the validation loss is small enough, and it starts to go up.
|
||||||
|
# Should stop training.
|
||||||
|
# Small is defined by the number of epochs it lasts.
|
||||||
|
if validation_loss < min_val_loss:
|
||||||
|
min_val_loss = validation_loss
|
||||||
|
min_val_loss_epoch = nli_epoch
|
||||||
|
|
||||||
|
if nli_epoch - min_val_loss_epoch > config['training']['stop_patience']:
|
||||||
|
print(nli_epoch, min_val_loss_epoch, min_val_loss)
|
||||||
|
logging.info("Saving model ...")
|
||||||
|
|
||||||
|
torch.save(
|
||||||
|
model.state_dict(),
|
||||||
|
open(os.path.join(save_dir, "best_model.model"), "wb"),
|
||||||
|
)
|
||||||
|
# Let the training end.
|
||||||
|
break
|
||||||
logging.info("Evaluating on NLI")
|
logging.info("Evaluating on NLI")
|
||||||
n_correct = 0.0
|
n_correct = 0.0
|
||||||
n_wrong = 0.0
|
n_wrong = 0.0
|
||||||
|
@ -464,22 +469,6 @@ def train(config, data_folder, learning_rate=0.0001):
|
||||||
logging.info(
|
logging.info(
|
||||||
"******************************************************"
|
"******************************************************"
|
||||||
)
|
)
|
||||||
# If the validation loss is small enough, and it starts to go up. Should stop training.
|
|
||||||
# Small is defined by the number of epochs it lasts.
|
|
||||||
if validation_loss < min_val_loss:
|
|
||||||
min_val_loss = validation_loss
|
|
||||||
min_val_loss_epoch = updates
|
|
||||||
|
|
||||||
if updates - min_val_loss_epoch > config['training']['stop_patience']:
|
|
||||||
print(updates, min_val_loss_epoch, min_val_loss)
|
|
||||||
logging.info("Saving model ...")
|
|
||||||
|
|
||||||
torch.save(
|
|
||||||
model.state_dict(),
|
|
||||||
open(os.path.join(save_dir, "best_model.model"), "wb"),
|
|
||||||
)
|
|
||||||
# Let the training end.
|
|
||||||
break
|
|
||||||
|
|
||||||
updates += batch_size * n_gpus
|
updates += batch_size * n_gpus
|
||||||
nli_ctr += 1
|
nli_ctr += 1
|
||||||
|
|
Загрузка…
Ссылка в новой задаче