ENH: Download Checkpoints across AML workspaces (#642)

Enable checkpoints transfer across AML workspaces

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Anton Schwaighofer <antonsc@microsoft.com>
This commit is contained in:
Kenza Bouzid 2022-11-01 17:38:37 +00:00 коммит произвёл GitHub
Родитель d54e28d0e6
Коммит 41d30807aa
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
25 изменённых файлов: 442 добавлений и 230 удалений

Просмотреть файл

@ -0,0 +1,61 @@
# Checkpoint Utils
Hi-ml toolbox offers different utilities to parse and download pretrained checkpoints that help you abstract checkpoint
downloading from different sources. Refer to
[CheckpointParser](https://github.com/microsoft/hi-ml/blob/main/hi-ml/src/health_ml/utils/checkpoint_utils.py#L238) for
more details on the supported checkpoints format. Here's how you can use the checkpoint parser depending on the source:
- For a local path, simply pass it as shown below. The parser will further check if the provided path exists:
```python
from health_ml.utils.checpoint_utils import CheckpointParser
download_dir = 'outputs/checkpoints'
checkpoint_parser = CheckpointParser(checkpoint='local/path/to/my_checkpoint/model.ckpt')
print('Checkpoint', checkpoint_parser.checkpoint, 'is a local file', checkpoint_parser.is_local_file)
local_file = parser.get_path(download_dir)
```
- To download a checkpoint from a URL:
```python
from health_ml.utils.checpoint_utils import CheckpointParser, MODEL_WEIGHTS_DIR_NAME
download_dir = 'outputs/checkpoints'
checkpoint_parser = CheckpointParser('https://my_checkpoint_url.com/model.ckpt')
print('Checkpoint', checkpoint_parser.checkpoint, 'is a URL', checkpoint_parser.is_url)
# will dowload the checkpoint to download_dir/MODEL_WEIGHTS_DIR_NAME
path_to_ckpt = checkpoint_parser.get_path(download_dir)
```
- Finally checkpoints from an Azure ML runs can be reused by providing an id in this format
`<AzureML_run_id>:<optional/custom/path/to/checkpoints/><filename.ckpt>`. If no custom path is provided (e.g.,
`<AzureML_run_id>:<filename.ckpt>`) the checkpoint will be downloaded from the default checkpoint folder
(e.g., `outputs/checkpoints`) If no filename is provided, (e.g., `src_checkpoint=<AzureML_run_id>`) the latest
checkpoint will be downloaded (e.g., `last.ckpt`).
```python
from health_ml.utils.checpoint_utils import CheckpointParser
checkpoint_parser = CheckpointParser('AzureML_run_id:best.ckpt')
print('Checkpoint', checkpoint_parser.checkpoint, 'is a AML run', checkpoint_parser.is_aml_run_id)
path_azure_ml_ckpt = checkpoint_parser.get_path(download_dir)
```
If the Azure ML run is in a different workspace, a temporary SAS URL to download the checkpoint can be generated as follow:
```bash
cd hi-ml-cpath
python src/health_cpath/scripts/generate_checkpoint_url.py --run_id=AzureML_run_id:best_val_loss.ckpt --expiry_days=10
```
N.B: config.json should correspond to the original workspace where the AML run lives.
## Use cases
CheckpointParser is used to specify a `src_checkpoint` to [resume training from a given
checkpoint](https://github.com/microsoft/hi-ml/blob/main/docs/source/runner.md#L238),
or [run inference with a pretrained model](https://github.com/microsoft/hi-ml/blob/main/docs/source/runner.md#L215),
as well as
[ssl_checkpoint](https://github.com/microsoft/hi-ml/blob/main/hi-ml-cpath/src/health_cpath/utils/deepmil_utils.py#L62)
for computation pathology self supervised pretrained encoders.

Просмотреть файл

@ -42,6 +42,7 @@ The `hi-ml` toolbox provides
logging.md
diagnostics.md
runner.md
checkpoints.md
.. toctree::
:maxdepth: 1

Просмотреть файл

@ -226,6 +226,8 @@ the model weights by setting `--src_checkpoint` argument that supports three typ
checkpoints folder `outputs/checkpoints`. If no filename is provided (e.g., `--src_checkpoint=AzureML_run_id`),
the last epoch checkpoint `outputs/checkpoints/last.ckpt` will be loaded.
Refer to [Checkpoints Utils](checkpoints.md) for more details on how checkpoints are parsed.
Running the following command line will run inference using `MyContainer` model with weights from the checkpoint saved
in the AzureMl run `MyContainer_XXXX_yyyy` at the best validation loss epoch `/outputs/checkpoints/best_val_loss.ckpt`.

Просмотреть файл

@ -29,6 +29,10 @@ pip_from_conda:
sed -e '1,/pip:/ d' environment.yml | grep -v "#" | cut -d "-" -f 2- > temp_requirements.txt
pip install -r temp_requirements.txt
# Lock the current Conda environment secondary dependencies versions
lock_env:
./create_and_lock_environment.sh
# clean build artifacts
clean:
rm -rf `find . -type d -name __pycache__`
@ -78,7 +82,7 @@ pytest_coverage:
pytest --cov=health_cpath --cov SSL --cov-branch --cov-report=html --cov-report=xml --cov-report=term-missing --cov-config=.coveragerc
SSL_CKPT_RUN_ID_CRCK := CRCK_SimCLR_1655731022_85790606
SRC_CKPT_RUN_ID_CRCK := TcgaCrckSSLMIL_1664478306_144bc833
SRC_CKPT_RUN_ID_CRCK := TcgaCrckSSLMIL_1667236343_af6e293f
# Run regression tests and compare performance
define BASE_CPATH_RUNNER_COMMAND
@ -95,7 +99,7 @@ define DEEPSMILEPANDATILES_ARGS
endef
define TCGACRCKSSLMIL_ARGS
--model=health_cpath.TcgaCrckSSLMIL --ssl_checkpoint_run_id=${SSL_CKPT_RUN_ID_CRCK}
--model=health_cpath.TcgaCrckSSLMIL --ssl_checkpoint=${SSL_CKPT_RUN_ID_CRCK}
endef
define TCGACRCKIMANEGETMIL_ARGS

Просмотреть файл

@ -81,6 +81,7 @@ dependencies:
- azure-mgmt-keyvault==10.0.0
- azure-mgmt-resource==21.1.0
- azure-mgmt-storage==20.0.0
- azure-storage-blob==12.5.0
- azureml-core==1.43.0
- azureml-dataprep==4.0.4
- azureml-dataprep-native==38.0.0

Просмотреть файл

@ -1,3 +1,4 @@
azure-storage-blob==12.5.0
coloredlogs==15.0.1
cucim==22.04.00
girder-client==3.1.14

Просмотреть файл

@ -19,8 +19,7 @@ from health_cpath.utils.callbacks import LossAnalysisCallback, LossCallbackParam
from health_ml.utils import fixed_paths
from health_ml.deep_learning_config import OptimizerParams
from health_ml.lightning_container import LightningContainer
from health_ml.deep_learning_config import SRC_CKPT_INFO_MESSAGE
from health_ml.utils.checkpoint_utils import get_best_checkpoint_path
from health_ml.utils.checkpoint_utils import get_best_checkpoint_path, CheckpointParser
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
from health_cpath.datamodules.base_module import CacheLocation, CacheMode, HistoDataModule
@ -80,8 +79,6 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara
default=False,
doc="If True, will use classifier weights from pretrained model specified in src_checkpoint. If False, will "
"initiliaze classifier with random weights.")
ssl_checkpoint_run_id: str = param.String(default="", doc="Optional run id from which to load checkpoint if "
"using SSLEncoder")
max_num_workers: int = param.Integer(10, bounds=(0, None),
doc="The maximum number of worker processes for dataloaders. Dataloaders use"
"a heuristic num_cpus/num_gpus to set the number of workers, which can be"
@ -100,6 +97,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara
def validate(self) -> None:
super().validate()
EncoderParams.validate(self)
if not any([self.tune_encoder, self.tune_pooling, self.tune_classifier]) and not self.run_inference_only:
raise ValueError(
"At least one of the encoder, pooling or classifier should be fine tuned. Turn on one of the tune "
@ -112,7 +110,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara
):
raise ValueError(
"You need to specify a source checkpoint, to use a pretrained encoder, pooling or classifier."
f"{SRC_CKPT_INFO_MESSAGE}"
f" {CheckpointParser.INFO_MESSAGE}"
)
if (
self.tune_encoder and self.encoding_chunk_size < self.max_bag_size
@ -273,8 +271,7 @@ class BaseMILTiles(BaseMIL):
def get_transforms_dict(self, image_key: str) -> Dict[ModelKey, Union[Callable, None]]:
if self.is_caching:
encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.ssl_checkpoint_run_id,
self.outputs_folder)
encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.outputs_folder)
transform = Compose([
LoadTilesBatchd(image_key, progress=True),
EncodeTilesBatchd(image_key, encoder, chunk_size=self.encoding_chunk_size) # type: ignore
@ -295,7 +292,7 @@ class BaseMILTiles(BaseMIL):
pretrained_classifier=self.pretrained_classifier,
dropout_rate=self.dropout_rate,
outputs_folder=self.outputs_folder,
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
encoder_params=create_from_matching_params(self, EncoderParams),
pooling_params=create_from_matching_params(self, PoolingParams),
optimizer_params=create_from_matching_params(self, OptimizerParams),
@ -339,7 +336,6 @@ class BaseMILSlides(BaseMIL):
pretrained_classifier=self.pretrained_classifier,
dropout_rate=self.dropout_rate,
outputs_folder=self.outputs_folder,
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
encoder_params=create_from_matching_params(self, EncoderParams),
pooling_params=create_from_matching_params(self, PoolingParams),
optimizer_params=create_from_matching_params(self, OptimizerParams),

Просмотреть файл

@ -27,6 +27,7 @@ from health_cpath.models.encoders import (
from health_cpath.configs.classification.BaseMIL import BaseMILTiles
from health_cpath.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
from health_cpath.utils.naming import PlotOption
from health_ml.utils.checkpoint_utils import CheckpointParser
class DeepSMILECrck(BaseMILTiles):
@ -56,8 +57,6 @@ class DeepSMILECrck(BaseMILTiles):
def setup(self) -> None:
super().setup()
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_crck_4ws
def get_data_module(self) -> TilesDataModule:
return TcgaCrckTilesDataModule(
@ -93,6 +92,8 @@ class TcgaCrckImageNetSimCLRMIL(DeepSMILECrck):
class TcgaCrckSSLMIL(DeepSMILECrck):
def __init__(self, **kwargs: Any) -> None:
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint = self.ssl_checkpoint or CheckpointParser(innereye_ssl_checkpoint_crck_4ws)
super().__init__(encoder_type=SSLEncoder.__name__, **kwargs)

Просмотреть файл

@ -22,6 +22,7 @@ from health_cpath.datasets.default_paths import (
PANDA_DATASET_ID,
PANDA_5X_TILES_DATASET_ID)
from health_cpath.utils.naming import PlotOption
from health_ml.utils.checkpoint_utils import CheckpointParser
class BaseDeepSMILEPanda(BaseMIL):
@ -70,8 +71,6 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
def setup(self) -> None:
BaseMILTiles.setup(self)
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_binary
def get_data_module(self) -> PandaTilesDataModule:
return PandaTilesDataModule(
@ -110,6 +109,8 @@ class TilesPandaImageNetSimCLRMIL(DeepSMILETilesPanda):
class TilesPandaSSLMIL(DeepSMILETilesPanda):
def __init__(self, **kwargs: Any) -> None:
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint = self.ssl_checkpoint or CheckpointParser(innereye_ssl_checkpoint_binary)
super().__init__(encoder_type=SSLEncoder.__name__, **kwargs)
@ -136,8 +137,6 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
def setup(self) -> None:
BaseMILSlides.setup(self)
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_binary
def get_dataloader_kwargs(self) -> dict:
return dict(
@ -181,6 +180,8 @@ class SlidesPandaImageNetSimCLRMIL(DeepSMILESlidesPanda):
class SlidesPandaSSLMIL(DeepSMILESlidesPanda):
def __init__(self, **kwargs: Any) -> None:
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint = self.ssl_checkpoint or CheckpointParser(innereye_ssl_checkpoint_binary)
super().__init__(encoder_type=SSLEncoder.__name__, **kwargs)

Просмотреть файл

@ -7,12 +7,13 @@
from typing import Any, Dict, Callable, Union
from torch import optim
from monai.transforms import Compose, ScaleIntensityRanged, RandRotate90d, RandFlipd
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_binary
from health_azure.utils import create_from_matching_params
from health_ml.networks.layers.attention_layers import (
TransformerPooling,
TransformerPoolingBenchmark
)
from health_ml.utils.checkpoint_utils import CheckpointParser
from health_ml.deep_learning_config import OptimizerParams
from health_cpath.datasets.panda_dataset import PandaDataset
from health_cpath.datamodules.panda_module_benchmark import PandaSlidesDataModuleBenchmark
@ -127,7 +128,6 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
class_weights=self.data_module.class_weights,
dropout_rate=self.dropout_rate,
outputs_folder=self.outputs_folder,
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
encoder_params=create_from_matching_params(self, EncoderParams),
pooling_params=create_from_matching_params(self, PoolingParams),
optimizer_params=create_from_matching_params(self, OptimizerParams),
@ -150,6 +150,8 @@ class SlidesPandaImageNetSimCLRMILBenchmark(DeepSMILESlidesPandaBenchmark):
class SlidesPandaSSLMILBenchmark(DeepSMILESlidesPandaBenchmark):
def __init__(self, **kwargs: Any) -> None:
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint = self.ssl_checkpoint or CheckpointParser(innereye_ssl_checkpoint_binary)
super().__init__(encoder_type=SSLEncoder.__name__, **kwargs)

Просмотреть файл

@ -46,7 +46,6 @@ class BaseDeepMILModule(LightningModule):
pretrained_classifier: bool = False,
dropout_rate: Optional[float] = None,
verbose: bool = False,
ssl_ckpt_run_id: Optional[str] = None,
outputs_folder: Optional[Path] = None,
encoder_params: EncoderParams = EncoderParams(),
pooling_params: PoolingParams = PoolingParams(),
@ -63,8 +62,6 @@ class BaseDeepMILModule(LightningModule):
:param pretrained_classifier: Whether to use pretrained classifier (default=False for random init).
:param dropout_rate: Rate of pre-classifier dropout (0-1). `None` for no dropout (default).
:param verbose: if True statements about memory usage are output at each step.
:param ssl_ckpt_run_id: Optional parameter to provide the AML run id from where to download the checkpoint
if using `SSLEncoder`.
:param outputs_folder: Path to output folder where encoder checkpoint is downloaded.
:param encoder_params: Encoder parameters that specify all encoder specific attributes.
:param pooling_params: Pooling layer parameters that specify all encoder specific attributes.
@ -98,7 +95,7 @@ class BaseDeepMILModule(LightningModule):
self.tune_classifier = tune_classifier
# Model components
self.encoder = encoder_params.get_encoder(ssl_ckpt_run_id, outputs_folder)
self.encoder = encoder_params.get_encoder(outputs_folder)
self.aggregation_fn, self.num_pooling = pooling_params.get_pooling_layer(self.encoder.num_encoding)
self.classifier_fn = self.get_classifier()
self.activation_fn = self.get_activation()

Просмотреть файл

@ -0,0 +1,51 @@
from azure.storage.blob import generate_blob_sas, BlobSasPermissions
from azureml.core import Workspace
from datetime import datetime, timedelta
from health_azure import get_workspace
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
from typing import Optional
def get_checkpoint_url_from_aml_run(
run_id: str,
checkpoint_filename: str,
expiry_days: int = 1,
aml_workspace: Optional[Workspace] = None,
sas_token: Optional[str] = None,
) -> str:
"""Generate a SAS URL for the checkpoint file in the given run.
:param run_id: The run ID of the checkpoint.
:param checkpoint_filename: The filename of the checkpoint.
:param expiry_days: The number of days the SAS URL is valid for, defaults to 30.
:param aml_workspace: The Azure ML workspace to use, defaults to the default workspace.
:param sas_token: The SAS token to use, defaults to None.
:return: The SAS URL for the checkpoint.
"""
datastore = get_workspace(aml_workspace=aml_workspace).get_default_datastore()
account_name = datastore.account_name
container_name = 'azureml'
blob_name = f'ExperimentRun/dcid.{run_id}/{DEFAULT_AML_CHECKPOINT_DIR}/{checkpoint_filename}'
if not sas_token:
sas_token = generate_blob_sas(account_name=datastore.account_name,
container_name=container_name,
blob_name=blob_name,
account_key=datastore.account_key,
permission=BlobSasPermissions(read=True),
expiry=datetime.utcnow() + timedelta(days=expiry_days))
return f'https://{account_name}.blob.core.windows.net/{container_name}/{blob_name}?{sas_token}'
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--run_id', type=str, help='The run id of the model checkpoint')
parser.add_argument('--checkpoint_filename', type=str, default='last.ckpt',
help='The filename of the model checkpoint. Default: last.ckpt')
parser.add_argument('--expiry_days', type=int, default=30,
help='The number of hours for which the SAS token is valid. Default: 30 for 1 month')
args = parser.parse_args()
url = get_checkpoint_url_from_aml_run(args.run_id, args.checkpoint_filename, args.expiry_days)
print(f'Checkpoint URL: {url}')

Просмотреть файл

@ -7,8 +7,7 @@ import param
from torch import nn
from pathlib import Path
from typing import Optional, Tuple
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, CheckpointDownloader
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
from health_ml.utils.checkpoint_utils import CheckpointParser
from health_cpath.models.encoders import (
HistoSSLEncoder,
ImageNetSimCLREncoder,
@ -60,11 +59,18 @@ class EncoderParams(param.Parameterized):
encoding_chunk_size: int = param.Integer(
default=0, doc="If > 0 performs encoding in chunks, by enconding_chunk_size tiles " "per chunk"
)
ssl_checkpoint: CheckpointParser = param.ClassSelector(class_=CheckpointParser, default=None,
instantiate=False, doc=CheckpointParser.DOC)
def get_encoder(self, ssl_ckpt_run_id: Optional[str], outputs_folder: Optional[Path]) -> TileEncoder:
def validate(self) -> None:
"""Validate the encoder parameters."""
if self.encoder_type == SSLEncoder.__name__ and not self.ssl_checkpoint:
raise ValueError("SSLEncoder requires an ssl_checkpoint. Please specify a valid checkpoint. "
f"{CheckpointParser.INFO_MESSAGE}")
def get_encoder(self, outputs_folder: Optional[Path]) -> TileEncoder:
"""Given the current encoder parameters, returns the encoder object.
:param ssl_ckpt_run_id: The AML run id for SSL checkpoint download.
:param outputs_folder: The output folder where SSL checkpoint should be saved.
:param encoder_params: The encoder arguments that define the encoder class object depending on the encoder type.
:raises ValueError: If the encoder type is not supported.
@ -90,15 +96,9 @@ class EncoderParams(param.Parameterized):
encoder = HistoSSLEncoder(tile_size=self.tile_size, n_channels=self.n_channels)
elif self.encoder_type == SSLEncoder.__name__:
assert ssl_ckpt_run_id and outputs_folder, "SSLEncoder requires ssl_ckpt_run_id and outputs_folder"
downloader = CheckpointDownloader(
run_id=ssl_ckpt_run_id,
download_dir=outputs_folder,
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
remote_checkpoint_dir=Path(DEFAULT_AML_CHECKPOINT_DIR),
)
assert outputs_folder is not None, "outputs_folder cannot be None for SSLEncoder"
encoder = SSLEncoder(
pl_checkpoint_path=downloader.local_checkpoint_path,
pl_checkpoint_path=self.ssl_checkpoint.get_path(outputs_folder),
tile_size=self.tile_size,
n_channels=self.n_channels,
)

Просмотреть файл

@ -36,7 +36,7 @@ from SSL.configs.CXR_SSL_configs import CXRImageClassifier, NIH_RSNA_SimCLR
from health_ml.runner import Runner
from health_ml.utils import AzureMLProgressBar
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME
from health_ml.utils.fixed_paths import repository_root_directory, OutputFolderForTests
from health_ml.utils.lightning_loggers import StoringLogger
@ -247,7 +247,7 @@ def test_ssl_container_rsna() -> None:
_compare_stored_metrics(runner, expected_metrics)
# Check that we are able to load the checkpoint and create classifier model
checkpoint_path = loaded_config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
checkpoint_path = loaded_config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME
model_namespace_cxr = "SSL.configs.CXRImageClassifier"
args = common_test_args + [f"--model={model_namespace_cxr}",
f"--local_datasets={str(path_to_cxr_test_dataset)}",

Просмотреть файл

@ -37,7 +37,7 @@ class MockDeepSMILETilesPanda(DeepSMILETilesPanda):
# declared in TrainerParams:
max_epochs=2,
crossval_count=1,
ssl_checkpoint_run_id="",
ssl_checkpoint=None,
analyse_loss=analyse_loss,
)
default_kwargs.update(kwargs)

Просмотреть файл

@ -15,13 +15,16 @@ from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Typ
from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose
from torch.utils.data._utils.collate import default_collate
from health_cpath.configs.classification.DeepSMILESlidesPandaBenchmark import SlidesPandaSSLMILBenchmark
from health_cpath.datamodules.panda_module import PandaTilesDataModule
from health_ml.networks.layers.attention_layers import AttentionLayer, TransformerPoolingBenchmark
from health_cpath.configs.classification.BaseMIL import BaseMIL, BaseMILTiles
from health_cpath.configs.classification.DeepSMILECrck import DeepSMILECrck
from health_cpath.configs.classification.DeepSMILEPanda import BaseDeepSMILEPanda, DeepSMILETilesPanda
from health_cpath.configs.classification.DeepSMILECrck import DeepSMILECrck, TcgaCrckSSLMIL
from health_cpath.configs.classification.DeepSMILEPanda import (
BaseDeepSMILEPanda, DeepSMILETilesPanda, SlidesPandaSSLMIL, TilesPandaSSLMIL
)
from health_cpath.datamodules.base_module import HistoDataModule, TilesDataModule
from health_cpath.datasets.base_dataset import DEFAULT_LABEL_COLUMN, TilesDataset
from health_cpath.datasets.default_paths import PANDA_5X_TILES_DATASET_ID, TCGA_CRCK_DATASET_DIR
@ -34,6 +37,9 @@ from testhisto.mocks.slides_generator import MockPandaSlidesGenerator, TilesPosi
from testhisto.mocks.tiles_generator import MockPandaTilesGenerator
from testhisto.mocks.container import MockDeepSMILETilesPanda, MockDeepSMILESlidesPanda
from health_ml.utils.common_utils import is_gpu_available
from health_ml.utils.checkpoint_utils import CheckpointParser
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_crck_4ws, innereye_ssl_checkpoint_binary
from testhisto.models.test_encoders import TEST_SSL_RUN_ID
no_gpu = not is_gpu_available()
@ -198,9 +204,7 @@ def add_callback(fn: Callable, callback: Callable) -> Callable:
def test_metrics(n_classes: int) -> None:
input_dim = (128,)
def _mock_get_encoder( # type: ignore
self, ssl_ckpt_run_id: Optional[str], outputs_folder: Optional[Path]
) -> TileEncoder:
def _mock_get_encoder(self, outputs_folder: Optional[Path]) -> TileEncoder: # type: ignore
return IdentityEncoder(input_dim=input_dim)
with patch("health_cpath.models.deepmil.EncoderParams.get_encoder", new=_mock_get_encoder):
@ -718,3 +722,17 @@ def test_on_run_extra_val_epoch(mock_panda_tiles_root_dir: Path) -> None:
container.model.outputs_handler.test_plots_handler.plot_options # type: ignore
== container.model.outputs_handler.val_plots_handler.plot_options # type: ignore
)
@pytest.mark.parametrize(
"container_type", [TcgaCrckSSLMIL, TilesPandaSSLMIL, SlidesPandaSSLMIL, SlidesPandaSSLMILBenchmark]
)
def test_ssl_containers_default_checkpoint(container_type: BaseMIL) -> None:
if container_type == TcgaCrckSSLMIL:
default_checkpoint = innereye_ssl_checkpoint_crck_4ws
else:
default_checkpoint = innereye_ssl_checkpoint_binary
assert container_type().ssl_checkpoint.checkpoint == default_checkpoint
container = container_type(ssl_checkpoint=CheckpointParser(TEST_SSL_RUN_ID))
assert container.ssl_checkpoint.checkpoint != default_checkpoint

Просмотреть файл

@ -12,7 +12,7 @@ from torch import Tensor, float32, nn, rand
from torchvision.models import resnet18
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, CheckpointDownloader
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME, CheckpointDownloader
from health_cpath.models.encoders import (Resnet18, TileEncoder, HistoSSLEncoder,
ImageNetSimCLREncoder, SSLEncoder)
from health_cpath.utils.layer_utils import setup_feature_extractor
@ -35,8 +35,9 @@ def get_simclr_imagenet_encoder() -> TileEncoder:
def get_ssl_encoder(download_dir: Path) -> TileEncoder:
downloader = CheckpointDownloader(run_id=TEST_SSL_RUN_ID,
download_dir=download_dir,
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME,
remote_checkpoint_dir=Path(DEFAULT_AML_CHECKPOINT_DIR))
downloader.download_checkpoint_if_necessary()
return SSLEncoder(pl_checkpoint_path=downloader.local_checkpoint_path, tile_size=TILE_SIZE)

Просмотреть файл

@ -0,0 +1,68 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import pytest
from pathlib import Path
from unittest.mock import MagicMock, patch
from health_cpath.models.encoders import SSLEncoder
from health_cpath.scripts.generate_checkpoint_url import get_checkpoint_url_from_aml_run
from health_cpath.utils.deepmil_utils import EncoderParams
from health_ml.utils.checkpoint_utils import CheckpointParser, LAST_CHECKPOINT_FILE_NAME, MODEL_WEIGHTS_DIR_NAME
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
from testhiml.utils.fixed_paths_for_tests import full_test_data_path
from testhisto.models.test_encoders import TEST_SSL_RUN_ID
from testhiml.utils_testhiml import DEFAULT_WORKSPACE
LAST_CHECKPOINT = f"{DEFAULT_AML_CHECKPOINT_DIR}/{LAST_CHECKPOINT_FILE_NAME}"
def test_validate_encoder_params() -> None:
with pytest.raises(ValueError, match=r"SSLEncoder requires an ssl_checkpoint"):
encoder = EncoderParams(encoder_type=SSLEncoder.__name__)
encoder.validate()
def test_load_ssl_checkpoint_from_local_file(tmp_path: Path) -> None:
checkpoint_filename = "hello_world_checkpoint.ckpt"
local_checkpoint_path = full_test_data_path(suffix=checkpoint_filename)
encoder_params = EncoderParams(
encoder_type=SSLEncoder.__name__, ssl_checkpoint=CheckpointParser(str(local_checkpoint_path))
)
assert encoder_params.ssl_checkpoint.is_local_file
ssl_checkpoint_path = encoder_params.ssl_checkpoint.get_path(tmp_path)
assert ssl_checkpoint_path.exists()
assert ssl_checkpoint_path == local_checkpoint_path
with patch("health_cpath.models.encoders.SSLEncoder._get_encoder") as mock_get_encoder:
mock_get_encoder.return_value = (MagicMock(), MagicMock())
encoder = encoder_params.get_encoder(tmp_path)
assert isinstance(encoder, SSLEncoder)
def test_load_ssl_checkpoint_from_url(tmp_path: Path) -> None:
blob_url = get_checkpoint_url_from_aml_run(
run_id=TEST_SSL_RUN_ID,
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME,
expiry_days=1,
aml_workspace=DEFAULT_WORKSPACE.workspace)
encoder_params = EncoderParams(encoder_type=SSLEncoder.__name__, ssl_checkpoint=CheckpointParser(blob_url))
assert encoder_params.ssl_checkpoint.is_url
ssl_checkpoint_path = encoder_params.ssl_checkpoint.get_path(tmp_path)
assert ssl_checkpoint_path.exists()
assert ssl_checkpoint_path == tmp_path / MODEL_WEIGHTS_DIR_NAME / LAST_CHECKPOINT_FILE_NAME
encoder = encoder_params.get_encoder(tmp_path)
assert isinstance(encoder, SSLEncoder)
def test_load_ssl_checkpoint_from_run_id(tmp_path: Path) -> None:
encoder_params = EncoderParams(encoder_type=SSLEncoder.__name__, ssl_checkpoint=CheckpointParser(TEST_SSL_RUN_ID))
assert encoder_params.ssl_checkpoint.is_aml_run_id
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:
mock_get_workspace.return_value = DEFAULT_WORKSPACE.workspace
ssl_checkpoint_path = encoder_params.ssl_checkpoint.get_path(tmp_path)
assert ssl_checkpoint_path.exists()
assert ssl_checkpoint_path == tmp_path / TEST_SSL_RUN_ID / LAST_CHECKPOINT
encoder = encoder_params.get_encoder(tmp_path)
assert isinstance(encoder, SSLEncoder)

Просмотреть файл

@ -7,12 +7,10 @@ from __future__ import annotations
import logging
import os
import param
import re
from enum import Enum, unique
from param import Parameterized
from pathlib import Path
from typing import List, Optional
from urllib.parse import urlparse
from azureml.train.hyperdrive import HyperDriveConfig
@ -23,6 +21,7 @@ from health_azure.amulet import (ENV_AMLT_PROJECT_NAME, ENV_AMLT_INPUT_OUTPUT,
is_amulet_job, get_amulet_aml_working_dir)
from health_azure.utils import (RUN_CONTEXT, PathOrString, is_global_rank_zero, is_running_in_azure_ml)
from health_ml.utils import fixed_paths
from health_ml.utils.checkpoint_utils import CheckpointParser
from health_ml.utils.common_utils import (CHECKPOINT_FOLDER,
create_unique_timestamp_id,
DEFAULT_AML_UPLOAD_DIR,
@ -143,35 +142,13 @@ class ExperimentFolderHandler(Parameterized):
)
SRC_CHECKPOINT_FORMAT_DOC = ("<AzureML_run_id>:<optional/custom/path/to/checkpoints/><filename.ckpt>"
"If no custom path is provided (e.g., <AzureML_run_id>:<filename.ckpt>)"
"the checkpoint will be downloaded from the default checkpoint folder "
"(e.g., 'outputs/checkpoints/'). If no filename is provided, (e.g., "
"`src_checkpoint=<AzureML_run_id>`) the latest checkpoint (last.ckpt) "
"will be used to initialize the model.")
SRC_CKPT_INFO_MESSAGE = ("Please specify a valid src_checkpoint. You can either use a URL, a local file or an azureml "
"run id. For custom checkpoint paths within an azureml run, (other than last.ckpt), provide "
f"a src_checkpoint in the format {SRC_CHECKPOINT_FORMAT_DOC}.")
class WorkflowParams(param.Parameterized):
"""
This class contains all parameters that affect how the whole training and testing workflow is executed.
"""
random_seed: int = param.Integer(42, doc="The seed to use for all random number generators.")
src_checkpoint: str = param.String(default="",
doc="This flag can be used in 3 different scenarios:"
"1- Resume training from a checkpoint to train longer using"
" `resume_training` flag jointly."
"2- Run inference-only using `run_inference_only` flag jointly."
"3- Transfer learning from a pretrained model checkpoint."
"We currently support three types of checkpoints: "
" a. A local checkpoint folder that contains a checkpoint file."
" b. A URL to a remote checkpoint to be downloaded."
" c. A previous azureml run id where the checkpoint is supposed to be "
" saved ('outputs/checkpoints/' folder by default.)"
"For the latter case 'c' : src_checkpoint should be in the format of "
f"{SRC_CHECKPOINT_FORMAT_DOC}")
src_checkpoint: CheckpointParser = param.ClassSelector(class_=CheckpointParser, default=None,
instantiate=False, doc=CheckpointParser.DOC)
crossval_count: int = param.Integer(default=1, bounds=(0, None),
doc="The number of splits to use when doing cross-validation. "
"Use 1 to disable cross-validation")
@ -216,41 +193,15 @@ class WorkflowParams(param.Parameterized):
CROSSVAL_COUNT_ARG_NAME = "crossval_count"
RANDOM_SEED_ARG_NAME = "random_seed"
@property
def src_checkpoint_is_url(self) -> bool:
try:
result = urlparse(self.src_checkpoint)
return all([result.scheme, result.netloc])
except ValueError:
return False
@property
def src_checkpoint_is_local_file(self) -> bool:
return Path(self.src_checkpoint).is_file()
@property
def src_checkpoint_is_aml_run_id(self) -> bool:
match = re.match(r"[_\w-]*$", self.src_checkpoint.split(":")[0])
return match is not None and not self.src_checkpoint_is_url and not self.src_checkpoint_is_local_file
@property
def is_valid_src_checkpoint(self) -> bool:
if self.src_checkpoint:
return self.src_checkpoint_is_local_file or self.src_checkpoint_is_url or self.src_checkpoint_is_aml_run_id
return True
def validate(self) -> None:
if not self.is_valid_src_checkpoint:
raise ValueError(f"Invalid src_checkpoint: {self.src_checkpoint}. Please provide a valid URL, local file "
"or azureml run id.")
if self.crossval_count > 1:
if not (0 <= self.crossval_index < self.crossval_count):
raise ValueError(f"Attribute crossval_index out of bounds (crossval_count = {self.crossval_count})")
if self.run_inference_only and not self.src_checkpoint:
raise ValueError(f"Cannot run inference without a src_checkpoint. {SRC_CKPT_INFO_MESSAGE}")
raise ValueError(f"Cannot run inference without a src_checkpoint. {CheckpointParser.INFO_MESSAGE}")
if self.resume_training and not self.src_checkpoint:
raise ValueError(f"Cannot resume training without a src_checkpoint. {SRC_CKPT_INFO_MESSAGE}")
raise ValueError(f"Cannot resume training without a src_checkpoint. {CheckpointParser.INFO_MESSAGE}")
@property
def is_running_in_aml(self) -> bool:

Просмотреть файл

@ -25,8 +25,8 @@ from health_ml.experiment_config import ExperimentConfig
from health_ml.lightning_container import LightningContainer
from health_ml.model_trainer import create_lightning_trainer, write_experiment_summary_file
from health_ml.utils import fixed_paths
from health_ml.utils.checkpoint_handler import CheckpointHandler
from health_ml.utils.checkpoint_utils import cleanup_checkpoints
from health_ml.utils.checkpoint_handler import CheckpointHandler
from health_ml.utils.common_utils import (
EFFECTIVE_RANDOM_SEED_KEY_NAME,
change_working_directory,

Просмотреть файл

@ -4,22 +4,13 @@
# -------------------------------------------------------------------------------------------
import logging
import os
import uuid
from azureml.core import Run
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
import requests
from azureml.core import Run
from health_azure.utils import is_global_rank_zero
from health_ml.lightning_container import LightningContainer
from health_ml.utils.checkpoint_utils import (
MODEL_WEIGHTS_DIR_NAME,
CheckpointDownloader,
find_recovery_checkpoint_on_disk_or_cloud,
)
from health_ml.utils.checkpoint_utils import find_recovery_checkpoint_on_disk_or_cloud
class CheckpointHandler:
@ -40,9 +31,8 @@ class CheckpointHandler:
the checkpoint_url, local_checkpoint or checkpoint from an azureml run id.
This is called at the start of training.
"""
if self.container.src_checkpoint:
self.trained_weights_path = self.get_local_checkpoints_path_or_download()
self.trained_weights_path = self.container.src_checkpoint.get_path(self.container.checkpoint_folder)
self.container.trained_weights_path = self.trained_weights_path
def additional_training_done(self) -> None:
@ -86,53 +76,3 @@ class CheckpointHandler:
logging.info(f"Using pre-trained weights from {self.trained_weights_path}")
return self.trained_weights_path
raise ValueError("Unable to determine which checkpoint should be used for testing.")
@staticmethod
def download_weights(url: str, download_folder: Path) -> Path:
"""
Download a checkpoint from checkpoint_url to the modelweights directory. The file name is determined from
from the file name in the URL. If that can't be determined, use a random file name.
:param url: The URL from which the weights should be downloaded.
:param download_folder: The target folder for the download.
:return: A path to the downloaded file.
"""
# assign the same filename as in the download url if possible, so that we can check for duplicates
# If that fails, map to a random uuid
file_name = os.path.basename(urlparse(url).path) or str(uuid.uuid4().hex)
checkpoint_path = download_folder / file_name
# only download if hasn't already been downloaded
if checkpoint_path.is_file():
logging.info(f"File already exists, skipping download: {checkpoint_path}")
else:
logging.info(f"Downloading weights from URL {url}")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(checkpoint_path, "wb") as file:
for chunk in response.iter_content(chunk_size=1024):
file.write(chunk)
return checkpoint_path
def get_local_checkpoints_path_or_download(self) -> Path:
"""
Get the path to the local weights to use or download them.
"""
if self.container.src_checkpoint_is_local_file:
checkpoint_path = Path(self.container.src_checkpoint)
elif self.container.src_checkpoint_is_url:
download_folder = self.container.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME
download_folder.mkdir(exist_ok=True, parents=True)
checkpoint_path = self.download_weights(url=self.container.src_checkpoint, download_folder=download_folder)
elif self.container.src_checkpoint_is_aml_run_id:
downloader = CheckpointDownloader(
run_id=self.container.src_checkpoint, download_dir=self.container.outputs_folder
)
checkpoint_path = downloader.local_checkpoint_path
else:
raise ValueError("Unable to determine how to get the checkpoint path.")
if checkpoint_path is None or not checkpoint_path.is_file():
raise FileNotFoundError(f"Could not find the weights file at {checkpoint_path}")
return checkpoint_path

Просмотреть файл

@ -2,27 +2,29 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import re
import os
import uuid
import torch
import logging
import tempfile
import requests
from pathlib import Path
from typing import Optional
import torch
from urllib.parse import urlparse
from azureml.core import Run, Workspace
from health_azure import download_checkpoints_from_run_id, get_workspace
from health_azure.utils import (RUN_CONTEXT, download_files_from_run_id, get_run_file_names, is_running_in_azure_ml)
from health_ml.utils.common_utils import (AUTOSAVE_CHECKPOINT_CANDIDATES, DEFAULT_AML_CHECKPOINT_DIR)
from health_ml.utils.common_utils import (AUTOSAVE_CHECKPOINT_CANDIDATES, DEFAULT_AML_CHECKPOINT_DIR, CHECKPOINT_SUFFIX)
from health_ml.utils.type_annotations import PathOrString
CHECKPOINT_SUFFIX = ".ckpt"
# This is a constant that must match a filename defined in pytorch_lightning.ModelCheckpoint, but we don't want
# to import that here.
LAST_CHECKPOINT_FILE_NAME = "last"
LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX = LAST_CHECKPOINT_FILE_NAME + CHECKPOINT_SUFFIX
LAST_CHECKPOINT_FILE_NAME = f"last{CHECKPOINT_SUFFIX}"
LEGACY_RECOVERY_CHECKPOINT_FILE_NAME = "recovery"
MODEL_INFERENCE_JSON_FILE_NAME = "model_inference_config.json"
MODEL_WEIGHTS_DIR_NAME = "trained_models"
MODEL_WEIGHTS_DIR_NAME = "pretrained_models"
def get_best_checkpoint_path(path: Path) -> Path:
@ -31,7 +33,7 @@ def get_best_checkpoint_path(path: Path) -> Path:
:param path to checkpoint folder
"""
return path / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
return path / LAST_CHECKPOINT_FILE_NAME
def download_folder_from_run_to_temp_folder(folder: str,
@ -120,7 +122,7 @@ def find_recovery_checkpoint(path: Path) -> Optional[Path]:
logging.warning(f"Found these legacy checkpoint files: {legacy_recovery_checkpoints}")
raise ValueError("The legacy recovery checkpoint setup is no longer supported. As a workaround, you can take "
f"one of the legacy checkpoints and upload as '{AUTOSAVE_CHECKPOINT_CANDIDATES[0]}'")
candidates = [*AUTOSAVE_CHECKPOINT_CANDIDATES, LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX]
candidates = [*AUTOSAVE_CHECKPOINT_CANDIDATES, LAST_CHECKPOINT_FILE_NAME]
highest_epoch: Optional[int] = None
file_with_highest_epoch: Optional[Path] = None
for f in candidates:
@ -147,10 +149,10 @@ def cleanup_checkpoints(ckpt_folder: Path) -> None:
if len(files_in_checkpoint_folder) == 0:
return
logging.info(f"Files in checkpoint folder: {' '.join(files_in_checkpoint_folder)}")
last_ckpt = ckpt_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
last_ckpt = ckpt_folder / LAST_CHECKPOINT_FILE_NAME
all_files = f"Existing files: {' '.join(p.name for p in ckpt_folder.glob('*'))}"
if not last_ckpt.is_file():
raise FileNotFoundError(f"Checkpoint file {LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX} not found. {all_files}")
raise FileNotFoundError(f"Checkpoint file {LAST_CHECKPOINT_FILE_NAME} not found. {all_files}")
# Training is finished now. To save storage, remove the autosave checkpoint which is now obsolete.
# Lightning does not overwrite checkpoints in-place. Rather, it writes "autosave.ckpt",
# then "autosave-1.ckpt" and deletes "autosave.ckpt", then "autosave.ckpt" and deletes "autosave-v1.ckpt"
@ -183,7 +185,6 @@ class CheckpointDownloader:
self.remote_checkpoint_dir = (
remote_checkpoint_dir or self.extract_remote_checkpoint_dir_from_checkpoint_filename()
)
self.download_checkpoint_if_necessary()
def extract_checkpoint_filename_from_run_id(self) -> str:
"""
@ -192,7 +193,7 @@ class CheckpointDownloader:
"""
run_id_split = self.run_id.split(":")
self.run_id = run_id_split[0]
return run_id_split[-1] if len(run_id_split) > 1 else LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
return run_id_split[-1] if len(run_id_split) > 1 else LAST_CHECKPOINT_FILE_NAME
def extract_remote_checkpoint_dir_from_checkpoint_filename(self) -> Path:
"""
@ -232,3 +233,111 @@ class CheckpointDownloader:
self.run_id, str(self.remote_checkpoint_path), self.local_checkpoint_dir, aml_workspace=workspace
)
assert self.local_checkpoint_path.exists(), f"Couln't download checkpoint from run {self.run_id}."
class CheckpointParser:
"""Wrapper class for parsing checkpoint arguments. A checkpoint can be specified in one of the following ways:
1. A local checkpoint file path
2. A remote checkpoint file path
3. A run ID from which to download the checkpoint file
"""
AML_RUN_ID_FORMAT = (f"<AzureML_run_id>:<optional/custom/path/to/checkpoints/><filename{CHECKPOINT_SUFFIX}>"
f"If no custom path is provided (e.g., <AzureML_run_id>:<filename{CHECKPOINT_SUFFIX}>)"
"the checkpoint will be downloaded from the default checkpoint folder "
f"(e.g., '{DEFAULT_AML_CHECKPOINT_DIR}') If no filename is provided, "
"(e.g., `src_checkpoint=<AzureML_run_id>`) the latest checkpoint "
f"({LAST_CHECKPOINT_FILE_NAME}) will be downloaded.")
INFO_MESSAGE = ("Please provide a valid checkpoint path, URL or AzureML run ID. For custom checkpoint paths "
f"within an azureml run, provide a checkpoint in the format {AML_RUN_ID_FORMAT}.")
DOC = ("We currently support three types of checkpoints: "
" a. A local checkpoint folder that contains a checkpoint file."
" b. A URL to a remote checkpoint to be downloaded."
" c. A previous azureml run id where the checkpoint is supposed to be "
" saved ('outputs/checkpoints/' folder by default.)"
f"For the latter case 'c' : src_checkpoint should be in the format of {AML_RUN_ID_FORMAT}")
def __init__(self, checkpoint: str = "") -> None:
self.checkpoint = checkpoint
self.validate()
@property
def is_url(self) -> bool:
try:
result = urlparse(self.checkpoint)
return all([result.scheme, result.netloc])
except ValueError:
return False
@property
def is_local_file(self) -> bool:
return Path(self.checkpoint).is_file()
@property
def is_aml_run_id(self) -> bool:
match = re.match(r"[_\w-]*$", self.checkpoint.split(":")[0])
return match is not None and not self.is_url and not self.is_local_file
@property
def is_valid(self) -> bool:
if self.checkpoint:
return self.is_local_file or self.is_url or self.is_aml_run_id
return True
def validate(self) -> None:
if not self.is_valid:
raise ValueError(f"Invalid checkpoint '{self.checkpoint}'. {self.INFO_MESSAGE}")
@staticmethod
def download_from_url(url: str, download_folder: Path) -> Path:
"""
Download a checkpoint from checkpoint_url to the download folder. The file name is determined from
from the file name in the URL. If that can't be determined, use a random file name.
:param url: The URL from which to download.
:param download_folder: The target folder for the download.
:return: A path to the downloaded file.
"""
# assign the same filename as in the download url if possible, so that we can check for duplicates
# If that fails, map to a random uuid
file_name = os.path.basename(urlparse(url).path) or str(uuid.uuid4().hex)
checkpoint_path = download_folder / file_name
# only download if hasn't already been downloaded
if checkpoint_path.is_file():
logging.info(f"File already exists, skipping download: {checkpoint_path}")
else:
logging.info(f"Downloading from URL {url}")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(checkpoint_path, "wb") as file:
for chunk in response.iter_content(chunk_size=1024):
file.write(chunk)
return checkpoint_path
def get_path(self, download_dir: Path) -> Path:
"""Returns the path to the checkpoint file. If the checkpoint is a URL, it will be downloaded to the checkpoints
folder. If the checkpoint is an AzureML run ID, it will be downloaded from the run to the checkpoints folder.
If the checkpoint is a local file, it will be returned as is.
:param download_dir: The checkpoints folder to which the checkpoint should be downloaded if it is a URL or
AzureML run ID.
:raises ValueError: If the checkpoint is not a local file, URL or AzureML run ID.
:raises FileNotFoundError: If the checkpoint is a URL or AzureML run ID and the download fails.
:return: The path to the checkpoint file.
"""
if self.is_local_file:
checkpoint_path = Path(self.checkpoint)
elif self.is_url:
download_folder = download_dir / MODEL_WEIGHTS_DIR_NAME
download_folder.mkdir(exist_ok=True, parents=True)
checkpoint_path = self.download_from_url(url=self.checkpoint, download_folder=download_folder)
elif self.is_aml_run_id:
downloader = CheckpointDownloader(run_id=self.checkpoint, download_dir=download_dir)
downloader.download_checkpoint_if_necessary()
checkpoint_path = downloader.local_checkpoint_path
else:
raise ValueError("Unable to determine how to get the checkpoint path.")
if checkpoint_path is None or not checkpoint_path.is_file():
raise FileNotFoundError(f"Could not find the file at {checkpoint_path}")
return checkpoint_path

Просмотреть файл

@ -8,52 +8,73 @@ from unittest import mock
import pytest
from health_ml.configs.hello_world import HelloWorld
from health_ml.deep_learning_config import WorkflowParams
from health_ml.lightning_container import LightningContainer
from health_ml.utils.checkpoint_handler import CheckpointHandler
from health_ml.utils.checkpoint_utils import (
LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
LAST_CHECKPOINT_FILE_NAME,
MODEL_WEIGHTS_DIR_NAME,
CheckpointDownloader,
)
CheckpointParser,)
from health_ml.utils.checkpoint_handler import CheckpointHandler
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
from testhiml.utils.fixed_paths_for_tests import full_test_data_path, mock_run_id
from testhiml.utils_testhiml import DEFAULT_WORKSPACE
def test_checkpoint_downloader_run_id() -> None:
with mock.patch("health_ml.utils.checkpoint_utils.CheckpointDownloader.download_checkpoint_if_necessary"):
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id")
assert checkpoint_downloader.run_id == "dummy_run_id"
assert checkpoint_downloader.checkpoint_filename == LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
assert checkpoint_downloader.remote_checkpoint_dir == Path(DEFAULT_AML_CHECKPOINT_DIR)
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id")
assert checkpoint_downloader.run_id == "dummy_run_id"
assert checkpoint_downloader.checkpoint_filename == LAST_CHECKPOINT_FILE_NAME
assert checkpoint_downloader.remote_checkpoint_dir == Path(DEFAULT_AML_CHECKPOINT_DIR)
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id:best.ckpt")
assert checkpoint_downloader.run_id == "dummy_run_id"
assert checkpoint_downloader.checkpoint_filename == "best.ckpt"
assert checkpoint_downloader.remote_checkpoint_dir == Path(DEFAULT_AML_CHECKPOINT_DIR)
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id:best.ckpt")
assert checkpoint_downloader.run_id == "dummy_run_id"
assert checkpoint_downloader.checkpoint_filename == "best.ckpt"
assert checkpoint_downloader.remote_checkpoint_dir == Path(DEFAULT_AML_CHECKPOINT_DIR)
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id:custom/path/best.ckpt")
assert checkpoint_downloader.run_id == "dummy_run_id"
assert checkpoint_downloader.checkpoint_filename == "best.ckpt"
assert checkpoint_downloader.remote_checkpoint_dir == Path("custom/path")
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id:custom/path/best.ckpt")
assert checkpoint_downloader.run_id == "dummy_run_id"
assert checkpoint_downloader.checkpoint_filename == "best.ckpt"
assert checkpoint_downloader.remote_checkpoint_dir == Path("custom/path")
def _test_invalid_checkpoint(checkpoint: str) -> None:
with pytest.raises(ValueError, match=r"Invalid checkpoint "):
CheckpointParser(checkpoint=checkpoint)
WorkflowParams(local_datasets=Path("foo"), src_checkpoint=checkpoint).validate()
def test_validate_checkpoint_parser() -> None:
_test_invalid_checkpoint(checkpoint="dummy/local/path/model.ckpt")
_test_invalid_checkpoint(checkpoint="INV@lid%RUN*id")
_test_invalid_checkpoint(checkpoint="http/dummy_url-com")
# The following should be okay
checkpoint = str(full_test_data_path(suffix="hello_world_checkpoint.ckpt"))
CheckpointParser(checkpoint=checkpoint)
WorkflowParams(local_datasets=Path("foo"), src_checkpoint=CheckpointParser(checkpoint)).validate()
checkpoint = mock_run_id(id=0)
CheckpointParser(checkpoint=checkpoint)
WorkflowParams(local_datasets=Path("foo"), src_checkpoint=CheckpointParser(checkpoint)).validate()
def get_checkpoint_handler(tmp_path: Path, src_checkpoint: str) -> Tuple[LightningContainer, CheckpointHandler]:
container = LightningContainer()
container.set_output_to(tmp_path)
container.checkpoint_folder.mkdir(parents=True)
container.src_checkpoint = src_checkpoint
container.src_checkpoint = CheckpointParser(src_checkpoint)
return container, CheckpointHandler(container=container, project_root=tmp_path)
def test_load_model_chcekpoints_from_url(tmp_path: Path) -> None:
def test_load_model_checkpoints_from_url(tmp_path: Path) -> None:
WEIGHTS_URL = (
"https://pl-bolts-weights.s3.us-east-2.amazonaws.com/" "simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
)
container, checkpoint_handler = get_checkpoint_handler(tmp_path, WEIGHTS_URL)
download_folder = container.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME
assert container.src_checkpoint_is_url
assert container.src_checkpoint.is_url
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert checkpoint_handler.trained_weights_path
assert checkpoint_handler.trained_weights_path.exists()
@ -64,7 +85,7 @@ def test_load_model_checkpoints_from_local_file(tmp_path: Path) -> None:
local_checkpoint_path = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
container, checkpoint_handler = get_checkpoint_handler(tmp_path, str(local_checkpoint_path))
assert container.src_checkpoint_is_local_file
assert container.src_checkpoint.is_local_file
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert checkpoint_handler.trained_weights_path
assert checkpoint_handler.trained_weights_path.exists()
@ -82,10 +103,10 @@ def test_load_model_checkpoints_from_aml_run_id(src_chekpoint_filename: str, tmp
src_checkpoint_filename = (
src_chekpoint_filename.split("/")[-1]
if src_chekpoint_filename
else LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
else LAST_CHECKPOINT_FILE_NAME
)
expected_weights_path = container.outputs_folder / run_id / checkpoint_path / src_checkpoint_filename
assert container.src_checkpoint_is_aml_run_id
expected_weights_path = container.checkpoint_folder / run_id / checkpoint_path / src_checkpoint_filename
assert container.src_checkpoint.is_aml_run_id
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert checkpoint_handler.trained_weights_path
assert checkpoint_handler.trained_weights_path.exists()
@ -99,7 +120,7 @@ def test_custom_checkpoint_for_test(tmp_path: Path) -> None:
container = HelloWorld()
container.set_output_to(tmp_path)
container.checkpoint_folder.mkdir(parents=True)
last_checkpoint = container.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
last_checkpoint = container.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME
last_checkpoint.touch()
checkpoint_handler = CheckpointHandler(container=container, project_root=tmp_path)
checkpoint_handler.additional_training_done()

Просмотреть файл

@ -13,40 +13,24 @@ from pathlib import Path
from health_ml.deep_learning_config import DatasetParams, WorkflowParams, OutputParams, OptimizerParams, \
ExperimentFolderHandler, TrainerParams
from health_ml.utils.checkpoint_utils import CheckpointParser
from testhiml.utils.fixed_paths_for_tests import full_test_data_path, mock_run_id
def _test_invalid_pre_checkpoint_workflow_params(src_checkpoint: str) -> None:
with pytest.raises(ValueError, match=r"Invalid src_checkpoint:"):
WorkflowParams(local_datasets=Path("foo"), src_checkpoint=src_checkpoint).validate()
def test_validate_workflow_params_src_checkpoint() -> None:
_test_invalid_pre_checkpoint_workflow_params(src_checkpoint="dummy/local/path/model.ckpt")
_test_invalid_pre_checkpoint_workflow_params(src_checkpoint="INV@lid%RUN*id")
_test_invalid_pre_checkpoint_workflow_params(src_checkpoint="http/dummy_url-com")
# The following should be okay
full_file_path = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
WorkflowParams(local_dataset=Path("foo"), src_checkpoint=str(full_file_path)).validate()
run_id = mock_run_id(id=0)
WorkflowParams(local_dataset=Path("foo"), src_checkpoint=run_id).validate()
def test_validate_workflow_params_for_inference_only() -> None:
with pytest.raises(ValueError, match=r"Cannot run inference without a src_checkpoint."):
WorkflowParams(local_datasets=Path("foo"), run_inference_only=True).validate()
full_file_path = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
run_id = mock_run_id(id=0)
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True, src_checkpoint=run_id).validate()
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True,
src_checkpoint=f"{run_id}:best_val_loss.ckpt").validate()
src_checkpoint=CheckpointParser(run_id)).validate()
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True,
src_checkpoint=f"{run_id}:custom/path/model.ckpt").validate()
src_checkpoint=CheckpointParser(f"{run_id}:best_val_loss.ckpt")).validate()
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True,
src_checkpoint=str(full_file_path)).validate()
src_checkpoint=CheckpointParser(f"{run_id}:custom/path/model.ckpt")).validate()
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True,
src_checkpoint=CheckpointParser(str(full_file_path))).validate()
def test_validate_workflow_params_for_resume_training() -> None:
@ -55,13 +39,14 @@ def test_validate_workflow_params_for_resume_training() -> None:
full_file_path = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
run_id = mock_run_id(id=0)
WorkflowParams(local_dataset=Path("foo"), resume_training=True, src_checkpoint=run_id).validate()
WorkflowParams(local_dataset=Path("foo"), resume_training=True,
src_checkpoint=f"{run_id}:best_val_loss.ckpt").validate()
src_checkpoint=CheckpointParser(run_id)).validate()
WorkflowParams(local_dataset=Path("foo"), resume_training=True,
src_checkpoint=f"{run_id}:custom/path/model.ckpt").validate()
src_checkpoint=CheckpointParser(f"{run_id}:best_val_loss.ckpt")).validate()
WorkflowParams(local_dataset=Path("foo"), resume_training=True,
src_checkpoint=str(full_file_path)).validate()
src_checkpoint=CheckpointParser(f"{run_id}:custom/path/model.ckpt")).validate()
WorkflowParams(local_dataset=Path("foo"), resume_training=True,
src_checkpoint=CheckpointParser(str(full_file_path))).validate()
@pytest.mark.fast

Просмотреть файл

@ -16,6 +16,7 @@ from health_ml.configs.hello_world import HelloWorld # type: ignore
from health_ml.experiment_config import ExperimentConfig
from health_ml.lightning_container import LightningContainer
from health_ml.run_ml import MLRunner
from health_ml.utils.checkpoint_utils import CheckpointParser
from health_ml.utils.common_utils import is_gpu_available
from health_azure.utils import is_global_rank_zero
from health_ml.utils.logging import AzureMLLogger
@ -62,7 +63,7 @@ def ml_runner_with_run_id() -> Generator:
experiment_config = ExperimentConfig(model="HelloWorld")
container = HelloWorld()
container.save_checkpoint = True
container.src_checkpoint = mock_run_id(id=0)
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:
mock_get_workspace.return_value = DEFAULT_WORKSPACE.workspace
runner = MLRunner(experiment_config=experiment_config, container=container)
@ -356,7 +357,7 @@ def test_model_weights_when_resume_training() -> None:
experiment_config = ExperimentConfig(model="HelloWorld")
container = HelloWorld()
container.max_num_gpus = 0
container.src_checkpoint = mock_run_id(id=0)
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
container.resume_training = True
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:
mock_get_workspace.return_value = DEFAULT_WORKSPACE.workspace
@ -375,7 +376,7 @@ def test_runner_end_to_end() -> None:
experiment_config = ExperimentConfig(model="HelloWorld")
container = HelloWorld()
container.max_num_gpus = 0
container.src_checkpoint = mock_run_id(id=0)
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:
mock_get_workspace.return_value = DEFAULT_WORKSPACE.workspace
runner = MLRunner(experiment_config=experiment_config, container=container)