Explicitly set the starting epoch index to the restored one.

This commit is contained in:
Miltos Allamanis 2021-07-21 10:23:09 +01:00 коммит произвёл Miltos
Родитель 764a7038ff
Коммит 03ab80499b
4 изменённых файлов: 18 добавлений и 7 удалений

Просмотреть файл

@ -195,6 +195,7 @@ class DistributedModelTrainer(ModelTrainer[TRawDatapoint, TTensorizedDatapoint,
device=None,
store_tensorized_data_in_memory: bool = False,
shuffle_training_data: bool = True,
start_epoch_idx: int = 0,
) -> None:
raise Exception("Use `distributed_train()` instead of calling `train().")
@ -210,6 +211,7 @@ class DistributedModelTrainer(ModelTrainer[TRawDatapoint, TTensorizedDatapoint,
parallelize: bool = True,
shuffle_training_data: bool = True,
worker_init: Optional[Callable[["DistributedModelTrainer", int, int], None]] = None,
start_epoch_idx: int = 0,
) -> None:
"""
The training-validation loop for `AbstractNeuralModel`s.
@ -224,6 +226,7 @@ class DistributedModelTrainer(ModelTrainer[TRawDatapoint, TTensorizedDatapoint,
assume that the model that is being trained has its metadata already initialized.
:param parallelize: Bool indicating whether to run in parallel
:param shuffle_training_data: shuffle the incoming data from `training_data`.
:param start_epoch_idx: the idx of the first epoch in this training loop (used for resuming).
"""
assert torch.distributed.is_available()
@ -253,6 +256,7 @@ class DistributedModelTrainer(ModelTrainer[TRawDatapoint, TTensorizedDatapoint,
shuffle_training_data,
validate_on_start,
worker_init,
start_epoch_idx,
),
nprocs=world_size,
join=True,
@ -270,6 +274,7 @@ class DistributedModelTrainer(ModelTrainer[TRawDatapoint, TTensorizedDatapoint,
shuffle_training_data,
validate_on_start: bool,
worker_init: Optional[Callable[["DistributedModelTrainer", int, int], None]] = None,
start_epoch_idx: int = 0,
):
assert torch.cuda.is_available(), "No CUDA available. Aborting training."
@ -330,7 +335,7 @@ class DistributedModelTrainer(ModelTrainer[TRawDatapoint, TTensorizedDatapoint,
num_epochs_not_improved: int = 0
for epoch in range(self._max_num_epochs):
for epoch in range(start=start_epoch_idx, stop=self._max_num_epochs):
try:
self._run_training(
distributed_neural_module,

Просмотреть файл

@ -328,6 +328,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
device=None,
store_tensorized_data_in_memory: bool = False,
shuffle_training_data: bool = True,
start_epoch_idx: int = 0,
) -> None:
"""
The training-validation loop for `AbstractNeuralModel`s.
@ -346,6 +347,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
:param device: the target PyTorch device for training
:param store_tensorized_data_in_memory: store all tensorized data in memory instead of computing them on-line.
:param shuffle_training_data: shuffle the incoming data from `training_data`.
:param start_epoch_idx: the idx of the first epoch in this training loop (used for resuming).
"""
if initialize_metadata:
self.load_metadata_and_create_network(training_data, parallelize, show_progress_bar)
@ -394,7 +396,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
best_target_metric = target_metric
num_epochs_not_improved: int = 0
for epoch in range(self._max_num_epochs):
for epoch in range(start=start_epoch_idx, stop=self._max_num_epochs):
self._run_training(
training_tensors,
epoch,
@ -408,11 +410,9 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
)
# Save optimizer and epoch id for scheduler
torch.save({
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch + 1
},
self._checkpoint_location.with_suffix(".optimizerstate")
torch.save(
{"optimizer_state_dict": optimizer.state_dict(), "epoch": epoch + 1},
self._checkpoint_location.with_suffix(".optimizerstate"),
)
target_metric, target_metric_improved, validation_metrics = self._run_validation(

Просмотреть файл

@ -164,6 +164,7 @@ def run(arguments):
if arguments["--restore-optimizer"]:
opt_state = torch.load(arguments["--restore-optimizer"])
current_epoch_idx = opt_state["epoch"]
def create_optimizer(parameters):
opt = torch.optim.Adam(parameters, lr=0.00025)
@ -173,6 +174,7 @@ def run(arguments):
return opt
else:
current_epoch_idx = 0
def create_optimizer(parameters):
return torch.optim.Adam(parameters, lr=0.00025)
@ -206,6 +208,7 @@ def run(arguments):
parallelize=not arguments["--sequential-run"],
patience=10,
store_tensorized_data_in_memory=True,
start_epoch_idx=current_epoch_idx,
)
test_data_path = RichPath.create(arguments["TEST_DATA_PATH"], azure_info_path)

Просмотреть файл

@ -134,9 +134,11 @@ def run(arguments):
if arguments["--restore-optimizer"]:
opt_state = torch.load(arguments["--restore-optimizer"])
current_epoch_idx = opt_state["epoch"]
create_optimizer_ = partial(create_optimizer, state=opt_state["optimizer_state_dict"])
else:
create_optimizer_ = create_optimizer
current_epoch_idx = 0
trainer = DistributedModelTrainer(
model,
@ -166,6 +168,7 @@ def run(arguments):
shuffle_training_data=True,
patience=10,
worker_init=worker_init,
start_epoch_idx=current_epoch_idx,
)