зеркало из https://github.com/microsoft/hi-ml.git
ENH: Download Checkpoints across AML workspaces (#642)
Enable checkpoints transfer across AML workspaces Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Anton Schwaighofer <antonsc@microsoft.com>
This commit is contained in:
Родитель
d54e28d0e6
Коммит
41d30807aa
|
@ -0,0 +1,61 @@
|
|||
# Checkpoint Utils
|
||||
|
||||
Hi-ml toolbox offers different utilities to parse and download pretrained checkpoints that help you abstract checkpoint
|
||||
downloading from different sources. Refer to
|
||||
[CheckpointParser](https://github.com/microsoft/hi-ml/blob/main/hi-ml/src/health_ml/utils/checkpoint_utils.py#L238) for
|
||||
more details on the supported checkpoints format. Here's how you can use the checkpoint parser depending on the source:
|
||||
|
||||
- For a local path, simply pass it as shown below. The parser will further check if the provided path exists:
|
||||
|
||||
```python
|
||||
from health_ml.utils.checpoint_utils import CheckpointParser
|
||||
|
||||
download_dir = 'outputs/checkpoints'
|
||||
checkpoint_parser = CheckpointParser(checkpoint='local/path/to/my_checkpoint/model.ckpt')
|
||||
print('Checkpoint', checkpoint_parser.checkpoint, 'is a local file', checkpoint_parser.is_local_file)
|
||||
local_file = parser.get_path(download_dir)
|
||||
```
|
||||
|
||||
- To download a checkpoint from a URL:
|
||||
|
||||
```python
|
||||
from health_ml.utils.checpoint_utils import CheckpointParser, MODEL_WEIGHTS_DIR_NAME
|
||||
|
||||
download_dir = 'outputs/checkpoints'
|
||||
checkpoint_parser = CheckpointParser('https://my_checkpoint_url.com/model.ckpt')
|
||||
print('Checkpoint', checkpoint_parser.checkpoint, 'is a URL', checkpoint_parser.is_url)
|
||||
# will dowload the checkpoint to download_dir/MODEL_WEIGHTS_DIR_NAME
|
||||
path_to_ckpt = checkpoint_parser.get_path(download_dir)
|
||||
```
|
||||
|
||||
- Finally checkpoints from an Azure ML runs can be reused by providing an id in this format
|
||||
`<AzureML_run_id>:<optional/custom/path/to/checkpoints/><filename.ckpt>`. If no custom path is provided (e.g.,
|
||||
`<AzureML_run_id>:<filename.ckpt>`) the checkpoint will be downloaded from the default checkpoint folder
|
||||
(e.g., `outputs/checkpoints`) If no filename is provided, (e.g., `src_checkpoint=<AzureML_run_id>`) the latest
|
||||
checkpoint will be downloaded (e.g., `last.ckpt`).
|
||||
|
||||
```python
|
||||
from health_ml.utils.checpoint_utils import CheckpointParser
|
||||
|
||||
checkpoint_parser = CheckpointParser('AzureML_run_id:best.ckpt')
|
||||
print('Checkpoint', checkpoint_parser.checkpoint, 'is a AML run', checkpoint_parser.is_aml_run_id)
|
||||
path_azure_ml_ckpt = checkpoint_parser.get_path(download_dir)
|
||||
```
|
||||
|
||||
If the Azure ML run is in a different workspace, a temporary SAS URL to download the checkpoint can be generated as follow:
|
||||
|
||||
```bash
|
||||
cd hi-ml-cpath
|
||||
python src/health_cpath/scripts/generate_checkpoint_url.py --run_id=AzureML_run_id:best_val_loss.ckpt --expiry_days=10
|
||||
```
|
||||
|
||||
N.B: config.json should correspond to the original workspace where the AML run lives.
|
||||
|
||||
## Use cases
|
||||
|
||||
CheckpointParser is used to specify a `src_checkpoint` to [resume training from a given
|
||||
checkpoint](https://github.com/microsoft/hi-ml/blob/main/docs/source/runner.md#L238),
|
||||
or [run inference with a pretrained model](https://github.com/microsoft/hi-ml/blob/main/docs/source/runner.md#L215),
|
||||
as well as
|
||||
[ssl_checkpoint](https://github.com/microsoft/hi-ml/blob/main/hi-ml-cpath/src/health_cpath/utils/deepmil_utils.py#L62)
|
||||
for computation pathology self supervised pretrained encoders.
|
|
@ -42,6 +42,7 @@ The `hi-ml` toolbox provides
|
|||
logging.md
|
||||
diagnostics.md
|
||||
runner.md
|
||||
checkpoints.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
|
|
@ -226,6 +226,8 @@ the model weights by setting `--src_checkpoint` argument that supports three typ
|
|||
checkpoints folder `outputs/checkpoints`. If no filename is provided (e.g., `--src_checkpoint=AzureML_run_id`),
|
||||
the last epoch checkpoint `outputs/checkpoints/last.ckpt` will be loaded.
|
||||
|
||||
Refer to [Checkpoints Utils](checkpoints.md) for more details on how checkpoints are parsed.
|
||||
|
||||
Running the following command line will run inference using `MyContainer` model with weights from the checkpoint saved
|
||||
in the AzureMl run `MyContainer_XXXX_yyyy` at the best validation loss epoch `/outputs/checkpoints/best_val_loss.ckpt`.
|
||||
|
||||
|
|
|
@ -29,6 +29,10 @@ pip_from_conda:
|
|||
sed -e '1,/pip:/ d' environment.yml | grep -v "#" | cut -d "-" -f 2- > temp_requirements.txt
|
||||
pip install -r temp_requirements.txt
|
||||
|
||||
# Lock the current Conda environment secondary dependencies versions
|
||||
lock_env:
|
||||
./create_and_lock_environment.sh
|
||||
|
||||
# clean build artifacts
|
||||
clean:
|
||||
rm -rf `find . -type d -name __pycache__`
|
||||
|
@ -78,7 +82,7 @@ pytest_coverage:
|
|||
pytest --cov=health_cpath --cov SSL --cov-branch --cov-report=html --cov-report=xml --cov-report=term-missing --cov-config=.coveragerc
|
||||
|
||||
SSL_CKPT_RUN_ID_CRCK := CRCK_SimCLR_1655731022_85790606
|
||||
SRC_CKPT_RUN_ID_CRCK := TcgaCrckSSLMIL_1664478306_144bc833
|
||||
SRC_CKPT_RUN_ID_CRCK := TcgaCrckSSLMIL_1667236343_af6e293f
|
||||
|
||||
# Run regression tests and compare performance
|
||||
define BASE_CPATH_RUNNER_COMMAND
|
||||
|
@ -95,7 +99,7 @@ define DEEPSMILEPANDATILES_ARGS
|
|||
endef
|
||||
|
||||
define TCGACRCKSSLMIL_ARGS
|
||||
--model=health_cpath.TcgaCrckSSLMIL --ssl_checkpoint_run_id=${SSL_CKPT_RUN_ID_CRCK}
|
||||
--model=health_cpath.TcgaCrckSSLMIL --ssl_checkpoint=${SSL_CKPT_RUN_ID_CRCK}
|
||||
endef
|
||||
|
||||
define TCGACRCKIMANEGETMIL_ARGS
|
||||
|
|
|
@ -81,6 +81,7 @@ dependencies:
|
|||
- azure-mgmt-keyvault==10.0.0
|
||||
- azure-mgmt-resource==21.1.0
|
||||
- azure-mgmt-storage==20.0.0
|
||||
- azure-storage-blob==12.5.0
|
||||
- azureml-core==1.43.0
|
||||
- azureml-dataprep==4.0.4
|
||||
- azureml-dataprep-native==38.0.0
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
azure-storage-blob==12.5.0
|
||||
coloredlogs==15.0.1
|
||||
cucim==22.04.00
|
||||
girder-client==3.1.14
|
||||
|
|
|
@ -19,8 +19,7 @@ from health_cpath.utils.callbacks import LossAnalysisCallback, LossCallbackParam
|
|||
from health_ml.utils import fixed_paths
|
||||
from health_ml.deep_learning_config import OptimizerParams
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.deep_learning_config import SRC_CKPT_INFO_MESSAGE
|
||||
from health_ml.utils.checkpoint_utils import get_best_checkpoint_path
|
||||
from health_ml.utils.checkpoint_utils import get_best_checkpoint_path, CheckpointParser
|
||||
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
|
||||
|
||||
from health_cpath.datamodules.base_module import CacheLocation, CacheMode, HistoDataModule
|
||||
|
@ -80,8 +79,6 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara
|
|||
default=False,
|
||||
doc="If True, will use classifier weights from pretrained model specified in src_checkpoint. If False, will "
|
||||
"initiliaze classifier with random weights.")
|
||||
ssl_checkpoint_run_id: str = param.String(default="", doc="Optional run id from which to load checkpoint if "
|
||||
"using SSLEncoder")
|
||||
max_num_workers: int = param.Integer(10, bounds=(0, None),
|
||||
doc="The maximum number of worker processes for dataloaders. Dataloaders use"
|
||||
"a heuristic num_cpus/num_gpus to set the number of workers, which can be"
|
||||
|
@ -100,6 +97,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara
|
|||
|
||||
def validate(self) -> None:
|
||||
super().validate()
|
||||
EncoderParams.validate(self)
|
||||
if not any([self.tune_encoder, self.tune_pooling, self.tune_classifier]) and not self.run_inference_only:
|
||||
raise ValueError(
|
||||
"At least one of the encoder, pooling or classifier should be fine tuned. Turn on one of the tune "
|
||||
|
@ -112,7 +110,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara
|
|||
):
|
||||
raise ValueError(
|
||||
"You need to specify a source checkpoint, to use a pretrained encoder, pooling or classifier."
|
||||
f"{SRC_CKPT_INFO_MESSAGE}"
|
||||
f" {CheckpointParser.INFO_MESSAGE}"
|
||||
)
|
||||
if (
|
||||
self.tune_encoder and self.encoding_chunk_size < self.max_bag_size
|
||||
|
@ -273,8 +271,7 @@ class BaseMILTiles(BaseMIL):
|
|||
|
||||
def get_transforms_dict(self, image_key: str) -> Dict[ModelKey, Union[Callable, None]]:
|
||||
if self.is_caching:
|
||||
encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.ssl_checkpoint_run_id,
|
||||
self.outputs_folder)
|
||||
encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.outputs_folder)
|
||||
transform = Compose([
|
||||
LoadTilesBatchd(image_key, progress=True),
|
||||
EncodeTilesBatchd(image_key, encoder, chunk_size=self.encoding_chunk_size) # type: ignore
|
||||
|
@ -295,7 +292,7 @@ class BaseMILTiles(BaseMIL):
|
|||
pretrained_classifier=self.pretrained_classifier,
|
||||
dropout_rate=self.dropout_rate,
|
||||
outputs_folder=self.outputs_folder,
|
||||
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
|
||||
|
||||
encoder_params=create_from_matching_params(self, EncoderParams),
|
||||
pooling_params=create_from_matching_params(self, PoolingParams),
|
||||
optimizer_params=create_from_matching_params(self, OptimizerParams),
|
||||
|
@ -339,7 +336,6 @@ class BaseMILSlides(BaseMIL):
|
|||
pretrained_classifier=self.pretrained_classifier,
|
||||
dropout_rate=self.dropout_rate,
|
||||
outputs_folder=self.outputs_folder,
|
||||
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
|
||||
encoder_params=create_from_matching_params(self, EncoderParams),
|
||||
pooling_params=create_from_matching_params(self, PoolingParams),
|
||||
optimizer_params=create_from_matching_params(self, OptimizerParams),
|
||||
|
|
|
@ -27,6 +27,7 @@ from health_cpath.models.encoders import (
|
|||
from health_cpath.configs.classification.BaseMIL import BaseMILTiles
|
||||
from health_cpath.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
|
||||
from health_cpath.utils.naming import PlotOption
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
|
||||
|
||||
class DeepSMILECrck(BaseMILTiles):
|
||||
|
@ -56,8 +57,6 @@ class DeepSMILECrck(BaseMILTiles):
|
|||
|
||||
def setup(self) -> None:
|
||||
super().setup()
|
||||
# If no SSL checkpoint is provided, use the default one
|
||||
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_crck_4ws
|
||||
|
||||
def get_data_module(self) -> TilesDataModule:
|
||||
return TcgaCrckTilesDataModule(
|
||||
|
@ -93,6 +92,8 @@ class TcgaCrckImageNetSimCLRMIL(DeepSMILECrck):
|
|||
|
||||
class TcgaCrckSSLMIL(DeepSMILECrck):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
# If no SSL checkpoint is provided, use the default one
|
||||
self.ssl_checkpoint = self.ssl_checkpoint or CheckpointParser(innereye_ssl_checkpoint_crck_4ws)
|
||||
super().__init__(encoder_type=SSLEncoder.__name__, **kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ from health_cpath.datasets.default_paths import (
|
|||
PANDA_DATASET_ID,
|
||||
PANDA_5X_TILES_DATASET_ID)
|
||||
from health_cpath.utils.naming import PlotOption
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
|
||||
|
||||
class BaseDeepSMILEPanda(BaseMIL):
|
||||
|
@ -70,8 +71,6 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
|
|||
|
||||
def setup(self) -> None:
|
||||
BaseMILTiles.setup(self)
|
||||
# If no SSL checkpoint is provided, use the default one
|
||||
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_binary
|
||||
|
||||
def get_data_module(self) -> PandaTilesDataModule:
|
||||
return PandaTilesDataModule(
|
||||
|
@ -110,6 +109,8 @@ class TilesPandaImageNetSimCLRMIL(DeepSMILETilesPanda):
|
|||
|
||||
class TilesPandaSSLMIL(DeepSMILETilesPanda):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
# If no SSL checkpoint is provided, use the default one
|
||||
self.ssl_checkpoint = self.ssl_checkpoint or CheckpointParser(innereye_ssl_checkpoint_binary)
|
||||
super().__init__(encoder_type=SSLEncoder.__name__, **kwargs)
|
||||
|
||||
|
||||
|
@ -136,8 +137,6 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
|
|||
|
||||
def setup(self) -> None:
|
||||
BaseMILSlides.setup(self)
|
||||
# If no SSL checkpoint is provided, use the default one
|
||||
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_binary
|
||||
|
||||
def get_dataloader_kwargs(self) -> dict:
|
||||
return dict(
|
||||
|
@ -181,6 +180,8 @@ class SlidesPandaImageNetSimCLRMIL(DeepSMILESlidesPanda):
|
|||
|
||||
class SlidesPandaSSLMIL(DeepSMILESlidesPanda):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
# If no SSL checkpoint is provided, use the default one
|
||||
self.ssl_checkpoint = self.ssl_checkpoint or CheckpointParser(innereye_ssl_checkpoint_binary)
|
||||
super().__init__(encoder_type=SSLEncoder.__name__, **kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -7,12 +7,13 @@
|
|||
from typing import Any, Dict, Callable, Union
|
||||
from torch import optim
|
||||
from monai.transforms import Compose, ScaleIntensityRanged, RandRotate90d, RandFlipd
|
||||
|
||||
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_binary
|
||||
from health_azure.utils import create_from_matching_params
|
||||
from health_ml.networks.layers.attention_layers import (
|
||||
TransformerPooling,
|
||||
TransformerPoolingBenchmark
|
||||
)
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
from health_ml.deep_learning_config import OptimizerParams
|
||||
from health_cpath.datasets.panda_dataset import PandaDataset
|
||||
from health_cpath.datamodules.panda_module_benchmark import PandaSlidesDataModuleBenchmark
|
||||
|
@ -127,7 +128,6 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
|||
class_weights=self.data_module.class_weights,
|
||||
dropout_rate=self.dropout_rate,
|
||||
outputs_folder=self.outputs_folder,
|
||||
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
|
||||
encoder_params=create_from_matching_params(self, EncoderParams),
|
||||
pooling_params=create_from_matching_params(self, PoolingParams),
|
||||
optimizer_params=create_from_matching_params(self, OptimizerParams),
|
||||
|
@ -150,6 +150,8 @@ class SlidesPandaImageNetSimCLRMILBenchmark(DeepSMILESlidesPandaBenchmark):
|
|||
|
||||
class SlidesPandaSSLMILBenchmark(DeepSMILESlidesPandaBenchmark):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
# If no SSL checkpoint is provided, use the default one
|
||||
self.ssl_checkpoint = self.ssl_checkpoint or CheckpointParser(innereye_ssl_checkpoint_binary)
|
||||
super().__init__(encoder_type=SSLEncoder.__name__, **kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -46,7 +46,6 @@ class BaseDeepMILModule(LightningModule):
|
|||
pretrained_classifier: bool = False,
|
||||
dropout_rate: Optional[float] = None,
|
||||
verbose: bool = False,
|
||||
ssl_ckpt_run_id: Optional[str] = None,
|
||||
outputs_folder: Optional[Path] = None,
|
||||
encoder_params: EncoderParams = EncoderParams(),
|
||||
pooling_params: PoolingParams = PoolingParams(),
|
||||
|
@ -63,8 +62,6 @@ class BaseDeepMILModule(LightningModule):
|
|||
:param pretrained_classifier: Whether to use pretrained classifier (default=False for random init).
|
||||
:param dropout_rate: Rate of pre-classifier dropout (0-1). `None` for no dropout (default).
|
||||
:param verbose: if True statements about memory usage are output at each step.
|
||||
:param ssl_ckpt_run_id: Optional parameter to provide the AML run id from where to download the checkpoint
|
||||
if using `SSLEncoder`.
|
||||
:param outputs_folder: Path to output folder where encoder checkpoint is downloaded.
|
||||
:param encoder_params: Encoder parameters that specify all encoder specific attributes.
|
||||
:param pooling_params: Pooling layer parameters that specify all encoder specific attributes.
|
||||
|
@ -98,7 +95,7 @@ class BaseDeepMILModule(LightningModule):
|
|||
self.tune_classifier = tune_classifier
|
||||
|
||||
# Model components
|
||||
self.encoder = encoder_params.get_encoder(ssl_ckpt_run_id, outputs_folder)
|
||||
self.encoder = encoder_params.get_encoder(outputs_folder)
|
||||
self.aggregation_fn, self.num_pooling = pooling_params.get_pooling_layer(self.encoder.num_encoding)
|
||||
self.classifier_fn = self.get_classifier()
|
||||
self.activation_fn = self.get_activation()
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
from azure.storage.blob import generate_blob_sas, BlobSasPermissions
|
||||
from azureml.core import Workspace
|
||||
from datetime import datetime, timedelta
|
||||
from health_azure import get_workspace
|
||||
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_checkpoint_url_from_aml_run(
|
||||
run_id: str,
|
||||
checkpoint_filename: str,
|
||||
expiry_days: int = 1,
|
||||
aml_workspace: Optional[Workspace] = None,
|
||||
sas_token: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Generate a SAS URL for the checkpoint file in the given run.
|
||||
|
||||
:param run_id: The run ID of the checkpoint.
|
||||
:param checkpoint_filename: The filename of the checkpoint.
|
||||
:param expiry_days: The number of days the SAS URL is valid for, defaults to 30.
|
||||
:param aml_workspace: The Azure ML workspace to use, defaults to the default workspace.
|
||||
:param sas_token: The SAS token to use, defaults to None.
|
||||
:return: The SAS URL for the checkpoint.
|
||||
"""
|
||||
datastore = get_workspace(aml_workspace=aml_workspace).get_default_datastore()
|
||||
account_name = datastore.account_name
|
||||
container_name = 'azureml'
|
||||
blob_name = f'ExperimentRun/dcid.{run_id}/{DEFAULT_AML_CHECKPOINT_DIR}/{checkpoint_filename}'
|
||||
|
||||
if not sas_token:
|
||||
sas_token = generate_blob_sas(account_name=datastore.account_name,
|
||||
container_name=container_name,
|
||||
blob_name=blob_name,
|
||||
account_key=datastore.account_key,
|
||||
permission=BlobSasPermissions(read=True),
|
||||
expiry=datetime.utcnow() + timedelta(days=expiry_days))
|
||||
|
||||
return f'https://{account_name}.blob.core.windows.net/{container_name}/{blob_name}?{sas_token}'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--run_id', type=str, help='The run id of the model checkpoint')
|
||||
parser.add_argument('--checkpoint_filename', type=str, default='last.ckpt',
|
||||
help='The filename of the model checkpoint. Default: last.ckpt')
|
||||
parser.add_argument('--expiry_days', type=int, default=30,
|
||||
help='The number of hours for which the SAS token is valid. Default: 30 for 1 month')
|
||||
args = parser.parse_args()
|
||||
url = get_checkpoint_url_from_aml_run(args.run_id, args.checkpoint_filename, args.expiry_days)
|
||||
print(f'Checkpoint URL: {url}')
|
|
@ -7,8 +7,7 @@ import param
|
|||
from torch import nn
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, CheckpointDownloader
|
||||
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
from health_cpath.models.encoders import (
|
||||
HistoSSLEncoder,
|
||||
ImageNetSimCLREncoder,
|
||||
|
@ -60,11 +59,18 @@ class EncoderParams(param.Parameterized):
|
|||
encoding_chunk_size: int = param.Integer(
|
||||
default=0, doc="If > 0 performs encoding in chunks, by enconding_chunk_size tiles " "per chunk"
|
||||
)
|
||||
ssl_checkpoint: CheckpointParser = param.ClassSelector(class_=CheckpointParser, default=None,
|
||||
instantiate=False, doc=CheckpointParser.DOC)
|
||||
|
||||
def get_encoder(self, ssl_ckpt_run_id: Optional[str], outputs_folder: Optional[Path]) -> TileEncoder:
|
||||
def validate(self) -> None:
|
||||
"""Validate the encoder parameters."""
|
||||
if self.encoder_type == SSLEncoder.__name__ and not self.ssl_checkpoint:
|
||||
raise ValueError("SSLEncoder requires an ssl_checkpoint. Please specify a valid checkpoint. "
|
||||
f"{CheckpointParser.INFO_MESSAGE}")
|
||||
|
||||
def get_encoder(self, outputs_folder: Optional[Path]) -> TileEncoder:
|
||||
"""Given the current encoder parameters, returns the encoder object.
|
||||
|
||||
:param ssl_ckpt_run_id: The AML run id for SSL checkpoint download.
|
||||
:param outputs_folder: The output folder where SSL checkpoint should be saved.
|
||||
:param encoder_params: The encoder arguments that define the encoder class object depending on the encoder type.
|
||||
:raises ValueError: If the encoder type is not supported.
|
||||
|
@ -90,15 +96,9 @@ class EncoderParams(param.Parameterized):
|
|||
encoder = HistoSSLEncoder(tile_size=self.tile_size, n_channels=self.n_channels)
|
||||
|
||||
elif self.encoder_type == SSLEncoder.__name__:
|
||||
assert ssl_ckpt_run_id and outputs_folder, "SSLEncoder requires ssl_ckpt_run_id and outputs_folder"
|
||||
downloader = CheckpointDownloader(
|
||||
run_id=ssl_ckpt_run_id,
|
||||
download_dir=outputs_folder,
|
||||
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
|
||||
remote_checkpoint_dir=Path(DEFAULT_AML_CHECKPOINT_DIR),
|
||||
)
|
||||
assert outputs_folder is not None, "outputs_folder cannot be None for SSLEncoder"
|
||||
encoder = SSLEncoder(
|
||||
pl_checkpoint_path=downloader.local_checkpoint_path,
|
||||
pl_checkpoint_path=self.ssl_checkpoint.get_path(outputs_folder),
|
||||
tile_size=self.tile_size,
|
||||
n_channels=self.n_channels,
|
||||
)
|
||||
|
|
|
@ -36,7 +36,7 @@ from SSL.configs.CXR_SSL_configs import CXRImageClassifier, NIH_RSNA_SimCLR
|
|||
|
||||
from health_ml.runner import Runner
|
||||
from health_ml.utils import AzureMLProgressBar
|
||||
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME
|
||||
from health_ml.utils.fixed_paths import repository_root_directory, OutputFolderForTests
|
||||
from health_ml.utils.lightning_loggers import StoringLogger
|
||||
|
||||
|
@ -247,7 +247,7 @@ def test_ssl_container_rsna() -> None:
|
|||
_compare_stored_metrics(runner, expected_metrics)
|
||||
|
||||
# Check that we are able to load the checkpoint and create classifier model
|
||||
checkpoint_path = loaded_config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
checkpoint_path = loaded_config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME
|
||||
model_namespace_cxr = "SSL.configs.CXRImageClassifier"
|
||||
args = common_test_args + [f"--model={model_namespace_cxr}",
|
||||
f"--local_datasets={str(path_to_cxr_test_dataset)}",
|
||||
|
|
|
@ -37,7 +37,7 @@ class MockDeepSMILETilesPanda(DeepSMILETilesPanda):
|
|||
# declared in TrainerParams:
|
||||
max_epochs=2,
|
||||
crossval_count=1,
|
||||
ssl_checkpoint_run_id="",
|
||||
ssl_checkpoint=None,
|
||||
analyse_loss=analyse_loss,
|
||||
)
|
||||
default_kwargs.update(kwargs)
|
||||
|
|
|
@ -15,13 +15,16 @@ from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Typ
|
|||
|
||||
from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose
|
||||
from torch.utils.data._utils.collate import default_collate
|
||||
from health_cpath.configs.classification.DeepSMILESlidesPandaBenchmark import SlidesPandaSSLMILBenchmark
|
||||
from health_cpath.datamodules.panda_module import PandaTilesDataModule
|
||||
|
||||
from health_ml.networks.layers.attention_layers import AttentionLayer, TransformerPoolingBenchmark
|
||||
from health_cpath.configs.classification.BaseMIL import BaseMIL, BaseMILTiles
|
||||
|
||||
from health_cpath.configs.classification.DeepSMILECrck import DeepSMILECrck
|
||||
from health_cpath.configs.classification.DeepSMILEPanda import BaseDeepSMILEPanda, DeepSMILETilesPanda
|
||||
from health_cpath.configs.classification.DeepSMILECrck import DeepSMILECrck, TcgaCrckSSLMIL
|
||||
from health_cpath.configs.classification.DeepSMILEPanda import (
|
||||
BaseDeepSMILEPanda, DeepSMILETilesPanda, SlidesPandaSSLMIL, TilesPandaSSLMIL
|
||||
)
|
||||
from health_cpath.datamodules.base_module import HistoDataModule, TilesDataModule
|
||||
from health_cpath.datasets.base_dataset import DEFAULT_LABEL_COLUMN, TilesDataset
|
||||
from health_cpath.datasets.default_paths import PANDA_5X_TILES_DATASET_ID, TCGA_CRCK_DATASET_DIR
|
||||
|
@ -34,6 +37,9 @@ from testhisto.mocks.slides_generator import MockPandaSlidesGenerator, TilesPosi
|
|||
from testhisto.mocks.tiles_generator import MockPandaTilesGenerator
|
||||
from testhisto.mocks.container import MockDeepSMILETilesPanda, MockDeepSMILESlidesPanda
|
||||
from health_ml.utils.common_utils import is_gpu_available
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_crck_4ws, innereye_ssl_checkpoint_binary
|
||||
from testhisto.models.test_encoders import TEST_SSL_RUN_ID
|
||||
|
||||
no_gpu = not is_gpu_available()
|
||||
|
||||
|
@ -198,9 +204,7 @@ def add_callback(fn: Callable, callback: Callable) -> Callable:
|
|||
def test_metrics(n_classes: int) -> None:
|
||||
input_dim = (128,)
|
||||
|
||||
def _mock_get_encoder( # type: ignore
|
||||
self, ssl_ckpt_run_id: Optional[str], outputs_folder: Optional[Path]
|
||||
) -> TileEncoder:
|
||||
def _mock_get_encoder(self, outputs_folder: Optional[Path]) -> TileEncoder: # type: ignore
|
||||
return IdentityEncoder(input_dim=input_dim)
|
||||
|
||||
with patch("health_cpath.models.deepmil.EncoderParams.get_encoder", new=_mock_get_encoder):
|
||||
|
@ -718,3 +722,17 @@ def test_on_run_extra_val_epoch(mock_panda_tiles_root_dir: Path) -> None:
|
|||
container.model.outputs_handler.test_plots_handler.plot_options # type: ignore
|
||||
== container.model.outputs_handler.val_plots_handler.plot_options # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"container_type", [TcgaCrckSSLMIL, TilesPandaSSLMIL, SlidesPandaSSLMIL, SlidesPandaSSLMILBenchmark]
|
||||
)
|
||||
def test_ssl_containers_default_checkpoint(container_type: BaseMIL) -> None:
|
||||
if container_type == TcgaCrckSSLMIL:
|
||||
default_checkpoint = innereye_ssl_checkpoint_crck_4ws
|
||||
else:
|
||||
default_checkpoint = innereye_ssl_checkpoint_binary
|
||||
assert container_type().ssl_checkpoint.checkpoint == default_checkpoint
|
||||
|
||||
container = container_type(ssl_checkpoint=CheckpointParser(TEST_SSL_RUN_ID))
|
||||
assert container.ssl_checkpoint.checkpoint != default_checkpoint
|
||||
|
|
|
@ -12,7 +12,7 @@ from torch import Tensor, float32, nn, rand
|
|||
from torchvision.models import resnet18
|
||||
|
||||
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
|
||||
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, CheckpointDownloader
|
||||
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME, CheckpointDownloader
|
||||
from health_cpath.models.encoders import (Resnet18, TileEncoder, HistoSSLEncoder,
|
||||
ImageNetSimCLREncoder, SSLEncoder)
|
||||
from health_cpath.utils.layer_utils import setup_feature_extractor
|
||||
|
@ -35,8 +35,9 @@ def get_simclr_imagenet_encoder() -> TileEncoder:
|
|||
def get_ssl_encoder(download_dir: Path) -> TileEncoder:
|
||||
downloader = CheckpointDownloader(run_id=TEST_SSL_RUN_ID,
|
||||
download_dir=download_dir,
|
||||
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
|
||||
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME,
|
||||
remote_checkpoint_dir=Path(DEFAULT_AML_CHECKPOINT_DIR))
|
||||
downloader.download_checkpoint_if_necessary()
|
||||
return SSLEncoder(pl_checkpoint_path=downloader.local_checkpoint_path, tile_size=TILE_SIZE)
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from health_cpath.models.encoders import SSLEncoder
|
||||
from health_cpath.scripts.generate_checkpoint_url import get_checkpoint_url_from_aml_run
|
||||
from health_cpath.utils.deepmil_utils import EncoderParams
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser, LAST_CHECKPOINT_FILE_NAME, MODEL_WEIGHTS_DIR_NAME
|
||||
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
|
||||
from testhiml.utils.fixed_paths_for_tests import full_test_data_path
|
||||
from testhisto.models.test_encoders import TEST_SSL_RUN_ID
|
||||
from testhiml.utils_testhiml import DEFAULT_WORKSPACE
|
||||
|
||||
|
||||
LAST_CHECKPOINT = f"{DEFAULT_AML_CHECKPOINT_DIR}/{LAST_CHECKPOINT_FILE_NAME}"
|
||||
|
||||
|
||||
def test_validate_encoder_params() -> None:
|
||||
with pytest.raises(ValueError, match=r"SSLEncoder requires an ssl_checkpoint"):
|
||||
encoder = EncoderParams(encoder_type=SSLEncoder.__name__)
|
||||
encoder.validate()
|
||||
|
||||
|
||||
def test_load_ssl_checkpoint_from_local_file(tmp_path: Path) -> None:
|
||||
checkpoint_filename = "hello_world_checkpoint.ckpt"
|
||||
local_checkpoint_path = full_test_data_path(suffix=checkpoint_filename)
|
||||
encoder_params = EncoderParams(
|
||||
encoder_type=SSLEncoder.__name__, ssl_checkpoint=CheckpointParser(str(local_checkpoint_path))
|
||||
)
|
||||
assert encoder_params.ssl_checkpoint.is_local_file
|
||||
ssl_checkpoint_path = encoder_params.ssl_checkpoint.get_path(tmp_path)
|
||||
assert ssl_checkpoint_path.exists()
|
||||
assert ssl_checkpoint_path == local_checkpoint_path
|
||||
with patch("health_cpath.models.encoders.SSLEncoder._get_encoder") as mock_get_encoder:
|
||||
mock_get_encoder.return_value = (MagicMock(), MagicMock())
|
||||
encoder = encoder_params.get_encoder(tmp_path)
|
||||
assert isinstance(encoder, SSLEncoder)
|
||||
|
||||
|
||||
def test_load_ssl_checkpoint_from_url(tmp_path: Path) -> None:
|
||||
blob_url = get_checkpoint_url_from_aml_run(
|
||||
run_id=TEST_SSL_RUN_ID,
|
||||
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME,
|
||||
expiry_days=1,
|
||||
aml_workspace=DEFAULT_WORKSPACE.workspace)
|
||||
encoder_params = EncoderParams(encoder_type=SSLEncoder.__name__, ssl_checkpoint=CheckpointParser(blob_url))
|
||||
assert encoder_params.ssl_checkpoint.is_url
|
||||
ssl_checkpoint_path = encoder_params.ssl_checkpoint.get_path(tmp_path)
|
||||
assert ssl_checkpoint_path.exists()
|
||||
assert ssl_checkpoint_path == tmp_path / MODEL_WEIGHTS_DIR_NAME / LAST_CHECKPOINT_FILE_NAME
|
||||
encoder = encoder_params.get_encoder(tmp_path)
|
||||
assert isinstance(encoder, SSLEncoder)
|
||||
|
||||
|
||||
def test_load_ssl_checkpoint_from_run_id(tmp_path: Path) -> None:
|
||||
encoder_params = EncoderParams(encoder_type=SSLEncoder.__name__, ssl_checkpoint=CheckpointParser(TEST_SSL_RUN_ID))
|
||||
assert encoder_params.ssl_checkpoint.is_aml_run_id
|
||||
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:
|
||||
mock_get_workspace.return_value = DEFAULT_WORKSPACE.workspace
|
||||
ssl_checkpoint_path = encoder_params.ssl_checkpoint.get_path(tmp_path)
|
||||
assert ssl_checkpoint_path.exists()
|
||||
assert ssl_checkpoint_path == tmp_path / TEST_SSL_RUN_ID / LAST_CHECKPOINT
|
||||
encoder = encoder_params.get_encoder(tmp_path)
|
||||
assert isinstance(encoder, SSLEncoder)
|
|
@ -7,12 +7,10 @@ from __future__ import annotations
|
|||
import logging
|
||||
import os
|
||||
import param
|
||||
import re
|
||||
from enum import Enum, unique
|
||||
from param import Parameterized
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from azureml.train.hyperdrive import HyperDriveConfig
|
||||
|
||||
|
@ -23,6 +21,7 @@ from health_azure.amulet import (ENV_AMLT_PROJECT_NAME, ENV_AMLT_INPUT_OUTPUT,
|
|||
is_amulet_job, get_amulet_aml_working_dir)
|
||||
from health_azure.utils import (RUN_CONTEXT, PathOrString, is_global_rank_zero, is_running_in_azure_ml)
|
||||
from health_ml.utils import fixed_paths
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
from health_ml.utils.common_utils import (CHECKPOINT_FOLDER,
|
||||
create_unique_timestamp_id,
|
||||
DEFAULT_AML_UPLOAD_DIR,
|
||||
|
@ -143,35 +142,13 @@ class ExperimentFolderHandler(Parameterized):
|
|||
)
|
||||
|
||||
|
||||
SRC_CHECKPOINT_FORMAT_DOC = ("<AzureML_run_id>:<optional/custom/path/to/checkpoints/><filename.ckpt>"
|
||||
"If no custom path is provided (e.g., <AzureML_run_id>:<filename.ckpt>)"
|
||||
"the checkpoint will be downloaded from the default checkpoint folder "
|
||||
"(e.g., 'outputs/checkpoints/'). If no filename is provided, (e.g., "
|
||||
"`src_checkpoint=<AzureML_run_id>`) the latest checkpoint (last.ckpt) "
|
||||
"will be used to initialize the model.")
|
||||
SRC_CKPT_INFO_MESSAGE = ("Please specify a valid src_checkpoint. You can either use a URL, a local file or an azureml "
|
||||
"run id. For custom checkpoint paths within an azureml run, (other than last.ckpt), provide "
|
||||
f"a src_checkpoint in the format {SRC_CHECKPOINT_FORMAT_DOC}.")
|
||||
|
||||
|
||||
class WorkflowParams(param.Parameterized):
|
||||
"""
|
||||
This class contains all parameters that affect how the whole training and testing workflow is executed.
|
||||
"""
|
||||
random_seed: int = param.Integer(42, doc="The seed to use for all random number generators.")
|
||||
src_checkpoint: str = param.String(default="",
|
||||
doc="This flag can be used in 3 different scenarios:"
|
||||
"1- Resume training from a checkpoint to train longer using"
|
||||
" `resume_training` flag jointly."
|
||||
"2- Run inference-only using `run_inference_only` flag jointly."
|
||||
"3- Transfer learning from a pretrained model checkpoint."
|
||||
"We currently support three types of checkpoints: "
|
||||
" a. A local checkpoint folder that contains a checkpoint file."
|
||||
" b. A URL to a remote checkpoint to be downloaded."
|
||||
" c. A previous azureml run id where the checkpoint is supposed to be "
|
||||
" saved ('outputs/checkpoints/' folder by default.)"
|
||||
"For the latter case 'c' : src_checkpoint should be in the format of "
|
||||
f"{SRC_CHECKPOINT_FORMAT_DOC}")
|
||||
src_checkpoint: CheckpointParser = param.ClassSelector(class_=CheckpointParser, default=None,
|
||||
instantiate=False, doc=CheckpointParser.DOC)
|
||||
crossval_count: int = param.Integer(default=1, bounds=(0, None),
|
||||
doc="The number of splits to use when doing cross-validation. "
|
||||
"Use 1 to disable cross-validation")
|
||||
|
@ -216,41 +193,15 @@ class WorkflowParams(param.Parameterized):
|
|||
CROSSVAL_COUNT_ARG_NAME = "crossval_count"
|
||||
RANDOM_SEED_ARG_NAME = "random_seed"
|
||||
|
||||
@property
|
||||
def src_checkpoint_is_url(self) -> bool:
|
||||
try:
|
||||
result = urlparse(self.src_checkpoint)
|
||||
return all([result.scheme, result.netloc])
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def src_checkpoint_is_local_file(self) -> bool:
|
||||
return Path(self.src_checkpoint).is_file()
|
||||
|
||||
@property
|
||||
def src_checkpoint_is_aml_run_id(self) -> bool:
|
||||
match = re.match(r"[_\w-]*$", self.src_checkpoint.split(":")[0])
|
||||
return match is not None and not self.src_checkpoint_is_url and not self.src_checkpoint_is_local_file
|
||||
|
||||
@property
|
||||
def is_valid_src_checkpoint(self) -> bool:
|
||||
if self.src_checkpoint:
|
||||
return self.src_checkpoint_is_local_file or self.src_checkpoint_is_url or self.src_checkpoint_is_aml_run_id
|
||||
return True
|
||||
|
||||
def validate(self) -> None:
|
||||
if not self.is_valid_src_checkpoint:
|
||||
raise ValueError(f"Invalid src_checkpoint: {self.src_checkpoint}. Please provide a valid URL, local file "
|
||||
"or azureml run id.")
|
||||
if self.crossval_count > 1:
|
||||
if not (0 <= self.crossval_index < self.crossval_count):
|
||||
raise ValueError(f"Attribute crossval_index out of bounds (crossval_count = {self.crossval_count})")
|
||||
|
||||
if self.run_inference_only and not self.src_checkpoint:
|
||||
raise ValueError(f"Cannot run inference without a src_checkpoint. {SRC_CKPT_INFO_MESSAGE}")
|
||||
raise ValueError(f"Cannot run inference without a src_checkpoint. {CheckpointParser.INFO_MESSAGE}")
|
||||
if self.resume_training and not self.src_checkpoint:
|
||||
raise ValueError(f"Cannot resume training without a src_checkpoint. {SRC_CKPT_INFO_MESSAGE}")
|
||||
raise ValueError(f"Cannot resume training without a src_checkpoint. {CheckpointParser.INFO_MESSAGE}")
|
||||
|
||||
@property
|
||||
def is_running_in_aml(self) -> bool:
|
||||
|
|
|
@ -25,8 +25,8 @@ from health_ml.experiment_config import ExperimentConfig
|
|||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.model_trainer import create_lightning_trainer, write_experiment_summary_file
|
||||
from health_ml.utils import fixed_paths
|
||||
from health_ml.utils.checkpoint_handler import CheckpointHandler
|
||||
from health_ml.utils.checkpoint_utils import cleanup_checkpoints
|
||||
from health_ml.utils.checkpoint_handler import CheckpointHandler
|
||||
from health_ml.utils.common_utils import (
|
||||
EFFECTIVE_RANDOM_SEED_KEY_NAME,
|
||||
change_working_directory,
|
||||
|
|
|
@ -4,22 +4,13 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from azureml.core import Run
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from azureml.core import Run
|
||||
|
||||
from health_azure.utils import is_global_rank_zero
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.utils.checkpoint_utils import (
|
||||
MODEL_WEIGHTS_DIR_NAME,
|
||||
CheckpointDownloader,
|
||||
find_recovery_checkpoint_on_disk_or_cloud,
|
||||
)
|
||||
from health_ml.utils.checkpoint_utils import find_recovery_checkpoint_on_disk_or_cloud
|
||||
|
||||
|
||||
class CheckpointHandler:
|
||||
|
@ -40,9 +31,8 @@ class CheckpointHandler:
|
|||
the checkpoint_url, local_checkpoint or checkpoint from an azureml run id.
|
||||
This is called at the start of training.
|
||||
"""
|
||||
|
||||
if self.container.src_checkpoint:
|
||||
self.trained_weights_path = self.get_local_checkpoints_path_or_download()
|
||||
self.trained_weights_path = self.container.src_checkpoint.get_path(self.container.checkpoint_folder)
|
||||
self.container.trained_weights_path = self.trained_weights_path
|
||||
|
||||
def additional_training_done(self) -> None:
|
||||
|
@ -86,53 +76,3 @@ class CheckpointHandler:
|
|||
logging.info(f"Using pre-trained weights from {self.trained_weights_path}")
|
||||
return self.trained_weights_path
|
||||
raise ValueError("Unable to determine which checkpoint should be used for testing.")
|
||||
|
||||
@staticmethod
|
||||
def download_weights(url: str, download_folder: Path) -> Path:
|
||||
"""
|
||||
Download a checkpoint from checkpoint_url to the modelweights directory. The file name is determined from
|
||||
from the file name in the URL. If that can't be determined, use a random file name.
|
||||
|
||||
:param url: The URL from which the weights should be downloaded.
|
||||
:param download_folder: The target folder for the download.
|
||||
:return: A path to the downloaded file.
|
||||
"""
|
||||
# assign the same filename as in the download url if possible, so that we can check for duplicates
|
||||
# If that fails, map to a random uuid
|
||||
file_name = os.path.basename(urlparse(url).path) or str(uuid.uuid4().hex)
|
||||
checkpoint_path = download_folder / file_name
|
||||
# only download if hasn't already been downloaded
|
||||
if checkpoint_path.is_file():
|
||||
logging.info(f"File already exists, skipping download: {checkpoint_path}")
|
||||
else:
|
||||
logging.info(f"Downloading weights from URL {url}")
|
||||
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
with open(checkpoint_path, "wb") as file:
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
file.write(chunk)
|
||||
return checkpoint_path
|
||||
|
||||
def get_local_checkpoints_path_or_download(self) -> Path:
|
||||
"""
|
||||
Get the path to the local weights to use or download them.
|
||||
"""
|
||||
|
||||
if self.container.src_checkpoint_is_local_file:
|
||||
checkpoint_path = Path(self.container.src_checkpoint)
|
||||
elif self.container.src_checkpoint_is_url:
|
||||
download_folder = self.container.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME
|
||||
download_folder.mkdir(exist_ok=True, parents=True)
|
||||
checkpoint_path = self.download_weights(url=self.container.src_checkpoint, download_folder=download_folder)
|
||||
elif self.container.src_checkpoint_is_aml_run_id:
|
||||
downloader = CheckpointDownloader(
|
||||
run_id=self.container.src_checkpoint, download_dir=self.container.outputs_folder
|
||||
)
|
||||
checkpoint_path = downloader.local_checkpoint_path
|
||||
else:
|
||||
raise ValueError("Unable to determine how to get the checkpoint path.")
|
||||
|
||||
if checkpoint_path is None or not checkpoint_path.is_file():
|
||||
raise FileNotFoundError(f"Could not find the weights file at {checkpoint_path}")
|
||||
return checkpoint_path
|
||||
|
|
|
@ -2,27 +2,29 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import re
|
||||
import os
|
||||
import uuid
|
||||
import torch
|
||||
import logging
|
||||
import tempfile
|
||||
import requests
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from urllib.parse import urlparse
|
||||
from azureml.core import Run, Workspace
|
||||
from health_azure import download_checkpoints_from_run_id, get_workspace
|
||||
|
||||
from health_azure.utils import (RUN_CONTEXT, download_files_from_run_id, get_run_file_names, is_running_in_azure_ml)
|
||||
from health_ml.utils.common_utils import (AUTOSAVE_CHECKPOINT_CANDIDATES, DEFAULT_AML_CHECKPOINT_DIR)
|
||||
from health_ml.utils.common_utils import (AUTOSAVE_CHECKPOINT_CANDIDATES, DEFAULT_AML_CHECKPOINT_DIR, CHECKPOINT_SUFFIX)
|
||||
from health_ml.utils.type_annotations import PathOrString
|
||||
|
||||
CHECKPOINT_SUFFIX = ".ckpt"
|
||||
# This is a constant that must match a filename defined in pytorch_lightning.ModelCheckpoint, but we don't want
|
||||
# to import that here.
|
||||
LAST_CHECKPOINT_FILE_NAME = "last"
|
||||
LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX = LAST_CHECKPOINT_FILE_NAME + CHECKPOINT_SUFFIX
|
||||
LAST_CHECKPOINT_FILE_NAME = f"last{CHECKPOINT_SUFFIX}"
|
||||
LEGACY_RECOVERY_CHECKPOINT_FILE_NAME = "recovery"
|
||||
MODEL_INFERENCE_JSON_FILE_NAME = "model_inference_config.json"
|
||||
MODEL_WEIGHTS_DIR_NAME = "trained_models"
|
||||
MODEL_WEIGHTS_DIR_NAME = "pretrained_models"
|
||||
|
||||
|
||||
def get_best_checkpoint_path(path: Path) -> Path:
|
||||
|
@ -31,7 +33,7 @@ def get_best_checkpoint_path(path: Path) -> Path:
|
|||
|
||||
:param path to checkpoint folder
|
||||
"""
|
||||
return path / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
return path / LAST_CHECKPOINT_FILE_NAME
|
||||
|
||||
|
||||
def download_folder_from_run_to_temp_folder(folder: str,
|
||||
|
@ -120,7 +122,7 @@ def find_recovery_checkpoint(path: Path) -> Optional[Path]:
|
|||
logging.warning(f"Found these legacy checkpoint files: {legacy_recovery_checkpoints}")
|
||||
raise ValueError("The legacy recovery checkpoint setup is no longer supported. As a workaround, you can take "
|
||||
f"one of the legacy checkpoints and upload as '{AUTOSAVE_CHECKPOINT_CANDIDATES[0]}'")
|
||||
candidates = [*AUTOSAVE_CHECKPOINT_CANDIDATES, LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX]
|
||||
candidates = [*AUTOSAVE_CHECKPOINT_CANDIDATES, LAST_CHECKPOINT_FILE_NAME]
|
||||
highest_epoch: Optional[int] = None
|
||||
file_with_highest_epoch: Optional[Path] = None
|
||||
for f in candidates:
|
||||
|
@ -147,10 +149,10 @@ def cleanup_checkpoints(ckpt_folder: Path) -> None:
|
|||
if len(files_in_checkpoint_folder) == 0:
|
||||
return
|
||||
logging.info(f"Files in checkpoint folder: {' '.join(files_in_checkpoint_folder)}")
|
||||
last_ckpt = ckpt_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
last_ckpt = ckpt_folder / LAST_CHECKPOINT_FILE_NAME
|
||||
all_files = f"Existing files: {' '.join(p.name for p in ckpt_folder.glob('*'))}"
|
||||
if not last_ckpt.is_file():
|
||||
raise FileNotFoundError(f"Checkpoint file {LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX} not found. {all_files}")
|
||||
raise FileNotFoundError(f"Checkpoint file {LAST_CHECKPOINT_FILE_NAME} not found. {all_files}")
|
||||
# Training is finished now. To save storage, remove the autosave checkpoint which is now obsolete.
|
||||
# Lightning does not overwrite checkpoints in-place. Rather, it writes "autosave.ckpt",
|
||||
# then "autosave-1.ckpt" and deletes "autosave.ckpt", then "autosave.ckpt" and deletes "autosave-v1.ckpt"
|
||||
|
@ -183,7 +185,6 @@ class CheckpointDownloader:
|
|||
self.remote_checkpoint_dir = (
|
||||
remote_checkpoint_dir or self.extract_remote_checkpoint_dir_from_checkpoint_filename()
|
||||
)
|
||||
self.download_checkpoint_if_necessary()
|
||||
|
||||
def extract_checkpoint_filename_from_run_id(self) -> str:
|
||||
"""
|
||||
|
@ -192,7 +193,7 @@ class CheckpointDownloader:
|
|||
"""
|
||||
run_id_split = self.run_id.split(":")
|
||||
self.run_id = run_id_split[0]
|
||||
return run_id_split[-1] if len(run_id_split) > 1 else LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
return run_id_split[-1] if len(run_id_split) > 1 else LAST_CHECKPOINT_FILE_NAME
|
||||
|
||||
def extract_remote_checkpoint_dir_from_checkpoint_filename(self) -> Path:
|
||||
"""
|
||||
|
@ -232,3 +233,111 @@ class CheckpointDownloader:
|
|||
self.run_id, str(self.remote_checkpoint_path), self.local_checkpoint_dir, aml_workspace=workspace
|
||||
)
|
||||
assert self.local_checkpoint_path.exists(), f"Couln't download checkpoint from run {self.run_id}."
|
||||
|
||||
|
||||
class CheckpointParser:
|
||||
"""Wrapper class for parsing checkpoint arguments. A checkpoint can be specified in one of the following ways:
|
||||
1. A local checkpoint file path
|
||||
2. A remote checkpoint file path
|
||||
3. A run ID from which to download the checkpoint file
|
||||
"""
|
||||
AML_RUN_ID_FORMAT = (f"<AzureML_run_id>:<optional/custom/path/to/checkpoints/><filename{CHECKPOINT_SUFFIX}>"
|
||||
f"If no custom path is provided (e.g., <AzureML_run_id>:<filename{CHECKPOINT_SUFFIX}>)"
|
||||
"the checkpoint will be downloaded from the default checkpoint folder "
|
||||
f"(e.g., '{DEFAULT_AML_CHECKPOINT_DIR}') If no filename is provided, "
|
||||
"(e.g., `src_checkpoint=<AzureML_run_id>`) the latest checkpoint "
|
||||
f"({LAST_CHECKPOINT_FILE_NAME}) will be downloaded.")
|
||||
INFO_MESSAGE = ("Please provide a valid checkpoint path, URL or AzureML run ID. For custom checkpoint paths "
|
||||
f"within an azureml run, provide a checkpoint in the format {AML_RUN_ID_FORMAT}.")
|
||||
DOC = ("We currently support three types of checkpoints: "
|
||||
" a. A local checkpoint folder that contains a checkpoint file."
|
||||
" b. A URL to a remote checkpoint to be downloaded."
|
||||
" c. A previous azureml run id where the checkpoint is supposed to be "
|
||||
" saved ('outputs/checkpoints/' folder by default.)"
|
||||
f"For the latter case 'c' : src_checkpoint should be in the format of {AML_RUN_ID_FORMAT}")
|
||||
|
||||
def __init__(self, checkpoint: str = "") -> None:
|
||||
self.checkpoint = checkpoint
|
||||
self.validate()
|
||||
|
||||
@property
|
||||
def is_url(self) -> bool:
|
||||
try:
|
||||
result = urlparse(self.checkpoint)
|
||||
return all([result.scheme, result.netloc])
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_local_file(self) -> bool:
|
||||
return Path(self.checkpoint).is_file()
|
||||
|
||||
@property
|
||||
def is_aml_run_id(self) -> bool:
|
||||
match = re.match(r"[_\w-]*$", self.checkpoint.split(":")[0])
|
||||
return match is not None and not self.is_url and not self.is_local_file
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
if self.checkpoint:
|
||||
return self.is_local_file or self.is_url or self.is_aml_run_id
|
||||
return True
|
||||
|
||||
def validate(self) -> None:
|
||||
if not self.is_valid:
|
||||
raise ValueError(f"Invalid checkpoint '{self.checkpoint}'. {self.INFO_MESSAGE}")
|
||||
|
||||
@staticmethod
|
||||
def download_from_url(url: str, download_folder: Path) -> Path:
|
||||
"""
|
||||
Download a checkpoint from checkpoint_url to the download folder. The file name is determined from
|
||||
from the file name in the URL. If that can't be determined, use a random file name.
|
||||
|
||||
:param url: The URL from which to download.
|
||||
:param download_folder: The target folder for the download.
|
||||
:return: A path to the downloaded file.
|
||||
"""
|
||||
# assign the same filename as in the download url if possible, so that we can check for duplicates
|
||||
# If that fails, map to a random uuid
|
||||
file_name = os.path.basename(urlparse(url).path) or str(uuid.uuid4().hex)
|
||||
checkpoint_path = download_folder / file_name
|
||||
# only download if hasn't already been downloaded
|
||||
if checkpoint_path.is_file():
|
||||
logging.info(f"File already exists, skipping download: {checkpoint_path}")
|
||||
else:
|
||||
logging.info(f"Downloading from URL {url}")
|
||||
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
with open(checkpoint_path, "wb") as file:
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
file.write(chunk)
|
||||
return checkpoint_path
|
||||
|
||||
def get_path(self, download_dir: Path) -> Path:
|
||||
"""Returns the path to the checkpoint file. If the checkpoint is a URL, it will be downloaded to the checkpoints
|
||||
folder. If the checkpoint is an AzureML run ID, it will be downloaded from the run to the checkpoints folder.
|
||||
If the checkpoint is a local file, it will be returned as is.
|
||||
|
||||
:param download_dir: The checkpoints folder to which the checkpoint should be downloaded if it is a URL or
|
||||
AzureML run ID.
|
||||
:raises ValueError: If the checkpoint is not a local file, URL or AzureML run ID.
|
||||
:raises FileNotFoundError: If the checkpoint is a URL or AzureML run ID and the download fails.
|
||||
:return: The path to the checkpoint file.
|
||||
"""
|
||||
if self.is_local_file:
|
||||
checkpoint_path = Path(self.checkpoint)
|
||||
elif self.is_url:
|
||||
download_folder = download_dir / MODEL_WEIGHTS_DIR_NAME
|
||||
download_folder.mkdir(exist_ok=True, parents=True)
|
||||
checkpoint_path = self.download_from_url(url=self.checkpoint, download_folder=download_folder)
|
||||
elif self.is_aml_run_id:
|
||||
downloader = CheckpointDownloader(run_id=self.checkpoint, download_dir=download_dir)
|
||||
downloader.download_checkpoint_if_necessary()
|
||||
checkpoint_path = downloader.local_checkpoint_path
|
||||
else:
|
||||
raise ValueError("Unable to determine how to get the checkpoint path.")
|
||||
|
||||
if checkpoint_path is None or not checkpoint_path.is_file():
|
||||
raise FileNotFoundError(f"Could not find the file at {checkpoint_path}")
|
||||
return checkpoint_path
|
||||
|
|
|
@ -8,52 +8,73 @@ from unittest import mock
|
|||
import pytest
|
||||
|
||||
from health_ml.configs.hello_world import HelloWorld
|
||||
from health_ml.deep_learning_config import WorkflowParams
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.utils.checkpoint_handler import CheckpointHandler
|
||||
from health_ml.utils.checkpoint_utils import (
|
||||
LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
|
||||
LAST_CHECKPOINT_FILE_NAME,
|
||||
MODEL_WEIGHTS_DIR_NAME,
|
||||
CheckpointDownloader,
|
||||
)
|
||||
CheckpointParser,)
|
||||
from health_ml.utils.checkpoint_handler import CheckpointHandler
|
||||
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
|
||||
from testhiml.utils.fixed_paths_for_tests import full_test_data_path, mock_run_id
|
||||
from testhiml.utils_testhiml import DEFAULT_WORKSPACE
|
||||
|
||||
|
||||
def test_checkpoint_downloader_run_id() -> None:
|
||||
with mock.patch("health_ml.utils.checkpoint_utils.CheckpointDownloader.download_checkpoint_if_necessary"):
|
||||
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id")
|
||||
assert checkpoint_downloader.run_id == "dummy_run_id"
|
||||
assert checkpoint_downloader.checkpoint_filename == LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
assert checkpoint_downloader.remote_checkpoint_dir == Path(DEFAULT_AML_CHECKPOINT_DIR)
|
||||
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id")
|
||||
assert checkpoint_downloader.run_id == "dummy_run_id"
|
||||
assert checkpoint_downloader.checkpoint_filename == LAST_CHECKPOINT_FILE_NAME
|
||||
assert checkpoint_downloader.remote_checkpoint_dir == Path(DEFAULT_AML_CHECKPOINT_DIR)
|
||||
|
||||
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id:best.ckpt")
|
||||
assert checkpoint_downloader.run_id == "dummy_run_id"
|
||||
assert checkpoint_downloader.checkpoint_filename == "best.ckpt"
|
||||
assert checkpoint_downloader.remote_checkpoint_dir == Path(DEFAULT_AML_CHECKPOINT_DIR)
|
||||
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id:best.ckpt")
|
||||
assert checkpoint_downloader.run_id == "dummy_run_id"
|
||||
assert checkpoint_downloader.checkpoint_filename == "best.ckpt"
|
||||
assert checkpoint_downloader.remote_checkpoint_dir == Path(DEFAULT_AML_CHECKPOINT_DIR)
|
||||
|
||||
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id:custom/path/best.ckpt")
|
||||
assert checkpoint_downloader.run_id == "dummy_run_id"
|
||||
assert checkpoint_downloader.checkpoint_filename == "best.ckpt"
|
||||
assert checkpoint_downloader.remote_checkpoint_dir == Path("custom/path")
|
||||
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id:custom/path/best.ckpt")
|
||||
assert checkpoint_downloader.run_id == "dummy_run_id"
|
||||
assert checkpoint_downloader.checkpoint_filename == "best.ckpt"
|
||||
assert checkpoint_downloader.remote_checkpoint_dir == Path("custom/path")
|
||||
|
||||
|
||||
def _test_invalid_checkpoint(checkpoint: str) -> None:
|
||||
with pytest.raises(ValueError, match=r"Invalid checkpoint "):
|
||||
CheckpointParser(checkpoint=checkpoint)
|
||||
WorkflowParams(local_datasets=Path("foo"), src_checkpoint=checkpoint).validate()
|
||||
|
||||
|
||||
def test_validate_checkpoint_parser() -> None:
|
||||
|
||||
_test_invalid_checkpoint(checkpoint="dummy/local/path/model.ckpt")
|
||||
_test_invalid_checkpoint(checkpoint="INV@lid%RUN*id")
|
||||
_test_invalid_checkpoint(checkpoint="http/dummy_url-com")
|
||||
|
||||
# The following should be okay
|
||||
checkpoint = str(full_test_data_path(suffix="hello_world_checkpoint.ckpt"))
|
||||
CheckpointParser(checkpoint=checkpoint)
|
||||
WorkflowParams(local_datasets=Path("foo"), src_checkpoint=CheckpointParser(checkpoint)).validate()
|
||||
checkpoint = mock_run_id(id=0)
|
||||
CheckpointParser(checkpoint=checkpoint)
|
||||
WorkflowParams(local_datasets=Path("foo"), src_checkpoint=CheckpointParser(checkpoint)).validate()
|
||||
|
||||
|
||||
def get_checkpoint_handler(tmp_path: Path, src_checkpoint: str) -> Tuple[LightningContainer, CheckpointHandler]:
|
||||
container = LightningContainer()
|
||||
container.set_output_to(tmp_path)
|
||||
container.checkpoint_folder.mkdir(parents=True)
|
||||
container.src_checkpoint = src_checkpoint
|
||||
container.src_checkpoint = CheckpointParser(src_checkpoint)
|
||||
return container, CheckpointHandler(container=container, project_root=tmp_path)
|
||||
|
||||
|
||||
def test_load_model_chcekpoints_from_url(tmp_path: Path) -> None:
|
||||
def test_load_model_checkpoints_from_url(tmp_path: Path) -> None:
|
||||
WEIGHTS_URL = (
|
||||
"https://pl-bolts-weights.s3.us-east-2.amazonaws.com/" "simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
|
||||
)
|
||||
|
||||
container, checkpoint_handler = get_checkpoint_handler(tmp_path, WEIGHTS_URL)
|
||||
download_folder = container.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME
|
||||
assert container.src_checkpoint_is_url
|
||||
assert container.src_checkpoint.is_url
|
||||
checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||
assert checkpoint_handler.trained_weights_path
|
||||
assert checkpoint_handler.trained_weights_path.exists()
|
||||
|
@ -64,7 +85,7 @@ def test_load_model_checkpoints_from_local_file(tmp_path: Path) -> None:
|
|||
local_checkpoint_path = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
|
||||
|
||||
container, checkpoint_handler = get_checkpoint_handler(tmp_path, str(local_checkpoint_path))
|
||||
assert container.src_checkpoint_is_local_file
|
||||
assert container.src_checkpoint.is_local_file
|
||||
checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||
assert checkpoint_handler.trained_weights_path
|
||||
assert checkpoint_handler.trained_weights_path.exists()
|
||||
|
@ -82,10 +103,10 @@ def test_load_model_checkpoints_from_aml_run_id(src_chekpoint_filename: str, tmp
|
|||
src_checkpoint_filename = (
|
||||
src_chekpoint_filename.split("/")[-1]
|
||||
if src_chekpoint_filename
|
||||
else LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
else LAST_CHECKPOINT_FILE_NAME
|
||||
)
|
||||
expected_weights_path = container.outputs_folder / run_id / checkpoint_path / src_checkpoint_filename
|
||||
assert container.src_checkpoint_is_aml_run_id
|
||||
expected_weights_path = container.checkpoint_folder / run_id / checkpoint_path / src_checkpoint_filename
|
||||
assert container.src_checkpoint.is_aml_run_id
|
||||
checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||
assert checkpoint_handler.trained_weights_path
|
||||
assert checkpoint_handler.trained_weights_path.exists()
|
||||
|
@ -99,7 +120,7 @@ def test_custom_checkpoint_for_test(tmp_path: Path) -> None:
|
|||
container = HelloWorld()
|
||||
container.set_output_to(tmp_path)
|
||||
container.checkpoint_folder.mkdir(parents=True)
|
||||
last_checkpoint = container.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
last_checkpoint = container.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME
|
||||
last_checkpoint.touch()
|
||||
checkpoint_handler = CheckpointHandler(container=container, project_root=tmp_path)
|
||||
checkpoint_handler.additional_training_done()
|
||||
|
|
|
@ -13,40 +13,24 @@ from pathlib import Path
|
|||
|
||||
from health_ml.deep_learning_config import DatasetParams, WorkflowParams, OutputParams, OptimizerParams, \
|
||||
ExperimentFolderHandler, TrainerParams
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
from testhiml.utils.fixed_paths_for_tests import full_test_data_path, mock_run_id
|
||||
|
||||
|
||||
def _test_invalid_pre_checkpoint_workflow_params(src_checkpoint: str) -> None:
|
||||
with pytest.raises(ValueError, match=r"Invalid src_checkpoint:"):
|
||||
WorkflowParams(local_datasets=Path("foo"), src_checkpoint=src_checkpoint).validate()
|
||||
|
||||
|
||||
def test_validate_workflow_params_src_checkpoint() -> None:
|
||||
|
||||
_test_invalid_pre_checkpoint_workflow_params(src_checkpoint="dummy/local/path/model.ckpt")
|
||||
_test_invalid_pre_checkpoint_workflow_params(src_checkpoint="INV@lid%RUN*id")
|
||||
_test_invalid_pre_checkpoint_workflow_params(src_checkpoint="http/dummy_url-com")
|
||||
|
||||
# The following should be okay
|
||||
full_file_path = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
|
||||
WorkflowParams(local_dataset=Path("foo"), src_checkpoint=str(full_file_path)).validate()
|
||||
run_id = mock_run_id(id=0)
|
||||
WorkflowParams(local_dataset=Path("foo"), src_checkpoint=run_id).validate()
|
||||
|
||||
|
||||
def test_validate_workflow_params_for_inference_only() -> None:
|
||||
with pytest.raises(ValueError, match=r"Cannot run inference without a src_checkpoint."):
|
||||
WorkflowParams(local_datasets=Path("foo"), run_inference_only=True).validate()
|
||||
|
||||
full_file_path = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
|
||||
run_id = mock_run_id(id=0)
|
||||
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True, src_checkpoint=run_id).validate()
|
||||
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True,
|
||||
src_checkpoint=f"{run_id}:best_val_loss.ckpt").validate()
|
||||
src_checkpoint=CheckpointParser(run_id)).validate()
|
||||
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True,
|
||||
src_checkpoint=f"{run_id}:custom/path/model.ckpt").validate()
|
||||
src_checkpoint=CheckpointParser(f"{run_id}:best_val_loss.ckpt")).validate()
|
||||
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True,
|
||||
src_checkpoint=str(full_file_path)).validate()
|
||||
src_checkpoint=CheckpointParser(f"{run_id}:custom/path/model.ckpt")).validate()
|
||||
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True,
|
||||
src_checkpoint=CheckpointParser(str(full_file_path))).validate()
|
||||
|
||||
|
||||
def test_validate_workflow_params_for_resume_training() -> None:
|
||||
|
@ -55,13 +39,14 @@ def test_validate_workflow_params_for_resume_training() -> None:
|
|||
|
||||
full_file_path = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
|
||||
run_id = mock_run_id(id=0)
|
||||
WorkflowParams(local_dataset=Path("foo"), resume_training=True, src_checkpoint=run_id).validate()
|
||||
WorkflowParams(local_dataset=Path("foo"), resume_training=True,
|
||||
src_checkpoint=f"{run_id}:best_val_loss.ckpt").validate()
|
||||
src_checkpoint=CheckpointParser(run_id)).validate()
|
||||
WorkflowParams(local_dataset=Path("foo"), resume_training=True,
|
||||
src_checkpoint=f"{run_id}:custom/path/model.ckpt").validate()
|
||||
src_checkpoint=CheckpointParser(f"{run_id}:best_val_loss.ckpt")).validate()
|
||||
WorkflowParams(local_dataset=Path("foo"), resume_training=True,
|
||||
src_checkpoint=str(full_file_path)).validate()
|
||||
src_checkpoint=CheckpointParser(f"{run_id}:custom/path/model.ckpt")).validate()
|
||||
WorkflowParams(local_dataset=Path("foo"), resume_training=True,
|
||||
src_checkpoint=CheckpointParser(str(full_file_path))).validate()
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
|
|
|
@ -16,6 +16,7 @@ from health_ml.configs.hello_world import HelloWorld # type: ignore
|
|||
from health_ml.experiment_config import ExperimentConfig
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.run_ml import MLRunner
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
from health_ml.utils.common_utils import is_gpu_available
|
||||
from health_azure.utils import is_global_rank_zero
|
||||
from health_ml.utils.logging import AzureMLLogger
|
||||
|
@ -62,7 +63,7 @@ def ml_runner_with_run_id() -> Generator:
|
|||
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||
container = HelloWorld()
|
||||
container.save_checkpoint = True
|
||||
container.src_checkpoint = mock_run_id(id=0)
|
||||
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
|
||||
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:
|
||||
mock_get_workspace.return_value = DEFAULT_WORKSPACE.workspace
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
|
@ -356,7 +357,7 @@ def test_model_weights_when_resume_training() -> None:
|
|||
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||
container = HelloWorld()
|
||||
container.max_num_gpus = 0
|
||||
container.src_checkpoint = mock_run_id(id=0)
|
||||
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
|
||||
container.resume_training = True
|
||||
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:
|
||||
mock_get_workspace.return_value = DEFAULT_WORKSPACE.workspace
|
||||
|
@ -375,7 +376,7 @@ def test_runner_end_to_end() -> None:
|
|||
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||
container = HelloWorld()
|
||||
container.max_num_gpus = 0
|
||||
container.src_checkpoint = mock_run_id(id=0)
|
||||
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
|
||||
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:
|
||||
mock_get_workspace.return_value = DEFAULT_WORKSPACE.workspace
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
|
|
Загрузка…
Ссылка в новой задаче