BUG: Fix extra validation epoch hook call order (#873)

Fix the order of hook call. on_run_extra_val_epoch needs to be run
before super().init_inference to account for any changes in datamodule
params before re-instantiation
This commit is contained in:
Kenza Bouzid 2023-04-18 13:01:38 +01:00 коммит произвёл GitHub
Родитель 5746042f68
Коммит 122cd12215
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 14 добавлений и 5 удалений

Просмотреть файл

@ -203,6 +203,7 @@ class RunnerBase:
3. Create a new data module instance for inference to account for any requested changes in the dataloading
parameters (e.g. batch_size, max_num_workers, etc) as part of on_run_extra_validation_epoch.
"""
logging.info("Preparing runner for inference.")
self.inference_checkpoint = str(self.checkpoint_handler.get_checkpoint_to_test())
self.set_trainer_for_inference()
self.data_module = self.get_data_module()

Просмотреть файл

@ -164,12 +164,14 @@ class TrainingRunner(RunnerBase):
def init_inference(self) -> None:
"""
Prepare the trainer for running inference on the validation and test set. This chooses a checkpoint,
initializes the PL Trainer object, and chooses the right data module. Afterwards, the hook for running
inference on the validation set is run (`LightningContainer.on_run_extra_validation_epoch`)
initializes the PL Trainer object, and chooses the right data module. The hook for running
inference on the validation set is run (`LightningContainer.on_run_extra_validation_epoch`) is first called to
reflect any changes to the model or datamodule states before running inference.
"""
super().init_inference()
if self.container.run_extra_val_epoch:
logging.info("Preparing to run an extra validation epoch to evaluate the model on the validation set.")
self.container.on_run_extra_validation_epoch()
super().init_inference()
def run_validation(self) -> None:
"""Run validation on the validation set for all models to save time/memory consuming outputs. This is done in

Просмотреть файл

@ -291,6 +291,7 @@ def test_init_inference(
run_extra_val_epoch: bool,
max_num_gpus_inf: int,
training_runner_hello_world_with_checkpoint: TrainingRunner,
caplog: LogCaptureFixture,
) -> None:
training_runner_hello_world_with_checkpoint.container.run_inference_only = run_inference_only
training_runner_hello_world_with_checkpoint.container.run_extra_val_epoch = run_extra_val_epoch
@ -324,6 +325,12 @@ def test_init_inference(
assert not training_runner_hello_world_with_checkpoint.container.model._on_extra_val_epoch
training_runner_hello_world_with_checkpoint.init_inference()
if run_extra_val_epoch:
assert (
caplog.messages[-3]
== "Preparing to run an extra validation epoch to evaluate the model on the validation set."
)
assert caplog.messages[-2] == "Preparing runner for inference."
expected_ckpt = str(training_runner_hello_world_with_checkpoint.checkpoint_handler.trained_weights_path)
expected_ckpt = expected_ckpt if run_inference_only else str(mock_checkpoint)
@ -405,8 +412,7 @@ def test_model_extra_val_epoch_missing_hook(caplog: LogCaptureFixture) -> None:
mock_get_checkpoint_to_test.return_value = MagicMock(is_file=MagicMock(return_value=True))
runner.init_inference()
runner.run_validation()
latest_message = caplog.records[-1].getMessage()
assert "Hook `on_run_extra_validation_epoch` is not implemented" in latest_message
assert "Hook `on_run_extra_validation_epoch` is not implemented" in caplog.messages[-3]
def test_run_inference(training_runner_hello_world: TrainingRunner, regression_datadir: Path) -> None: