ENH: Validate model weights before inference (#635)

Add extra validation of model weights with checkpoint
This commit is contained in:
Kenza Bouzid 2022-10-19 14:29:10 +01:00 коммит произвёл GitHub
Родитель 7e7d3751bf
Коммит fa4e0984d0
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 28 добавлений и 4 удалений

Просмотреть файл

@ -164,7 +164,7 @@ def test_get_dataset() -> None:
Test if a dataset that does not yet exist can be created from a folder in blob storage
"""
# A folder with a single tiny file
tiny_dataset = "himl-tiny_dataset"
tiny_dataset = "himl_tiny_dataset"
workspace = DEFAULT_WORKSPACE.workspace
# When creating a dataset, we need a non-empty name
with pytest.raises(ValueError) as ex:

Просмотреть файл

@ -282,6 +282,16 @@ class MLRunner:
assert self.trainer, "Trainer should be initialized before validation. Call self.init_training() first."
self.trainer.validate(self.container.model, datamodule=self.data_module)
def validate_model_weights(self) -> None:
logging.info("Validating model weights.")
weights = torch.load(self.checkpoint_handler.get_checkpoint_to_test())["state_dict"]
number_mismatch = 0
for name, param in self.container.model.named_parameters():
if not torch.allclose(weights[name].cpu(), param):
logging.warning(f"Parameter {name} does not match between model and checkpoint.")
number_mismatch += 1
logging.info(f"Number of mismatched parameters: {number_mismatch}")
def run_inference(self) -> None:
"""
Run inference on the test set for all models.
@ -296,8 +306,14 @@ class MLRunner:
self.container.max_num_gpus = 1
checkpoint_path = (
self.checkpoint_handler.get_checkpoint_to_test() if self.container.src_checkpoint else None
self.checkpoint_handler.get_checkpoint_to_test() if self.container.run_inference_only else None
)
if self.container.run_inference_only:
assert checkpoint_path is not None
else:
self.validate_model_weights()
trainer, _ = create_lightning_trainer(
container=self.container,
resume_from_checkpoint=checkpoint_path,

Просмотреть файл

@ -266,7 +266,9 @@ def test_run_inference(ml_runner_with_container: MLRunner, tmp_path: Path) -> No
actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train()
assert actual_train_ckpt_path is None
ml_runner_with_container.run()
with patch.object(ml_runner_with_container, "validate_model_weights") as mock_validate_model_weights:
ml_runner_with_container.run()
mock_validate_model_weights.assert_called_once()
actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train()
assert actual_train_ckpt_path == expected_ckpt_path
@ -315,7 +317,12 @@ def test_run_inference_only(ml_runner_with_run_id: MLRunner) -> None:
ml_runner_with_run_id.container.run_inference_only = True
assert ml_runner_with_run_id.checkpoint_handler.trained_weights_path
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
with patch.multiple(ml_runner_with_run_id, run_training=DEFAULT, run_validation=DEFAULT) as mocks:
with patch.multiple(
ml_runner_with_run_id,
run_training=DEFAULT,
run_validation=DEFAULT,
validate_model_weights=DEFAULT
) as mocks:
mock_trainer = MagicMock()
mock_create_trainer.return_value = mock_trainer, MagicMock()
ml_runner_with_run_id.run()
@ -324,6 +331,7 @@ def test_run_inference_only(ml_runner_with_run_id: MLRunner) -> None:
assert recovery_checkpoint == ml_runner_with_run_id.checkpoint_handler.trained_weights_path
mocks["run_training"].assert_not_called()
mocks["run_validation"].assert_not_called()
mocks["validate_model_weights"].assert_not_called()
mock_trainer.test.assert_called_once()