зеркало из https://github.com/microsoft/hi-ml.git
ENH: Validate model weights before inference (#635)
Add extra validation of model weights with checkpoint
This commit is contained in:
Родитель
7e7d3751bf
Коммит
fa4e0984d0
|
@ -164,7 +164,7 @@ def test_get_dataset() -> None:
|
|||
Test if a dataset that does not yet exist can be created from a folder in blob storage
|
||||
"""
|
||||
# A folder with a single tiny file
|
||||
tiny_dataset = "himl-tiny_dataset"
|
||||
tiny_dataset = "himl_tiny_dataset"
|
||||
workspace = DEFAULT_WORKSPACE.workspace
|
||||
# When creating a dataset, we need a non-empty name
|
||||
with pytest.raises(ValueError) as ex:
|
||||
|
|
|
@ -282,6 +282,16 @@ class MLRunner:
|
|||
assert self.trainer, "Trainer should be initialized before validation. Call self.init_training() first."
|
||||
self.trainer.validate(self.container.model, datamodule=self.data_module)
|
||||
|
||||
def validate_model_weights(self) -> None:
|
||||
logging.info("Validating model weights.")
|
||||
weights = torch.load(self.checkpoint_handler.get_checkpoint_to_test())["state_dict"]
|
||||
number_mismatch = 0
|
||||
for name, param in self.container.model.named_parameters():
|
||||
if not torch.allclose(weights[name].cpu(), param):
|
||||
logging.warning(f"Parameter {name} does not match between model and checkpoint.")
|
||||
number_mismatch += 1
|
||||
logging.info(f"Number of mismatched parameters: {number_mismatch}")
|
||||
|
||||
def run_inference(self) -> None:
|
||||
"""
|
||||
Run inference on the test set for all models.
|
||||
|
@ -296,8 +306,14 @@ class MLRunner:
|
|||
self.container.max_num_gpus = 1
|
||||
|
||||
checkpoint_path = (
|
||||
self.checkpoint_handler.get_checkpoint_to_test() if self.container.src_checkpoint else None
|
||||
self.checkpoint_handler.get_checkpoint_to_test() if self.container.run_inference_only else None
|
||||
)
|
||||
|
||||
if self.container.run_inference_only:
|
||||
assert checkpoint_path is not None
|
||||
else:
|
||||
self.validate_model_weights()
|
||||
|
||||
trainer, _ = create_lightning_trainer(
|
||||
container=self.container,
|
||||
resume_from_checkpoint=checkpoint_path,
|
||||
|
|
|
@ -266,7 +266,9 @@ def test_run_inference(ml_runner_with_container: MLRunner, tmp_path: Path) -> No
|
|||
|
||||
actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train()
|
||||
assert actual_train_ckpt_path is None
|
||||
ml_runner_with_container.run()
|
||||
with patch.object(ml_runner_with_container, "validate_model_weights") as mock_validate_model_weights:
|
||||
ml_runner_with_container.run()
|
||||
mock_validate_model_weights.assert_called_once()
|
||||
actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train()
|
||||
assert actual_train_ckpt_path == expected_ckpt_path
|
||||
|
||||
|
@ -315,7 +317,12 @@ def test_run_inference_only(ml_runner_with_run_id: MLRunner) -> None:
|
|||
ml_runner_with_run_id.container.run_inference_only = True
|
||||
assert ml_runner_with_run_id.checkpoint_handler.trained_weights_path
|
||||
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
|
||||
with patch.multiple(ml_runner_with_run_id, run_training=DEFAULT, run_validation=DEFAULT) as mocks:
|
||||
with patch.multiple(
|
||||
ml_runner_with_run_id,
|
||||
run_training=DEFAULT,
|
||||
run_validation=DEFAULT,
|
||||
validate_model_weights=DEFAULT
|
||||
) as mocks:
|
||||
mock_trainer = MagicMock()
|
||||
mock_create_trainer.return_value = mock_trainer, MagicMock()
|
||||
ml_runner_with_run_id.run()
|
||||
|
@ -324,6 +331,7 @@ def test_run_inference_only(ml_runner_with_run_id: MLRunner) -> None:
|
|||
assert recovery_checkpoint == ml_runner_with_run_id.checkpoint_handler.trained_weights_path
|
||||
mocks["run_training"].assert_not_called()
|
||||
mocks["run_validation"].assert_not_called()
|
||||
mocks["validate_model_weights"].assert_not_called()
|
||||
mock_trainer.test.assert_called_once()
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче