diff --git a/hi-ml/src/health_ml/run_ml.py b/hi-ml/src/health_ml/run_ml.py index 28756c4c..4ab98eaa 100644 --- a/hi-ml/src/health_ml/run_ml.py +++ b/hi-ml/src/health_ml/run_ml.py @@ -257,6 +257,25 @@ class MLRunner: return self.container.crossval_index == 0 return True + def get_trainer_for_inference(self, checkpoint_path: Optional[Path] = None) -> Trainer: + # We run inference on a single device because distributed strategies such as DDP use DistributedSampler + # internally, which replicates some samples to make sure all devices have the same batch size in case of + # uneven inputs. + self.container.max_num_gpus = 1 + + if self.container.run_inference_only: + assert checkpoint_path is not None + else: + self.validate_model_weights() + + trainer, _ = create_lightning_trainer( + container=self.container, + resume_from_checkpoint=checkpoint_path, + num_nodes=1, + azureml_run_for_logging=self.azureml_run_for_logging + ) + return trainer + def run_training(self) -> None: """ The main training loop. It creates the Pytorch model based on the configuration options passed in, @@ -278,9 +297,9 @@ class MLRunner: Run validation on the validation set for all models to save time/memory consuming outputs. """ self.container.on_run_extra_validation_epoch() + trainer = self.get_trainer_for_inference(checkpoint_path=None) with change_working_directory(self.container.outputs_folder): - assert self.trainer, "Trainer should be initialized before validation. Call self.init_training() first." - self.trainer.validate(self.container.model, datamodule=self.data_module) + trainer.validate(self.container.model, datamodule=self.data_module) def validate_model_weights(self) -> None: logging.info("Validating model weights.") @@ -300,27 +319,10 @@ class MLRunner: if self.container.has_custom_test_step(): # Run Lightning's built-in test procedure if the `test_step` method has been overridden logging.info("Running inference via the LightningModule.test_step method") - # We run inference on a single device because distributed strategies such as DDP use DistributedSampler - # internally, which replicates some samples to make sure all devices have some batch size in case of - # uneven inputs. - self.container.max_num_gpus = 1 - checkpoint_path = ( self.checkpoint_handler.get_checkpoint_to_test() if self.container.run_inference_only else None ) - - if self.container.run_inference_only: - assert checkpoint_path is not None - else: - self.validate_model_weights() - - trainer, _ = create_lightning_trainer( - container=self.container, - resume_from_checkpoint=checkpoint_path, - num_nodes=1, - azureml_run_for_logging=self.azureml_run_for_logging - ) - + trainer = self.get_trainer_for_inference(checkpoint_path) # Change to the outputs folder so that the model can write to current working directory, and still # everything is put into the right place in AzureML (there, only the contents of the "outputs" folder # retained) @@ -382,6 +384,9 @@ class MLRunner: with logging_section("Model training"): self.run_training() + # Kill all processes besides rank 0 + self.after_ddp_cleanup(old_environ) + # load model checkpoint for custom inference or additional validation step if self.container.has_custom_test_step() or self.container.run_extra_val_epoch: self.load_model_checkpoint() @@ -391,9 +396,6 @@ class MLRunner: with logging_section("Model Validation to save plots on validation set"): self.run_validation() - # Kill all processes besides rank 0 - self.after_ddp_cleanup(old_environ) - # Run inference on a single device with logging_section("Model inference"): self.run_inference() diff --git a/hi-ml/testhiml/testhiml/test_run_ml.py b/hi-ml/testhiml/testhiml/test_run_ml.py index 31e5515f..5394120d 100644 --- a/hi-ml/testhiml/testhiml/test_run_ml.py +++ b/hi-ml/testhiml/testhiml/test_run_ml.py @@ -181,7 +181,9 @@ def test_run_validation(run_extra_val_epoch: bool) -> None: mock_trainer.validate = Mock() if run_extra_val_epoch: - runner.run_validation() + with patch.object(runner, "validate_model_weights") as mock_validate_model_weights: + runner.run_validation() + mock_validate_model_weights.assert_called_once() assert mock_on_run_extra_validation_epoch.called == run_extra_val_epoch assert hasattr(container.model, "on_run_extra_validation_epoch") @@ -207,8 +209,9 @@ def test_model_extra_val_epoch(run_extra_val_epoch: bool) -> None: mock_trainer.validate = Mock() if run_extra_val_epoch: - runner.run_validation() - + with patch.object(runner, "validate_model_weights") as mock_validate_model_weights: + runner.run_validation() + mock_validate_model_weights.assert_called_once() assert mock_on_run_extra_validation_epoch.called == run_extra_val_epoch assert mock_trainer.validate.called == run_extra_val_epoch @@ -229,7 +232,9 @@ def test_model_extra_val_epoch_missing_hook(caplog: LogCaptureFixture) -> None: runner.setup() mock_create_trainer.return_value = MagicMock(), MagicMock() runner.init_training() - runner.run_validation() + with patch.object(runner, "validate_model_weights") as mock_validate_model_weights: + runner.run_validation() + mock_validate_model_weights.assert_called_once() latest_message = caplog.records[-1].getMessage() assert "Hook `on_run_extra_validation_epoch` is not implemented by lightning module." in latest_message