ENH: Upgrading package versions for security patches (#757)

* 🚧 💥 Update to secure versions of packages

* ⬆️ Upgrade hi-ml version

* 🏷️ Fix mypy and update windows env

*  Update current epoch value in lightning tests

* 🐛 Fix wnidows line endings

*  Remove checkpoint load epoch check

* 📌 Lock env anew

* 🎨 🐛 Add merge_conda_files() to common_util

*  Add logging to tests

* 🚧 Update VarINetWithImageLogging logger syntax

*  Fix Train2Nodes tests

*  Remove cwd change, update CIFAR SSL metrics

* ⚰️ Remove unnecessary PL backwards compatibility

* 📌 Lock env

* 📌 Upgrade to hi-ml v0.2.5

* 📌 Testing lightning 1.6.5

* ♻️ Resolve PR comments
This commit is contained in:
Peter Hessey 2022-09-14 17:33:17 +01:00 коммит произвёл GitHub
Родитель 7894498635
Коммит 5b21840df4
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
15 изменённых файлов: 210 добавлений и 142 удалений

Просмотреть файл

@ -12,7 +12,14 @@ from contextlib import contextmanager
from enum import Enum
from functools import wraps
from pathlib import Path
from typing import Any, Callable, Generator, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union
import conda_merge
import ruamel.yaml
from health_azure.utils import (
CONDA_CHANNELS, CONDA_DEPENDENCIES, CONDA_NAME, CONDA_PIP, CondaDependencies, PinnedOperator,
_log_conda_dependencies_stats, _retrieve_unique_deps, is_conda_file_with_pip_include, is_pip_include_dependency
)
from InnerEye.Common.fixed_paths import repository_root_directory
from InnerEye.Common.type_annotations import PathOrString
@ -427,3 +434,75 @@ def change_working_directory(path_or_str: PathOrString) -> Generator:
os.chdir(new_path)
yield
os.chdir(old_path)
def merge_conda_files(
conda_files: List[Path],
result_file: Path,
pip_files: Optional[List[Path]] = None,
) -> None:
"""
Merges the given Conda environment files using the conda_merge package, optionally adds any
dependencies from pip requirements files, and writes the merged file to disk.
:param conda_files: The Conda environment files to read.
:param result_file: The location where the merge results should be written.
:param pip_files: An optional list of one or more pip requirements files including extra dependencies.
"""
env_definitions: List[Any] = []
for file in conda_files:
_, pip_without_include = is_conda_file_with_pip_include(file)
env_definitions.append(pip_without_include)
unified_definition = {}
extra_pip_deps = []
for pip_file in pip_files or []:
additional_pip_deps = [d for d in pip_file.read_text().split("\n") if d and not is_pip_include_dependency(d)]
extra_pip_deps.extend(additional_pip_deps)
name = conda_merge.merge_names(env.get(CONDA_NAME) for env in env_definitions)
if name:
unified_definition[CONDA_NAME] = name
try:
channels = conda_merge.merge_channels(env.get(CONDA_CHANNELS) for env in env_definitions)
except conda_merge.MergeError:
logging.error("Failed to merge channel priorities.")
raise
if channels:
unified_definition[CONDA_CHANNELS] = channels
try:
deps_to_merge = [env.get(CONDA_DEPENDENCIES) for env in env_definitions]
if len(extra_pip_deps) > 0:
deps_to_merge.append([{CONDA_PIP: extra_pip_deps}])
deps = conda_merge.merge_dependencies(deps_to_merge)
# Get conda dependencies and pip dependencies from specification
pip_deps_entries = [d for d in deps if isinstance(d, dict) and CONDA_PIP in d] # type: ignore
if len(pip_deps_entries) == 0:
raise ValueError("Didn't find a dictionary with the key 'pip' in the list of dependencies")
pip_deps_entry: Dict[str, List[str]] = pip_deps_entries[0]
pip_deps = pip_deps_entry[CONDA_PIP]
# temporarily remove pip dependencies from deps to be added back after deduplicaton
deps.remove(pip_deps_entry)
# remove all non-pip duplicates from the list of dependencies
unique_deps = _retrieve_unique_deps(deps, PinnedOperator.CONDA)
unique_pip_deps = sorted(_retrieve_unique_deps(pip_deps, PinnedOperator.PIP))
# finally add back the deduplicated list of dependencies
unique_deps.append({CONDA_PIP: unique_pip_deps}) # type: ignore
except conda_merge.MergeError:
logging.error("Failed to merge dependencies.")
raise
if unique_deps:
unified_definition[CONDA_DEPENDENCIES] = unique_deps
else:
raise ValueError("No dependencies found in any of the conda files.")
with result_file.open("w") as f:
ruamel.yaml.dump(unified_definition, f, indent=2, default_flow_style=False)
_log_conda_dependencies_stats(CondaDependencies(result_file), "Merged Conda environment")

Просмотреть файл

@ -97,16 +97,10 @@ class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
p=self.drop_p,
n_hidden=self.hidden_dim)
self.evaluator.to(pl_module.device)
if hasattr(trainer, "accelerator_connector"):
# This works with Lightning 1.3.8
accelerator = trainer.accelerator_connector
elif hasattr(trainer, "_accelerator_connector"):
# This works with Lightning 1.5.5
accelerator = trainer._accelerator_connector
else:
raise ValueError("Unable to retrieve the accelerator information")
accelerator = trainer._accelerator_connector
if accelerator.is_distributed:
if accelerator.use_ddp:
if accelerator.strategy.strategy_name == "ddp":
self.evaluator = SyncBatchNorm.convert_sync_batchnorm(self.evaluator)
self.evaluator = DistributedDataParallel(self.evaluator, device_ids=[pl_module.device]) # type: ignore
else:

Просмотреть файл

@ -32,8 +32,8 @@ class VarNetWithImageLogging(VarNetModule):
"""
def log_image(self, name: str, image: torch.Tensor) -> None:
experiments = self.logger.experiment if isinstance(self.logger.experiment, list) \
else [self.logger.experiment]
experiments = self.loggers[0].experiment if isinstance(self.loggers[0].experiment, list) \
else [self.loggers[0].experiment]
for experiment in experiments:
if isinstance(experiment, SummaryWriter):
experiment.add_image(name, image, global_step=self.global_step)

Просмотреть файл

@ -289,7 +289,7 @@ class InnerEyeLightning(LightningModule):
This hook is called at the very end of training. Use that to write the very last set of training and
validation metrics from the StoringLogger to disk.
"""
self.read_epoch_results_from_logger_and_store(epoch=self.current_epoch)
self.read_epoch_results_from_logger_and_store(epoch=self.current_epoch-1)
@rank_zero_only
def read_epoch_results_from_logger_and_store(self, epoch: int) -> None:

Просмотреть файл

@ -120,7 +120,7 @@ def create_lightning_trainer(container: LightningContainer,
save_top_k=0)
recovery_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder),
filename=AUTOSAVE_CHECKPOINT_FILE_NAME,
every_n_val_epochs=container.autosave_every_n_val_epochs,
every_n_epochs=container.autosave_every_n_val_epochs,
save_last=False)
callbacks: List[Callback] = [
last_checkpoint_callback,
@ -264,11 +264,10 @@ def model_train(checkpoint_path: Optional[Path],
lightning_model.storing_logger = storing_logger
logging.info("Starting training")
# When training models that are not built-in InnerEye models, we have no guarantee that they write
# files to the right folder. Best guess is to change the current working directory to where files should go.
with change_working_directory(container.outputs_folder):
trainer.fit(lightning_model, datamodule=data_module)
trainer.logger.close() # type: ignore
trainer.fit(lightning_model, datamodule=data_module)
trainer.logger.close() # type: ignore
world_size = getattr(trainer, "world_size", 0)
is_azureml_run = not is_offline_run_context(RUN_CONTEXT)
# Per-subject model outputs for regression models are written per rank, and need to be aggregated here.

Просмотреть файл

@ -15,6 +15,8 @@ import stopit
import torch.multiprocessing
from azureml._restclient.constants import RunStatus
from azureml.core import Model, Run, model
from health_azure import AzureRunInfo
from health_azure.utils import ENVIRONMENT_VERSION, create_run_recovery_id, is_global_rank_zero
from pytorch_lightning import LightningModule, seed_everything
from pytorch_lightning.core.datamodule import LightningDataModule
from torch.utils.data import DataLoader
@ -22,26 +24,28 @@ from torch.utils.data import DataLoader
from InnerEye.Azure import azure_util
from InnerEye.Azure.azure_config import AzureConfig
from InnerEye.Azure.azure_runner import ENV_OMPI_COMM_WORLD_RANK, get_git_tags
from InnerEye.Azure.azure_util import CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, DEFAULT_CROSS_VALIDATION_SPLIT_INDEX, \
EFFECTIVE_RANDOM_SEED_KEY_NAME, IS_ENSEMBLE_KEY_NAME, MODEL_ID_KEY_NAME, PARENT_RUN_CONTEXT, \
PARENT_RUN_ID_KEY_NAME, RUN_CONTEXT, RUN_RECOVERY_FROM_ID_KEY_NAME, RUN_RECOVERY_ID_KEY_NAME, \
get_all_environment_files, is_offline_run_context
from InnerEye.Azure.azure_util import (
CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, DEFAULT_CROSS_VALIDATION_SPLIT_INDEX, EFFECTIVE_RANDOM_SEED_KEY_NAME,
IS_ENSEMBLE_KEY_NAME, MODEL_ID_KEY_NAME, PARENT_RUN_CONTEXT, PARENT_RUN_ID_KEY_NAME, RUN_CONTEXT,
RUN_RECOVERY_FROM_ID_KEY_NAME, RUN_RECOVERY_ID_KEY_NAME, get_all_environment_files, is_offline_run_context
)
from InnerEye.Common import fixed_paths
from InnerEye.Common.common_util import (BASELINE_COMPARISONS_FOLDER, BASELINE_WILCOXON_RESULTS_FILE,
CROSSVAL_RESULTS_FOLDER, ENSEMBLE_SPLIT_NAME, FULL_METRICS_DATAFRAME_FILE,
METRICS_AGGREGATES_FILE, ModelProcessing,
OTHER_RUNS_SUBDIR_NAME, SCATTERPLOTS_SUBDIR_NAME, SUBJECT_METRICS_FILE_NAME,
change_working_directory, get_best_epoch_results_path, is_windows,
logging_section, print_exception, remove_file_or_directory)
from InnerEye.Common.common_util import (
BASELINE_COMPARISONS_FOLDER, BASELINE_WILCOXON_RESULTS_FILE, CROSSVAL_RESULTS_FOLDER, ENSEMBLE_SPLIT_NAME,
FULL_METRICS_DATAFRAME_FILE, METRICS_AGGREGATES_FILE, OTHER_RUNS_SUBDIR_NAME, SCATTERPLOTS_SUBDIR_NAME,
SUBJECT_METRICS_FILE_NAME, ModelProcessing, change_working_directory, get_best_epoch_results_path,
is_windows, logging_section, merge_conda_files, print_exception, remove_file_or_directory
)
from InnerEye.Common.fixed_paths import INNEREYE_PACKAGE_NAME, PYTHON_ENVIRONMENT_NAME
from InnerEye.Common.type_annotations import PathOrString
from InnerEye.ML.baselines_util import compare_folders_and_run_outputs
from InnerEye.ML.common import CHECKPOINT_FOLDER, EXTRA_RUN_SUBFOLDER, FINAL_ENSEMBLE_MODEL_FOLDER, \
FINAL_MODEL_FOLDER, \
ModelExecutionMode
from InnerEye.ML.common import (
CHECKPOINT_FOLDER, EXTRA_RUN_SUBFOLDER, FINAL_ENSEMBLE_MODEL_FOLDER, FINAL_MODEL_FOLDER, ModelExecutionMode
)
from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.deep_learning_config import DeepLearningConfig, ModelCategory, MultiprocessingStartMethod, \
load_checkpoint
from InnerEye.ML.deep_learning_config import (
DeepLearningConfig, ModelCategory, MultiprocessingStartMethod, load_checkpoint
)
from InnerEye.ML.lightning_base import InnerEyeContainer
from InnerEye.ML.lightning_container import InnerEyeInference, LightningContainer
from InnerEye.ML.lightning_loggers import StoringLogger
@ -50,16 +54,16 @@ from InnerEye.ML.model_config_base import ModelConfigBase
from InnerEye.ML.model_inference_config import ModelInferenceConfig
from InnerEye.ML.model_testing import model_test
from InnerEye.ML.model_training import create_lightning_trainer, model_train
from InnerEye.ML.reports.notebook_report import generate_classification_crossval_notebook, \
generate_classification_multilabel_notebook, generate_classification_notebook, generate_segmentation_notebook, \
get_ipynb_report_name, reports_folder
from InnerEye.ML.reports.notebook_report import (
generate_classification_crossval_notebook, generate_classification_multilabel_notebook,
generate_classification_notebook, generate_segmentation_notebook, get_ipynb_report_name, reports_folder
)
from InnerEye.ML.scalar_config import ScalarModelBase
from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler, download_all_checkpoints_from_run
from InnerEye.ML.visualizers import activation_maps
from InnerEye.ML.visualizers.plot_cross_validation import \
from InnerEye.ML.visualizers.plot_cross_validation import (
get_config_and_results_for_offline_runs, plot_cross_validation_from_files
from health_azure import AzureRunInfo
from health_azure.utils import ENVIRONMENT_VERSION, create_run_recovery_id, is_global_rank_zero, merge_conda_files
)
ModelDeploymentHookSignature = Callable[[LightningContainer, AzureConfig, Model, ModelProcessing], Any]
PostCrossValidationHookSignature = Callable[[ModelConfigBase, Path], None]
@ -797,8 +801,10 @@ class MLRunner:
remove_file_or_directory(other_runs_dir)
def plot_cross_validation_and_upload_results(self) -> Path:
from InnerEye.ML.visualizers.plot_cross_validation import crossval_config_from_model_config, \
plot_cross_validation, unroll_aggregate_metrics
from InnerEye.ML.visualizers.plot_cross_validation import (
crossval_config_from_model_config, plot_cross_validation, unroll_aggregate_metrics
)
# perform aggregation as cross val splits are now ready
plot_crossval_config = crossval_config_from_model_config(self.innereye_config)
plot_crossval_config.run_recovery_id = PARENT_RUN_CONTEXT.tags[RUN_RECOVERY_ID_KEY_NAME]

Просмотреть файл

@ -24,36 +24,36 @@ from InnerEye.Common import fixed_paths
# in a submodule
fixed_paths.add_submodules_to_path()
import matplotlib
from azureml._base_sdk_common import user_agent
from azureml._restclient.constants import RunStatus
from azureml.core import Run, ScriptRunConfig
from health_azure import AzureRunInfo, submit_to_azure_if_needed
from health_azure.utils import create_run_recovery_id, is_global_rank_zero, is_local_rank_zero, merge_conda_files, \
to_azure_friendly_string
import matplotlib
from health_azure.utils import create_run_recovery_id, is_global_rank_zero, is_local_rank_zero, to_azure_friendly_string
from InnerEye.Azure.tensorboard_monitor import AMLTensorBoardMonitorConfig, monitor
from InnerEye.Azure import azure_util
from InnerEye.Azure.azure_config import AzureConfig, ParserResult, SourceConfig
from InnerEye.Azure.azure_runner import (DEFAULT_DOCKER_BASE_IMAGE, create_dataset_configs, create_experiment_name,
create_runner_parser,
get_git_tags,
parse_args_and_add_yaml_variables,
parse_arguments, additional_run_tags,
set_environment_variables_for_multi_node)
from InnerEye.Azure.azure_util import (RUN_CONTEXT, RUN_RECOVERY_ID_KEY_NAME, get_all_environment_files,
is_offline_run_context)
from InnerEye.Azure.azure_runner import (
DEFAULT_DOCKER_BASE_IMAGE, additional_run_tags, create_dataset_configs,
create_experiment_name, create_runner_parser, get_git_tags,
parse_args_and_add_yaml_variables, parse_arguments, set_environment_variables_for_multi_node
)
from InnerEye.Azure.azure_util import (
RUN_CONTEXT, RUN_RECOVERY_ID_KEY_NAME, get_all_environment_files, is_offline_run_context
)
from InnerEye.Azure.run_pytest import download_pytest_result, run_pytest
from InnerEye.Common.common_util import (FULL_METRICS_DATAFRAME_FILE, METRICS_AGGREGATES_FILE,
is_linux, logging_to_stdout)
from InnerEye.Azure.tensorboard_monitor import AMLTensorBoardMonitorConfig, monitor
from InnerEye.Common.common_util import (
FULL_METRICS_DATAFRAME_FILE, METRICS_AGGREGATES_FILE, is_linux, logging_to_stdout, merge_conda_files
)
from InnerEye.Common.generic_parsing import GenericConfig
from InnerEye.ML.common import DATASET_CSV_FILE_NAME
from InnerEye.ML.deep_learning_config import DeepLearningConfig
from InnerEye.ML.lightning_base import InnerEyeContainer
from InnerEye.ML.lightning_container import LightningContainer
from InnerEye.ML.model_config_base import ModelConfigBase
from InnerEye.ML.run_ml import MLRunner, ModelDeploymentHookSignature, PostCrossValidationHookSignature
from InnerEye.ML.utils.config_loader import ModelConfigLoader
from InnerEye.ML.lightning_container import LightningContainer
# We change the current working directory before starting the actual training. However, this throws off starting
# the child training threads because sys.argv[0] is a relative path when running in AzureML. Turn that into an absolute

Просмотреть файл

@ -62,7 +62,7 @@ dependencies:
- wheel=0.37.1=pyhd3eb1b0_0
- x264=1!157.20191217=h7b6447c_0
- xz=5.2.5=h7f8727e_1
- zlib=1.2.12=h7f8727e_2
- zlib=1.2.12=h5eee18b_3
- zstd=1.5.2=ha4553b6_0
- pip:
- absl-py==1.2.0
@ -78,7 +78,7 @@ dependencies:
- async-timeout==4.0.2
- attrs==22.1.0
- azure-common==1.1.28
- azure-core==1.25.0
- azure-core==1.25.1
- azure-graphrbac==0.61.1
- azure-identity==1.7.0
- azure-mgmt-authorization==0.61.0
@ -109,7 +109,7 @@ dependencies:
- backports-tempfile==1.0
- backports-weakref==1.0.post1
- beautifulsoup4==4.11.1
- black==22.6.0
- black==22.8.0
- bleach==5.0.1
- cachetools==4.2.4
- cffi==1.15.1
@ -139,33 +139,32 @@ dependencies:
- flake8==3.8.3
- flask==2.2.2
- frozenlist==1.3.1
- fsspec==2022.7.1
- fsspec==2022.8.2
- furo==2022.6.21
- fusepy==3.0.1
- future==0.18.2
- gitdb==4.0.9
- gitpython==3.1.7
- google-auth==1.35.0
- google-auth-oauthlib==0.4.6
- gputil==1.4.0
- greenlet==1.1.3
- grpcio==1.47.0
- grpcio==1.48.1
- gunicorn==20.1.0
- h5py==2.10.0
- hi-ml==0.2.2
- hi-ml-azure==0.2.2
- hi-ml==0.2.5
- hi-ml-azure==0.2.5
- humanize==4.3.0
- idna==3.3
- idna==3.4
- imageio==2.15.0
- imagesize==1.4.1
- importlib-metadata==4.12.0
- importlib-resources==5.9.0
- iniconfig==1.1.1
- innereye-dicom-rt==1.0.3
- ipykernel==6.15.1
- ipykernel==6.15.3
- ipython==7.31.1
- ipython-genutils==0.2.0
- ipywidgets==8.0.1
- ipywidgets==8.0.2
- isodate==0.6.1
- itsdangerous==2.1.2
- jeepney==0.8.0
@ -173,18 +172,18 @@ dependencies:
- jmespath==0.10.0
- joblib==0.16.0
- jsonpickle==2.2.0
- jsonschema==4.14.0
- jsonschema==4.16.0
- jupyter==1.0.0
- jupyter-client==6.1.5
- jupyter-console==6.4.4
- jupyter-core==4.11.1
- jupyterlab-pygments==0.2.2
- jupyterlab-widgets==3.0.2
- jupyterlab-widgets==3.0.3
- kiwisolver==1.4.4
- lightning-bolts==0.4.0
- llvmlite==0.34.0
- lxml==4.9.1
- mako==1.2.1
- mako==1.2.2
- markdown==3.4.1
- markdown-it-py==2.1.0
- markupsafe==2.1.1
@ -207,26 +206,26 @@ dependencies:
- mypy==0.910
- mypy-extensions==0.4.3
- myst-parser==0.18.0
- nbclient==0.6.7
- nbclient==0.6.8
- nbconvert==7.0.0
- nbformat==5.4.0
- nbformat==5.5.0
- ndg-httpsclient==0.5.1
- nest-asyncio==1.5.5
- networkx==2.8.6
- nibabel==4.0.1
- nibabel==4.0.2
- notebook==6.4.12
- numba==0.51.2
- numpy==1.19.1
- oauthlib==3.2.0
- oauthlib==3.2.1
- opencv-python-headless==4.5.1.48
- packaging==21.3
- pandas==1.1.0
- pandocfilters==1.5.0
- papermill==2.2.2
- param==1.9.3
- pathspec==0.9.0
- pathspec==0.10.1
- pexpect==4.8.0
- pillow==9.0.0
- pillow==9.1.1
- pkgutil-resolve-name==1.3.10
- platformdirs==2.5.2
- pluggy==0.13.1
@ -242,10 +241,10 @@ dependencies:
- pyasn1-modules==0.2.8
- pycodestyle==2.6.0
- pycparser==2.21
- pydeprecate==0.3.1
- pydeprecate==0.3.2
- pydicom==2.0.0
- pyflakes==2.2.0
- pyjwt==1.7.1
- pyjwt==2.4.0
- pynndescent==0.5.7
- pyopenssl==20.0.1
- pyparsing==3.0.9
@ -256,7 +255,7 @@ dependencies:
- pytest-forked==1.3.0
- pytest-xdist==1.34.0
- python-dateutil==2.8.2
- pytorch-lightning==1.5.5
- pytorch-lightning==1.6.5
- pytz==2022.2.1
- pywavelets==1.3.0
- pyyaml==6.0
@ -293,7 +292,7 @@ dependencies:
- sphinxcontrib-jsmath==1.0.1
- sphinxcontrib-qthelp==1.0.3
- sphinxcontrib-serializinghtml==1.1.5
- sqlalchemy==1.4.40
- sqlalchemy==1.4.41
- sqlparse==0.4.2
- stopit==1.1.2
- stringcase==1.2.0
@ -312,14 +311,14 @@ dependencies:
- torchio==0.18.74
- torchmetrics==0.6.0
- tornado==6.2
- tqdm==4.64.0
- tqdm==4.64.1
- typing-inspect==0.8.0
- umap-learn==0.5.2
- urllib3==1.26.7
- webencodings==0.5.1
- websocket-client==1.4.0
- websocket-client==1.4.1
- werkzeug==2.2.2
- widgetsnbextension==4.0.2
- widgetsnbextension==4.0.3
- wrapt==1.14.1
- yacs==0.1.8
- yarl==1.8.1

Просмотреть файл

@ -431,8 +431,8 @@ def test_training_2nodes(test_output_dirs: OutputFolderForTests) -> None:
assert training_indicator in log1_txt
# Check diagnostic messages that show if DDP was set up correctly. This could fail if Lightning
# changes its diagnostic outputs.
assert "initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4" in log0_txt
assert "initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4" in log1_txt
assert "Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4" in log0_txt
assert "Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4" in log1_txt
@pytest.mark.skip("The recovery job hangs after completing on AML")

Просмотреть файл

@ -32,7 +32,6 @@ def create_model_and_store_checkpoint(config: ModelConfigBase, checkpoint_path:
"""
Creates a Lightning model for the given model configuration, and stores it as a checkpoint file.
If a GPU is available, the model is moved to the GPU before storing.
The trainer properties `current_epoch` and `global_step` are set to fixed non-default values.
:param config: The model configuration.
:param checkpoint_path: The path and filename of the checkpoint file.
@ -43,10 +42,6 @@ def create_model_and_store_checkpoint(config: ModelConfigBase, checkpoint_path:
if machine_has_gpu:
model = model.cuda() # type: ignore
trainer.model = model
# Before saving, the values for epoch and step are incremented. Save them here in such a way that we can assert
# easily later. We can't mock that because otherwise the mock object would be written to disk (that fails)
trainer.fit_loop.current_epoch = FIXED_EPOCH - 1 # type: ignore
trainer.fit_loop.global_step = FIXED_GLOBAL_STEP - 1 # type: ignore
# In PL, it is the Trainer's responsibility to save the model. Checkpoint handling refers back to the trainer
# to get a save_func. Mimicking that here.
trainer.save_checkpoint(checkpoint_path, weights_only=weights_only)
@ -74,10 +69,8 @@ def test_create_model_from_lightning_checkpoint(test_output_dirs: OutputFolderFo
assert config._train_output_size is None
# method to get all devices of a model
loaded_model = load_from_checkpoint_and_adjust_for_inference(config, checkpoint_path)
# Information about epoch and global step must be present in the message that on_checkpoint_load writes
assert str(FIXED_EPOCH) in loaded_model.checkpoint_loading_message
assert str(FIXED_GLOBAL_STEP) in loaded_model.checkpoint_loading_message
assert loaded_model is not None
assert "(epoch = 0, global_step = 0)" in loaded_model.checkpoint_loading_message
if isinstance(config, SegmentationModelBase):
assert config._test_output_size is not None
assert config._train_output_size is not None

Просмотреть файл

@ -35,8 +35,8 @@ def test_update_tau() -> None:
byol_weight_update = ByolMovingAverageWeightUpdate(initial_tau=0.99)
trainer = Trainer(max_epochs=5)
trainer.train_dataloader = dummy_rsna_train_dataloader
n_steps_per_epoch = len(trainer.train_dataloader)
trainer.train_dataloader = dummy_rsna_train_dataloader # type: ignore
n_steps_per_epoch = len(trainer.train_dataloader) # type: ignore
total_steps = n_steps_per_epoch * trainer.max_epochs # type: ignore
byol_module = BYOLInnerEye(num_samples=16,
learning_rate=1e-3,

Просмотреть файл

@ -134,12 +134,12 @@ def test_innereye_ssl_container_cifar10_resnet_simclr() -> None:
# Check the metrics that were recorded during training
# Note: It is possible that after the PyTorch 1.10 upgrade, we can't get parity between local runs and runs on
# the hosted build agents. If that suspicion is confirmed, we need to add branching for local and cloud results.
expected_metrics = {'simclr/val/loss': 2.8736934661865234,
'ssl_online_evaluator/val/loss': 2.2684895992279053,
expected_metrics = {'simclr/val/loss': 2.859630584716797,
'ssl_online_evaluator/val/loss': 2.26649808883667,
'ssl_online_evaluator/val/AccuracyAtThreshold05': 0.20000000298023224,
'simclr/train/loss': 3.6261773109436035,
'simclr/train/loss': 3.6261844635009766,
'simclr/learning_rate': 0.0,
'ssl_online_evaluator/train/loss': 3.1140334606170654,
'ssl_online_evaluator/train/loss': 3.212641477584839,
'ssl_online_evaluator/train/online_AccuracyAtThreshold05': 0.0}
_compare_stored_metrics(runner, expected_metrics, abs=5e-5)
@ -311,7 +311,7 @@ def test_simclr_training_recovery(test_output_dirs: OutputFolderForTests) -> Non
checkpoint_folder = test_output_dirs.create_file_or_folder_path("checkpoints")
checkpoint_folder.mkdir(exist_ok=True)
checkpoint = ModelCheckpoint(dirpath=checkpoint_folder,
every_n_val_epochs=1,
every_n_epochs=1,
save_last=True)
trainer = Trainer(default_root_dir=str(test_output_dirs.root_dir),
@ -359,7 +359,7 @@ def test_online_evaluator_recovery(test_output_dirs: OutputFolderForTests) -> No
checkpoint_folder = test_output_dirs.create_file_or_folder_path("checkpoints")
checkpoint_folder.mkdir(exist_ok=True)
checkpoints = ModelCheckpoint(dirpath=checkpoint_folder,
every_n_val_epochs=1,
every_n_epochs=1,
save_last=True)
# Create a first callback, that will be used in training.
callback1 = SSLOnlineEvaluatorInnerEye(class_weights=None,
@ -469,7 +469,6 @@ def test_online_evaluator_distributed() -> None:
trainer = Trainer(strategy="ddp", num_processes=2)
# Test the two flags that the internal logic of on_pretrain_routine_start uses
assert trainer._accelerator_connector.is_distributed
assert trainer._accelerator_connector.use_ddp
callback.on_pretrain_routine_start(trainer, mock_module)
# Check that SyncBatchNorm has been turned on
mock_sync.assert_called_once()

Просмотреть файл

@ -62,7 +62,7 @@ dependencies:
- wheel=0.37.1=pyhd3eb1b0_0
- x264=1!157.20191217=h7b6447c_0
- xz=5.2.5=h7f8727e_1
- zlib=1.2.12=h7f8727e_2
- zlib=1.2.12=h5eee18b_3
- zstd=1.5.2=ha4553b6_0
- pip:
- absl-py==1.2.0
@ -78,7 +78,7 @@ dependencies:
- async-timeout==4.0.2
- attrs==22.1.0
- azure-common==1.1.28
- azure-core==1.25.0
- azure-core==1.25.1
- azure-graphrbac==0.61.1
- azure-identity==1.7.0
- azure-mgmt-authorization==0.61.0
@ -109,7 +109,7 @@ dependencies:
- backports-tempfile==1.0
- backports-weakref==1.0.post1
- beautifulsoup4==4.11.1
- black==22.6.0
- black==22.8.0
- bleach==5.0.1
- cachetools==4.2.4
- cffi==1.15.1
@ -139,33 +139,32 @@ dependencies:
- flake8==3.8.3
- flask==2.2.2
- frozenlist==1.3.1
- fsspec==2022.7.1
- fsspec==2022.8.2
- furo==2022.6.21
- fusepy==3.0.1
- future==0.18.2
- gitdb==4.0.9
- gitpython==3.1.7
- google-auth==1.35.0
- google-auth-oauthlib==0.4.6
- gputil==1.4.0
- greenlet==1.1.3
- grpcio==1.47.0
- grpcio==1.48.1
- gunicorn==20.1.0
- h5py==2.10.0
- hi-ml==0.2.2
- hi-ml-azure==0.2.2
- hi-ml==0.2.5
- hi-ml-azure==0.2.5
- humanize==4.3.0
- idna==3.3
- idna==3.4
- imageio==2.15.0
- imagesize==1.4.1
- importlib-metadata==4.12.0
- importlib-resources==5.9.0
- iniconfig==1.1.1
- innereye-dicom-rt==1.0.3
- ipykernel==6.15.1
- ipykernel==6.15.3
- ipython==7.31.1
- ipython-genutils==0.2.0
- ipywidgets==8.0.1
- ipywidgets==8.0.2
- isodate==0.6.1
- itsdangerous==2.1.2
- jeepney==0.8.0
@ -173,18 +172,18 @@ dependencies:
- jmespath==0.10.0
- joblib==0.16.0
- jsonpickle==2.2.0
- jsonschema==4.14.0
- jsonschema==4.16.0
- jupyter==1.0.0
- jupyter-client==6.1.5
- jupyter-console==6.4.4
- jupyter-core==4.11.1
- jupyterlab-pygments==0.2.2
- jupyterlab-widgets==3.0.2
- jupyterlab-widgets==3.0.3
- kiwisolver==1.4.4
- lightning-bolts==0.4.0
- llvmlite==0.34.0
- lxml==4.9.1
- mako==1.2.1
- mako==1.2.2
- markdown==3.4.1
- markdown-it-py==2.1.0
- markupsafe==2.1.1
@ -207,26 +206,26 @@ dependencies:
- mypy==0.910
- mypy-extensions==0.4.3
- myst-parser==0.18.0
- nbclient==0.6.7
- nbclient==0.6.8
- nbconvert==7.0.0
- nbformat==5.4.0
- nbformat==5.5.0
- ndg-httpsclient==0.5.1
- nest-asyncio==1.5.5
- networkx==2.8.6
- nibabel==4.0.1
- nibabel==4.0.2
- notebook==6.4.12
- numba==0.51.2
- numpy==1.19.1
- oauthlib==3.2.0
- oauthlib==3.2.1
- opencv-python-headless==4.5.1.48
- packaging==21.3
- pandas==1.1.0
- pandocfilters==1.5.0
- papermill==2.2.2
- param==1.9.3
- pathspec==0.9.0
- pathspec==0.10.1
- pexpect==4.8.0
- pillow==9.0.0
- pillow==9.1.1
- pkgutil-resolve-name==1.3.10
- platformdirs==2.5.2
- pluggy==0.13.1
@ -242,10 +241,10 @@ dependencies:
- pyasn1-modules==0.2.8
- pycodestyle==2.6.0
- pycparser==2.21
- pydeprecate==0.3.1
- pydeprecate==0.3.2
- pydicom==2.0.0
- pyflakes==2.2.0
- pyjwt==1.7.1
- pyjwt==2.4.0
- pynndescent==0.5.7
- pyopenssl==20.0.1
- pyparsing==3.0.9
@ -256,7 +255,7 @@ dependencies:
- pytest-forked==1.3.0
- pytest-xdist==1.34.0
- python-dateutil==2.8.2
- pytorch-lightning==1.5.5
- pytorch-lightning==1.6.5
- pytz==2022.2.1
- pywavelets==1.3.0
- pyyaml==6.0
@ -293,7 +292,7 @@ dependencies:
- sphinxcontrib-jsmath==1.0.1
- sphinxcontrib-qthelp==1.0.3
- sphinxcontrib-serializinghtml==1.1.5
- sqlalchemy==1.4.40
- sqlalchemy==1.4.41
- sqlparse==0.4.2
- stopit==1.1.2
- stringcase==1.2.0
@ -312,14 +311,14 @@ dependencies:
- torchio==0.18.74
- torchmetrics==0.6.0
- tornado==6.2
- tqdm==4.64.0
- tqdm==4.64.1
- typing-inspect==0.8.0
- umap-learn==0.5.2
- urllib3==1.26.7
- webencodings==0.5.1
- websocket-client==1.4.0
- websocket-client==1.4.1
- werkzeug==2.2.2
- widgetsnbextension==4.0.2
- widgetsnbextension==4.0.3
- wrapt==1.14.1
- yacs==0.1.8
- yarl==1.8.1

Просмотреть файл

@ -6,7 +6,7 @@ dependencies:
- blas=1.0=mkl
- blosc=1.21.0=h19a0ad4_1
- ca-certificates=2022.4.26=haa95532_0
- certifi=2022.5.18.1=py38haa95532_0
- certifi=2022.6.15=py38haa95532_0
- cudatoolkit=11.3.1=h59b6b97_2
- freetype=2.10.4=hd328e21_0
- intel-openmp=2021.4.0=haa95532_3556
@ -20,7 +20,7 @@ dependencies:
- mkl-service=2.4.0=py38h2bbff1b_0
- mkl_fft=1.3.1=py38h277e83a_0
- mkl_random=1.2.2=py38hf11a4ad_0
- openssl=1.1.1o=h2bbff1b_0
- openssl=1.1.1p=h2bbff1b_0
- pip=20.1.1=py38_1
- python=3.8.3=he1778fa_2
- python-blosc=1.7.0=py38he774522_0
@ -124,7 +124,7 @@ dependencies:
- greenlet==1.1.2
- grpcio==1.46.3
- h5py==2.10.0
- hi-ml==0.2.2
- hi-ml==0.2.3
- hi-ml-azure==0.2.2
- humanize==4.2.0
- idna==3.3
@ -214,7 +214,7 @@ dependencies:
- pydicom==2.0.0
- pyflakes==2.2.0
- pygments==2.12.0
- pyjwt==1.7.1
- pyjwt==2.4.0
- pynndescent==0.5.7
- pyopenssl==20.0.1
- pyparsing==3.0.9
@ -225,7 +225,7 @@ dependencies:
- pytest-forked==1.3.0
- pytest-xdist==1.34.0
- python-dateutil==2.8.2
- pytorch-lightning==1.5.5
- pytorch-lightning==1.6.4
- pytz==2022.1
- pywavelets==1.3.0
- pywin32==227

Просмотреть файл

@ -28,8 +28,8 @@ dependencies:
- gitpython==3.1.7
- gputil==1.4.0
- h5py==2.10.0
- hi-ml==0.2.2
- hi-ml-azure==0.2.2
- hi-ml==0.2.5
- hi-ml-azure==0.2.5
- imageio==2.15.0
- InnerEye-DICOM-RT==1.0.3
- ipython==7.31.1
@ -40,8 +40,8 @@ dependencies:
- matplotlib==3.3.0
- mlflow==1.23.1
- monai==0.6.0
- mypy-extensions==0.4.3
- mypy==0.910
- mypy-extensions==0.4.3
- myst-parser==0.18.0
- numba==0.51.2
- numba==0.51.2
@ -50,17 +50,17 @@ dependencies:
- pandas==1.1.0
- papermill==2.2.2
- param==1.9.3
- pillow==9.0.0
- pillow==9.1.1
- protobuf<=3.20.1
- psutil==5.7.2
- pydicom==2.0.0
- pyflakes==2.2.0
- PyJWT==1.7.1
- PyJWT==2.4.0
- pytest==6.0.1
- pytest-cov==2.10.1
- pytest-forked==1.3.0
- pytest-xdist==1.34.0
- pytest==6.0.1
- pytorch-lightning==1.5.5
- pytorch-lightning==1.6.5
- rich==10.13.0
- rpdb==0.1.6
- ruamel.yaml==0.16.12