BUG: Rename pretraining_run_checkpoints (#841)
"Parameter `extra_downloaded_run_id` has been renamed to `pretraining_run_checkpoints`" (According to CHANGELOG.md) but is still in used in SSLClassifierContainer class (in InnerEye/ML/SSL/lightning_containers/ssl_image_classifier.py). It will raise an AttributeError: 'SSLClassifierCIFAR' object has no attribute 'extra_downloaded_run_id' when trying to run python InnerEyeML/runner.py --model=CXRImageClassifier --pretraining_run_recovery_id={THE_ID_TO_YOUR_SSL_TRAINING_JOB}. So renamed it to `pretraining_run_checkpoints` there too. <!-- ## Guidelines Please follow the guidelines for pull requests (PRs) in [CONTRIBUTING](/docs/contributing.md). Checklist: - Ensure that your PR is small, and implements one change - Give your PR title one of the prefixes ENH, BUG, STYLE, DOC, DEL to indicate what type of change that is (see [CONTRIBUTING](/docs/contributing.md)) - Link the correct GitHub issue for tracking - Add unit tests for all functions that you introduced or modified - Run automatic code formatting / linting on all files ("Format Document" Shift-Alt-F in VSCode) ## Change the default merge message When completing your PR, you will be asked for a title and an optional extended description. By default, the extended description will be a concatenation of the individual commit messages. Please DELETE/REPLACE that with a human readable extended description for non-trivial PRs. -->
This commit is contained in:
Родитель
d902e02fc6
Коммит
03b4cc3f2a
|
@ -35,12 +35,12 @@ class SSLClassifierContainer(SSLContainer):
|
|||
This method must create the actual Lightning model that will be trained.
|
||||
"""
|
||||
if self.local_ssl_weights_path is None:
|
||||
assert self.extra_downloaded_run_id is not None
|
||||
assert self.pretraining_run_checkpoints is not None
|
||||
try:
|
||||
path_to_checkpoint = self.extra_downloaded_run_id.get_best_checkpoint_paths()
|
||||
path_to_checkpoint = self.pretraining_run_checkpoints.get_best_checkpoint_paths()
|
||||
except FileNotFoundError:
|
||||
logging.info("Best checkpoint not found - using last recovery checkpoint instead")
|
||||
path_to_checkpoint = self.extra_downloaded_run_id.get_recovery_checkpoint_paths()
|
||||
path_to_checkpoint = self.pretraining_run_checkpoints.get_recovery_checkpoint_paths()
|
||||
path_to_checkpoint = path_to_checkpoint[0] # type: ignore
|
||||
else:
|
||||
path_to_checkpoint = self.local_ssl_weights_path
|
||||
|
|
Загрузка…
Ссылка в новой задаче