зеркало из https://github.com/microsoft/hi-ml.git
ENH: Run extra validation epoch on a single gpu (#639)
Kill DDP processes after run_training and initialize a new trainer instance with a single device for inference Co-authored-by: Anton Schwaighofer <antonsc@microsoft.com>
This commit is contained in:
Родитель
048f9b2aa9
Коммит
84e37efc59
|
@ -257,6 +257,25 @@ class MLRunner:
|
|||
return self.container.crossval_index == 0
|
||||
return True
|
||||
|
||||
def get_trainer_for_inference(self, checkpoint_path: Optional[Path] = None) -> Trainer:
|
||||
# We run inference on a single device because distributed strategies such as DDP use DistributedSampler
|
||||
# internally, which replicates some samples to make sure all devices have the same batch size in case of
|
||||
# uneven inputs.
|
||||
self.container.max_num_gpus = 1
|
||||
|
||||
if self.container.run_inference_only:
|
||||
assert checkpoint_path is not None
|
||||
else:
|
||||
self.validate_model_weights()
|
||||
|
||||
trainer, _ = create_lightning_trainer(
|
||||
container=self.container,
|
||||
resume_from_checkpoint=checkpoint_path,
|
||||
num_nodes=1,
|
||||
azureml_run_for_logging=self.azureml_run_for_logging
|
||||
)
|
||||
return trainer
|
||||
|
||||
def run_training(self) -> None:
|
||||
"""
|
||||
The main training loop. It creates the Pytorch model based on the configuration options passed in,
|
||||
|
@ -278,9 +297,9 @@ class MLRunner:
|
|||
Run validation on the validation set for all models to save time/memory consuming outputs.
|
||||
"""
|
||||
self.container.on_run_extra_validation_epoch()
|
||||
trainer = self.get_trainer_for_inference(checkpoint_path=None)
|
||||
with change_working_directory(self.container.outputs_folder):
|
||||
assert self.trainer, "Trainer should be initialized before validation. Call self.init_training() first."
|
||||
self.trainer.validate(self.container.model, datamodule=self.data_module)
|
||||
trainer.validate(self.container.model, datamodule=self.data_module)
|
||||
|
||||
def validate_model_weights(self) -> None:
|
||||
logging.info("Validating model weights.")
|
||||
|
@ -300,27 +319,10 @@ class MLRunner:
|
|||
if self.container.has_custom_test_step():
|
||||
# Run Lightning's built-in test procedure if the `test_step` method has been overridden
|
||||
logging.info("Running inference via the LightningModule.test_step method")
|
||||
# We run inference on a single device because distributed strategies such as DDP use DistributedSampler
|
||||
# internally, which replicates some samples to make sure all devices have some batch size in case of
|
||||
# uneven inputs.
|
||||
self.container.max_num_gpus = 1
|
||||
|
||||
checkpoint_path = (
|
||||
self.checkpoint_handler.get_checkpoint_to_test() if self.container.run_inference_only else None
|
||||
)
|
||||
|
||||
if self.container.run_inference_only:
|
||||
assert checkpoint_path is not None
|
||||
else:
|
||||
self.validate_model_weights()
|
||||
|
||||
trainer, _ = create_lightning_trainer(
|
||||
container=self.container,
|
||||
resume_from_checkpoint=checkpoint_path,
|
||||
num_nodes=1,
|
||||
azureml_run_for_logging=self.azureml_run_for_logging
|
||||
)
|
||||
|
||||
trainer = self.get_trainer_for_inference(checkpoint_path)
|
||||
# Change to the outputs folder so that the model can write to current working directory, and still
|
||||
# everything is put into the right place in AzureML (there, only the contents of the "outputs" folder
|
||||
# retained)
|
||||
|
@ -382,6 +384,9 @@ class MLRunner:
|
|||
with logging_section("Model training"):
|
||||
self.run_training()
|
||||
|
||||
# Kill all processes besides rank 0
|
||||
self.after_ddp_cleanup(old_environ)
|
||||
|
||||
# load model checkpoint for custom inference or additional validation step
|
||||
if self.container.has_custom_test_step() or self.container.run_extra_val_epoch:
|
||||
self.load_model_checkpoint()
|
||||
|
@ -391,9 +396,6 @@ class MLRunner:
|
|||
with logging_section("Model Validation to save plots on validation set"):
|
||||
self.run_validation()
|
||||
|
||||
# Kill all processes besides rank 0
|
||||
self.after_ddp_cleanup(old_environ)
|
||||
|
||||
# Run inference on a single device
|
||||
with logging_section("Model inference"):
|
||||
self.run_inference()
|
||||
|
|
|
@ -181,7 +181,9 @@ def test_run_validation(run_extra_val_epoch: bool) -> None:
|
|||
mock_trainer.validate = Mock()
|
||||
|
||||
if run_extra_val_epoch:
|
||||
with patch.object(runner, "validate_model_weights") as mock_validate_model_weights:
|
||||
runner.run_validation()
|
||||
mock_validate_model_weights.assert_called_once()
|
||||
|
||||
assert mock_on_run_extra_validation_epoch.called == run_extra_val_epoch
|
||||
assert hasattr(container.model, "on_run_extra_validation_epoch")
|
||||
|
@ -207,8 +209,9 @@ def test_model_extra_val_epoch(run_extra_val_epoch: bool) -> None:
|
|||
mock_trainer.validate = Mock()
|
||||
|
||||
if run_extra_val_epoch:
|
||||
with patch.object(runner, "validate_model_weights") as mock_validate_model_weights:
|
||||
runner.run_validation()
|
||||
|
||||
mock_validate_model_weights.assert_called_once()
|
||||
assert mock_on_run_extra_validation_epoch.called == run_extra_val_epoch
|
||||
assert mock_trainer.validate.called == run_extra_val_epoch
|
||||
|
||||
|
@ -229,7 +232,9 @@ def test_model_extra_val_epoch_missing_hook(caplog: LogCaptureFixture) -> None:
|
|||
runner.setup()
|
||||
mock_create_trainer.return_value = MagicMock(), MagicMock()
|
||||
runner.init_training()
|
||||
with patch.object(runner, "validate_model_weights") as mock_validate_model_weights:
|
||||
runner.run_validation()
|
||||
mock_validate_model_weights.assert_called_once()
|
||||
latest_message = caplog.records[-1].getMessage()
|
||||
assert "Hook `on_run_extra_validation_epoch` is not implemented by lightning module." in latest_message
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче