ENH: Run extra validation epoch on a single gpu (#639)

Kill DDP processes after run_training and initialize a new trainer
instance with a single device for inference

Co-authored-by: Anton Schwaighofer <antonsc@microsoft.com>
This commit is contained in:
Kenza Bouzid 2022-10-21 14:51:13 +01:00 коммит произвёл GitHub
Родитель 048f9b2aa9
Коммит 84e37efc59
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 34 добавлений и 27 удалений

Просмотреть файл

@ -257,6 +257,25 @@ class MLRunner:
return self.container.crossval_index == 0
return True
def get_trainer_for_inference(self, checkpoint_path: Optional[Path] = None) -> Trainer:
# We run inference on a single device because distributed strategies such as DDP use DistributedSampler
# internally, which replicates some samples to make sure all devices have the same batch size in case of
# uneven inputs.
self.container.max_num_gpus = 1
if self.container.run_inference_only:
assert checkpoint_path is not None
else:
self.validate_model_weights()
trainer, _ = create_lightning_trainer(
container=self.container,
resume_from_checkpoint=checkpoint_path,
num_nodes=1,
azureml_run_for_logging=self.azureml_run_for_logging
)
return trainer
def run_training(self) -> None:
"""
The main training loop. It creates the Pytorch model based on the configuration options passed in,
@ -278,9 +297,9 @@ class MLRunner:
Run validation on the validation set for all models to save time/memory consuming outputs.
"""
self.container.on_run_extra_validation_epoch()
trainer = self.get_trainer_for_inference(checkpoint_path=None)
with change_working_directory(self.container.outputs_folder):
assert self.trainer, "Trainer should be initialized before validation. Call self.init_training() first."
self.trainer.validate(self.container.model, datamodule=self.data_module)
trainer.validate(self.container.model, datamodule=self.data_module)
def validate_model_weights(self) -> None:
logging.info("Validating model weights.")
@ -300,27 +319,10 @@ class MLRunner:
if self.container.has_custom_test_step():
# Run Lightning's built-in test procedure if the `test_step` method has been overridden
logging.info("Running inference via the LightningModule.test_step method")
# We run inference on a single device because distributed strategies such as DDP use DistributedSampler
# internally, which replicates some samples to make sure all devices have some batch size in case of
# uneven inputs.
self.container.max_num_gpus = 1
checkpoint_path = (
self.checkpoint_handler.get_checkpoint_to_test() if self.container.run_inference_only else None
)
if self.container.run_inference_only:
assert checkpoint_path is not None
else:
self.validate_model_weights()
trainer, _ = create_lightning_trainer(
container=self.container,
resume_from_checkpoint=checkpoint_path,
num_nodes=1,
azureml_run_for_logging=self.azureml_run_for_logging
)
trainer = self.get_trainer_for_inference(checkpoint_path)
# Change to the outputs folder so that the model can write to current working directory, and still
# everything is put into the right place in AzureML (there, only the contents of the "outputs" folder
# retained)
@ -382,6 +384,9 @@ class MLRunner:
with logging_section("Model training"):
self.run_training()
# Kill all processes besides rank 0
self.after_ddp_cleanup(old_environ)
# load model checkpoint for custom inference or additional validation step
if self.container.has_custom_test_step() or self.container.run_extra_val_epoch:
self.load_model_checkpoint()
@ -391,9 +396,6 @@ class MLRunner:
with logging_section("Model Validation to save plots on validation set"):
self.run_validation()
# Kill all processes besides rank 0
self.after_ddp_cleanup(old_environ)
# Run inference on a single device
with logging_section("Model inference"):
self.run_inference()

Просмотреть файл

@ -181,7 +181,9 @@ def test_run_validation(run_extra_val_epoch: bool) -> None:
mock_trainer.validate = Mock()
if run_extra_val_epoch:
with patch.object(runner, "validate_model_weights") as mock_validate_model_weights:
runner.run_validation()
mock_validate_model_weights.assert_called_once()
assert mock_on_run_extra_validation_epoch.called == run_extra_val_epoch
assert hasattr(container.model, "on_run_extra_validation_epoch")
@ -207,8 +209,9 @@ def test_model_extra_val_epoch(run_extra_val_epoch: bool) -> None:
mock_trainer.validate = Mock()
if run_extra_val_epoch:
with patch.object(runner, "validate_model_weights") as mock_validate_model_weights:
runner.run_validation()
mock_validate_model_weights.assert_called_once()
assert mock_on_run_extra_validation_epoch.called == run_extra_val_epoch
assert mock_trainer.validate.called == run_extra_val_epoch
@ -229,7 +232,9 @@ def test_model_extra_val_epoch_missing_hook(caplog: LogCaptureFixture) -> None:
runner.setup()
mock_create_trainer.return_value = MagicMock(), MagicMock()
runner.init_training()
with patch.object(runner, "validate_model_weights") as mock_validate_model_weights:
runner.run_validation()
mock_validate_model_weights.assert_called_once()
latest_message = caplog.records[-1].getMessage()
assert "Hook `on_run_extra_validation_epoch` is not implemented by lightning module." in latest_message