зеркало из https://github.com/microsoft/ptgnn.git
Fix bugs in distributed saving of optimizer state.
This commit is contained in:
Родитель
3bd5966e88
Коммит
764a7038ff
|
@ -346,13 +346,13 @@ class DistributedModelTrainer(ModelTrainer[TRawDatapoint, TTensorizedDatapoint,
|
|||
self.LOGGER.exception("Error during training", exc_info=e)
|
||||
raise e
|
||||
|
||||
# Save optimizer and epoch id for scheduler
|
||||
optimizer_state = optimizer.state_dict()
|
||||
if dist.get_rank() == 0:
|
||||
# Save optimizer and epoch id for scheduler
|
||||
torch.save({
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"epoch": epoch + 1
|
||||
},
|
||||
self._checkpoint_location.with_suffix(".optimizerstate")
|
||||
# Save only on rank 0
|
||||
torch.save(
|
||||
{"optimizer_state_dict": optimizer_state, "epoch": epoch + 1},
|
||||
self._checkpoint_location.with_suffix(".optimizerstate"),
|
||||
)
|
||||
|
||||
target_metric, target_metric_improved, validation_metrics = self._run_validation(
|
||||
|
|
|
@ -20,10 +20,13 @@ Options:
|
|||
import logging
|
||||
import random
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from docopt import docopt
|
||||
from dpu_utils.utils import RichPath, run_and_debug
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
from typing import Optional
|
||||
|
||||
from ptgnn.baseneuralmodel.distributedtrainer import DistributedModelTrainer
|
||||
from ptgnn.baseneuralmodel.utils.amlutils import configure_logging, log_run
|
||||
|
@ -45,11 +48,29 @@ def load_from_folder(path: RichPath, shuffle: bool, rank: int, world_size):
|
|||
yield from file.read_as_jsonl()
|
||||
|
||||
|
||||
def create_optimizer(parameters):
|
||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
class ZeroRedundancyOptimizer_(ZeroRedundancyOptimizer):
|
||||
"""
|
||||
ZeroRedundancyOptimizer has a different interface than other optimizer and the `consolidate_state_dict`
|
||||
needs to be invoked.
|
||||
|
||||
This class wraps `ZeroRedundancyOptimizer` and ensure that `state_dict()` works as all other optimizers
|
||||
(at least for rank 0).
|
||||
"""
|
||||
|
||||
def state_dict(self):
|
||||
self.consolidate_state_dict()
|
||||
if dist.get_rank() == 0:
|
||||
return ZeroRedundancyOptimizer.state_dict(self)
|
||||
return None
|
||||
|
||||
|
||||
def create_optimizer(parameters, state: Optional = None):
|
||||
# return torch.optim.Adam(parameters, lr=0.01)
|
||||
return ZeroRedundancyOptimizer(parameters, optimizer_class=torch.optim.Adam, lr=0.001)
|
||||
optimizer = ZeroRedundancyOptimizer_(parameters, optimizer_class=torch.optim.Adam, lr=0.001)
|
||||
if state is not None:
|
||||
optimizer.load_state_dict(state)
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
def log_run_lambda(aml_ctx, fold, model, nn, epoch, metrics):
|
||||
|
@ -75,7 +96,10 @@ def worker_init(trainer: DistributedModelTrainer, rank, world_isze):
|
|||
def upload_hook(model, nn, epoch, metrics):
|
||||
aml_ctx.upload_file(name="model.pkl.gz", path_or_stream=str(trainer._checkpoint_location))
|
||||
aml_ctx.upload_file(name="full.log", path_or_stream=log_path)
|
||||
aml_ctx.upload_file(name="optimizer_state.pt", path_or_stream=str(trainer._checkpoint_location.with_suffix(".optimizerstate")))
|
||||
aml_ctx.upload_file(
|
||||
name="optimizer_state.pt",
|
||||
path_or_stream=str(trainer._checkpoint_location.with_suffix(".optimizerstate")),
|
||||
)
|
||||
|
||||
if rank == 0 and aml_ctx is not None:
|
||||
trainer.register_epoch_improved_end_hook(upload_hook)
|
||||
|
@ -110,23 +134,16 @@ def run(arguments):
|
|||
|
||||
if arguments["--restore-optimizer"]:
|
||||
opt_state = torch.load(arguments["--restore-optimizer"])
|
||||
|
||||
def create_optimizer(parameters):
|
||||
opt = torch.optim.Adam(parameters, lr=0.00025)
|
||||
logging.info("Restoring optimizer state from `%s`.", arguments["--restore-optimizer"])
|
||||
opt.load_state_dict(opt_state["optimizer_state_dict"])
|
||||
return opt
|
||||
|
||||
create_optimizer_ = partial(create_optimizer, state=opt_state["optimizer_state_dict"])
|
||||
else:
|
||||
def create_optimizer(parameters):
|
||||
return torch.optim.Adam(parameters, lr=0.00025)
|
||||
create_optimizer_ = create_optimizer
|
||||
|
||||
trainer = DistributedModelTrainer(
|
||||
model,
|
||||
model_path,
|
||||
max_num_epochs=int(arguments["--max-num-epochs"]),
|
||||
minibatch_size=int(arguments["--minibatch-size"]),
|
||||
optimizer_creator=create_optimizer,
|
||||
optimizer_creator=create_optimizer_,
|
||||
clip_gradient_norm=1,
|
||||
target_validation_metric="Accuracy",
|
||||
target_validation_metric_higher_is_better=True,
|
||||
|
|
Загрузка…
Ссылка в новой задаче