зеркало из https://github.com/microsoft/hi-ml.git
BUG: Fix extra validation epoch hook call order (#873)
Fix the order of hook call. on_run_extra_val_epoch needs to be run before super().init_inference to account for any changes in datamodule params before re-instantiation
This commit is contained in:
Родитель
5746042f68
Коммит
122cd12215
|
@ -203,6 +203,7 @@ class RunnerBase:
|
|||
3. Create a new data module instance for inference to account for any requested changes in the dataloading
|
||||
parameters (e.g. batch_size, max_num_workers, etc) as part of on_run_extra_validation_epoch.
|
||||
"""
|
||||
logging.info("Preparing runner for inference.")
|
||||
self.inference_checkpoint = str(self.checkpoint_handler.get_checkpoint_to_test())
|
||||
self.set_trainer_for_inference()
|
||||
self.data_module = self.get_data_module()
|
||||
|
|
|
@ -164,12 +164,14 @@ class TrainingRunner(RunnerBase):
|
|||
def init_inference(self) -> None:
|
||||
"""
|
||||
Prepare the trainer for running inference on the validation and test set. This chooses a checkpoint,
|
||||
initializes the PL Trainer object, and chooses the right data module. Afterwards, the hook for running
|
||||
inference on the validation set is run (`LightningContainer.on_run_extra_validation_epoch`)
|
||||
initializes the PL Trainer object, and chooses the right data module. The hook for running
|
||||
inference on the validation set is run (`LightningContainer.on_run_extra_validation_epoch`) is first called to
|
||||
reflect any changes to the model or datamodule states before running inference.
|
||||
"""
|
||||
super().init_inference()
|
||||
if self.container.run_extra_val_epoch:
|
||||
logging.info("Preparing to run an extra validation epoch to evaluate the model on the validation set.")
|
||||
self.container.on_run_extra_validation_epoch()
|
||||
super().init_inference()
|
||||
|
||||
def run_validation(self) -> None:
|
||||
"""Run validation on the validation set for all models to save time/memory consuming outputs. This is done in
|
||||
|
|
|
@ -291,6 +291,7 @@ def test_init_inference(
|
|||
run_extra_val_epoch: bool,
|
||||
max_num_gpus_inf: int,
|
||||
training_runner_hello_world_with_checkpoint: TrainingRunner,
|
||||
caplog: LogCaptureFixture,
|
||||
) -> None:
|
||||
training_runner_hello_world_with_checkpoint.container.run_inference_only = run_inference_only
|
||||
training_runner_hello_world_with_checkpoint.container.run_extra_val_epoch = run_extra_val_epoch
|
||||
|
@ -324,6 +325,12 @@ def test_init_inference(
|
|||
assert not training_runner_hello_world_with_checkpoint.container.model._on_extra_val_epoch
|
||||
|
||||
training_runner_hello_world_with_checkpoint.init_inference()
|
||||
if run_extra_val_epoch:
|
||||
assert (
|
||||
caplog.messages[-3]
|
||||
== "Preparing to run an extra validation epoch to evaluate the model on the validation set."
|
||||
)
|
||||
assert caplog.messages[-2] == "Preparing runner for inference."
|
||||
|
||||
expected_ckpt = str(training_runner_hello_world_with_checkpoint.checkpoint_handler.trained_weights_path)
|
||||
expected_ckpt = expected_ckpt if run_inference_only else str(mock_checkpoint)
|
||||
|
@ -405,8 +412,7 @@ def test_model_extra_val_epoch_missing_hook(caplog: LogCaptureFixture) -> None:
|
|||
mock_get_checkpoint_to_test.return_value = MagicMock(is_file=MagicMock(return_value=True))
|
||||
runner.init_inference()
|
||||
runner.run_validation()
|
||||
latest_message = caplog.records[-1].getMessage()
|
||||
assert "Hook `on_run_extra_validation_epoch` is not implemented" in latest_message
|
||||
assert "Hook `on_run_extra_validation_epoch` is not implemented" in caplog.messages[-3]
|
||||
|
||||
|
||||
def test_run_inference(training_runner_hello_world: TrainingRunner, regression_datadir: Path) -> None:
|
||||
|
|
Загрузка…
Ссылка в новой задаче