зеркало из https://github.com/microsoft/ptgnn.git
Explicitly set the starting epoch index to the restored one.
This commit is contained in:
Родитель
764a7038ff
Коммит
03ab80499b
|
@ -195,6 +195,7 @@ class DistributedModelTrainer(ModelTrainer[TRawDatapoint, TTensorizedDatapoint,
|
|||
device=None,
|
||||
store_tensorized_data_in_memory: bool = False,
|
||||
shuffle_training_data: bool = True,
|
||||
start_epoch_idx: int = 0,
|
||||
) -> None:
|
||||
raise Exception("Use `distributed_train()` instead of calling `train().")
|
||||
|
||||
|
@ -210,6 +211,7 @@ class DistributedModelTrainer(ModelTrainer[TRawDatapoint, TTensorizedDatapoint,
|
|||
parallelize: bool = True,
|
||||
shuffle_training_data: bool = True,
|
||||
worker_init: Optional[Callable[["DistributedModelTrainer", int, int], None]] = None,
|
||||
start_epoch_idx: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
The training-validation loop for `AbstractNeuralModel`s.
|
||||
|
@ -224,6 +226,7 @@ class DistributedModelTrainer(ModelTrainer[TRawDatapoint, TTensorizedDatapoint,
|
|||
assume that the model that is being trained has its metadata already initialized.
|
||||
:param parallelize: Bool indicating whether to run in parallel
|
||||
:param shuffle_training_data: shuffle the incoming data from `training_data`.
|
||||
:param start_epoch_idx: the idx of the first epoch in this training loop (used for resuming).
|
||||
"""
|
||||
assert torch.distributed.is_available()
|
||||
|
||||
|
@ -253,6 +256,7 @@ class DistributedModelTrainer(ModelTrainer[TRawDatapoint, TTensorizedDatapoint,
|
|||
shuffle_training_data,
|
||||
validate_on_start,
|
||||
worker_init,
|
||||
start_epoch_idx,
|
||||
),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
|
@ -270,6 +274,7 @@ class DistributedModelTrainer(ModelTrainer[TRawDatapoint, TTensorizedDatapoint,
|
|||
shuffle_training_data,
|
||||
validate_on_start: bool,
|
||||
worker_init: Optional[Callable[["DistributedModelTrainer", int, int], None]] = None,
|
||||
start_epoch_idx: int = 0,
|
||||
):
|
||||
assert torch.cuda.is_available(), "No CUDA available. Aborting training."
|
||||
|
||||
|
@ -330,7 +335,7 @@ class DistributedModelTrainer(ModelTrainer[TRawDatapoint, TTensorizedDatapoint,
|
|||
|
||||
num_epochs_not_improved: int = 0
|
||||
|
||||
for epoch in range(self._max_num_epochs):
|
||||
for epoch in range(start=start_epoch_idx, stop=self._max_num_epochs):
|
||||
try:
|
||||
self._run_training(
|
||||
distributed_neural_module,
|
||||
|
|
|
@ -328,6 +328,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
|
|||
device=None,
|
||||
store_tensorized_data_in_memory: bool = False,
|
||||
shuffle_training_data: bool = True,
|
||||
start_epoch_idx: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
The training-validation loop for `AbstractNeuralModel`s.
|
||||
|
@ -346,6 +347,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
|
|||
:param device: the target PyTorch device for training
|
||||
:param store_tensorized_data_in_memory: store all tensorized data in memory instead of computing them on-line.
|
||||
:param shuffle_training_data: shuffle the incoming data from `training_data`.
|
||||
:param start_epoch_idx: the idx of the first epoch in this training loop (used for resuming).
|
||||
"""
|
||||
if initialize_metadata:
|
||||
self.load_metadata_and_create_network(training_data, parallelize, show_progress_bar)
|
||||
|
@ -394,7 +396,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
|
|||
best_target_metric = target_metric
|
||||
|
||||
num_epochs_not_improved: int = 0
|
||||
for epoch in range(self._max_num_epochs):
|
||||
for epoch in range(start=start_epoch_idx, stop=self._max_num_epochs):
|
||||
self._run_training(
|
||||
training_tensors,
|
||||
epoch,
|
||||
|
@ -408,11 +410,9 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
|
|||
)
|
||||
|
||||
# Save optimizer and epoch id for scheduler
|
||||
torch.save({
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"epoch": epoch + 1
|
||||
},
|
||||
self._checkpoint_location.with_suffix(".optimizerstate")
|
||||
torch.save(
|
||||
{"optimizer_state_dict": optimizer.state_dict(), "epoch": epoch + 1},
|
||||
self._checkpoint_location.with_suffix(".optimizerstate"),
|
||||
)
|
||||
|
||||
target_metric, target_metric_improved, validation_metrics = self._run_validation(
|
||||
|
|
|
@ -164,6 +164,7 @@ def run(arguments):
|
|||
|
||||
if arguments["--restore-optimizer"]:
|
||||
opt_state = torch.load(arguments["--restore-optimizer"])
|
||||
current_epoch_idx = opt_state["epoch"]
|
||||
|
||||
def create_optimizer(parameters):
|
||||
opt = torch.optim.Adam(parameters, lr=0.00025)
|
||||
|
@ -173,6 +174,7 @@ def run(arguments):
|
|||
return opt
|
||||
|
||||
else:
|
||||
current_epoch_idx = 0
|
||||
|
||||
def create_optimizer(parameters):
|
||||
return torch.optim.Adam(parameters, lr=0.00025)
|
||||
|
@ -206,6 +208,7 @@ def run(arguments):
|
|||
parallelize=not arguments["--sequential-run"],
|
||||
patience=10,
|
||||
store_tensorized_data_in_memory=True,
|
||||
start_epoch_idx=current_epoch_idx,
|
||||
)
|
||||
|
||||
test_data_path = RichPath.create(arguments["TEST_DATA_PATH"], azure_info_path)
|
||||
|
|
|
@ -134,9 +134,11 @@ def run(arguments):
|
|||
|
||||
if arguments["--restore-optimizer"]:
|
||||
opt_state = torch.load(arguments["--restore-optimizer"])
|
||||
current_epoch_idx = opt_state["epoch"]
|
||||
create_optimizer_ = partial(create_optimizer, state=opt_state["optimizer_state_dict"])
|
||||
else:
|
||||
create_optimizer_ = create_optimizer
|
||||
current_epoch_idx = 0
|
||||
|
||||
trainer = DistributedModelTrainer(
|
||||
model,
|
||||
|
@ -166,6 +168,7 @@ def run(arguments):
|
|||
shuffle_training_data=True,
|
||||
patience=10,
|
||||
worker_init=worker_init,
|
||||
start_epoch_idx=current_epoch_idx,
|
||||
)
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче