Fix bugs in distributed saving of optimizer state.

This commit is contained in:
Miltos Allamanis 2021-07-20 14:19:35 +00:00 коммит произвёл Miltos
Родитель 3bd5966e88
Коммит 764a7038ff
2 изменённых файлов: 37 добавлений и 20 удалений

Просмотреть файл

@ -346,13 +346,13 @@ class DistributedModelTrainer(ModelTrainer[TRawDatapoint, TTensorizedDatapoint,
self.LOGGER.exception("Error during training", exc_info=e)
raise e
# Save optimizer and epoch id for scheduler
optimizer_state = optimizer.state_dict()
if dist.get_rank() == 0:
# Save optimizer and epoch id for scheduler
torch.save({
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch + 1
},
self._checkpoint_location.with_suffix(".optimizerstate")
# Save only on rank 0
torch.save(
{"optimizer_state_dict": optimizer_state, "epoch": epoch + 1},
self._checkpoint_location.with_suffix(".optimizerstate"),
)
target_metric, target_metric_improved, validation_metrics = self._run_validation(

Просмотреть файл

@ -20,10 +20,13 @@ Options:
import logging
import random
import torch
import torch.distributed as dist
from docopt import docopt
from dpu_utils.utils import RichPath, run_and_debug
from functools import partial
from pathlib import Path
from torch.distributed.optim import ZeroRedundancyOptimizer
from typing import Optional
from ptgnn.baseneuralmodel.distributedtrainer import DistributedModelTrainer
from ptgnn.baseneuralmodel.utils.amlutils import configure_logging, log_run
@ -45,11 +48,29 @@ def load_from_folder(path: RichPath, shuffle: bool, rank: int, world_size):
yield from file.read_as_jsonl()
def create_optimizer(parameters):
from torch.distributed.optim import ZeroRedundancyOptimizer
class ZeroRedundancyOptimizer_(ZeroRedundancyOptimizer):
"""
ZeroRedundancyOptimizer has a different interface than other optimizer and the `consolidate_state_dict`
needs to be invoked.
This class wraps `ZeroRedundancyOptimizer` and ensure that `state_dict()` works as all other optimizers
(at least for rank 0).
"""
def state_dict(self):
self.consolidate_state_dict()
if dist.get_rank() == 0:
return ZeroRedundancyOptimizer.state_dict(self)
return None
def create_optimizer(parameters, state: Optional = None):
# return torch.optim.Adam(parameters, lr=0.01)
return ZeroRedundancyOptimizer(parameters, optimizer_class=torch.optim.Adam, lr=0.001)
optimizer = ZeroRedundancyOptimizer_(parameters, optimizer_class=torch.optim.Adam, lr=0.001)
if state is not None:
optimizer.load_state_dict(state)
return optimizer
def log_run_lambda(aml_ctx, fold, model, nn, epoch, metrics):
@ -75,7 +96,10 @@ def worker_init(trainer: DistributedModelTrainer, rank, world_isze):
def upload_hook(model, nn, epoch, metrics):
aml_ctx.upload_file(name="model.pkl.gz", path_or_stream=str(trainer._checkpoint_location))
aml_ctx.upload_file(name="full.log", path_or_stream=log_path)
aml_ctx.upload_file(name="optimizer_state.pt", path_or_stream=str(trainer._checkpoint_location.with_suffix(".optimizerstate")))
aml_ctx.upload_file(
name="optimizer_state.pt",
path_or_stream=str(trainer._checkpoint_location.with_suffix(".optimizerstate")),
)
if rank == 0 and aml_ctx is not None:
trainer.register_epoch_improved_end_hook(upload_hook)
@ -110,23 +134,16 @@ def run(arguments):
if arguments["--restore-optimizer"]:
opt_state = torch.load(arguments["--restore-optimizer"])
def create_optimizer(parameters):
opt = torch.optim.Adam(parameters, lr=0.00025)
logging.info("Restoring optimizer state from `%s`.", arguments["--restore-optimizer"])
opt.load_state_dict(opt_state["optimizer_state_dict"])
return opt
create_optimizer_ = partial(create_optimizer, state=opt_state["optimizer_state_dict"])
else:
def create_optimizer(parameters):
return torch.optim.Adam(parameters, lr=0.00025)
create_optimizer_ = create_optimizer
trainer = DistributedModelTrainer(
model,
model_path,
max_num_epochs=int(arguments["--max-num-epochs"]),
minibatch_size=int(arguments["--minibatch-size"]),
optimizer_creator=create_optimizer,
optimizer_creator=create_optimizer_,
clip_gradient_norm=1,
target_validation_metric="Accuracy",
target_validation_metric_higher_is_better=True,