Merge work items from transferability study.
This commit is contained in:
Kenza Bouzid 2022-11-14 12:39:58 +00:00 коммит произвёл GitHub
Родитель 2fd33ab350 b8a1ae0bc0
Коммит 70885901e8
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
82 изменённых файлов: 3791 добавлений и 1145 удалений

103
.github/workflows/cpath-pr.yml поставляемый
Просмотреть файл

@ -128,7 +128,22 @@ jobs:
cd ${{ env.folder }}
make smoke_test_tilespandaimagenetmil_aml
smoke_test_tcgacrcksslmil_aml:
smoke_test_tcgacrckimagenetmil:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
with:
lfs: true
- name: Prepare Conda environment
uses: ./.github/actions/prepare_cpath_environment
- name: smoke test
run: |
cd ${{ env.folder }}
make smoke_test_tcgacrckimagenetmil_aml
smoke_test_tcgacrcksslmil:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
@ -143,7 +158,7 @@ jobs:
cd ${{ env.folder }}
make smoke_test_tcgacrcksslmil_aml
smoke_test_crck_simclr_aml:
smoke_test_crck_simclr:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
@ -158,6 +173,81 @@ jobs:
cd ${{ env.folder }}
make smoke_test_crck_simclr_aml
smoke_test_crck_flexible_finetuning:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
with:
lfs: true
- name: Prepare Conda environment
uses: ./.github/actions/prepare_cpath_environment
- name: smoke test
run: |
cd ${{ env.folder }}
make smoke_test_crck_flexible_finetuning_aml
smoke_test_crck_loss_analysis:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
with:
lfs: true
- name: Prepare Conda environment
uses: ./.github/actions/prepare_cpath_environment
- name: smoke test
run: |
cd ${{ env.folder }}
make smoke_test_crck_loss_analysis_aml
smoke_test_slides_panda_loss_analysis:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
with:
lfs: true
- name: Prepare Conda environment
uses: ./.github/actions/prepare_cpath_environment
- name: smoke test
run: |
cd ${{ env.folder }}
make smoke_test_slides_panda_loss_analysis_aml
smoke_test_slides_panda_no_ddp_sampler:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
with:
lfs: true
- name: Prepare Conda environment
uses: ./.github/actions/prepare_cpath_environment
- name: smoke test
run: |
cd ${{ env.folder }}
make smoke_test_slides_panda_no_ddp_sampler_aml
smoke_test_tiles_panda_no_ddp_sampler:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
with:
lfs: true
- name: Prepare Conda environment
uses: ./.github/actions/prepare_cpath_environment
- name: smoke test
run: |
cd ${{ env.folder }}
make smoke_test_tiles_panda_no_ddp_sampler_aml
publish:
runs-on: ubuntu-20.04
needs: [
@ -166,7 +256,14 @@ jobs:
pytest,
smoke_test_slidespandaimagenetmil,
smoke_test_tilespandaimagenetmil,
smoke_test_crck_simclr_aml
smoke_test_tcgacrckimagenetmil,
smoke_test_tcgacrcksslmil,
smoke_test_crck_simclr,
smoke_test_crck_flexible_finetuning,
smoke_test_crck_loss_analysis,
smoke_test_slides_panda_loss_analysis,
smoke_test_slides_panda_no_ddp_sampler,
smoke_test_tiles_panda_no_ddp_sampler,
]
steps:
- uses: actions/checkout@v3

Просмотреть файл

@ -16,7 +16,9 @@ dependencies:
- azureml-train-core==1.43.0
- conda-merge==0.1.5
- msal-extensions==0.3.1
- pandas==1.3.4
- param==1.12
- protobuf==3.20.1
- pytorch-lightning>=1.6.0, <1.7
- ruamel.yaml==0.16.12
- setuptools==59.5.0

Просмотреть файл

@ -63,7 +63,7 @@ jobs:
- name: HelloWorld
sku: G1
command:
- python hi-ml/src/health_ml/runner.py --model=health_cpath.TilesPandaImageNetMIL --is_finetune --batch_size=2
- python hi-ml/src/health_ml/runner.py --model=health_cpath.TilesPandaImageNetMIL --tune_encoder --batch_size=2
```
Note that SKU here refers to the number of GPUs/CPUs to reserve, and its memory. In this case we have specified 1 GPU.

Просмотреть файл

@ -0,0 +1,61 @@
# Checkpoint Utils
Hi-ml toolbox offers different utilities to parse and download pretrained checkpoints that help you abstract checkpoint
downloading from different sources. Refer to
[CheckpointParser](https://github.com/microsoft/hi-ml/blob/main/hi-ml/src/health_ml/utils/checkpoint_utils.py#L238) for
more details on the supported checkpoints format. Here's how you can use the checkpoint parser depending on the source:
- For a local path, simply pass it as shown below. The parser will further check if the provided path exists:
```python
from health_ml.utils.checpoint_utils import CheckpointParser
download_dir = 'outputs/checkpoints'
checkpoint_parser = CheckpointParser(checkpoint='local/path/to/my_checkpoint/model.ckpt')
print('Checkpoint', checkpoint_parser.checkpoint, 'is a local file', checkpoint_parser.is_local_file)
local_file = parser.get_path(download_dir)
```
- To download a checkpoint from a URL:
```python
from health_ml.utils.checpoint_utils import CheckpointParser, MODEL_WEIGHTS_DIR_NAME
download_dir = 'outputs/checkpoints'
checkpoint_parser = CheckpointParser('https://my_checkpoint_url.com/model.ckpt')
print('Checkpoint', checkpoint_parser.checkpoint, 'is a URL', checkpoint_parser.is_url)
# will dowload the checkpoint to download_dir/MODEL_WEIGHTS_DIR_NAME
path_to_ckpt = checkpoint_parser.get_path(download_dir)
```
- Finally checkpoints from an Azure ML runs can be reused by providing an id in this format
`<AzureML_run_id>:<optional/custom/path/to/checkpoints/><filename.ckpt>`. If no custom path is provided (e.g.,
`<AzureML_run_id>:<filename.ckpt>`) the checkpoint will be downloaded from the default checkpoint folder
(e.g., `outputs/checkpoints`) If no filename is provided, (e.g., `src_checkpoint=<AzureML_run_id>`) the latest
checkpoint will be downloaded (e.g., `last.ckpt`).
```python
from health_ml.utils.checpoint_utils import CheckpointParser
checkpoint_parser = CheckpointParser('AzureML_run_id:best.ckpt')
print('Checkpoint', checkpoint_parser.checkpoint, 'is a AML run', checkpoint_parser.is_aml_run_id)
path_azure_ml_ckpt = checkpoint_parser.get_path(download_dir)
```
If the Azure ML run is in a different workspace, a temporary SAS URL to download the checkpoint can be generated as follow:
```bash
cd hi-ml-cpath
python src/health_cpath/scripts/generate_checkpoint_url.py --run_id=AzureML_run_id:best_val_loss.ckpt --expiry_days=10
```
N.B: config.json should correspond to the original workspace where the AML run lives.
## Use cases
CheckpointParser is used to specify a `src_checkpoint` to [resume training from a given
checkpoint](https://github.com/microsoft/hi-ml/blob/main/docs/source/runner.md#L238),
or [run inference with a pretrained model](https://github.com/microsoft/hi-ml/blob/main/docs/source/runner.md#L215),
as well as
[ssl_checkpoint](https://github.com/microsoft/hi-ml/blob/main/hi-ml-cpath/src/health_cpath/utils/deepmil_utils.py#L62)
for computation pathology self supervised pretrained encoders.

Просмотреть файл

@ -42,6 +42,7 @@ The `hi-ml` toolbox provides
logging.md
diagnostics.md
runner.md
checkpoints.md
.. toctree::
:maxdepth: 1

Просмотреть файл

@ -44,7 +44,7 @@ addition, you can turn on fine-tuning of the encoder, which will improve the res
```shell
conda activate HimlHisto
python ../hi-ml/src/health_ml/runner.py --model health_cpath.SlidesPandaImageNetMILBenchmark --is_finetune --cluster=<your_cluster_name>
python ../hi-ml/src/health_ml/runner.py --model health_cpath.SlidesPandaImageNetMILBenchmark --tune_encoder --cluster=<your_cluster_name>
```
Then the script will output "Successfully queued run number ..." and a line prefixed "Run URL: ...". Open that

Просмотреть файл

@ -226,6 +226,8 @@ the model weights by setting `--src_checkpoint` argument that supports three typ
checkpoints folder `outputs/checkpoints`. If no filename is provided (e.g., `--src_checkpoint=AzureML_run_id`),
the last epoch checkpoint `outputs/checkpoints/last.ckpt` will be loaded.
Refer to [Checkpoints Utils](checkpoints.md) for more details on how checkpoints are parsed.
Running the following command line will run inference using `MyContainer` model with weights from the checkpoint saved
in the AzureMl run `MyContainer_XXXX_yyyy` at the best validation loss epoch `/outputs/checkpoints/best_val_loss.ckpt`.
@ -235,12 +237,12 @@ himl-runner --model=Mycontainer --run_inference_only --src_checkpoint=MyContaine
## Resume training from a given checkpoint
Analogously, one can resume training by setting `--src_checkpoint` to either continue training or transfer learning.
Analogously, one can resume training by setting `--src_checkpoint` and `--resume_training` to train a model longer.
The pytorch lightning trainer will initialize the lightning module from the given checkpoint corresponding to the best
validation loss epoch as set in the following comandline.
```bash
himl-runner --model=Mycontainer --cluster=my_cluster_name --src_checkpoint=MyContainer_XXXX_yyyy:best_val_loss.ckpt
himl-runner --model=Mycontainer --cluster=my_cluster_name --src_checkpoint=MyContainer_XXXX_yyyy:best_val_loss.ckpt --resume_training
```
Warning: When resuming training, one should make sure to set `container.max_epochs` greater than the last epoch of the

Просмотреть файл

@ -29,6 +29,10 @@ pip_from_conda:
sed -e '1,/pip:/ d' environment.yml | grep -v "#" | cut -d "-" -f 2- > temp_requirements.txt
pip install -r temp_requirements.txt
# Lock the current Conda environment secondary dependencies versions
lock_env:
../create_and_lock_environment.sh
# clean build artifacts
clean:
rm -rf `find . -type d -name __pycache__`
@ -78,6 +82,7 @@ pytest_coverage:
pytest --cov=health_cpath --cov SSL --cov-branch --cov-report=html --cov-report=xml --cov-report=term-missing --cov-config=.coveragerc
SSL_CKPT_RUN_ID_CRCK := CRCK_SimCLR_1655731022_85790606
SRC_CKPT_RUN_ID_CRCK := TcgaCrckSSLMIL_1667832811_4e855804
# Run regression tests and compare performance
define BASE_CPATH_RUNNER_COMMAND
@ -86,15 +91,19 @@ python hi-ml/src/health_ml/runner.py --mount_in_azureml --conda_env=hi-ml-cpath/
endef
define DEEPSMILEPANDASLIDES_ARGS
--model=health_cpath.SlidesPandaImageNetMILBenchmark --is_finetune
--model=health_cpath.SlidesPandaImageNetMILBenchmark --tune_encoder
endef
define DEEPSMILEPANDATILES_ARGS
--model=health_cpath.TilesPandaImageNetMIL --is_finetune --batch_size=2
--model=health_cpath.TilesPandaImageNetMIL --tune_encoder --batch_size=2
endef
define TCGACRCKSSLMIL_ARGS
--model=health_cpath.TcgaCrckSSLMIL --ssl_checkpoint_run_id=${SSL_CKPT_RUN_ID_CRCK}
--model=health_cpath.TcgaCrckSSLMIL --ssl_checkpoint=${SSL_CKPT_RUN_ID_CRCK}
endef
define TCGACRCKIMANEGETMIL_ARGS
--model=health_cpath.TcgaCrckImageNetMIL --is_caching=False --batch_size=2 --max_num_gpus=1
endef
define CRCKSIMCLR_ARGS
@ -102,7 +111,7 @@ define CRCKSIMCLR_ARGS
endef
define REGRESSION_TEST_ARGS
--cluster dedicated-nc24s-v2 --regression_test_csv_tolerance=0.5
--cluster dedicated-nc24s-v2 --regression_test_csv_tolerance=0.5 --strictly_aml_v1=True
endef
define PANDA_REGRESSION_METRICS
@ -158,12 +167,7 @@ define AML_MULTIPLE_DEVICE_ARGS
--cluster=dedicated-nc24s-v2 --wait_for_completion
endef
define DEEPSMILEPANDASLIDES_SMOKE_TEST_ARGS
--crossval_count=0 --num_top_slides=2 --num_top_tiles=2 --max_bag_size=3 \
--max_bag_size_inf=3
endef
define DEEPSMILEPANDATILES_SMOKE_TEST_ARGS
define DEEPSMILEDEFAULT_SMOKE_TEST_ARGS
--crossval_count=0 --num_top_slides=2 --num_top_tiles=2 --max_bag_size=3 \
--max_bag_size_inf=3
endef
@ -176,24 +180,38 @@ define CRCKSIMCLR_SMOKE_TEST_ARGS
--is_debug_model=True --num_workers=0
endef
define CRCK_TUNING_ARGS
--tune_encoder=False --tune_pooling=True --tune_classifier=True --pretrained_encoder=True \
--pretrained_pooling=True --pretrained_classifier=True --src_checkpoint=${SRC_CKPT_RUN_ID_CRCK}
endef
define LOSS_ANALYSIS_ARGS
--analyse_loss=True --loss_analysis_patience=0 --loss_analysis_epochs_interval=1 \
--num_slides_scatter=2 --num_slides_heatmap=2 --save_tile_ids=True
endef
define DDP_SAMPLER_ARGS
--pl_replace_sampler_ddp=False
endef
# The following test takes around 5 minutes
smoke_test_slidespandaimagenetmil_local:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEPANDASLIDES_SMOKE_TEST_ARGS};}
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS};}
# Once running in AML the following test takes around 6 minutes
smoke_test_slidespandaimagenetmil_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEPANDASLIDES_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
# The following test takes about 6 minutes
smoke_test_tilespandaimagenetmil_local:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEPANDATILES_SMOKE_TEST_ARGS};}
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS};}
smoke_test_tilespandaimagenetmil_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEPANDATILES_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
# The following test takes about 30 seconds
smoke_test_tcgacrcksslmil_local:
@ -204,6 +222,14 @@ smoke_test_tcgacrcksslmil_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKSSLMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${TCGACRCKSSLMIL_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
smoke_test_tcgacrckimagenetmil_local:
{ ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKIMANEGETMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS};}
smoke_test_tcgacrckimagenetmil_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKIMANEGETMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
# The following test takes about 3 minutes
smoke_test_crck_simclr_local:
{ ${BASE_CPATH_RUNNER_COMMAND} ${CRCKSIMCLR_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
@ -213,6 +239,46 @@ smoke_test_crck_simclr_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${CRCKSIMCLR_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${CRCKSIMCLR_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
smoke tests local: smoke_test_slidespandaimagenetmil_local smoke_test_tilespandaimagenetmil_local smoke_test_tcgacrcksslmil_local smoke_test_crck_simclr_local
smoke_test_crck_flexible_finetuning_local:
{ ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKSSLMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${TCGACRCKSSLMIL_SMOKE_TEST_ARGS} ${CRCK_TUNING_ARGS};}
smoke tests AML: smoke_test_slidespandaimagenetmil_aml smoke_test_tilespandaimagenetmil_aml smoke_test_tcgacrcksslmil_aml smoke_test_crck_simclr_aml
smoke_test_crck_flexible_finetuning_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKSSLMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${TCGACRCKSSLMIL_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS} ${CRCK_TUNING_ARGS};}
smoke_test_crck_loss_analysis_local:
{ ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKSSLMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${TCGACRCKSSLMIL_SMOKE_TEST_ARGS} ${LOSS_ANALYSIS_ARGS};}
smoke_test_crck_loss_analysis_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKSSLMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${TCGACRCKSSLMIL_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS} ${LOSS_ANALYSIS_ARGS};}
smoke_test_slides_panda_loss_analysis_local:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${LOSS_ANALYSIS_ARGS};}
smoke_test_slides_panda_loss_analysis_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${LOSS_ANALYSIS_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
smoke_test_slides_panda_no_ddp_sampler_local:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${DDP_SAMPLER_ARGS};}
smoke_test_slides_panda_no_ddp_sampler_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${DDP_SAMPLER_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
smoke_test_tiles_panda_no_ddp_sampler_local:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${DDP_SAMPLER_ARGS};}
smoke_test_tiles_panda_no_ddp_sampler_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${DDP_SAMPLER_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
smoke tests local: smoke_test_slidespandaimagenetmil_local smoke_test_tilespandaimagenetmil_local smoke_test_tcgacrcksslmil_local smoke_test_crck_simclr_local smoke_test_crck_flexible_finetuning_local smoke_test_tcgacrckimagenetmil_local smoke_test_crck_loss_analysis_local smoke_test_slides_panda_loss_analysis_local smoke_test_slides_panda_no_ddp_sampler_local smoke_test_tiles_panda_no_ddp_sampler_local
smoke tests AML: smoke_test_slidespandaimagenetmil_aml smoke_test_tilespandaimagenetmil_aml smoke_test_tcgacrcksslmil_aml smoke_test_crck_simclr_aml smoke_test_crck_flexible_finetuning_aml smoke_test_tcgacrckimagenetmil_aml smoke_test_crck_loss_analysis_aml smoke_test_slides_panda_loss_analysis_aml smoke_test_slides_panda_no_ddp_sampler_aml smoke_test_tiles_panda_no_ddp_sampler_aml

Просмотреть файл

@ -21,11 +21,11 @@ jobs:
- name: TilesPandaImageNetMIL
sku: G1
command:
- python hi-ml/src/health_ml/runner.py --model=health_cpath.TilesPandaImageNetMIL --is_finetune --batch_size=2
- python hi-ml/src/health_ml/runner.py --model=health_cpath.TilesPandaImageNetMIL --tune_encoder --batch_size=2
--crossval_count=0 --num_top_slides=2 --num_top_tiles=2 --max_bag_size=3 --max_bag_size_inf=3
--mount_in_azureml=True --pl_limit_train_batches=2 --pl_limit_val_batches=2 --pl_limit_test_batches=2
- name: SlidesPandaImageNetMIL
sku: G1
command:
- python hi-ml/src/health_ml/runner.py --model=health_cpath.SlidesPandaImageNetMIL --is_finetune --crossval_count=0 --num_top_slides=2 --num_top_tiles=2 --max_bag_size=3 --max_bag_size_inf=3
- python hi-ml/src/health_ml/runner.py --model=health_cpath.SlidesPandaImageNetMIL --tune_encoder --crossval_count=0 --num_top_slides=2 --num_top_tiles=2 --max_bag_size=3 --max_bag_size_inf=3

Просмотреть файл

@ -1,3 +1,4 @@
azure-storage-blob==12.5.0
coloredlogs==15.0.1
cucim==22.04.00
girder-client==3.1.14

Просмотреть файл

@ -14,33 +14,34 @@ from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from health_azure.utils import create_from_matching_params
from health_cpath.utils.callbacks import LossAnalysisCallback, LossCallbackParams
from health_ml.utils import fixed_paths
from health_ml.deep_learning_config import OptimizerParams
from health_ml.lightning_container import LightningContainer
from health_ml.utils.checkpoint_utils import get_best_checkpoint_path
from health_ml.utils.checkpoint_utils import get_best_checkpoint_path, CheckpointParser
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
from health_cpath.datamodules.base_module import CacheLocation, CacheMode, HistoDataModule
from health_cpath.datasets.base_dataset import SlidesDataset
from health_cpath.models.deepmil import TilesDeepMILModule, SlidesDeepMILModule, BaseDeepMILModule
from health_cpath.models.transforms import EncodeTilesBatchd, LoadTilesBatchd
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams
from health_cpath.utils.deepmil_utils import ClassifierParams, EncoderParams, PoolingParams
from health_cpath.utils.output_utils import DeepMILOutputsHandler
from health_cpath.utils.naming import MetricsKey, PlotOption, SlideKey, ModelKey
from health_cpath.utils.tiles_selection_utils import TilesSelector
class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams, LossCallbackParams):
"""BaseMIL is an abstract container defining basic functionality for running MIL experiments in both slides and
tiles settings. It is responsible for instantiating the encoder and pooling layer. Subclasses should define the
full DeepMIL model depending on the type of dataset (tiles/slides based).
"""
dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.")
class_names: Optional[Sequence[str]] = param.List(None, item_type=str, doc="List of class names. If `None`, "
"defaults to `('0', '1', ...)`.")
# Data module parameters:
batch_size: int = param.Integer(16, bounds=(1, None), doc="Number of slides to load per batch.")
batch_size_inf: int = param.Integer(16, bounds=(1, None), doc="Number of slides per batch during inference.")
max_bag_size: int = param.Integer(1000, bounds=(0, None),
doc="Upper bound on number of tiles in each loaded bag during training stage. "
"If 0 (default), will return all samples in each bag. "
@ -71,12 +72,13 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
"generating outputs.")
maximise_primary_metric: bool = param.Boolean(True, doc="Whether the primary validation metric should be "
"maximised (otherwise minimised).")
ssl_checkpoint_run_id: str = param.String(default="", doc="Optional run id from which to load checkpoint if "
"using SSLEncoder")
max_num_workers: int = param.Integer(10, bounds=(0, None),
doc="The maximum number of worker processes for dataloaders. Dataloaders use"
"a heuristic num_cpus/num_gpus to set the number of workers, which can be"
"very high for small num_gpus. This parameters sets an upper bound.")
wsi_has_mask: bool = param.Boolean(default=True,
doc="Whether the WSI has a mask. If True, will use the mask to load a specific"
"region of the WSI. If False, will load the whole WSI.")
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
@ -84,6 +86,42 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
metric_optim = "max" if self.maximise_primary_metric else "min"
self.best_checkpoint_filename = f"checkpoint_{metric_optim}_val_{self.primary_val_metric.value}"
self.best_checkpoint_filename_with_suffix = self.best_checkpoint_filename + ".ckpt"
self.validate()
if not self.pl_replace_sampler_ddp and self.max_num_gpus > 1:
logging.info(
"Replacing sampler with `DistributedSampler` is disabled. Make sure to set your own DDP sampler"
)
def validate(self) -> None:
super().validate()
EncoderParams.validate(self)
if not any([self.tune_encoder, self.tune_pooling, self.tune_classifier]) and not self.run_inference_only:
raise ValueError(
"At least one of the encoder, pooling or classifier should be fine tuned. Turn on one of the tune "
"arguments `tune_encoder`, `tune_pooling`, `tune_classifier`. Otherwise, activate inference only "
"mode via `run_inference_only` flag."
)
if (
any([self.pretrained_encoder, self.pretrained_pooling, self.pretrained_classifier])
and not self.src_checkpoint
):
raise ValueError(
"You need to specify a source checkpoint, to use a pretrained encoder, pooling or classifier."
f" {CheckpointParser.INFO_MESSAGE}"
)
if (
self.tune_encoder and self.encoding_chunk_size < self.max_bag_size
and self.pl_sync_batchnorm and self.max_num_gpus > 1
):
raise ValueError(
"The encoding chunk size should be at least as large as the maximum bag size when fine tuning the "
"encoder. You might encounter Batch Norm synchronization issues if the chunk size is smaller than "
"the maximum bag size causing the processes to hang silently. This is due to the encoder being called "
"different number of times on each device, which cause Batch Norm running statistics to be updated "
"inconsistently across processes. In case you can't increase the `encoding_chunk_size` any further, set"
" `pl_sync_batchnorm=False` to simply skip Batch Norm synchronization across devices. Note that this "
"might come with some performance penalty."
)
@property
def cache_dir(self) -> Path:
@ -96,6 +134,9 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
return options
def get_val_plot_options(self) -> Set[PlotOption]:
""" Override this method if you want to produce validation plots at each epoch. By default, at the end of the
training an extra validation epoch is run where val_plot_options = test_plot_options
"""
return set()
def get_outputs_handler(self) -> DeepMILOutputsHandler:
@ -110,6 +151,8 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
maximise=self.maximise_primary_metric,
val_plot_options=self.get_val_plot_options(),
test_plot_options=self.get_test_plot_options(),
wsi_has_mask=self.wsi_has_mask,
val_set_is_dist=self.pl_replace_sampler_ddp and self.max_num_gpus > 1,
)
if self.num_top_slides > 0:
outputs_handler.tiles_selector = TilesSelector(
@ -118,12 +161,26 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
return outputs_handler
def get_callbacks(self) -> List[Callback]:
return [*super().get_callbacks(),
ModelCheckpoint(dirpath=self.checkpoint_folder,
monitor=f"{ModelKey.VAL}/{self.primary_val_metric}",
filename=self.best_checkpoint_filename,
auto_insert_metric_name=False,
mode="max" if self.maximise_primary_metric else "min")]
callbacks = [*super().get_callbacks(),
ModelCheckpoint(dirpath=self.checkpoint_folder,
monitor=f"{ModelKey.VAL}/{self.primary_val_metric}",
filename=self.best_checkpoint_filename,
auto_insert_metric_name=False,
mode="max" if self.maximise_primary_metric else "min")]
if self.analyse_loss:
callbacks.append(LossAnalysisCallback(outputs_folder=self.outputs_folder,
max_epochs=self.max_epochs,
patience=self.loss_analysis_patience,
epochs_interval=self.loss_analysis_epochs_interval,
num_slides_scatter=self.num_slides_scatter,
num_slides_heatmap=self.num_slides_heatmap,
save_tile_ids=self.save_tile_ids,
log_exceptions=self.log_exceptions,
val_set_is_dist=(
self.pl_replace_sampler_ddp and self.max_num_gpus > 1),
)
)
return callbacks
def get_checkpoint_to_test(self) -> Path:
"""
@ -144,6 +201,10 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
if absolute_checkpoint_path_parent.is_file():
return absolute_checkpoint_path_parent
checkpoint_path = Path(self.checkpoint_folder, self.best_checkpoint_filename_with_suffix)
if checkpoint_path.is_file():
return checkpoint_path
checkpoint_path = get_best_checkpoint_path(self.checkpoint_folder)
if checkpoint_path.is_file():
return checkpoint_path
@ -197,8 +258,8 @@ class BaseMILTiles(BaseMIL):
def setup(self) -> None:
super().setup()
# Fine-tuning requires tiles to be loaded on-the-fly, hence, caching is disabled by default.
# When is_finetune and is_caching are both set, below lines should disable caching automatically.
if self.is_finetune:
# When tune_encoder and is_caching are both set, below lines should disable caching automatically.
if self.tune_encoder:
self.is_caching = False
if not self.is_caching:
self.cache_mode = CacheMode.NONE
@ -213,8 +274,7 @@ class BaseMILTiles(BaseMIL):
def get_transforms_dict(self, image_key: str) -> Dict[ModelKey, Union[Callable, None]]:
if self.is_caching:
encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.ssl_checkpoint_run_id,
self.outputs_folder)
encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.outputs_folder)
transform = Compose([
LoadTilesBatchd(image_key, progress=True),
EncodeTilesBatchd(image_key, encoder, chunk_size=self.encoding_chunk_size) # type: ignore
@ -231,13 +291,15 @@ class BaseMILTiles(BaseMIL):
n_classes=self.data_module.train_dataset.n_classes,
class_names=self.class_names,
class_weights=self.data_module.class_weights,
dropout_rate=self.dropout_rate,
outputs_folder=self.outputs_folder,
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
encoder_params=create_from_matching_params(self, EncoderParams),
pooling_params=create_from_matching_params(self, PoolingParams),
classifier_params=create_from_matching_params(self, ClassifierParams),
optimizer_params=create_from_matching_params(self, OptimizerParams),
outputs_handler=outputs_handler)
outputs_folder=self.outputs_folder,
outputs_handler=outputs_handler,
analyse_loss=self.analyse_loss,
)
deepmil_module.transfer_weights(self.trained_weights_path)
outputs_handler.set_slides_dataset_for_plots_handlers(self.get_slides_dataset())
return deepmil_module
@ -271,12 +333,14 @@ class BaseMILSlides(BaseMIL):
n_classes=self.data_module.train_dataset.n_classes,
class_names=self.class_names,
class_weights=self.data_module.class_weights,
dropout_rate=self.dropout_rate,
outputs_folder=self.outputs_folder,
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
encoder_params=create_from_matching_params(self, EncoderParams),
pooling_params=create_from_matching_params(self, PoolingParams),
classifier_params=create_from_matching_params(self, ClassifierParams),
optimizer_params=create_from_matching_params(self, OptimizerParams),
outputs_handler=outputs_handler)
outputs_handler=outputs_handler,
analyse_loss=self.analyse_loss,
)
deepmil_module.transfer_weights(self.trained_weights_path)
outputs_handler.set_slides_dataset_for_plots_handlers(self.get_slides_dataset())
return deepmil_module

Просмотреть файл

@ -11,7 +11,7 @@ Reference:
- Schirris (2021). DeepSMILE: Self-supervised heterogeneity-aware multiple instance learning for DNA
damage response defect classification directly from H&E whole-slide images. arXiv:2107.09405
"""
from typing import Any
from typing import Any, Set
from health_ml.networks.layers.attention_layers import AttentionLayer
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_crck_4ws
@ -20,12 +20,14 @@ from health_cpath.datamodules.tcga_crck_module import TcgaCrckTilesDataModule
from health_cpath.datasets.default_paths import TCGA_CRCK_DATASET_ID
from health_cpath.models.encoders import (
HistoSSLEncoder,
ImageNetEncoder,
ImageNetSimCLREncoder,
Resnet18,
SSLEncoder,
)
from health_cpath.configs.classification.BaseMIL import BaseMILTiles
from health_cpath.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
from health_cpath.utils.naming import PlotOption
from health_ml.utils.checkpoint_utils import CheckpointParser
class DeepSMILECrck(BaseMILTiles):
@ -37,7 +39,7 @@ class DeepSMILECrck(BaseMILTiles):
num_transformer_pool_layers=4,
num_transformer_pool_heads=4,
encoding_chunk_size=60,
is_finetune=False,
tune_encoder=False,
is_caching=True,
num_top_slides=0,
azure_datasets=[TCGA_CRCK_DATASET_ID],
@ -55,14 +57,13 @@ class DeepSMILECrck(BaseMILTiles):
def setup(self) -> None:
super().setup()
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_crck_4ws
def get_data_module(self) -> TilesDataModule:
return TcgaCrckTilesDataModule(
root_path=self.local_datasets[0],
max_bag_size=self.max_bag_size,
batch_size=self.batch_size,
batch_size_inf=self.batch_size_inf,
max_bag_size_inf=self.max_bag_size_inf,
transforms_dict=self.get_transforms_dict(TcgaCrck_TilesDataset.IMAGE_COLUMN),
cache_mode=self.cache_mode,
@ -72,12 +73,18 @@ class DeepSMILECrck(BaseMILTiles):
crossval_index=self.crossval_index,
dataloader_kwargs=self.get_dataloader_kwargs(),
seed=self.get_effective_random_seed(),
pl_replace_sampler_ddp=self.pl_replace_sampler_ddp,
)
def get_test_plot_options(self) -> Set[PlotOption]:
plot_options = super().get_test_plot_options()
plot_options.add(PlotOption.PR_CURVE)
return plot_options
class TcgaCrckImageNetMIL(DeepSMILECrck):
def __init__(self, **kwargs: Any) -> None:
super().__init__(encoder_type=ImageNetEncoder.__name__, **kwargs)
super().__init__(encoder_type=Resnet18.__name__, **kwargs)
class TcgaCrckImageNetSimCLRMIL(DeepSMILECrck):
@ -87,6 +94,8 @@ class TcgaCrckImageNetSimCLRMIL(DeepSMILECrck):
class TcgaCrckSSLMIL(DeepSMILECrck):
def __init__(self, **kwargs: Any) -> None:
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint = self.ssl_checkpoint or CheckpointParser(innereye_ssl_checkpoint_crck_4ws)
super().__init__(encoder_type=SSLEncoder.__name__, **kwargs)

Просмотреть файл

@ -13,8 +13,8 @@ from health_cpath.datamodules.panda_module import (
from health_cpath.datasets.panda_tiles_dataset import PandaTilesDataset
from health_cpath.models.encoders import (
HistoSSLEncoder,
ImageNetEncoder,
ImageNetSimCLREncoder,
Resnet18,
SSLEncoder)
from health_cpath.configs.classification.BaseMIL import BaseMILSlides, BaseMILTiles, BaseMIL
from health_cpath.datasets.panda_dataset import PandaDataset
@ -22,6 +22,7 @@ from health_cpath.datasets.default_paths import (
PANDA_DATASET_ID,
PANDA_5X_TILES_DATASET_ID)
from health_cpath.utils.naming import PlotOption
from health_ml.utils.checkpoint_utils import CheckpointParser
class BaseDeepSMILEPanda(BaseMIL):
@ -33,7 +34,7 @@ class BaseDeepSMILEPanda(BaseMIL):
pool_type=AttentionLayer.__name__,
num_transformer_pool_layers=4,
num_transformer_pool_heads=4,
is_finetune=False,
tune_encoder=False,
# average number of tiles is 56 for PANDA
encoding_chunk_size=60,
max_bag_size=56,
@ -55,7 +56,7 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
""" DeepSMILETilesPanda is derived from BaseMILTiles and BaseDeepSMILEPanda to inherit common behaviors from both
tiles basemil and panda specific configuration.
`is_finetune` sets the fine-tuning mode. `is_finetune` sets the fine-tuning mode. For fine-tuning, batch_size = 2
`tune_encoder` sets the fine-tuning mode of the encoder. For fine-tuning the encoder, batch_size = 2
runs on multiple GPUs with ~ 6:24 min/epoch (train) and ~ 00:50 min/epoch (validation).
"""
@ -64,20 +65,20 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
# declared in BaseMILTiles:
is_caching=False,
batch_size=8,
batch_size_inf=8,
azure_datasets=[PANDA_5X_TILES_DATASET_ID, PANDA_DATASET_ID])
default_kwargs.update(kwargs)
super().__init__(**default_kwargs)
def setup(self) -> None:
BaseMILTiles.setup(self)
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_binary
def get_data_module(self) -> PandaTilesDataModule:
return PandaTilesDataModule(
root_path=self.local_datasets[0],
max_bag_size=self.max_bag_size,
batch_size=self.batch_size,
batch_size_inf=self.batch_size_inf,
max_bag_size=self.max_bag_size,
max_bag_size_inf=self.max_bag_size_inf,
transforms_dict=self.get_transforms_dict(PandaTilesDataset.IMAGE_COLUMN),
cache_mode=self.cache_mode,
@ -87,6 +88,7 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
crossval_index=self.crossval_index,
dataloader_kwargs=self.get_dataloader_kwargs(),
seed=self.get_effective_random_seed(),
pl_replace_sampler_ddp=self.pl_replace_sampler_ddp,
)
def get_slides_dataset(self) -> Optional[PandaDataset]:
@ -94,13 +96,13 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
def get_test_plot_options(self) -> Set[PlotOption]:
plot_options = super().get_test_plot_options()
plot_options.add(PlotOption.SLIDE_THUMBNAIL_HEATMAP)
plot_options.update([PlotOption.SLIDE_THUMBNAIL, PlotOption.ATTENTION_HEATMAP])
return plot_options
class TilesPandaImageNetMIL(DeepSMILETilesPanda):
def __init__(self, **kwargs: Any) -> None:
super().__init__(encoder_type=ImageNetEncoder.__name__, **kwargs)
super().__init__(encoder_type=Resnet18.__name__, **kwargs)
class TilesPandaImageNetSimCLRMIL(DeepSMILETilesPanda):
@ -110,6 +112,8 @@ class TilesPandaImageNetSimCLRMIL(DeepSMILETilesPanda):
class TilesPandaSSLMIL(DeepSMILETilesPanda):
def __init__(self, **kwargs: Any) -> None:
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint = self.ssl_checkpoint or CheckpointParser(innereye_ssl_checkpoint_binary)
super().__init__(encoder_type=SSLEncoder.__name__, **kwargs)
@ -136,8 +140,6 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
def setup(self) -> None:
BaseMILSlides.setup(self)
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_binary
def get_dataloader_kwargs(self) -> dict:
return dict(
@ -149,6 +151,7 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
return PandaSlidesDataModule(
root_path=self.local_datasets[0],
batch_size=self.batch_size,
batch_size_inf=self.batch_size_inf,
level=self.level,
max_bag_size=self.max_bag_size,
max_bag_size_inf=self.max_bag_size_inf,
@ -163,6 +166,7 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
crossval_count=self.crossval_count,
crossval_index=self.crossval_index,
dataloader_kwargs=self.get_dataloader_kwargs(),
pl_replace_sampler_ddp=self.pl_replace_sampler_ddp,
)
def get_slides_dataset(self) -> PandaDataset:
@ -171,7 +175,7 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
class SlidesPandaImageNetMIL(DeepSMILESlidesPanda):
def __init__(self, **kwargs: Any) -> None:
super().__init__(encoder_type=ImageNetEncoder.__name__, **kwargs)
super().__init__(encoder_type=Resnet18.__name__, **kwargs)
class SlidesPandaImageNetSimCLRMIL(DeepSMILESlidesPanda):
@ -181,6 +185,8 @@ class SlidesPandaImageNetSimCLRMIL(DeepSMILESlidesPanda):
class SlidesPandaSSLMIL(DeepSMILESlidesPanda):
def __init__(self, **kwargs: Any) -> None:
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint = self.ssl_checkpoint or CheckpointParser(innereye_ssl_checkpoint_binary)
super().__init__(encoder_type=SSLEncoder.__name__, **kwargs)

Просмотреть файл

@ -7,24 +7,25 @@
from typing import Any, Dict, Callable, Union
from torch import optim
from monai.transforms import Compose, ScaleIntensityRanged, RandRotate90d, RandFlipd
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_binary
from health_azure.utils import create_from_matching_params
from health_ml.networks.layers.attention_layers import (
TransformerPooling,
TransformerPoolingBenchmark
)
from health_ml.utils.checkpoint_utils import CheckpointParser
from health_ml.deep_learning_config import OptimizerParams
from health_cpath.datasets.panda_dataset import PandaDataset
from health_cpath.datamodules.panda_module_benchmark import PandaSlidesDataModuleBenchmark
from health_cpath.models.encoders import (
HistoSSLEncoder,
ImageNetEncoder_Resnet50,
ImageNetSimCLREncoder,
Resnet50_NoPreproc,
SSLEncoder,
)
from health_cpath.configs.classification.DeepSMILEPanda import DeepSMILESlidesPanda
from health_cpath.models.deepmil import SlidesDeepMILModule
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams
from health_cpath.utils.deepmil_utils import ClassifierParams, EncoderParams, PoolingParams
from health_cpath.utils.naming import MetricsKey, ModelKey, SlideKey
@ -50,9 +51,8 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
"""
Configuration for PANDA experiments from Myronenko et al. 2021:
(https://link.springer.com/chapter/10.1007/978-3-030-87237-3_32)
`is_finetune` sets the fine-tuning mode. For fine-tuning,
batch_size = 2 runs on 8 GPUs with
~ 6:24 min/epoch (train) and ~ 00:50 min/epoch (validation).
`tune_encoder` sets the fine-tuning mode of the encoder. For fine-tuning, batch_size = 2 runs on 8 GPUs
with ~ 6:24 min/epoch (train) and ~ 00:50 min/epoch (validation).
"""
def __init__(self, **kwargs: Any) -> None:
@ -64,6 +64,7 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
encoding_chunk_size=60,
max_bag_size=56,
batch_size=8, # effective batch size = batch_size * num_GPUs
batch_size_inf=8,
max_epochs=50,
l_rate=3e-4,
weight_decay=0,
@ -77,8 +78,9 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
self.l_rate = 3e-5
self.weight_decay = 0.1
# Params specific to fine-tuning
if self.is_finetune:
if self.tune_encoder:
self.batch_size = 2
self.batch_size_inf = 2
super().setup()
def get_transforms_dict(self, image_key: str) -> Dict[ModelKey, Union[Callable, None]]:
@ -100,8 +102,9 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
# Hence, inherited `PandaSlidesDataModuleBenchmark` from `SlidesDataModule`
return PandaSlidesDataModuleBenchmark(
root_path=self.local_datasets[0],
max_bag_size=self.max_bag_size,
batch_size=self.batch_size,
batch_size_inf=self.batch_size_inf,
max_bag_size=self.max_bag_size,
max_bag_size_inf=self.max_bag_size_inf,
level=self.level,
tile_size=self.tile_size,
@ -115,6 +118,7 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
crossval_count=self.crossval_count,
crossval_index=self.crossval_index,
dataloader_kwargs=self.get_dataloader_kwargs(),
pl_replace_sampler_ddp=self.pl_replace_sampler_ddp,
)
def create_model(self) -> SlidesDeepMILModule:
@ -126,13 +130,13 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
n_classes=self.data_module.train_dataset.n_classes,
class_names=self.class_names,
class_weights=self.data_module.class_weights,
dropout_rate=self.dropout_rate,
outputs_folder=self.outputs_folder,
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
encoder_params=create_from_matching_params(self, EncoderParams),
pooling_params=create_from_matching_params(self, PoolingParams),
classifier_params=create_from_matching_params(self, ClassifierParams),
optimizer_params=create_from_matching_params(self, OptimizerParams),
outputs_handler=outputs_handler,
analyse_loss=self.analyse_loss,
)
outputs_handler.set_slides_dataset_for_plots_handlers(self.get_slides_dataset())
return deepmil_module
@ -140,7 +144,7 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
class SlidesPandaImageNetMILBenchmark(DeepSMILESlidesPandaBenchmark):
def __init__(self, **kwargs: Any) -> None:
super().__init__(encoder_type=ImageNetEncoder_Resnet50.__name__, **kwargs)
super().__init__(encoder_type=Resnet50_NoPreproc.__name__, **kwargs)
class SlidesPandaImageNetSimCLRMILBenchmark(DeepSMILESlidesPandaBenchmark):
@ -150,6 +154,8 @@ class SlidesPandaImageNetSimCLRMILBenchmark(DeepSMILESlidesPandaBenchmark):
class SlidesPandaSSLMILBenchmark(DeepSMILESlidesPandaBenchmark):
def __init__(self, **kwargs: Any) -> None:
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint = self.ssl_checkpoint or CheckpointParser(innereye_ssl_checkpoint_binary)
super().__init__(encoder_type=SSLEncoder.__name__, **kwargs)

Просмотреть файл

@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, Generic, Optional, Sequence, Tuple, Type
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, DistributedSampler
from health_ml.utils.bag_utils import BagDataset, multibag_collate
from health_ml.utils.common_utils import _create_generator
@ -47,17 +47,21 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
self,
root_path: Path,
batch_size: int = 1,
batch_size_inf: Optional[int] = None,
max_bag_size: int = 0,
max_bag_size_inf: int = 0,
seed: Optional[int] = None,
transforms_dict: Optional[Dict[ModelKey, Union[Callable, None]]] = None,
crossval_count: int = 0,
crossval_index: int = 0,
pl_replace_sampler_ddp: bool = True,
dataloader_kwargs: Optional[Dict[str, Any]] = None,
dataframe_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""
:param root_path: Root directory of the source dataset.
:param batch_size: Number of slides to load per batch.
:param batch_size_inf: Number of slides to load per batch during inference. If None, use batch_size.
:param max_bag_size: Upper bound on number of tiles in each loaded bag during training stage. If 0 (default),
will return all samples in each bag. If > 0 , bags larger than `max_bag_size` will yield
random subsets of instances. For SlideDataModule, this parameter is used in TileOnGridd Transform to set the
@ -73,21 +77,25 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
By default (`None`).
:param crossval_count: Number of folds to perform.
:param crossval_index: Index of the cross validation split to be performed.
:param pl_replace_sampler_ddp: If True, replace the sampler with a DistributedSampler when using DDP.
:param dataloader_kwargs: Additional keyword arguments for the training, validation, and test dataloaders.
:param dataframe_kwargs: Keyword arguments to pass to `pd.read_csv()` when loading the dataset CSV.
"""
batch_size_inf = batch_size_inf or batch_size
super().__init__()
self.root_path = root_path
self.transforms_dict = transforms_dict
self.batch_size = batch_size
self.max_bag_size = max_bag_size
self.max_bag_size_inf = max_bag_size_inf
self.batch_sizes = {ModelKey.TRAIN: batch_size, ModelKey.VAL: batch_size_inf, ModelKey.TEST: batch_size_inf}
self.bag_sizes = {ModelKey.TRAIN: max_bag_size, ModelKey.VAL: max_bag_size_inf, ModelKey.TEST: max_bag_size_inf}
self.crossval_count = crossval_count
self.crossval_index = crossval_index
self.pl_replace_sampler_ddp = pl_replace_sampler_ddp
self.train_dataset: _SlidesOrTilesDataset
self.val_dataset: _SlidesOrTilesDataset
self.test_dataset: _SlidesOrTilesDataset
self.dataframe_kwargs = dataframe_kwargs or {}
self.train_dataset, self.val_dataset, self.test_dataset = self.get_splits()
self.class_weights = self.train_dataset.get_class_weights()
self.seed = seed
@ -97,6 +105,18 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
"""Create the training, validation, and test datasets"""
raise NotImplementedError
def _get_dataloader(
self, dataset: _SlidesOrTilesDataset, stage: ModelKey, shuffle: bool, **dataloader_kwargs: Any
) -> DataLoader:
raise NotImplementedError
def _get_ddp_sampler(self, dataset: Dataset, stage: ModelKey) -> Optional[DistributedSampler]:
is_distributed = torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1
if is_distributed and not self.pl_replace_sampler_ddp and stage == ModelKey.TRAIN:
assert self.seed is not None, "seed must be set when using distributed training for reproducibility"
return DistributedSampler(dataset, shuffle=True, seed=self.seed)
return None
def train_dataloader(self) -> DataLoader:
return self._get_dataloader(self.train_dataset, # type: ignore
shuffle=True,
@ -117,7 +137,11 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
class TilesDataModule(HistoDataModule[TilesDataset]):
"""Base class to load the tiles of a dataset as train, val, test sets"""
"""Base class to load the tiles of a dataset as train, val, test sets. Note that tiles are always shuffled by
default. This means that we sample a random subset of tiles from each bag at each epoch. This is different from
slides shuffling that is switched on during training time only. This is done to avoid overfitting to the order of
the tiles in each bag.
"""
def __init__(
self,
@ -198,16 +222,11 @@ class TilesDataModule(HistoDataModule[TilesDataset]):
generator = _create_generator(self.seed)
if stage in [ModelKey.VAL, ModelKey.TEST]:
eff_max_bag_size = self.max_bag_size_inf
else:
eff_max_bag_size = self.max_bag_size
bag_dataset = BagDataset(
tiles_dataset, # type: ignore
bag_ids=tiles_dataset.slide_ids,
max_bag_size=eff_max_bag_size,
shuffle_samples=shuffle,
max_bag_size=self.bag_sizes[stage],
shuffle_samples=True,
generator=generator,
)
if self.transforms_dict and self.transforms_dict[stage]:
@ -233,11 +252,14 @@ class TilesDataModule(HistoDataModule[TilesDataset]):
transformed_bag_dataset = self._load_dataset(dataset, stage=stage, shuffle=shuffle)
bag_dataset: BagDataset = transformed_bag_dataset.data # type: ignore
generator = bag_dataset.bag_sampler.generator
sampler = self._get_ddp_sampler(transformed_bag_dataset, stage)
return DataLoader(
transformed_bag_dataset,
batch_size=self.batch_size,
batch_size=self.batch_sizes[stage],
collate_fn=multibag_collate,
shuffle=shuffle,
sampler=sampler,
# sampler option is mutually exclusive with shuffle
shuffle=shuffle if sampler is None else None, # type: ignore
generator=generator,
**dataloader_kwargs,
)
@ -252,13 +274,13 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
def __init__(
self,
level: Optional[int] = 1,
tile_size: Optional[int] = 224,
level: int = 1,
tile_size: int = 224,
step: Optional[int] = None,
random_offset: Optional[bool] = True,
pad_full: Optional[bool] = False,
background_val: Optional[int] = 255,
filter_mode: Optional[str] = "min",
random_offset: bool = True,
pad_full: bool = False,
background_val: int = 255,
filter_mode: str = "min",
**kwargs: Any,
) -> None:
"""
@ -290,8 +312,9 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
self.filter_mode = filter_mode
# TileOnGridd transform expects None to select all foreground tile so we hardcode max_bag_size and
# max_bag_size_inf to None if set to 0
self.max_bag_size = None if self.max_bag_size == 0 else self.max_bag_size # type: ignore
self.max_bag_size_inf = None if self.max_bag_size_inf == 0 else self.max_bag_size_inf # type: ignore
for stage_key, max_bag_size in self.bag_sizes.items():
if max_bag_size == 0:
self.bag_sizes[stage_key] = None # type: ignore
def _load_dataset(self, slides_dataset: SlidesDataset, stage: ModelKey) -> Dataset:
base_transform = Compose(
@ -306,7 +329,7 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
),
TileOnGridd(
keys=slides_dataset.IMAGE_COLUMN,
tile_count=self.max_bag_size if stage == ModelKey.TRAIN else self.max_bag_size_inf,
tile_count=self.bag_sizes[stage],
tile_size=self.tile_size,
step=self.step,
random_offset=self.random_offset if stage == ModelKey.TRAIN else False,
@ -331,11 +354,14 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
**dataloader_kwargs: Any) -> DataLoader:
transformed_slides_dataset = self._load_dataset(dataset, stage)
generator = _create_generator(self.seed)
sampler = self._get_ddp_sampler(transformed_slides_dataset, stage)
return DataLoader(
transformed_slides_dataset,
batch_size=self.batch_size,
batch_size=self.batch_sizes[stage],
collate_fn=image_collate,
shuffle=shuffle,
sampler=sampler,
# sampler option is mutually exclusive with shuffle
shuffle=shuffle if not sampler else None, # type: ignore
generator=generator,
**dataloader_kwargs,
)

Просмотреть файл

@ -12,7 +12,7 @@ import torch
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import Dataset
from health_cpath.utils.naming import SlideKey
from health_cpath.utils.naming import SlideKey, TileKey
DEFAULT_TRAIN_SPLIT_LABEL = "train" # Value used to indicate the training split in `SPLIT_COLUMN`
@ -50,7 +50,8 @@ class TilesDataset(Dataset):
train: Optional[bool] = None,
validate_columns: bool = True,
label_column: str = DEFAULT_LABEL_COLUMN,
n_classes: int = 1) -> None:
n_classes: int = 1,
dataframe_kwargs: Dict[str, Any] = {}) -> None:
"""
:param root: Root directory of the dataset.
:param dataset_csv: Full path to a dataset CSV file, containing at least
@ -65,6 +66,7 @@ class TilesDataset(Dataset):
for this class
:param label_column: CSV column name for tile label. Defaults to `DEFAULT_LABEL_COLUMN="label"`.
:param n_classes: Number of classes indexed in `label_column`. Default is 1 for binary classification.
:param dataframe_kwargs: Keyword arguments to pass to `pd.read_csv()` when loading the dataset CSV.
"""
if self.SPLIT_COLUMN is None and train is not None:
raise ValueError("Train/test split was specified but dataset has no split column")
@ -77,9 +79,10 @@ class TilesDataset(Dataset):
self.dataset_csv = None
else:
self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME
dataset_df = pd.read_csv(self.dataset_csv)
dataset_df = pd.read_csv(self.dataset_csv, **dataframe_kwargs)
dataset_df = dataset_df.set_index(self.TILE_ID_COLUMN)
if dataset_df.index.name != self.TILE_ID_COLUMN:
dataset_df = dataset_df.set_index(self.TILE_ID_COLUMN)
if train is None:
self.dataset_df = dataset_df
else:
@ -131,6 +134,19 @@ class TilesDataset(Dataset):
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=slide_labels)
return torch.as_tensor(class_weights)
def copy_coordinates_columns(self) -> None:
"""Copy columns "left" --> "tile_x" and "top" --> "tile_y" to be consistent with TilesDataset `TILE_X_COLUMN`
and `TILE_Y_COLUMN`."""
if TileKey.TILE_LEFT in self.dataset_df.columns:
self.dataset_df = self.dataset_df.assign(
**{TilesDataset.TILE_X_COLUMN: self.dataset_df[TileKey.TILE_LEFT]}
)
if TileKey.TILE_TOP in self.dataset_df.columns:
self.dataset_df = self.dataset_df.assign(
**{TilesDataset.TILE_Y_COLUMN: self.dataset_df[TileKey.TILE_TOP]}
)
class SlidesDataset(Dataset):
"""Base class for datasets of WSIs, iterating dictionaries of image paths and metadata.
@ -158,7 +174,8 @@ class SlidesDataset(Dataset):
train: Optional[bool] = None,
validate_columns: bool = True,
label_column: str = DEFAULT_LABEL_COLUMN,
n_classes: int = 1) -> None:
n_classes: int = 1,
dataframe_kwargs: Dict[str, Any] = {}) -> None:
"""
:param root: Root directory of the dataset.
:param dataset_csv: Full path to a dataset CSV file, containing at least
@ -173,6 +190,7 @@ class SlidesDataset(Dataset):
for this class
:param label_column: CSV column name for tile label. Default is `DEFAULT_LABEL_COLUMN="label"`.
:param n_classes: Number of classes indexed in `label_column`. Default is 1 for binary classification.
:param dataframe_kwargs: Keyword arguments to pass to `pd.read_csv()` when loading the dataset CSV.
"""
if self.SPLIT_COLUMN is None and train is not None:
raise ValueError("Train/test split was specified but dataset has no split column")
@ -180,14 +198,16 @@ class SlidesDataset(Dataset):
self.root_dir = Path(root)
self.label_column = label_column
self.n_classes = n_classes
self.dataframe_kwargs = dataframe_kwargs
if dataset_df is not None:
self.dataset_csv = None
else:
self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME
dataset_df = pd.read_csv(self.dataset_csv)
dataset_df = pd.read_csv(self.dataset_csv, **self.dataframe_kwargs)
dataset_df = dataset_df.set_index(self.SLIDE_ID_COLUMN)
if dataset_df.index.name != self.SLIDE_ID_COLUMN:
dataset_df = dataset_df.set_index(self.SLIDE_ID_COLUMN)
if train is None:
self.dataset_df = dataset_df
else:
@ -203,12 +223,14 @@ class SlidesDataset(Dataset):
If the constructor is overloaded in a subclass, you can pass `validate_columns=False` and
call `validate_columns()` after creating derived columns, for example.
"""
columns = [self.IMAGE_COLUMN, self.label_column, self.MASK_COLUMN,
self.SPLIT_COLUMN] + list(self.METADATA_COLUMNS)
columns_not_found = []
for column in columns:
if column is not None and column not in self.dataset_df.columns:
columns_not_found.append(column)
mandatory_columns = {self.IMAGE_COLUMN, self.label_column, self.MASK_COLUMN, self.SPLIT_COLUMN}
optional_columns = (
set(self.dataframe_kwargs["usecols"]) if "usecols" in self.dataframe_kwargs else set(self.METADATA_COLUMNS)
)
columns = mandatory_columns.union(optional_columns)
# SLIDE_ID_COLUMN is used for indexing and is not in df.columns anymore
# None might be in columns if SPLITS_COLUMN is None
columns_not_found = columns - set(self.dataset_df.columns) - {None, self.SLIDE_ID_COLUMN}
if len(columns_not_found) > 0:
raise ValueError(f"Expected columns '{columns_not_found}' not found in the dataframe")

Просмотреть файл

@ -10,6 +10,7 @@ import pandas as pd
from monai.config import KeysCollection
from monai.data.image_reader import ImageReader, WSIReader
from monai.transforms import MapTransform
from health_cpath.utils.naming import SlideKey
from health_ml.utils import box_utils
@ -42,9 +43,10 @@ class PandaDataset(SlidesDataset):
dataset_csv: Optional[Union[str, Path]] = None,
dataset_df: Optional[pd.DataFrame] = None,
label_column: str = "isup_grade",
n_classes: int = 6) -> None:
n_classes: int = 6,
dataframe_kwargs: Dict[str, Any] = {}) -> None:
super().__init__(root, dataset_csv, dataset_df, validate_columns=False, label_column=label_column,
n_classes=n_classes)
n_classes=n_classes, dataframe_kwargs=dataframe_kwargs)
# PANDA CSV does not come with paths for image and mask files
slide_ids = self.dataset_df.index
self.dataset_df[self.IMAGE_COLUMN] = "train_images/" + slide_ids + ".tiff"
@ -120,8 +122,9 @@ class LoadPandaROId(MapTransform):
# but relative region size in pixels at the chosen level
scale = mask_obj.resolutions['level_downsamples'][self.level]
scaled_bbox = level0_bbox / scale
origin = (level0_bbox.y, level0_bbox.x)
get_data_kwargs = dict(
location=(level0_bbox.y, level0_bbox.x),
location=origin,
size=(scaled_bbox.h, scaled_bbox.w),
level=self.level,
)
@ -129,7 +132,8 @@ class LoadPandaROId(MapTransform):
data[self.mask_key] = mask[:1] # PANDA segmentation mask is in 'R' channel
data[self.image_key], _ = self.reader.get_data(image_obj, **get_data_kwargs) # type: ignore
data.update(get_data_kwargs)
data['scale'] = scale
data[SlideKey.SCALE] = scale
data[SlideKey.ORIGIN] = origin
mask_obj.close()
image_obj.close()

Просмотреть файл

@ -11,7 +11,6 @@ from torchvision.datasets.vision import VisionDataset
from health_cpath.datasets.base_dataset import TilesDataset
from health_cpath.models.transforms import load_pil_image
from health_cpath.utils.naming import TileKey
from health_cpath.datasets.dataset_return_index import DatasetWithReturnIndex
@ -73,12 +72,7 @@ class PandaTilesDataset(TilesDataset):
dataset_df_filtered = self.dataset_df.sample(n=df_length_random_subset_fraction)
self.dataset_df = dataset_df_filtered
# Copy columns "left" --> "tile_x" and "top" --> "tile_y"
# to be consistent with TilesDataset `TILE_X_COLUMN` and `TILE_Y_COLUMN`
if TileKey.TILE_LEFT in self.dataset_df.columns:
self.dataset_df[TilesDataset.TILE_X_COLUMN] = self.dataset_df[TileKey.TILE_LEFT]
if TileKey.TILE_TOP in self.dataset_df.columns:
self.dataset_df[TilesDataset.TILE_Y_COLUMN] = self.dataset_df[TileKey.TILE_TOP]
self.copy_coordinates_columns()
self.validate_columns()

Просмотреть файл

@ -6,20 +6,20 @@ import torch
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pathlib import Path
from pytorch_lightning import LightningModule
from torch import Tensor, argmax, mode, nn, optim, round, set_grad_enabled
from torchmetrics import AUROC, F1Score, Accuracy, ConfusionMatrix, Precision, Recall, CohenKappa # type: ignore
from torch import Tensor, argmax, mode, nn, optim, round
from torchmetrics import (AUROC, F1Score, Accuracy, ConfusionMatrix, Precision,
Recall, CohenKappa, AveragePrecision, Specificity)
from health_ml.utils import log_on_epoch
from health_ml.deep_learning_config import OptimizerParams
from health_cpath.models.encoders import IdentityEncoder
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams
from health_cpath.utils.deepmil_utils import ClassifierParams, EncoderParams, PoolingParams
from health_cpath.datasets.base_dataset import TilesDataset
from health_cpath.utils.naming import MetricsKey, ResultsKey, SlideKey, ModelKey, TileKey
from health_cpath.utils.naming import DeepMILSubmodules, MetricsKey, ResultsKey, SlideKey, ModelKey, TileKey
from health_cpath.utils.output_utils import (BatchResultsType, DeepMILOutputsHandler, EpochResultsType,
validate_class_names)
validate_class_names, EXTRA_PREFIX)
RESULTS_COLS = [ResultsKey.SLIDE_ID, ResultsKey.TILE_ID, ResultsKey.IMAGE_PATH, ResultsKey.PROB,
ResultsKey.CLASS_PROBS, ResultsKey.PRED_LABEL, ResultsKey.TRUE_LABEL, ResultsKey.BAG_ATTN]
@ -39,31 +39,31 @@ class BaseDeepMILModule(LightningModule):
n_classes: int,
class_weights: Optional[Tensor] = None,
class_names: Optional[Sequence[str]] = None,
dropout_rate: Optional[float] = None,
verbose: bool = False,
ssl_ckpt_run_id: Optional[str] = None,
outputs_folder: Optional[Path] = None,
encoder_params: EncoderParams = EncoderParams(),
pooling_params: PoolingParams = PoolingParams(),
classifier_params: ClassifierParams = ClassifierParams(),
optimizer_params: OptimizerParams = OptimizerParams(),
outputs_handler: Optional[DeepMILOutputsHandler] = None) -> None:
outputs_folder: Optional[Path] = None,
outputs_handler: Optional[DeepMILOutputsHandler] = None,
analyse_loss: Optional[bool] = False,
verbose: bool = False,
) -> None:
"""
:param label_column: Label key for input batch dictionary.
:param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes should be
set to 1.
:param class_weights: Tensor containing class weights (default=None).
:param class_names: The names of the classes if available (default=None).
:param dropout_rate: Rate of pre-classifier dropout (0-1). `None` for no dropout (default).
:param verbose: if True statements about memory usage are output at each step.
:param ssl_ckpt_run_id: Optional parameter to provide the AML run id from where to download the checkpoint
if using `SSLEncoder`.
:param outputs_folder: Path to output folder where encoder checkpoint is downloaded.
:param encoder_params: Encoder parameters that specify all encoder specific attributes.
:param pooling_params: Pooling layer parameters that specify all encoder specific attributes.
:param classifier_params: Classifier parameters that specify all classifier specific attributes.
:param optimizer_params: Optimizer parameters that specify all specific attributes to be used for oprimization.
:param outputs_folder: Path to output folder where encoder checkpoint is downloaded.
:param outputs_handler: A configured :py:class:`DeepMILOutputsHandler` object to save outputs for the best
validation epoch and test stage. If omitted (default), no outputs will be saved to disk (aside from usual
metrics logging).
:param analyse_loss: If True, the loss is analysed per sample and analysed with LossAnalysisCallback.
:param verbose: if True statements about memory usage are output at each step.
"""
super().__init__()
@ -73,8 +73,9 @@ class BaseDeepMILModule(LightningModule):
self.class_weights = class_weights
self.class_names = validate_class_names(class_names, self.n_classes)
self.dropout_rate = dropout_rate
self.encoder_params = encoder_params
self.pooling_params = pooling_params
self.classifier_params = classifier_params
self.optimizer_params = optimizer_params
self.save_hyperparameters()
@ -82,44 +83,84 @@ class BaseDeepMILModule(LightningModule):
self.outputs_handler = outputs_handler
# This flag can be switched on before invoking trainer.validate() to enable saving additional time/memory
# consuming validation outputs
self.run_extra_val_epoch = False
# consuming validation outputs via calling self.on_run_extra_validation_epoch()
self._on_extra_val_epoch = False
# Model components
self.encoder = encoder_params.get_encoder(ssl_ckpt_run_id, outputs_folder)
self.encoder = encoder_params.get_encoder(outputs_folder)
self.aggregation_fn, self.num_pooling = pooling_params.get_pooling_layer(self.encoder.num_encoding)
self.classifier_fn = self.get_classifier()
self.classifier_fn = classifier_params.get_classifier(self.num_pooling, self.n_classes)
self.activation_fn = self.get_activation()
self.loss_fn = self.get_loss()
# Loss function
self.loss_fn = self.get_loss(reduction="mean")
self.loss_fn_no_reduction = self.get_loss(reduction="none")
self.analyse_loss = analyse_loss
# Metrics Objects
self.train_metrics = self.get_metrics()
self.val_metrics = self.get_metrics()
self.test_metrics = self.get_metrics()
def get_classifier(self) -> Callable:
classifier_layer = nn.Linear(in_features=self.num_pooling,
out_features=self.n_classes)
if self.dropout_rate is None:
return classifier_layer
elif 0 <= self.dropout_rate < 1:
return nn.Sequential(nn.Dropout(self.dropout_rate), classifier_layer)
else:
raise ValueError(f"Dropout rate should be in [0, 1), got {self.dropout_rate}")
@staticmethod
def copy_weights(
current_submodule: nn.Module, pretrained_submodule: nn.Module, submodule_name: DeepMILSubmodules
) -> None:
"""Copy weights from pretrained submodule to current submodule.
def get_loss(self) -> Callable:
:param current_submodule: Submodule to copy weights to.
:param pretrained_submodule: Submodule to copy weights from.
:param submodule_name: Name of the submodule.
"""
def _total_params(submodule: nn.Module) -> int:
return sum(p.numel() for p in submodule.parameters())
pre_total_params = _total_params(pretrained_submodule)
cur_total_params = _total_params(current_submodule)
if pre_total_params != cur_total_params:
raise ValueError(f"Submodule {submodule_name} has different number of parameters "
f"({cur_total_params} vs {pre_total_params}) from pretrained model.")
for param, pretrained_param in zip(
current_submodule.state_dict().values(), pretrained_submodule.state_dict().values()
):
try:
param.data.copy_(pretrained_param.data)
except Exception as e:
raise ValueError(f"Failed to copy weights for {submodule_name} because of the following exception: {e}")
def transfer_weights(self, pretrained_checkpoint_path: Optional[Path]) -> None:
"""Transfer weights from pretrained checkpoint if provided."""
if pretrained_checkpoint_path:
pretrained_model = self.load_from_checkpoint(checkpoint_path=str(pretrained_checkpoint_path))
if self.encoder_params.pretrained_encoder:
self.copy_weights(self.encoder, pretrained_model.encoder, DeepMILSubmodules.ENCODER)
if self.pooling_params.pretrained_pooling:
self.copy_weights(self.aggregation_fn, pretrained_model.aggregation_fn, DeepMILSubmodules.POOLING)
if self.classifier_params.pretrained_classifier:
if pretrained_model.n_classes != self.n_classes:
raise ValueError(f"Number of classes in pretrained model {pretrained_model.n_classes} "
f"does not match number of classes in current model {self.n_classes}.")
self.copy_weights(self.classifier_fn, pretrained_model.classifier_fn, DeepMILSubmodules.CLASSIFIER)
def get_loss(self, reduction: str = "mean") -> Callable:
if self.n_classes > 1:
if self.class_weights is None:
return nn.CrossEntropyLoss()
return nn.CrossEntropyLoss(reduction=reduction)
else:
class_weights = self.class_weights.float()
return nn.CrossEntropyLoss(weight=class_weights)
return nn.CrossEntropyLoss(weight=class_weights, reduction=reduction)
else:
pos_weight = None
if self.class_weights is not None:
pos_weight = Tensor([self.class_weights[1] / (self.class_weights[0] + 1e-5)])
return nn.BCEWithLogitsLoss(pos_weight=pos_weight)
return nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction=reduction)
def get_activation(self) -> Callable:
if self.n_classes > 1:
@ -133,26 +174,38 @@ class BaseDeepMILModule(LightningModule):
def get_metrics(self) -> nn.ModuleDict:
if self.n_classes > 1:
return nn.ModuleDict({MetricsKey.ACC: Accuracy(num_classes=self.n_classes),
MetricsKey.ACC_MACRO: Accuracy(num_classes=self.n_classes, average='macro'),
MetricsKey.ACC_WEIGHTED: Accuracy(num_classes=self.n_classes, average='weighted'),
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
# Quadratic Weighted Kappa (QWK) used in PANDA challenge
# is calculated using Cohen's Kappa with quadratic weights
# https://www.kaggle.com/code/reighns/understanding-the-quadratic-weighted-kappa/
MetricsKey.COHENKAPPA: CohenKappa(num_classes=self.n_classes, weights='quadratic'),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes)})
return nn.ModuleDict({
MetricsKey.ACC: Accuracy(num_classes=self.n_classes),
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
MetricsKey.AVERAGE_PRECISION: AveragePrecision(num_classes=self.n_classes),
# Quadratic Weighted Kappa (QWK) used in PANDA challenge
# is calculated using Cohen's Kappa with quadratic weights
# https://www.kaggle.com/code/reighns/understanding-the-quadratic-weighted-kappa/
MetricsKey.COHENKAPPA: CohenKappa(num_classes=self.n_classes, weights='quadratic'),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes),
# Metrics below are computed for multi-class case only
MetricsKey.ACC_MACRO: Accuracy(num_classes=self.n_classes, average='macro'),
MetricsKey.ACC_WEIGHTED: Accuracy(num_classes=self.n_classes, average='weighted')})
else:
threshold = 0.5
return nn.ModuleDict({MetricsKey.ACC: Accuracy(threshold=threshold),
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
MetricsKey.PRECISION: Precision(threshold=threshold),
MetricsKey.RECALL: Recall(threshold=threshold),
MetricsKey.F1: F1Score(threshold=threshold),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=2, threshold=threshold)})
return nn.ModuleDict({
MetricsKey.ACC: Accuracy(),
MetricsKey.AUROC: AUROC(num_classes=None),
# Average precision is a measure of area under the PR curve
MetricsKey.AVERAGE_PRECISION: AveragePrecision(),
MetricsKey.COHENKAPPA: CohenKappa(num_classes=2, weights='quadratic'),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=2),
# Metrics below are computed for binary case only
MetricsKey.F1: F1Score(),
MetricsKey.PRECISION: Precision(),
MetricsKey.RECALL: Recall(),
MetricsKey.SPECIFICITY: Specificity()})
def log_metrics(self, stage: str) -> None:
valid_stages = [stage for stage in ModelKey]
def get_extra_prefix(self) -> str:
"""Get extra prefix for the metrics name to avoir overriding best validation metrics."""
return EXTRA_PREFIX if self._on_extra_val_epoch else ""
def log_metrics(self, stage: str, prefix: str = '') -> None:
valid_stages = set([stage for stage in ModelKey])
if stage not in valid_stages:
raise Exception(f"Invalid stage. Chose one of {valid_stages}")
for metric_name, metric_object in self.get_metrics_dict(stage).items():
@ -160,29 +213,46 @@ class BaseDeepMILModule(LightningModule):
metric_value = metric_object.compute()
metric_value_n = metric_value / metric_value.sum(axis=1, keepdims=True)
for i in range(metric_value_n.shape[0]):
log_on_epoch(self, f'{stage}/{self.class_names[i]}', metric_value_n[i, i])
log_on_epoch(self, f'{prefix}{stage}/{self.class_names[i]}', metric_value_n[i, i])
else:
log_on_epoch(self, f'{stage}/{metric_name}', metric_object)
log_on_epoch(self, f'{prefix}{stage}/{metric_name}', metric_object)
def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
should_enable_encoder_grad = torch.is_grad_enabled() and self.encoder_params.is_finetune
with set_grad_enabled(should_enable_encoder_grad):
if self.encoder_params.encoding_chunk_size > 0:
embeddings = []
chunks = torch.split(instances, self.encoder_params.encoding_chunk_size)
for chunk in chunks:
chunk_embeddings = self.encoder(chunk)
embeddings.append(chunk_embeddings)
instance_features = torch.cat(embeddings)
else:
instance_features = self.encoder(instances) # N X L x 1 x 1
def get_instance_features(self, instances: Tensor) -> Tensor:
if not self.encoder_params.tune_encoder:
self.encoder.eval()
if self.encoder_params.encoding_chunk_size > 0:
embeddings = []
chunks = torch.split(instances, self.encoder_params.encoding_chunk_size)
for chunk in chunks:
chunk_embeddings = self.encoder(chunk)
embeddings.append(chunk_embeddings)
instance_features = torch.cat(embeddings)
else:
instance_features = self.encoder(instances) # N X L x 1 x 1
return instance_features
def get_attentions_and_bag_features(self, instance_features: Tensor) -> Tuple[Tensor, Tensor]:
if not self.pooling_params.tune_pooling:
self.aggregation_fn.eval()
attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L
bag_features = bag_features.view(1, -1)
return attentions, bag_features
def get_bag_logit(self, bag_features: Tensor) -> Tensor:
if not self.classifier_params.tune_classifier:
self.classifier_fn.eval()
bag_logit = self.classifier_fn(bag_features)
return bag_logit
def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
instance_features = self.get_instance_features(instances)
attentions, bag_features = self.get_attentions_and_bag_features(instance_features)
bag_logit = self.get_bag_logit(bag_features)
return bag_logit, attentions
def configure_optimizers(self) -> optim.Optimizer:
return optim.Adam(self.parameters(), lr=self.optimizer_params.l_rate,
return optim.Adam(filter(lambda p: p.requires_grad, self.parameters()),
lr=self.optimizer_params.l_rate,
weight_decay=self.optimizer_params.weight_decay,
betas=self.optimizer_params.adam_betas)
@ -212,21 +282,32 @@ class BaseDeepMILModule(LightningModule):
"""Update training results with data specific info. This can be either tiles or slides related metadata."""
raise NotImplementedError
def update_slides_selection(self, stage: str, batch: Dict, results: Dict) -> None:
if (
(stage == ModelKey.TEST or (stage == ModelKey.VAL and self._on_extra_val_epoch))
and self.outputs_handler
and self.outputs_handler.tiles_selector
):
self.outputs_handler.tiles_selector.update_slides_selection(batch, results)
def _compute_loss(self, loss_fn: Callable, bag_logits: Tensor, bag_labels: Tensor) -> Tensor:
if self.n_classes > 1:
return loss_fn(bag_logits, bag_labels.long())
else:
return loss_fn(bag_logits.squeeze(1), bag_labels.float())
def _shared_step(self, batch: Dict, batch_idx: int, stage: str) -> BatchResultsType:
bag_logits, bag_labels, bag_attn_list = self.compute_bag_labels_logits_and_attn_maps(batch)
if self.n_classes > 1:
loss = self.loss_fn(bag_logits, bag_labels.long())
else:
loss = self.loss_fn(bag_logits.squeeze(1), bag_labels.float())
loss = self._compute_loss(self.loss_fn, bag_logits, bag_labels)
predicted_probs = self.activation_fn(bag_logits)
if self.n_classes > 1:
predicted_labels = argmax(predicted_probs, dim=1)
probs_perclass = predicted_probs
else:
predicted_labels = round(predicted_probs)
predicted_labels = round(predicted_probs).int()
probs_perclass = Tensor([[1.0 - predicted_probs[i][0].item(), predicted_probs[i][0].item()]
for i in range(len(predicted_probs))])
@ -237,9 +318,14 @@ class BaseDeepMILModule(LightningModule):
if self.n_classes == 1:
predicted_probs = predicted_probs.squeeze(dim=1)
results = dict()
if self.analyse_loss and stage in [ModelKey.TRAIN, ModelKey.VAL]:
loss_per_sample = self._compute_loss(self.loss_fn_no_reduction, bag_logits, bag_labels)
results[ResultsKey.LOSS_PER_SAMPLE] = loss_per_sample.detach().cpu().numpy()
bag_labels = bag_labels.view(-1, 1)
results = dict()
for metric_object in self.get_metrics_dict(stage).values():
metric_object.update(predicted_probs, bag_labels.view(batch_size,).int())
results.update({ResultsKey.LOSS: loss,
@ -250,56 +336,46 @@ class BaseDeepMILModule(LightningModule):
ResultsKey.BAG_ATTN: bag_attn_list
})
self.update_results_with_data_specific_info(batch=batch, results=results)
if (
(stage == ModelKey.TEST or (stage == ModelKey.VAL and self.run_extra_val_epoch))
and self.outputs_handler
and self.outputs_handler.tiles_selector
):
self.outputs_handler.tiles_selector.update_slides_selection(batch, results)
self.update_slides_selection(stage=stage, batch=batch, results=results)
return results
def training_step(self, batch: Dict, batch_idx: int) -> Tensor: # type: ignore
def training_step(self, batch: Dict, batch_idx: int) -> BatchResultsType: # type: ignore
train_result = self._shared_step(batch, batch_idx, ModelKey.TRAIN)
self.log('train/loss', train_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
sync_dist=True)
self.log('train/loss', train_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True, sync_dist=True)
if self.verbose:
print(f"After loading images batch {batch_idx} -", _format_cuda_memory_stats())
return train_result[ResultsKey.LOSS]
results = {ResultsKey.LOSS: train_result[ResultsKey.LOSS]}
if self.analyse_loss:
results.update({ResultsKey.LOSS_PER_SAMPLE: train_result[ResultsKey.LOSS_PER_SAMPLE],
ResultsKey.CLASS_PROBS: train_result[ResultsKey.CLASS_PROBS],
ResultsKey.TILE_ID: train_result[ResultsKey.TILE_ID]})
return results
def validation_step(self, batch: Dict, batch_idx: int) -> BatchResultsType: # type: ignore
val_result = self._shared_step(batch, batch_idx, ModelKey.VAL)
self.log('val/loss', val_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
sync_dist=True)
name = f'{self.get_extra_prefix()}val/loss'
self.log(name, val_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True, sync_dist=True)
return val_result
def test_step(self, batch: Dict, batch_idx: int) -> BatchResultsType: # type: ignore
test_result = self._shared_step(batch, batch_idx, ModelKey.TEST)
self.log('test/loss', test_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
sync_dist=True)
self.log('test/loss', test_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True, sync_dist=True)
return test_result
def training_epoch_end(self, outputs: EpochResultsType) -> None: # type: ignore
self.log_metrics(ModelKey.TRAIN)
def validation_epoch_end(self, epoch_results: EpochResultsType) -> None: # type: ignore
self.log_metrics(ModelKey.VAL)
self.log_metrics(stage=ModelKey.VAL, prefix=self.get_extra_prefix())
if self.outputs_handler:
if self.run_extra_val_epoch:
self.outputs_handler.val_plots_handler.plot_options = (
self.outputs_handler.test_plots_handler.plot_options
)
self.outputs_handler.save_validation_outputs(
epoch_results=epoch_results,
metrics_dict=self.get_metrics_dict(ModelKey.VAL), # type: ignore
epoch=self.current_epoch,
is_global_rank_zero=self.global_rank == 0,
run_extra_val_epoch=self.run_extra_val_epoch
on_extra_val=self._on_extra_val_epoch
)
# Reset the top and bottom slides heaps
if self.outputs_handler.tiles_selector is not None:
self.outputs_handler.tiles_selector._clear_cached_slides_heaps()
def test_epoch_end(self, epoch_results: EpochResultsType) -> None: # type: ignore
self.log_metrics(ModelKey.TEST)
if self.outputs_handler:
@ -308,6 +384,13 @@ class BaseDeepMILModule(LightningModule):
is_global_rank_zero=self.global_rank == 0
)
def on_run_extra_validation_epoch(self) -> None:
"""Hook to be called at the beginning of an extra validation epoch to set validation plots options to the same
as the test plots options."""
self._on_extra_val_epoch = True
if self.outputs_handler:
self.outputs_handler.val_plots_handler.plot_options = self.outputs_handler.test_plots_handler.plot_options
class TilesDeepMILModule(BaseDeepMILModule):
"""Base class for Tiles based deep multiple-instance learning."""

Просмотреть файл

@ -10,8 +10,8 @@ import numpy as np
import torch
from pl_bolts.models.self_supervised import SimCLR
from torch import Tensor as T, nn
from torchvision.models import resnet18
from torchvision.transforms import Compose
from torchvision.models import resnet18, resnet50
from monai.transforms import Compose
from health_cpath.utils.layer_utils import (get_imagenet_preprocessing,
load_weights_to_model,
@ -62,24 +62,49 @@ class ImageNetEncoder(TileEncoder):
"""Feature extractor pretrained for classification on ImageNet"""
def __init__(self, feature_extraction_model: Callable[..., nn.Module],
tile_size: int, n_channels: int = 3) -> None:
tile_size: int, n_channels: int = 3, apply_imagenet_preprocessing: bool = True) -> None:
"""
:param feature_extraction_model: A function accepting a `pretrained` keyword argument that
returns a classifier pretrained on ImageNet, such as the ones from `torchvision.models.*`.
:param tile_size: Tile width/height, in pixels.
:param n_channels: Number of channels in the tile (default=3).
:param apply_imagenet_preprocessing: Whether to apply ImageNet preprocessing to the input.
"""
self.create_feature_extractor_fn = feature_extraction_model
self.apply_imagenet_preprocessing = apply_imagenet_preprocessing
super().__init__(tile_size=tile_size, n_channels=n_channels)
def _get_preprocessing(self) -> Callable:
return get_imagenet_preprocessing()
base_preprocessing = super()._get_preprocessing()
if self.apply_imagenet_preprocessing:
return Compose([get_imagenet_preprocessing(), base_preprocessing]).flatten()
return base_preprocessing
def _get_encoder(self) -> Tuple[torch.nn.Module, int]:
pretrained_model = self.create_feature_extractor_fn(pretrained=True)
return setup_feature_extractor(pretrained_model, self.input_dim)
class Resnet18(ImageNetEncoder):
def __init__(self, tile_size: int, n_channels: int = 3) -> None:
super().__init__(resnet18, tile_size, n_channels, apply_imagenet_preprocessing=True)
class Resnet18_NoPreproc(ImageNetEncoder):
def __init__(self, tile_size: int, n_channels: int = 3) -> None:
super().__init__(resnet18, tile_size, n_channels, apply_imagenet_preprocessing=False)
class Resnet50(ImageNetEncoder):
def __init__(self, tile_size: int, n_channels: int = 3) -> None:
super().__init__(resnet50, tile_size, n_channels, apply_imagenet_preprocessing=True)
class Resnet50_NoPreproc(ImageNetEncoder):
def __init__(self, tile_size: int, n_channels: int = 3) -> None:
super().__init__(resnet50, tile_size, n_channels, apply_imagenet_preprocessing=False)
class ImageNetSimCLREncoder(TileEncoder):
"""SimCLR encoder pretrained on ImageNet"""
@ -143,23 +168,3 @@ class HistoSSLEncoder(TileEncoder):
resnet18_model = resnet18(pretrained=False)
histossl_encoder = load_weights_to_model(self.WEIGHTS_URL, resnet18_model)
return setup_feature_extractor(histossl_encoder, self.input_dim)
class ImageNetEncoder_Resnet50(TileEncoder):
# Myronenko et al. 2021 uses intensity scaling (0-255)-->(0-1), and no ImageNet preprocessing is used.
# ResNet50 CNN encoder without ImageNet preprocessing is defined below.
def __init__(self, feature_extraction_model: Callable[..., nn.Module],
tile_size: int, n_channels: int = 3) -> None:
"""
:param feature_extraction_model: A function accepting a `pretrained` keyword argument that
returns a classifier pretrained on ImageNet, such as the ones from `torchvision.models.*`.
:param tile_size: Tile width/height, in pixels.
:param n_channels: Number of channels in the tile (default=3).
"""
self.create_feature_extractor_fn = feature_extraction_model
super().__init__(tile_size=tile_size, n_channels=n_channels)
def _get_encoder(self) -> Tuple[nn.Module, int]:
pretrained_model = self.create_feature_extractor_fn(pretrained=True)
return setup_feature_extractor(pretrained_model, self.input_dim)

Просмотреть файл

@ -26,16 +26,35 @@ def load_pil_image(image_path: PathOrString) -> PIL.Image.Image:
return image
def load_image_as_tensor(image_path: PathOrString) -> torch.Tensor:
"""Load an image as a tensor from the given path"""
pil_image = load_pil_image(image_path)
return to_tensor(pil_image)
def load_image_as_tensor(image_path: PathOrString, scale_intensity: bool = True) -> torch.Tensor:
"""Load an image as a tensor from the given path
:param image_path: path to the image
:param scale_intensity: if True, use `to_tensor` from torchvision which scales the image pixel intensities to
[0, 1] by defaul as [C, H, W] tensors. Otherwise, only transpose the image to [C, H, W] format and return it as a
torch tensor.
"""
pil_image = load_pil_image(image_path) # pil_image is in channels last format [H, W, C]
if scale_intensity:
return to_tensor(pil_image) # to_tensor scales the image pixel intensities to [0, 1] as [C, H, W] tensors
else:
return torch.from_numpy(pil_image.transpose((2, 0, 1))).contiguous() # only transpose to [C, H, W]
def load_image_stack_as_tensor(image_paths: Sequence[PathOrString],
progress: bool = False) -> torch.Tensor:
"""Load a batch of images of the same size as a tensor from the given paths"""
loading_generator = (load_image_as_tensor(path) for path in image_paths)
progress: bool = False,
scale_intensity: bool = True) -> torch.Tensor:
"""Load a batch of images of the same size as a tensor from the given paths
:param image_paths: paths to the images
:param progress: if True, show a progress bar
:param scale_intensity: if True, use `to_tensor` from torchvision which scales the image pixel intensities to
[0, 1] by defaul as [C, H, W] tensors. Otherwise, only transpose the image to [C, H, W] format and return it as a
torch tensor.
"""
loading_generator = (load_image_as_tensor(path, scale_intensity) for path in image_paths)
if progress:
from tqdm import tqdm
loading_generator = tqdm(loading_generator, desc="Loading image stack",
@ -98,20 +117,27 @@ class LoadTilesBatchd(MapTransform):
# Cannot reuse MONAI readers because they support stacking only images with no channels
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False,
progress: bool = False) -> None:
progress: bool = False, scale_intensity: bool = True) -> None:
"""
:param keys: Key(s) for the image path(s) in the input dictionary.
:param allow_missing_keys: If `False` (default), raises an exception when an input
dictionary is missing any of the specified keys.
:param progress: Whether to display a tqdm progress bar.
:param scale_intensity: if True, use `to_tensor` from torchvision which scales the image pixel intensities to
[0, 1] by defaul as [C, H, W] tensors. Otherwise, only transpose the image to [C, H, W] format and return it as
a torch tensor.
"""
super().__init__(keys, allow_missing_keys)
self.progress = progress
self.scale_intensity = scale_intensity
def __call__(self, data: Mapping) -> Mapping:
out_data = dict(data) # create shallow copy
for key in self.key_iterator(out_data):
out_data[key] = load_image_stack_as_tensor(data[key], progress=self.progress)
out_data[key] = load_image_stack_as_tensor(
data[key], progress=self.progress, scale_intensity=self.scale_intensity
)
return out_data

Просмотреть файл

@ -0,0 +1,51 @@
from azure.storage.blob import generate_blob_sas, BlobSasPermissions
from azureml.core import Workspace
from datetime import datetime, timedelta
from health_azure import get_workspace
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
from typing import Optional
def get_checkpoint_url_from_aml_run(
run_id: str,
checkpoint_filename: str,
expiry_days: int = 1,
aml_workspace: Optional[Workspace] = None,
sas_token: Optional[str] = None,
) -> str:
"""Generate a SAS URL for the checkpoint file in the given run.
:param run_id: The run ID of the checkpoint.
:param checkpoint_filename: The filename of the checkpoint.
:param expiry_days: The number of days the SAS URL is valid for, defaults to 30.
:param aml_workspace: The Azure ML workspace to use, defaults to the default workspace.
:param sas_token: The SAS token to use, defaults to None.
:return: The SAS URL for the checkpoint.
"""
datastore = get_workspace(aml_workspace=aml_workspace).get_default_datastore()
account_name = datastore.account_name
container_name = 'azureml'
blob_name = f'ExperimentRun/dcid.{run_id}/{DEFAULT_AML_CHECKPOINT_DIR}/{checkpoint_filename}'
if not sas_token:
sas_token = generate_blob_sas(account_name=datastore.account_name,
container_name=container_name,
blob_name=blob_name,
account_key=datastore.account_key,
permission=BlobSasPermissions(read=True),
expiry=datetime.utcnow() + timedelta(days=expiry_days))
return f'https://{account_name}.blob.core.windows.net/{container_name}/{blob_name}?{sas_token}'
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--run_id', type=str, help='The run id of the model checkpoint')
parser.add_argument('--checkpoint_filename', type=str, default='last.ckpt',
help='The filename of the model checkpoint. Default: last.ckpt')
parser.add_argument('--expiry_days', type=int, default=30,
help='The number of hours for which the SAS token is valid. Default: 30 for 1 month')
args = parser.parse_args()
url = get_checkpoint_url_from_aml_run(args.run_id, args.checkpoint_filename, args.expiry_days)
print(f'Checkpoint URL: {url}')

Просмотреть файл

@ -12,22 +12,25 @@ from matplotlib import pyplot as plt
from health_azure.utils import get_aml_run_from_run_id, get_workspace
from health_ml.utils.reports import HTMLReport
from health_cpath.utils.analysis_plot_utils import (add_training_curves_legend, plot_confusion_matrices,
plot_crossval_roc_and_pr_curves,
plot_crossval_training_curves)
plot_hyperdrive_roc_and_pr_curves,
plot_hyperdrive_training_curves)
from health_cpath.utils.output_utils import (AML_LEGACY_TEST_OUTPUTS_CSV, AML_TEST_OUTPUTS_CSV,
AML_VAL_OUTPUTS_CSV)
from health_cpath.utils.report_utils import (collect_crossval_metrics, collect_crossval_outputs,
crossval_runs_have_val_and_test_outputs, get_best_epoch_metrics,
get_best_epochs, get_crossval_metrics_table, get_formatted_run_info,
collect_class_info)
from health_cpath.utils.report_utils import (collect_hyperdrive_metrics, collect_hyperdrive_outputs,
child_runs_have_val_and_test_outputs, get_best_epoch_metrics,
get_best_epochs, get_child_runs_hyperparams, get_hyperdrive_metrics_table,
get_formatted_run_info, collect_class_info, get_max_epochs,
download_hyperdrive_metrics_if_required)
from health_cpath.utils.naming import MetricsKey, ModelKey
def generate_html_report(parent_run_id: str, output_dir: Path,
workspace_config_path: Optional[Path] = None,
include_test: bool = False, overwrite: bool = False) -> None:
include_test: bool = False, overwrite: bool = False,
hyperdrive_arg_name: str = "crossval_index",
primary_metric: str = MetricsKey.AUROC) -> None:
"""
Function to generate an HTML report of a cross validation AML run.
Function to generate an HTML report of a Hyperdrive AML run (e.g., cross validation, different random seeds, ...).
:param run_id: The parent Hyperdrive run ID.
:param output_dir: Directory where to download Azure ML data and save the report.
@ -35,6 +38,9 @@ def generate_html_report(parent_run_id: str, output_dir: Path,
If omitted, will try to load default workspace.
:param include_test: Include test results in the generated report.
:param overwrite: Forces (re)download of metrics and output files, even if they already exist locally.
:param hyperdrive_arg_name: Name of the Hyperdrive argument used for indexing the child runs.
Default `crossval_index`.
:param primary_metric: Name of the reference metric to optimise. Default `MetricsKey.AUROC`.
"""
aml_workspace = get_workspace(workspace_config_path=workspace_config_path)
parent_run = get_aml_run_from_run_id(parent_run_id, aml_workspace=aml_workspace)
@ -48,26 +54,38 @@ def generate_html_report(parent_run_id: str, output_dir: Path,
report.add_heading("Azure ML metrics", level=2)
# Download metrics from AML. Can take several seconds for each child run
metrics_df = collect_crossval_metrics(parent_run_id, report_dir, aml_workspace, overwrite=overwrite)
best_epochs = get_best_epochs(metrics_df, f'{ModelKey.VAL}/{MetricsKey.AUROC}', maximise=True)
metrics_json = download_hyperdrive_metrics_if_required(parent_run_id, report_dir, aml_workspace,
overwrite=overwrite, hyperdrive_arg_name=hyperdrive_arg_name)
# Get metrics dataframe from the downloaded json file
metrics_df = collect_hyperdrive_metrics(metrics_json=metrics_json)
hyperparameters_children = get_child_runs_hyperparams(metrics_df)
max_epochs_dict = get_max_epochs(hyperparams_children=hyperparameters_children)
best_epochs = get_best_epochs(metrics_df=metrics_df, primary_metric=f'{ModelKey.VAL}/{primary_metric}',
max_epochs_dict=max_epochs_dict, maximise=True)
# Add training curves for loss and AUROC (train and val.)
render_training_curves(report, heading="Training curves", level=3,
metrics_df=metrics_df, best_epochs=best_epochs, report_dir=report_dir)
metrics_df=metrics_df, best_epochs=best_epochs, report_dir=report_dir,
primary_metric=primary_metric)
# Get metrics list with class names
num_classes, class_names = collect_class_info(metrics_df=metrics_df)
num_classes, class_names = collect_class_info(hyperparams_children=hyperparameters_children)
base_metrics_list: List[str]
base_metrics_list: List[str] = [MetricsKey.ACC, MetricsKey.AUROC, MetricsKey.AVERAGE_PRECISION,
MetricsKey.COHENKAPPA]
if num_classes > 1:
base_metrics_list = [MetricsKey.ACC, MetricsKey.AUROC]
base_metrics_list += [MetricsKey.ACC_MACRO, MetricsKey.ACC_WEIGHTED]
else:
base_metrics_list = [MetricsKey.ACC, MetricsKey.AUROC, MetricsKey.PRECISION, MetricsKey.RECALL, MetricsKey.F1]
base_metrics_list += [MetricsKey.F1, MetricsKey.PRECISION, MetricsKey.RECALL, MetricsKey.SPECIFICITY]
base_metrics_list += class_names
# Add tables with relevant metrics (val. and test)
render_metrics_table(report, heading="Validation metrics (best epoch based on maximum validation AUROC)", level=3,
render_metrics_table(report,
heading=f"Validation metrics (best epoch based on maximum validation {primary_metric})",
level=3,
metrics_df=metrics_df, best_epochs=best_epochs,
base_metrics_list=base_metrics_list, metrics_prefix=f'{ModelKey.VAL}/')
@ -76,54 +94,58 @@ def generate_html_report(parent_run_id: str, output_dir: Path,
metrics_df=metrics_df, best_epochs=None,
base_metrics_list=base_metrics_list, metrics_prefix=f'{ModelKey.TEST}/')
has_val_and_test_outputs = crossval_runs_have_val_and_test_outputs(parent_run)
# Get output data frames
if has_val_and_test_outputs:
output_filename_val = AML_VAL_OUTPUTS_CSV
outputs_dfs_val = collect_crossval_outputs(parent_run_id=parent_run_id, download_dir=report_dir,
aml_workspace=aml_workspace,
output_filename=output_filename_val, overwrite=overwrite)
if include_test:
output_filename_test = AML_TEST_OUTPUTS_CSV if has_val_and_test_outputs else AML_LEGACY_TEST_OUTPUTS_CSV
outputs_dfs_test = collect_crossval_outputs(parent_run_id=parent_run_id, download_dir=report_dir,
aml_workspace=aml_workspace,
output_filename=output_filename_test, overwrite=overwrite)
if num_classes == 1:
# Currently ROC and PR curves rendered only for binary case
# TODO: Enable rendering of multi-class ROC and PR curves
report.add_heading("ROC and PR curves", level=2)
# Get output data frames if available
try:
has_val_and_test_outputs = child_runs_have_val_and_test_outputs(parent_run)
if has_val_and_test_outputs:
# Add val. ROC and PR curves
render_roc_and_pr_curves(report=report, heading="Validation ROC and PR curves", level=3,
report_dir=report_dir,
outputs_dfs=outputs_dfs_val,
prefix=f'{ModelKey.VAL}_')
output_filename_val = AML_VAL_OUTPUTS_CSV
outputs_dfs_val = collect_hyperdrive_outputs(parent_run_id=parent_run_id, download_dir=report_dir,
aml_workspace=aml_workspace,
output_filename=output_filename_val, overwrite=overwrite,
hyperdrive_arg_name=hyperdrive_arg_name)
if include_test:
output_filename_test = AML_TEST_OUTPUTS_CSV if has_val_and_test_outputs else AML_LEGACY_TEST_OUTPUTS_CSV
outputs_dfs_test = collect_hyperdrive_outputs(parent_run_id=parent_run_id, download_dir=report_dir,
aml_workspace=aml_workspace,
output_filename=output_filename_test, overwrite=overwrite,
hyperdrive_arg_name=hyperdrive_arg_name)
if num_classes == 1:
# Currently ROC and PR curves rendered only for binary case
# TODO: Enable rendering of multi-class ROC and PR curves
report.add_heading("ROC and PR curves", level=2)
if has_val_and_test_outputs:
# Add val. ROC and PR curves
render_roc_and_pr_curves(report=report, heading="Validation ROC and PR curves", level=3,
report_dir=report_dir,
outputs_dfs=outputs_dfs_val,
prefix=f'{ModelKey.VAL}_')
if include_test:
# Add test ROC and PR curves
render_roc_and_pr_curves(report=report, heading="Test ROC and PR curves", level=3,
report_dir=report_dir,
outputs_dfs=outputs_dfs_test,
prefix=f'{ModelKey.TEST}_')
# Add confusion matrices for each fold
report.add_heading("Confusion matrices", level=2)
if has_val_and_test_outputs:
# Add val. confusion matrices
render_confusion_matrices(report=report, heading="Validation confusion matrices", level=3,
class_names=class_names,
report_dir=report_dir, outputs_dfs=outputs_dfs_val,
prefix=f'{ModelKey.VAL}_')
if include_test:
# Add test ROC and PR curves
render_roc_and_pr_curves(report=report, heading="Test ROC and PR curves", level=3,
report_dir=report_dir,
outputs_dfs=outputs_dfs_test,
prefix=f'{ModelKey.TEST}_')
# Add test confusion matrices
render_confusion_matrices(report=report, heading="Test confusion matrices", level=3,
class_names=class_names,
report_dir=report_dir, outputs_dfs=outputs_dfs_test,
prefix=f'{ModelKey.TEST}_')
# Add confusion matrices for each fold
report.add_heading("Confusion matrices", level=2)
if has_val_and_test_outputs:
# Add val. confusion matrices
render_confusion_matrices(report=report, heading="Validation confusion matrices", level=3,
class_names=class_names,
report_dir=report_dir, outputs_dfs=outputs_dfs_val,
prefix=f'{ModelKey.VAL}_')
if include_test:
# Add test confusion matrices
render_confusion_matrices(report=report, heading="Test confusion matrices", level=3,
class_names=class_names,
report_dir=report_dir, outputs_dfs=outputs_dfs_test,
prefix=f'{ModelKey.TEST}_')
except ValueError as e:
print(e)
print("Since all expected output files were not found, skipping ROC-PR curves and confusion matrices.")
# TODO: Add qualitative model outputs
# report.add_heading("Qualitative model outputs", level=2)
@ -134,26 +156,26 @@ def generate_html_report(parent_run_id: str, output_dir: Path,
def render_training_curves(report: HTMLReport, heading: str, level: int,
metrics_df: pd.DataFrame, best_epochs: Optional[Dict[int, int]],
report_dir: Path) -> None:
report_dir: Path, primary_metric: str = MetricsKey.AUROC) -> None:
"""
Function to render training curves for HTML reports.
:param report: HTML report to perform the rendering.
:param heading: Heading of the section.
:param level: Level of HTML heading (e.g. sub-section, sub-sub-section) corresponding to HTML heading levels.
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_crossval_metrics()` and
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_hyperdrive_metrics()` and
:py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
:param best_epochs: Dictionary mapping each cross-validation index to its best epoch.
:param best_epochs: Dictionary mapping each hyperdrive child index to its best epoch.
:param report_dir: Directory of the HTML report.
:param primary_metric: Primary metric name. Default is AUROC.
"""
report.add_heading(heading, level=level)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
plot_crossval_training_curves(metrics_df, train_metric=f'{ModelKey.TRAIN}/loss_epoch',
val_metric=f'{ModelKey.VAL}/loss_epoch',
ylabel="Loss", best_epochs=best_epochs, ax=ax1)
plot_crossval_training_curves(metrics_df, train_metric=f'{ModelKey.TRAIN}/{MetricsKey.AUROC}',
val_metric=f'{ModelKey.VAL}/{MetricsKey.AUROC}',
ylabel="AUROC", best_epochs=best_epochs, ax=ax2)
metrics = {"loss_epoch", MetricsKey.AUROC.value, primary_metric}
fig, axs = plt.subplots(1, len(metrics), figsize=(5 * len(metrics), 4))
for i, metric in enumerate(metrics):
plot_hyperdrive_training_curves(metrics_df, train_metric=f'{ModelKey.TRAIN}/{metric}',
val_metric=f'{ModelKey.VAL}/{metric}',
ylabel=metric, best_epochs=best_epochs, ax=axs[i])
add_training_curves_legend(fig, include_best_epoch=True)
training_curves_fig_path = report_dir / "training_curves.png"
fig.savefig(training_curves_fig_path, bbox_inches='tight')
@ -169,17 +191,17 @@ def render_metrics_table(report: HTMLReport, heading: str, level: int,
:param report: HTML report to perform the rendering.
:param heading: Heading of the section.
:param level: Level of HTML heading (e.g. sub-section, sub-sub-section) corresponding to HTML heading levels.
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_crossval_metrics()` and
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_hyperdrive_metrics()` and
:py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
:param base_metrics_list: List of metric names to include in the table.
:param best_epochs: Dictionary mapping each cross-validation index to its best epoch.
:param best_epochs: Dictionary mapping each hyperdrive child index to its best epoch.
:param metrics_prefix: Prefix to add to the metrics names (e.g. `val`, `test`)
"""
report.add_heading(heading, level=level)
metrics_list = [metrics_prefix + metric for metric in base_metrics_list]
if best_epochs:
metrics_df = get_best_epoch_metrics(metrics_df, metrics_list, best_epochs)
metrics_table = get_crossval_metrics_table(metrics_df, metrics_list)
metrics_table = get_hyperdrive_metrics_table(metrics_df, metrics_list)
report.add_tables([metrics_table])
@ -192,11 +214,11 @@ def render_roc_and_pr_curves(report: HTMLReport, heading: str, level: int, repor
:param heading: Heading of the section.
:param level: Level of HTML heading (e.g. sub-section, sub-sub-section) corresponding to HTML heading levels.
:param report_dir: Local directory where the report is stored.
:param outputs_dfs: A dictionary of dataframes with the sorted cross-validation indices as keys.
:param outputs_dfs: A dictionary of dataframes with the sorted hyperdrive child runs indices as keys.
:param prefix: Prefix to add to the figures saved (e.g. `val`, `test`).
"""
report.add_heading(heading, level=level)
fig = plot_crossval_roc_and_pr_curves(outputs_dfs, scores_column='prob_class1')
fig = plot_hyperdrive_roc_and_pr_curves(outputs_dfs, scores_column='prob_class1')
roc_pr_curves_fig_path = report_dir / f"{prefix}roc_pr_curves.png"
fig.savefig(roc_pr_curves_fig_path, bbox_inches='tight')
report.add_images([roc_pr_curves_fig_path], base64_encode=True)
@ -212,11 +234,11 @@ def render_confusion_matrices(report: HTMLReport, heading: str, level: int, clas
:param level: Level of HTML heading (e.g. sub-section, sub-sub-section) corresponding to HTML heading levels.
:param class_names: Names of classes.
:param report_dir: Local directory where the report is stored.
:param outputs_dfs: A dictionary of dataframes with the sorted cross-validation indices as keys.
:param outputs_dfs: A dictionary of dataframes with the sorted hyperdrive child runs indices as keys.
:param prefix: Prefix to add to the figures saved (e.g. `val`, `test`).
"""
report.add_heading(heading, level=level)
fig = plot_confusion_matrices(crossval_dfs=outputs_dfs, class_names=class_names)
fig = plot_confusion_matrices(hyperdrive_dfs=outputs_dfs, class_names=class_names)
confusion_matrices_fig_path = report_dir / f"{prefix}confusion_matrices.png"
fig.savefig(confusion_matrices_fig_path, bbox_inches='tight')
report.add_images([confusion_matrices_fig_path], base64_encode=True)
@ -225,7 +247,7 @@ def render_confusion_matrices(report: HTMLReport, heading: str, level: int, clas
if __name__ == "__main__":
"""
Usage example from CLI:
python generate_crossval_report.py \
python generate_hyperdrive_report.py \
--run_id <insert AML run ID here> \
--output_dir outputs \
--include_test
@ -242,6 +264,9 @@ if __name__ == "__main__":
"in the generated report.")
parser.add_argument('--overwrite', action='store_true', help="Forces (re)download of metrics and output files, "
"even if they already exist locally.")
parser.add_argument("--hyper_arg_name", default="crossval_index",
help="Name of the Hyperdrive argument used for indexing the child runs.")
parser.add_argument("--primary_metric", default=MetricsKey.AUROC, help="Name of the reference metric to optimise.")
args = parser.parse_args()
if args.output_dir is None:
@ -258,4 +283,6 @@ if __name__ == "__main__":
output_dir=Path(args.output_dir),
workspace_config_path=workspace_config,
include_test=args.include_test,
overwrite=args.overwrite)
overwrite=args.overwrite,
hyperdrive_arg_name=args.hyper_arg_name,
primary_metric=args.primary_metric)

Просмотреть файл

@ -120,32 +120,32 @@ def plot_histogram(data: List[Any], title: str = "") -> None:
plt.gca().set(title=title, xlabel='Values', ylabel='Frequency')
def plot_roc_curve(labels: Sequence, scores: Sequence, label: str, ax: Axes) -> None:
def plot_roc_curve(labels: Sequence, scores: Sequence, legend_label: str, ax: Axes) -> None:
"""Plot ROC curve for the given labels and scores, with AUROC in the line legend.
:param labels: The true binary labels.
:param scores: Scores predicted by the model.
:param label: An line identifier to be displayed in the legend.
:param legend_label: An line identifier to be displayed in the legend.
:param ax: `Axes` object onto which to plot.
"""
fpr, tpr, _ = roc_curve(labels, scores)
auroc = auc(fpr, tpr)
label = f"{label} (AUROC: {auroc:.3f})"
ax.plot(fpr, tpr, label=label)
legend_label = f"{legend_label} (AUROC: {auroc:.3f})"
ax.plot(fpr, tpr, label=legend_label)
def plot_pr_curve(labels: Sequence, scores: Sequence, label: str, ax: Axes) -> None:
"""Plot precision-recall curve for the given labels and scores, with AUROC in the line legend.
def plot_pr_curve(labels: Sequence, scores: Sequence, legend_label: str, ax: Axes) -> None:
"""Plot precision-recall curve for the given labels and scores, with AUPR in the line legend.
:param labels: The true binary labels.
:param scores: Scores predicted by the model.
:param label: An line identifier to be displayed in the legend.
:param legend_label: A line identifier to be displayed in the legend.
:param ax: `Axes` object onto which to plot.
"""
precision, recall, _ = precision_recall_curve(labels, scores)
aupr = auc(recall, precision)
label = f"{label} (AUPR: {aupr:.3f})"
ax.plot(recall, precision, label=label)
legend_label = f"{legend_label} (AUPR: {aupr:.3f})"
ax.plot(recall, precision, label=legend_label)
def format_pr_or_roc_axes(plot_type: str, ax: Axes) -> None:
@ -168,18 +168,18 @@ def format_pr_or_roc_axes(plot_type: str, ax: Axes) -> None:
ax.grid(color='0.9')
def _plot_crossval_roc_and_pr_curves(crossval_dfs: Dict[int, pd.DataFrame], roc_ax: Axes, pr_ax: Axes,
scores_column: str = ResultsKey.PROB) -> None:
"""Plot ROC and precision-recall curves for multiple cross-validation runs onto provided axes.
def _plot_hyperdrive_roc_and_pr_curves(hyperdrive_dfs: Dict[int, pd.DataFrame], roc_ax: Axes, pr_ax: Axes,
scores_column: str = ResultsKey.PROB) -> None:
"""Plot ROC and precision-recall curves for multiple hyperdrive runs onto provided axes.
This is called by :py:func:`plot_crossval_roc_and_pr_curves()`, which additionally creates a figure and the axes.
This is called by :py:func:`plot_hyperdrive_roc_and_pr_curves()`, which additionally creates a figure and the axes.
:param crossval_dfs: Dictionary of dataframes with cross-validation indices as keys,
as returned by :py:func:`collect_crossval_outputs()`.
:param hyperdrive_dfs: Dictionary of dataframes with hyperdrive child runs indices as keys,
as returned by :py:func:`collect_hyperdrive_outputs()`.
:param roc_ax: `Axes` object onto which to plot ROC curves.
:param pr_ax: `Axes` object onto which to plot precision-recall curves.
"""
for k, tiles_df in crossval_dfs.items():
for k, tiles_df in hyperdrive_dfs.items():
slides_groupby = tiles_df.groupby(ResultsKey.SLIDE_ID)
tile_labels = slides_groupby[ResultsKey.TRUE_LABEL]
@ -197,8 +197,8 @@ def _plot_crossval_roc_and_pr_curves(crossval_dfs: Dict[int, pd.DataFrame], roc_
# assert len(non_unique_slides) == 0
scores = tile_scores.first()
plot_roc_curve(labels, scores, label=f"Fold {k}", ax=roc_ax)
plot_pr_curve(labels, scores, label=f"Fold {k}", ax=pr_ax)
plot_roc_curve(labels, scores, legend_label=f"Child {k}", ax=roc_ax)
plot_pr_curve(labels, scores, legend_label=f"Child {k}", ax=pr_ax)
legend_kwargs = dict(edgecolor='none', fontsize='small')
roc_ax.legend(**legend_kwargs)
pr_ax.legend(**legend_kwargs)
@ -206,26 +206,26 @@ def _plot_crossval_roc_and_pr_curves(crossval_dfs: Dict[int, pd.DataFrame], roc_
format_pr_or_roc_axes('pr', pr_ax)
def plot_crossval_roc_and_pr_curves(crossval_dfs: Dict[int, pd.DataFrame],
scores_column: str = ResultsKey.PROB) -> Figure:
"""Plot ROC and precision-recall curves for multiple cross-validation runs.
def plot_hyperdrive_roc_and_pr_curves(hyperdrive_dfs: Dict[int, pd.DataFrame],
scores_column: str = ResultsKey.PROB) -> Figure:
"""Plot ROC and precision-recall curves for multiple hyperdrive child runs.
This will create a new figure with two subplots (left: ROC, right: PR).
:param crossval_dfs: Dictionary of dataframes with cross-validation indices as keys,
as returned by :py:func:`collect_crossval_outputs()`.
:param hyperdrive_dfs: Dictionary of dataframes with hyperdrive child indices as keys,
as returned by :py:func:`collect_hyperdrive_outputs()`.
:return: The created `Figure` object.
"""
fig, axs = plt.subplots(1, 2, figsize=(8, 4))
_plot_crossval_roc_and_pr_curves(crossval_dfs, scores_column=scores_column, roc_ax=axs[0], pr_ax=axs[1])
_plot_hyperdrive_roc_and_pr_curves(hyperdrive_dfs, scores_column=scores_column, roc_ax=axs[0], pr_ax=axs[1])
return fig
def plot_crossval_training_curves(metrics_df: pd.DataFrame, train_metric: str, val_metric: str, ax: Axes,
best_epochs: Optional[Dict[int, int]] = None, ylabel: Optional[str] = None) -> None:
"""Plot paired training and validation metrics for every training epoch of cross-validation runs.
def plot_hyperdrive_training_curves(metrics_df: pd.DataFrame, train_metric: str, val_metric: str, ax: Axes,
best_epochs: Optional[Dict[int, int]] = None, ylabel: Optional[str] = None) -> None:
"""Plot paired training and validation metrics for every training epoch of hyperdrive child runs.
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_crossval_metrics()` and
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_hyperdrive_metrics()` and
:py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
:param train_metric: Name of the training metric to plot.
:param val_metric: Name of the validation metric to plot.
@ -236,14 +236,17 @@ def plot_crossval_training_curves(metrics_df: pd.DataFrame, train_metric: str, v
for k in sorted(metrics_df.columns):
train_values = metrics_df.loc[train_metric, k]
val_values = metrics_df.loc[val_metric, k]
line, = ax.plot(train_values, **TRAIN_STYLE, label=f"Fold {k}")
color = line.get_color()
ax.plot(val_values, color=color, **VAL_STYLE)
if train_values is not None:
line, = ax.plot(train_values, **TRAIN_STYLE, label=f"Child {k}")
color = line.get_color()
if val_values is not None:
ax.plot(val_values, color=color, **VAL_STYLE)
if best_epochs is not None:
best_epoch = best_epochs[k]
ax.plot(best_epoch, train_values[best_epoch], color=color, zorder=1000, **BEST_TRAIN_MARKER_STYLE)
ax.plot(best_epoch, val_values[best_epoch], color=color, zorder=1000, **BEST_VAL_MARKER_STYLE)
ax.axvline(best_epoch, color=color, **BEST_EPOCH_LINE_STYLE)
if best_epoch is not None:
ax.plot(best_epoch, train_values[best_epoch], color=color, zorder=1000, **BEST_TRAIN_MARKER_STYLE)
ax.plot(best_epoch, val_values[best_epoch], color=color, zorder=1000, **BEST_VAL_MARKER_STYLE)
ax.axvline(best_epoch, color=color, **BEST_EPOCH_LINE_STYLE)
ax.grid(color='0.9')
ax.set_xlabel("Epoch")
if ylabel:
@ -251,15 +254,15 @@ def plot_crossval_training_curves(metrics_df: pd.DataFrame, train_metric: str, v
def add_training_curves_legend(fig: Figure, include_best_epoch: bool = False) -> None:
"""Add a legend to a training curves figure, indicating cross-validation indices and train/val.
"""Add a legend to a training curves figure, indicating hyperdrive child indices and train/val.
:param fig: `Figure` object onto which to add the legend.
:param include_best_epoch: If `True`, adds legend items for the best epoch indicators from
:py:func:`plot_crossval_training_curves()`.
:py:func:`plot_hyperdrive_training_curves()`.
"""
legend_kwargs = dict(edgecolor='none', fontsize='small', borderpad=.2)
# Add primary legend for main lines (crossval folds)
# Add primary legend for main lines (hyperdrive runs)
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
fig.legend(by_label.values(), by_label.keys(), **legend_kwargs, loc='lower center',
@ -277,17 +280,18 @@ def add_training_curves_legend(fig: Figure, include_best_epoch: bool = False) ->
bbox_to_anchor=(0.5, -0.1), ncol=len(legend_handles))
def plot_confusion_matrices(crossval_dfs: Dict[int, pd.DataFrame], class_names: List[str]) -> Figure:
def plot_confusion_matrices(hyperdrive_dfs: Dict[int, pd.DataFrame], class_names: List[str]) -> Figure:
"""
Plot normalized confusion matrices from HyperDrive child runs.
:param crossval_dfs: Dictionary of dataframes with cross-validation indices as keys,
as returned by :py:func:`collect_crossval_outputs()`.
:param hyperdrive_dfs: Dictionary of dataframes with hyperdrive indices as keys,
as returned by :py:func:`collect_hyperdrive_outputs()`.
:param class_names: Names of classes.
:return: The created `Figure` object.
"""
crossval_count = len(crossval_dfs)
fig, axs = plt.subplots(1, crossval_count, figsize=(crossval_count * 6, 5))
for k, tiles_df in crossval_dfs.items():
hyperdrive_count = len(hyperdrive_dfs)
fig, axs = plt.subplots(1, hyperdrive_count, figsize=(hyperdrive_count * 6, 5))
ax_index = 0
for k, tiles_df in hyperdrive_dfs.items():
slides_groupby = tiles_df.groupby(ResultsKey.SLIDE_ID)
tile_labels_true = slides_groupby[ResultsKey.TRUE_LABEL]
# True slide label is guaranteed unique
@ -303,9 +307,10 @@ def plot_confusion_matrices(crossval_dfs: Dict[int, pd.DataFrame], class_names:
labels_pred = tile_labels_pred.first()
cf_matrix_n = confusion_matrix(y_true=labels_true, y_pred=labels_pred, normalize='true')
sns.heatmap(cf_matrix_n, annot=True, cmap='Blues', fmt=".2%", ax=axs[k],
sns.heatmap(cf_matrix_n, annot=True, cmap='Blues', fmt=".2%", ax=axs[ax_index],
xticklabels=class_names, yticklabels=class_names)
axs[k].set_xlabel('Predicted')
axs[k].set_ylabel('True')
axs[k].set_title(f'Fold {k}')
axs[ax_index].set_xlabel('Predicted')
axs[ax_index].set_ylabel('True')
axs[ax_index].set_title(f'Child {k}')
ax_index += 1
return fig

Просмотреть файл

@ -0,0 +1,559 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import os
import torch
import param
import logging
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from health_cpath.models.deepmil import BaseDeepMILModule
from health_cpath.utils.naming import ModelKey, ResultsKey
from health_cpath.utils.output_utils import BatchResultsType
LossCacheDictType = Dict[Union[ResultsKey, str], List]
LossDictType = Dict[str, List]
AnomalyDictType = Dict[ModelKey, List[str]]
class LossCallbackParams(param.Parameterized):
"""Parameters class to group all attributes for loss values analysis callback"""
analyse_loss: bool = param.Boolean(
False,
doc="If True, will use `LossValuesAnalysisCallback` to cache loss values per slide/epoch for further analysis."
"See `LossValuesAnalysisCallback` for more details.",
)
loss_analysis_patience: int = param.Integer(
0,
bounds=(0, None),
doc="Number of epochs to wait before starting loss values per slide analysis. Default: 0, It will start"
"caching loss values per epoch immediately. Use loss_analysis_patience=n>0 to wait for a few epochs "
"before starting the analysis.",
)
loss_analysis_epochs_interval: int = param.Integer(
1,
bounds=(1, None),
doc="Epochs interval to save loss values. Default: 1, It will save loss values every epoch. Use "
"loss_analysis_epochs_interval=n>1 to save loss values every n epochs.",
)
num_slides_scatter: int = param.Integer(
20,
bounds=(1, None),
doc="Number of slides to plot in the scatter plot. Default: 10, It will plot a scatter of the 10 slides "
"with highest/lowest loss values across epochs.",
)
num_slides_heatmap: int = param.Integer(
20,
bounds=(1, None),
doc="Number of slides to plot in the heatmap plot. Default: 20, It will plot the loss values for the 20 slides "
"with highest/lowest loss values.",
)
save_tile_ids: bool = param.Boolean(
True,
doc="If True, will save the tile ids for each bag in the loss cache. Default: True. If False, will only save "
"the slide ids and their loss values.",
)
log_exceptions: bool = param.Boolean(
True,
doc="If True, will log exceptions raised during loss values analysis. Default: True. If False, will raise the "
"intercepted exceptions.",
)
class LossAnalysisCallback(Callback):
"""Callback to analyse loss values per slide across epochs. It saves the loss values per slide in a csv file every n
epochs and plots the evolution of the loss values per slide in a heatmap as well as the slides with the
highest/lowest loss values per epoch in a scatter plot."""
TILES_JOIN_TOKEN = "$"
X_LABEL, Y_LABEL = "Epoch", "Slide ids"
TOP, BOTTOM = "top", "bottom"
HIGHEST, LOWEST = "highest", "lowest"
def __init__(
self,
outputs_folder: Path,
max_epochs: int = 30,
patience: int = 0,
epochs_interval: int = 1,
num_slides_scatter: int = 10,
num_slides_heatmap: int = 20,
save_tile_ids: bool = False,
log_exceptions: bool = True,
create_outputs_folders: bool = True,
val_set_is_dist: bool = True,
) -> None:
"""
:param outputs_folder: Path to the folder where the outputs will be saved.
:param patience: Number of epochs to wait before starting to cache loss values, defaults to 0.
:param epochs_interval: Epochs interval to save loss values, defaults to 1.
:param max_epochs: Maximum number of epochs to train, defaults to 30.
:param num_slides_scatter: Number of slides to plot in the scatter plot, defaults to 10.
:param num_slides_heatmap: Number of slides to plot in the heatmap, defaults to 20.
:param save_tile_ids: If True, will save the tile ids of the tiles in the bag in the loss cache, defaults to
False. This is useful to analyse the tiles that are contributing to the loss value of a slide.
:param log_exceptions: If True, will log exceptions raised during loss values analysis, defaults to True. If
False will raise the intercepted exceptions.
:param create_outputs_folders: If True, will create the output folders if they don't exist, defaults to True.
:param val_set_is_dist: If True, will assume that the validation set is distributed, defaults to True.
"""
self.patience = patience
self.epochs_interval = epochs_interval
self.max_epochs = max_epochs
self.num_slides_scatter = num_slides_scatter
self.num_slides_heatmap = num_slides_heatmap
self.save_tile_ids = save_tile_ids
self.log_exceptions = log_exceptions
self.val_set_is_dist = val_set_is_dist
self.outputs_folder = outputs_folder / "loss_analysis_callback"
if create_outputs_folders:
self.create_outputs_folders()
self.loss_cache: Dict[ModelKey, LossCacheDictType] = {
stage: self.get_empty_loss_cache() for stage in [ModelKey.TRAIN, ModelKey.VAL]
}
self.epochs_range = list(range(self.patience, self.max_epochs, self.epochs_interval))
self.nan_slides: AnomalyDictType = {stage: [] for stage in [ModelKey.TRAIN, ModelKey.VAL]}
self.anomaly_slides: AnomalyDictType = {stage: [] for stage in [ModelKey.TRAIN, ModelKey.VAL]}
def get_cache_folder(self, stage: ModelKey) -> Path:
return self.outputs_folder / f"{stage}/loss_cache"
def get_scatter_folder(self, stage: ModelKey) -> Path:
return self.outputs_folder / f"{stage}/loss_scatter"
def get_heatmap_folder(self, stage: ModelKey) -> Path:
return self.outputs_folder / f"{stage}/loss_heatmap"
def get_stats_folder(self, stage: ModelKey) -> Path:
return self.outputs_folder / f"{stage}/loss_stats"
def get_anomalies_folder(self, stage: ModelKey) -> Path:
return self.outputs_folder / f"{stage}/loss_anomalies"
def create_outputs_folders(self) -> None:
"""Creates the output folders if they don't exist."""
folders = [
self.get_cache_folder,
self.get_scatter_folder,
self.get_heatmap_folder,
self.get_stats_folder,
self.get_anomalies_folder,
]
stages = [ModelKey.TRAIN, ModelKey.VAL]
for folder in folders:
for stage in stages:
os.makedirs(folder(stage), exist_ok=True)
def get_empty_loss_cache(self) -> LossCacheDictType:
"""Returns an empty loss cache dictionary for keys: slide_id, loss, entropy and tile_ids if save_tile_ids."""
keys = [ResultsKey.SLIDE_ID, ResultsKey.LOSS, ResultsKey.ENTROPY]
if self.save_tile_ids:
keys.append(ResultsKey.TILE_ID)
return {key: [] for key in keys}
def _format_epoch(self, epoch: int) -> str:
"""Formats the epoch number to a string with 3 digits."""
return str(epoch).zfill(len(str(self.max_epochs)))
def get_loss_cache_file(self, epoch: int, stage: ModelKey) -> Path:
return self.get_cache_folder(stage) / f"epoch_{self._format_epoch(epoch)}.csv"
def get_all_epochs_loss_cache_file(self, stage: ModelKey) -> Path:
return self.get_cache_folder(stage) / f"all_epochs_{stage}.csv"
def get_loss_stats_file(self, stage: ModelKey) -> Path:
return self.get_stats_folder(stage) / f"loss_stats_{stage}.csv"
def get_loss_ranks_file(self, stage: ModelKey) -> Path:
return self.get_stats_folder(stage) / f"loss_ranks_{stage}.csv"
def get_loss_ranks_stats_file(self, stage: ModelKey) -> Path:
return self.get_stats_folder(stage) / f"loss_ranks_stats_{stage}.csv"
def get_nan_slides_file(self, stage: ModelKey) -> Path:
return self.get_anomalies_folder(stage) / f"nan_slides_{stage}.txt"
def get_anomaly_slides_file(self, stage: ModelKey) -> Path:
return self.get_anomalies_folder(stage) / f"anomaly_slides_{stage}.txt"
def get_scatter_plot_file(self, order: str, stage: ModelKey) -> Path:
return self.get_scatter_folder(stage) / f"slides_with_{order}_loss_values_{stage}.png"
def get_heatmap_plot_file(self, epoch: int, order: str, stage: ModelKey) -> Path:
return self.get_heatmap_folder(stage) / f"epoch_{self._format_epoch(epoch)}_{order}_slides_{stage}.png"
def read_loss_cache(self, epoch: int, stage: ModelKey, idx_col: Optional[ResultsKey] = None) -> pd.DataFrame:
columns = [ResultsKey.SLIDE_ID, ResultsKey.LOSS, ResultsKey.ENTROPY]
return pd.read_csv(self.get_loss_cache_file(epoch, stage), index_col=idx_col, usecols=columns)
def should_cache_loss_values(self, current_epoch: int) -> bool:
"""Returns True if the current epoch is a multiple of the epochs_interval."""
if current_epoch >= self.max_epochs:
return False # Don't cache loss values for the extra validation epoch
current_epoch = current_epoch + 1
first_time = (current_epoch - self.patience) == 1
return first_time or (
current_epoch > self.patience and (current_epoch - self.patience) % self.epochs_interval == 0
)
def merge_loss_caches(self, loss_caches: List[LossCacheDictType]) -> LossCacheDictType:
"""Merges the loss caches from all the workers into a single loss cache"""
loss_cache = self.get_empty_loss_cache()
for loss_cache_per_device in loss_caches:
for key in loss_cache.keys():
loss_cache[key].extend(loss_cache_per_device[key])
return loss_cache
def gather_loss_cache(self, rank: int, stage: ModelKey) -> None:
"""Gathers the loss cache from all the workers"""
if self.val_set_is_dist and torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
if world_size > 1:
loss_caches = [None] * world_size
torch.distributed.all_gather_object(loss_caches, self.loss_cache[stage])
if rank == 0:
self.loss_cache[stage] = self.merge_loss_caches(loss_caches) # type: ignore
def save_loss_cache(self, current_epoch: int, stage: ModelKey) -> None:
"""Saves the loss cache to a csv file"""
loss_cache_df = pd.DataFrame(self.loss_cache[stage])
# Some slides may be appear multiple times in the loss cache in DDP mode. The Distributed Sampler duplicate
# some slides to even out the number of samples per device, so we only keep the first occurrence.
loss_cache_df.drop_duplicates(subset=ResultsKey.SLIDE_ID, inplace=True, keep="first")
loss_cache_df = loss_cache_df.sort_values(by=ResultsKey.LOSS, ascending=False)
loss_cache_df.to_csv(self.get_loss_cache_file(current_epoch, stage), index=False)
def _select_values_for_epoch(
self,
keys: List[ResultsKey],
epoch: int,
stage: ModelKey,
high: Optional[bool] = None,
num_values: Optional[int] = None
) -> List[np.ndarray]:
"""Selects the values corresponding to keys from a dataframe for the given epoch and stage.
:param keys: The keys to select.
:param epoch: The epoch to select.
:param stage: The model's stage e.g. train, val, test.
:param high: If True, selects the highest values, if False, selects the lowest values, if None, selects all
values.
:param num_values: The number of values to select.
"""
loss_cache = self.read_loss_cache(epoch, stage)
return_values = []
for key in keys:
values = loss_cache[key].values
if high is not None:
assert num_values is not None, "num_values must be specified if high is specified"
if high:
return_values.append(values[:num_values])
elif not high:
return_values.append(values[-num_values:])
else:
return_values.append(values)
return return_values
def select_slides_for_epoch(
self, epoch: int, stage: ModelKey, high: Optional[bool] = None, num_slides: Optional[int] = None
) -> np.ndarray:
"""Selects slides in ascending/descending order of loss values at a given epoch
:param epoch: The epoch to select the slides from.
:param stage: The model's stage (train/val).
:param high: If True, selects the slides with the highest loss values, else selects the slides with the lowest
loss values. If None, selects all slides.
:param num_slides: The number of slides to select. If None, selects all slides.
"""
return self._select_values_for_epoch([ResultsKey.SLIDE_ID], epoch, stage, high, num_slides)[0]
def select_slides_entropy_for_epoch(
self, epoch: int, stage: ModelKey, high: Optional[bool] = None
) -> Tuple[np.ndarray, np.ndarray]:
"""Selects slides and loss values of slides in ascending/descending order of loss at a given epoch
:param epoch: The epoch to select the slides from.
:param stage: The model's stage (train/val).
:param high: If True, selects the slides with the highest loss values, else selects the slides with the lowest
loss values. If None, selects all slides.
"""
keys = [ResultsKey.SLIDE_ID, ResultsKey.ENTROPY]
return_values = self._select_values_for_epoch(keys, epoch, stage, high, self.num_slides_scatter)
return return_values[0], return_values[1]
def select_all_losses_for_selected_slides(self, slides: np.ndarray, stage: ModelKey) -> LossDictType:
"""Selects the loss values for a given set of slides
:param slides: The slides to select the loss values for.
:param stage: The model's stage (train/val).
"""
slides_loss_values: LossDictType = {slide_id: [] for slide_id in slides}
for epoch in self.epochs_range:
loss_cache = self.read_loss_cache(epoch, stage, idx_col=ResultsKey.SLIDE_ID)
for slide_id in slides:
slides_loss_values[slide_id].append(loss_cache.loc[slide_id, ResultsKey.LOSS])
return slides_loss_values
def select_slides_and_entropy_across_epochs(
self, stage: ModelKey, high: bool = True
) -> Tuple[np.ndarray, np.ndarray]:
"""Selects the slides with the highest/lowest loss values across epochs and their entropy values
:param stage: The model's stage (train/val).
:param high: If True, selects the slides with the highest loss values, else selects the slides with the lowest
loss values.
"""
slides = []
slides_entropy = []
for epoch in self.epochs_range:
epoch_slides, epoch_slides_entropy = self.select_slides_entropy_for_epoch(epoch, stage, high)
slides.append(epoch_slides)
slides_entropy.append(epoch_slides_entropy)
return np.array(slides).T, np.array(slides_entropy).T
def save_slide_ids(self, slide_ids: List[str], path: Path) -> None:
"""Dumps the slides ids in a txt file.
:param slide_ids: The slides ids to save.
:param path: The path to save the slides ids to.
"""
if slide_ids:
with open(path, "w") as f:
for slide_id in slide_ids:
f.write(f"{slide_id}\n")
def sanity_check_loss_values(self, loss_values: LossDictType, stage: ModelKey) -> None:
"""Checks if there are any NaNs or any other potential annomalies in the loss values.
:param loss_values: The loss values for all slides.
:param stage: The model's stage (train/val).
"""
# We don't want any of these exceptions to interrupt validation. So we catch them and log them.
loss_values_copy = loss_values.copy()
for slide_id, loss in loss_values_copy.items():
try:
if np.isnan(loss).any():
logging.warning(f"NaNs found in loss values for slide {slide_id}.")
self.nan_slides[stage].append(slide_id)
loss_values.pop(slide_id)
except Exception as e:
logging.warning(f"Error while checking for NaNs in loss values for slide {slide_id} with error {e}.")
logging.warning(f"Loss values that caused the issue: {loss}")
self.anomaly_slides[stage].append(slide_id)
loss_values.pop(slide_id)
self.save_slide_ids(self.nan_slides[stage], self.get_nan_slides_file(stage))
self.save_slide_ids(self.anomaly_slides[stage], self.get_anomaly_slides_file(stage))
def save_loss_ranks(self, slides_loss_values: LossDictType, stage: ModelKey) -> None:
"""Saves the loss ranks for each slide across all epochs and their respective statistics in csv files.
:param slides_loss_values: The loss values for all slides.
:param stage: The model's stage (train/val).
"""
loss_df = pd.DataFrame(slides_loss_values).T
loss_df.index.names = [ResultsKey.SLIDE_ID.value]
loss_df.to_csv(self.get_all_epochs_loss_cache_file(stage))
loss_stats = loss_df.T.describe().T.sort_values(by="mean", ascending=False)
loss_stats.to_csv(self.get_loss_stats_file(stage))
loss_ranks = loss_df.rank(ascending=False)
loss_ranks.to_csv(self.get_loss_ranks_file(stage))
loss_ranks_stats = loss_ranks.T.describe().T.sort_values("mean", ascending=True)
loss_ranks_stats.to_csv(self.get_loss_ranks_stats_file(stage))
def plot_slides_loss_scatter(
self,
slides: np.ndarray,
slides_entropy: np.ndarray,
stage: ModelKey,
high: bool = True,
figsize: Tuple[float, float] = (30, 30),
) -> None:
"""Plots the slides with the highest/lowest loss values across epochs in a scatter plot
:param slides: The slides ids.
:param slides_entropy: The entropy values for each slide.
:param stage: The model's stage (train/val).
:param figsize: The figure size, defaults to (30, 30)
:param high: If True, plots the slides with the highest loss values, else plots the slides with the lowest loss.
"""
label = self.TOP if high else self.BOTTOM
plt.figure(figsize=figsize)
markers_size = [15 * i for i in range(1, self.num_slides_scatter + 1)]
markers_size = markers_size[::-1] if high else markers_size
colors = cm.rainbow(np.linspace(0, 1, self.num_slides_scatter))
for i in range(self.num_slides_scatter - 1, -1, -1):
plt.scatter(self.epochs_range, slides[i], label=f"{label}_{i+1}", s=markers_size[i], color=colors[i])
for entropy, epoch, slide in zip(slides_entropy[i], self.epochs_range, slides[i]):
plt.annotate(f"{entropy:.3f}", (epoch, slide))
plt.xlabel(self.X_LABEL)
plt.ylabel(self.Y_LABEL)
order = self.HIGHEST if high else self.LOWEST
plt.title(f"Slides with {order} loss values per epoch and their entropy values")
plt.xticks(self.epochs_range)
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.grid(True, linestyle="--")
plt.savefig(self.get_scatter_plot_file(order, stage), bbox_inches="tight")
def plot_loss_heatmap_for_slides_of_epoch(
self,
slides_loss_values: LossDictType,
epoch: int,
stage: ModelKey,
high: bool,
figsize: Tuple[float, float] = (15, 15)
) -> None:
"""Plots the loss values for each slide across all epochs in a heatmap.
:param slides_loss_values: The loss values for each slide across all epochs.
:param epoch: The epoch used to select the slides.
:param stage: The model's stage (train/val).
:param high: If True, plots the slides with the highest loss values, else plots the slides with the lowest loss.
:param figsize: The figure size, defaults to (15, 15)
"""
order = self.HIGHEST if high else self.LOWEST
loss_values = np.array(list(slides_loss_values.values()))
slides = list(slides_loss_values.keys())
plt.figure(figsize=figsize)
_ = sns.heatmap(loss_values, linewidth=0.5, annot=True, yticklabels=slides)
plt.xlabel(self.X_LABEL)
plt.ylabel(self.Y_LABEL)
plt.title(f"Loss values evolution for {order} slides of epoch {epoch}")
plt.savefig(self.get_heatmap_plot_file(epoch, order, stage), bbox_inches="tight")
@staticmethod
def compute_entropy(class_probs: torch.Tensor) -> List[float]:
"""Computes the entropy of the class probabilities.
:param class_probs: The class probabilities.
"""
return (-torch.sum(class_probs * torch.log(class_probs), dim=1)).tolist()
def update_loss_cache(self, trainer: Trainer, outputs: BatchResultsType, batch: Dict, stage: ModelKey) -> None:
"""Updates the loss cache with the loss values for the current batch."""
if self.should_cache_loss_values(trainer.current_epoch):
self.loss_cache[stage][ResultsKey.LOSS].extend(outputs[ResultsKey.LOSS_PER_SAMPLE])
self.loss_cache[stage][ResultsKey.SLIDE_ID].extend([slides[0] for slides in batch[ResultsKey.SLIDE_ID]])
self.loss_cache[stage][ResultsKey.ENTROPY].extend(self.compute_entropy(outputs[ResultsKey.CLASS_PROBS]))
if self.save_tile_ids:
self.loss_cache[stage][ResultsKey.TILE_ID].extend(
[self.TILES_JOIN_TOKEN.join(tiles) for tiles in outputs[ResultsKey.TILE_ID]]
)
def synchronise_processes_and_reset(self, trainer: Trainer, pl_module: BaseDeepMILModule, stage: ModelKey) -> None:
"""Synchronises the processes in DDP mode and resets the loss cache for the next epoch."""
if self.should_cache_loss_values(trainer.current_epoch):
self.gather_loss_cache(rank=pl_module.global_rank, stage=stage)
if pl_module.global_rank == 0:
self.save_loss_cache(trainer.current_epoch, stage)
self.loss_cache[stage] = self.get_empty_loss_cache() # reset loss cache for all processes
def save_loss_outliers_analaysis_results(self, stage: ModelKey) -> None:
"""Saves the loss outliers analysis results."""
all_slides = self.select_slides_for_epoch(epoch=0, stage=stage)
all_loss_values_per_slides = self.select_all_losses_for_selected_slides(all_slides, stage=stage)
self.sanity_check_loss_values(all_loss_values_per_slides, stage=stage)
self.save_loss_ranks(all_loss_values_per_slides, stage=stage)
top_slides, top_slides_entropy = self.select_slides_and_entropy_across_epochs(stage, high=True)
self.plot_slides_loss_scatter(top_slides, top_slides_entropy, stage, high=True)
bottom_slides, bottom_slides_entropy = self.select_slides_and_entropy_across_epochs(stage, high=False)
self.plot_slides_loss_scatter(bottom_slides, bottom_slides_entropy, stage, high=False)
for epoch in self.epochs_range:
epoch_slides = self.select_slides_for_epoch(epoch, stage=stage)
top_slides = epoch_slides[:self.num_slides_heatmap]
top_slides_loss_values = self.select_all_losses_for_selected_slides(top_slides, stage=stage)
self.plot_loss_heatmap_for_slides_of_epoch(top_slides_loss_values, epoch, stage, high=True)
bottom_slides = epoch_slides[-self.num_slides_heatmap:]
bottom_slides_loss_values = self.select_all_losses_for_selected_slides(bottom_slides, stage=stage)
self.plot_loss_heatmap_for_slides_of_epoch(bottom_slides_loss_values, epoch, stage, high=False)
self.loss_cache[stage] = self.get_empty_loss_cache() # reset loss cache
def handle_loss_exceptions(self, stage: ModelKey, exception: Exception) -> None:
"""Handles the loss exceptions. If log_exceptions is True, logs the exception as warnings, else raises it."""
if self.log_exceptions:
# If something goes wrong, we don't want to crash the training. We just log the error and carry on
# validation.
logging.warning(f"Error while detecting {stage} loss values outliers: {exception}")
else:
# If we want to debug the error, we raise it. This will crash the training. This is useful when
# running smoke tests.
raise Exception(f"Error while detecting {stage} loss values outliers: {exception}")
def on_train_batch_end( # type: ignore
self,
trainer: Trainer,
pl_module: BaseDeepMILModule,
outputs: BatchResultsType,
batch: Dict,
batch_idx: int,
unused: int = 0
) -> None:
"""Caches train loss values per slide at each training step in a local variable self.loss_cache."""
self.update_loss_cache(trainer, outputs, batch, stage=ModelKey.TRAIN)
def on_validation_batch_end( # type: ignore
self,
trainer: Trainer,
pl_module: BaseDeepMILModule,
outputs: BatchResultsType,
batch: Any,
batch_idx: int,
dataloader_idx: int
) -> None:
"""Caches validation loss values per slide at each training step in a local variable self.loss_cache."""
self.update_loss_cache(trainer, outputs, batch, stage=ModelKey.VAL)
def on_train_epoch_end(self, trainer: Trainer, pl_module: BaseDeepMILModule) -> None: # type: ignore
"""Gathers loss values per slide from all processes at the end of each epoch and saves them to a csv file."""
self.synchronise_processes_and_reset(trainer, pl_module, ModelKey.TRAIN)
def on_validation_epoch_end(self, trainer: Trainer, pl_module: BaseDeepMILModule) -> None: # type: ignore
"""Gathers loss values per slide from all processes at the end of each epoch and saves them to a csv file."""
self.synchronise_processes_and_reset(trainer, pl_module, ModelKey.VAL)
def on_train_end(self, trainer: Trainer, pl_module: BaseDeepMILModule) -> None: # type: ignore
"""Hook called at the end of training. Plot the loss heatmap and scratter plots after ranking the slides by loss
values."""
if pl_module.global_rank == 0:
try:
self.save_loss_outliers_analaysis_results(stage=ModelKey.TRAIN)
except Exception as e:
self.handle_loss_exceptions(stage=ModelKey.TRAIN, exception=e)
def on_validation_end(self, trainer: Trainer, pl_module: BaseDeepMILModule) -> None: # type: ignore
"""Hook called at the end of validation. Plot the loss heatmap and scratter plots after ranking the slides by
loss values."""
epoch = trainer.current_epoch
if pl_module.global_rank == 0 and not pl_module._on_extra_val_epoch and epoch == (self.max_epochs - 1):
try:
self.save_loss_outliers_analaysis_results(stage=ModelKey.VAL)
except Exception as e:
self.handle_loss_exceptions(stage=ModelKey.VAL, exception=e)

Просмотреть файл

@ -7,16 +7,16 @@ import param
from torch import nn
from pathlib import Path
from typing import Optional, Tuple
from torchvision.models.resnet import resnet18, resnet50
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, CheckpointDownloader
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
from health_ml.utils.checkpoint_utils import CheckpointParser
from health_cpath.models.encoders import (
HistoSSLEncoder,
ImageNetEncoder,
ImageNetEncoder_Resnet50,
ImageNetSimCLREncoder,
SSLEncoder,
TileEncoder,
Resnet18,
Resnet50,
Resnet18_NoPreproc,
Resnet50_NoPreproc,
)
from health_ml.networks.layers.attention_layers import (
AttentionLayer,
@ -28,14 +28,28 @@ from health_ml.networks.layers.attention_layers import (
)
def set_module_gradients_enabled(model: nn.Module, tuning_flag: bool) -> None:
"""Given a model, enable or disable gradients for all parameters.
:param model: A PyTorch model.
:param tuning_flag: A boolean indicating whether to enable or disable gradients for the model parameters.
"""
for params in model.parameters():
params.requires_grad = tuning_flag
class EncoderParams(param.Parameterized):
"""Parameters class to group all encoder specific attributes for deepmil module. """
encoder_type: str = param.String(doc="Name of the encoder class to use.")
tile_size: int = param.Integer(default=224, bounds=(1, None), doc="Tile width/height, in pixels.")
n_channels: int = param.Integer(default=3, bounds=(1, None), doc="Number of channels in the tile.")
is_finetune: bool = param.Boolean(
False, doc="If True, fine-tune the encoder during training. If False (default), " "keep the encoder frozen."
tune_encoder: bool = param.Boolean(
False, doc="If True, fine-tune the encoder during training. If False (default), keep the encoder frozen."
)
pretrained_encoder = param.Boolean(
False, doc="If True, transfer weights from the pretrained model (specified in `src_checkpoint`) to the encoder."
"Else (False), keep the encoder weights as defined by the `encoder_type`."
)
is_caching: bool = param.Boolean(
default=False,
@ -45,26 +59,35 @@ class EncoderParams(param.Parameterized):
encoding_chunk_size: int = param.Integer(
default=0, doc="If > 0 performs encoding in chunks, by enconding_chunk_size tiles " "per chunk"
)
ssl_checkpoint: CheckpointParser = param.ClassSelector(class_=CheckpointParser, default=None,
instantiate=False, doc=CheckpointParser.DOC)
def get_encoder(self, ssl_ckpt_run_id: Optional[str], outputs_folder: Optional[Path]) -> TileEncoder:
def validate(self) -> None:
"""Validate the encoder parameters."""
if self.encoder_type == SSLEncoder.__name__ and not self.ssl_checkpoint:
raise ValueError("SSLEncoder requires an ssl_checkpoint. Please specify a valid checkpoint. "
f"{CheckpointParser.INFO_MESSAGE}")
def get_encoder(self, outputs_folder: Optional[Path]) -> TileEncoder:
"""Given the current encoder parameters, returns the encoder object.
:param ssl_ckpt_run_id: The AML run id for SSL checkpoint download.
:param outputs_folder: The output folder where SSL checkpoint should be saved.
:param encoder_params: The encoder arguments that define the encoder class object depending on the encoder type.
:raises ValueError: If the encoder type is not supported.
:return: A TileEncoder instance for deepmil module.
"""
encoder: TileEncoder
if self.encoder_type == ImageNetEncoder.__name__:
encoder = ImageNetEncoder(
feature_extraction_model=resnet18, tile_size=self.tile_size, n_channels=self.n_channels,
)
elif self.encoder_type == ImageNetEncoder_Resnet50.__name__:
# Myronenko et al. 2021 uses Resnet50 CNN encoder
encoder = ImageNetEncoder_Resnet50(
feature_extraction_model=resnet50, tile_size=self.tile_size, n_channels=self.n_channels,
)
if self.encoder_type == Resnet18.__name__:
encoder = Resnet18(tile_size=self.tile_size, n_channels=self.n_channels)
elif self.encoder_type == Resnet18_NoPreproc.__name__:
encoder = Resnet18_NoPreproc(tile_size=self.tile_size, n_channels=self.n_channels)
elif self.encoder_type == Resnet50.__name__:
encoder = Resnet50(tile_size=self.tile_size, n_channels=self.n_channels)
elif self.encoder_type == Resnet50_NoPreproc.__name__:
encoder = Resnet50_NoPreproc(tile_size=self.tile_size, n_channels=self.n_channels)
elif self.encoder_type == ImageNetSimCLREncoder.__name__:
encoder = ImageNetSimCLREncoder(tile_size=self.tile_size, n_channels=self.n_channels)
@ -73,24 +96,15 @@ class EncoderParams(param.Parameterized):
encoder = HistoSSLEncoder(tile_size=self.tile_size, n_channels=self.n_channels)
elif self.encoder_type == SSLEncoder.__name__:
assert ssl_ckpt_run_id and outputs_folder, "SSLEncoder requires ssl_ckpt_run_id and outputs_folder"
downloader = CheckpointDownloader(run_id=ssl_ckpt_run_id,
download_dir=outputs_folder,
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
remote_checkpoint_dir=Path(DEFAULT_AML_CHECKPOINT_DIR))
assert outputs_folder is not None, "outputs_folder cannot be None for SSLEncoder"
encoder = SSLEncoder(
pl_checkpoint_path=downloader.local_checkpoint_path,
pl_checkpoint_path=self.ssl_checkpoint.get_path(outputs_folder),
tile_size=self.tile_size,
n_channels=self.n_channels,
)
else:
raise ValueError(f"Unsupported encoder type: {self.encoder_type}")
if self.is_finetune:
for params in encoder.parameters():
params.requires_grad = True
else:
encoder.eval()
set_module_gradients_enabled(encoder, tuning_flag=self.tune_encoder)
return encoder
@ -106,9 +120,20 @@ class PoolingParams(param.Parameterized):
default=4, doc="If transformer pooling is chosen, this defines the number of encoding layers.",
)
num_transformer_pool_heads: int = param.Integer(
4,
doc="If transformer pooling is chosen, this defines the number\
of attention heads.",
default=4, doc="If transformer pooling is chosen, this defines the number of attention heads.",
)
tune_pooling: bool = param.Boolean(
default=True,
doc="If True (default), fine-tune the pooling layer during training. If False, keep the pooling layer frozen.",
)
pretrained_pooling = param.Boolean(
default=False,
doc="If True, transfer weights from the pretrained model (specified in `src_checkpoint`) to the pooling"
"layer. Else (False), initialize the pooling layer randomly.",
)
transformer_dropout: float = param.Number(
default=0.0,
doc="If transformer pooling is chosen, this defines the dropout of the tranformer encoder layers.",
)
def get_pooling_layer(self, num_encoding: int) -> Tuple[nn.Module, int]:
@ -121,24 +146,53 @@ class PoolingParams(param.Parameterized):
"""
pooling_layer: nn.Module
if self.pool_type == AttentionLayer.__name__:
pooling_layer = AttentionLayer(num_encoding, self.pool_hidden_dim, self.pool_out_dim)
pooling_layer = AttentionLayer(input_dims=num_encoding,
hidden_dims=self.pool_hidden_dim,
attention_dims=self.pool_out_dim)
elif self.pool_type == GatedAttentionLayer.__name__:
pooling_layer = GatedAttentionLayer(num_encoding, self.pool_hidden_dim, self.pool_out_dim)
pooling_layer = GatedAttentionLayer(input_dims=num_encoding,
hidden_dims=self.pool_hidden_dim,
attention_dims=self.pool_out_dim)
elif self.pool_type == MeanPoolingLayer.__name__:
pooling_layer = MeanPoolingLayer()
elif self.pool_type == MaxPoolingLayer.__name__:
pooling_layer = MaxPoolingLayer()
elif self.pool_type == TransformerPooling.__name__:
pooling_layer = TransformerPooling(
self.num_transformer_pool_layers, self.num_transformer_pool_heads, num_encoding
)
num_layers=self.num_transformer_pool_layers,
num_heads=self.num_transformer_pool_heads,
dim_representation=num_encoding,
transformer_dropout=self.transformer_dropout)
self.pool_out_dim = 1 # currently this is hardcoded in forward of the TransformerPooling
elif self.pool_type == TransformerPoolingBenchmark.__name__:
pooling_layer = TransformerPoolingBenchmark(
self.num_transformer_pool_layers, self.num_transformer_pool_heads, num_encoding, self.pool_hidden_dim
)
num_layers=self.num_transformer_pool_layers,
num_heads=self.num_transformer_pool_heads,
dim_representation=num_encoding,
hidden_dim=self.pool_hidden_dim,
transformer_dropout=self.transformer_dropout)
self.pool_out_dim = 1 # currently this is hardcoded in forward of the TransformerPooling
else:
raise ValueError(f"Unsupported pooling type: {self.pooling_type}")
raise ValueError(f"Unsupported pooling type: {self.pool_type}")
num_features = num_encoding * self.pool_out_dim
set_module_gradients_enabled(pooling_layer, tuning_flag=self.tune_pooling)
return pooling_layer, num_features
class ClassifierParams(param.Parameterized):
dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.")
tune_classifier: bool = param.Boolean(
default=True,
doc="If True (default), fine-tune the classifier during training. If False, keep the classifier frozen.")
pretrained_classifier: bool = param.Boolean(
default=False,
doc="If True, will use classifier weights from pretrained model specified in src_checkpoint. If False, will "
"initiliaze classifier with random weights.")
def get_classifier(self, in_features: int, out_features: int) -> nn.Module:
classifier_layer = nn.Linear(in_features=in_features,
out_features=out_features)
set_module_gradients_enabled(classifier_layer, tuning_flag=self.tune_classifier)
if self.dropout_rate is None:
return classifier_layer
return nn.Sequential(nn.Dropout(self.dropout_rate), classifier_layer)

Просмотреть файл

@ -50,6 +50,7 @@ class ResultsKey(str, Enum):
FEATURES = 'features'
IMAGE_PATH = 'image_path'
LOSS = 'loss'
LOSS_PER_SAMPLE = 'loss_per_sample'
PROB = 'prob'
CLASS_PROBS = 'prob_class'
PRED_LABEL = 'pred_label'
@ -59,6 +60,7 @@ class ResultsKey(str, Enum):
TILE_TOP = 'top'
TILE_RIGHT = 'right'
TILE_BOTTOM = 'bottom'
ENTROPY = 'entropy'
class MetricsKey(str, Enum):
@ -71,6 +73,8 @@ class MetricsKey(str, Enum):
RECALL = 'recall'
F1 = 'f1score'
COHENKAPPA = 'cohenkappa'
AVERAGE_PRECISION = 'average_precision'
SPECIFICITY = 'specificity'
class ModelKey(str, Enum):
@ -85,10 +89,19 @@ class AMLMetricsJsonKey(str, Enum):
VALUE = 'value'
N_CLASSES = 'n_classes'
CLASS_NAMES = 'class_names'
MAX_EPOCHS = 'max_epochs'
class PlotOption(Enum):
TOP_BOTTOM_TILES = "top_bottom_tiles"
SLIDE_THUMBNAIL_HEATMAP = "slide_thumbnail_heatmap"
SLIDE_THUMBNAIL = "slide_thumbnail"
ATTENTION_HEATMAP = "attention_heatmap"
CONFUSION_MATRIX = "confusion_matrix"
HISTOGRAM = "histogram"
PR_CURVE = "pr_curve"
class DeepMILSubmodules(str, Enum):
ENCODER = 'encoder'
POOLING = 'aggregation_fn'
CLASSIFIER = 'classifier_fn'

Просмотреть файл

@ -25,6 +25,8 @@ OUTPUTS_CSV_FILENAME = "test_output.csv"
VAL_OUTPUTS_SUBDIR = "val"
PREV_VAL_OUTPUTS_SUBDIR = "val_old"
TEST_OUTPUTS_SUBDIR = "test"
EXTRA_VAL_OUTPUTS_SUBDIR = "extra_val"
EXTRA_PREFIX = "extra_"
AML_OUTPUTS_DIR = "outputs"
AML_LEGACY_TEST_OUTPUTS_CSV = "/".join([AML_OUTPUTS_DIR, OUTPUTS_CSV_FILENAME])
@ -55,10 +57,18 @@ def validate_class_names(class_names: Optional[Sequence[str]], n_classes: int) -
def validate_slide_datasets_for_plot_options(
plot_options: Collection[PlotOption], slides_dataset: Optional[SlidesDataset]
) -> None:
if PlotOption.SLIDE_THUMBNAIL_HEATMAP in plot_options and not slides_dataset:
raise ValueError("You can not plot slide thumbnails and heatmaps without setting a slides_dataset. "
"Please remove `PlotOption.SLIDE_THUMBNAIL_HEATMAP` from your plot options or provide "
"a slide dataset.")
"""Validate that the specified plot options are compatible with the specified slides dataset.
:param plot_options: Plot options to validate.
:param slides_dataset: Slides dataset to validate against.
"""
def _validate_slide_plot_option(plot_option: PlotOption) -> None:
if plot_option in plot_options and not slides_dataset:
raise ValueError(f"Plot option {plot_option} requires a slides dataset")
_validate_slide_plot_option(PlotOption.SLIDE_THUMBNAIL)
_validate_slide_plot_option(PlotOption.ATTENTION_HEATMAP)
def normalize_dict_for_df(dict_old: Dict[ResultsKey, Any]) -> Dict[str, Any]:
@ -194,7 +204,7 @@ class OutputsPolicy:
YAML().dump(contents, self.best_metric_file_path)
def should_save_validation_outputs(self, metrics_dict: Mapping[MetricsKey, Metric], epoch: int,
is_global_rank_zero: bool = True) -> bool:
is_global_rank_zero: bool = True, on_extra_val: bool = False) -> bool:
"""Determine whether validation outputs should be saved given the current epoch's metrics.
:param metrics_dict: Current epoch's metrics dictionary from
@ -202,8 +212,11 @@ class OutputsPolicy:
:param epoch: Current epoch number.
:param is_global_rank_zero: Whether this is the global rank-0 process in distributed scenarios.
Set to `True` (default) if running a single process.
:param on_extra_val: Whether this is an extra validation epoch (e.g. after training).
:return: Whether this is the best validation epoch so far.
"""
if on_extra_val:
return False
metric = metrics_dict[self.primary_val_metric]
# If the metric hasn't been updated we don't want to save it
if not metric._update_called:
@ -244,7 +257,8 @@ class DeepMILOutputsHandler:
def __init__(self, outputs_root: Path, n_classes: int, tile_size: int, level: int,
class_names: Optional[Sequence[str]], primary_val_metric: MetricsKey,
maximise: bool, val_plot_options: Collection[PlotOption],
test_plot_options: Collection[PlotOption]) -> None:
test_plot_options: Collection[PlotOption], wsi_has_mask: bool = True,
val_set_is_dist: bool = True) -> None:
"""
:param outputs_root: Root directory where to save all produced outputs.
:param n_classes: Number of MIL classes (set `n_classes=1` for binary).
@ -256,6 +270,11 @@ class DeepMILOutputsHandler:
:param maximise: Whether higher is better for `primary_val_metric`.
:param val_plot_options: The desired plot options for validation time.
:param test_plot_options: The desired plot options for test time.
:param wsi_has_mask: Whether the whole slides have a mask to crop specific ROIs.
:param val_set_is_dist: If True, the validation set is distributed across processes. Otherwise, the validation
set is replicated on each process. This shouldn't affect the results, as we take the mean of the validation
set metrics across processes. This is only relevant for the outputs_handler, which needs to know whether to
gather the validation set outputs across processes or not before saving them.
"""
self.outputs_root = outputs_root
self.n_classes = n_classes
@ -273,19 +292,28 @@ class DeepMILOutputsHandler:
plot_options=val_plot_options,
level=self.level,
tile_size=self.tile_size,
class_names=self.class_names
class_names=self.class_names,
stage=ModelKey.VAL,
wsi_has_mask=wsi_has_mask
)
self.test_plots_handler = DeepMILPlotsHandler(
plot_options=test_plot_options,
level=self.level,
tile_size=self.tile_size,
class_names=self.class_names
class_names=self.class_names,
stage=ModelKey.TEST,
wsi_has_mask=wsi_has_mask
)
self.val_set_is_dist = val_set_is_dist
@property
def validation_outputs_dir(self) -> Path:
return self.outputs_root / VAL_OUTPUTS_SUBDIR
@property
def extra_validation_outputs_dir(self) -> Path:
return self.outputs_root / EXTRA_VAL_OUTPUTS_SUBDIR
@property
def previous_validation_outputs_dir(self) -> Path:
return self.validation_outputs_dir.with_name(PREV_VAL_OUTPUTS_SUBDIR)
@ -300,11 +328,15 @@ class DeepMILOutputsHandler:
self.test_plots_handler.slides_dataset = slides_dataset
self.val_plots_handler.slides_dataset = slides_dataset
def should_gather_tiles(self, plots_handler: DeepMILPlotsHandler) -> bool:
return PlotOption.TOP_BOTTOM_TILES in plots_handler.plot_options and self.tiles_selector is not None
def _save_outputs(self, epoch_results: EpochResultsType, outputs_dir: Path, stage: ModelKey = ModelKey.VAL) -> None:
"""Trigger the rendering and saving of DeepMIL outputs and figures.
:param epoch_results: Aggregated results from all epoch batches.
:param outputs_dir: Specific directory into which outputs should be saved (different for validation and test).
:param stage: The stage of the model (e.g. `ModelKey.VAL` or `ModelKey.TEST`).
"""
# outputs object consists of a list of dictionaries (of metadata and results, including encoded features)
# It can be indexed as outputs[batch_idx][batch_key][bag_idx][tile_idx]
@ -321,8 +353,7 @@ class DeepMILOutputsHandler:
plots_handler.save_plots(outputs_dir, self.tiles_selector, results)
def save_validation_outputs(self, epoch_results: EpochResultsType, metrics_dict: Mapping[MetricsKey, Metric],
epoch: int, is_global_rank_zero: bool = True, run_extra_val_epoch: bool = False
) -> None:
epoch: int, is_global_rank_zero: bool = True, on_extra_val: bool = False) -> None:
"""Render and save validation epoch outputs, according to the configured :py:class:`OutputsPolicy`.
:param epoch_results: Aggregated results from all epoch batches, as passed to :py:meth:`validation_epoch_end()`.
@ -331,29 +362,35 @@ class DeepMILOutputsHandler:
:param is_global_rank_zero: Whether this is the global rank-0 process in distributed scenarios.
Set to `True` (default) if running a single process.
:param epoch: Current epoch number.
:param on_extra_val: Whether this is an extra validation epoch (e.g. after training).
"""
# All DDP processes must reach this point to allow synchronising epoch results
gathered_epoch_results = gather_results(epoch_results)
if PlotOption.TOP_BOTTOM_TILES in self.val_plots_handler.plot_options and self.tiles_selector:
self.tiles_selector.gather_selected_tiles_across_devices()
# All DDP processes must reach this point to allow synchronising epoch results if val_set_is_dist is True
if self.val_set_is_dist:
epoch_results = gather_results(epoch_results)
if self.should_gather_tiles(self.val_plots_handler):
self.tiles_selector.gather_selected_tiles_across_devices() # type: ignore
# Only global rank-0 process should actually render and save the outputs
# We also want to save the plots of the extra validation epoch
if (
self.outputs_policy.should_save_validation_outputs(metrics_dict, epoch, is_global_rank_zero)
or (run_extra_val_epoch and is_global_rank_zero)
):
if self.outputs_policy.should_save_validation_outputs(metrics_dict, epoch, is_global_rank_zero, on_extra_val):
# First move existing outputs to a temporary directory, to avoid mixing
# outputs of different epochs in case writing fails halfway through
if self.validation_outputs_dir.exists():
replace_directory(source=self.validation_outputs_dir,
target=self.previous_validation_outputs_dir)
self._save_outputs(gathered_epoch_results, self.validation_outputs_dir, ModelKey.VAL)
self._save_outputs(epoch_results, self.validation_outputs_dir, ModelKey.VAL)
# Writing completed successfully; delete temporary back-up
if self.previous_validation_outputs_dir.exists():
shutil.rmtree(self.previous_validation_outputs_dir, ignore_errors=True)
elif on_extra_val and is_global_rank_zero:
self._save_outputs(epoch_results, self.extra_validation_outputs_dir, ModelKey.VAL)
# Reset the top and bottom slides heaps
if self.should_gather_tiles(self.val_plots_handler):
self.tiles_selector._clear_cached_slides_heaps() # type: ignore
def save_test_outputs(self, epoch_results: EpochResultsType, is_global_rank_zero: bool = True) -> None:
"""Render and save test epoch outputs.
@ -365,8 +402,8 @@ class DeepMILOutputsHandler:
"""
# All DDP processes must reach this point to allow synchronising epoch results
gathered_epoch_results = gather_results(epoch_results)
if PlotOption.TOP_BOTTOM_TILES in self.test_plots_handler.plot_options and self.tiles_selector:
self.tiles_selector.gather_selected_tiles_across_devices()
if self.should_gather_tiles(self.test_plots_handler):
self.tiles_selector.gather_selected_tiles_across_devices() # type: ignore
# Only global rank-0 process should actually render and save the outputs-
if self.outputs_policy.should_save_test_outputs(is_global_rank_zero):

Просмотреть файл

@ -8,6 +8,7 @@ from typing import Any, Collection, List, Optional, Sequence, Tuple, Dict
from sklearn.metrics import confusion_matrix
from torch import Tensor
import matplotlib.pyplot as plt
from health_cpath.datasets.base_dataset import SlidesDataset
from health_cpath.utils.viz_utils import (
@ -17,12 +18,14 @@ from health_cpath.utils.viz_utils import (
plot_scores_hist,
plot_slide,
)
from health_cpath.utils.analysis_plot_utils import plot_pr_curve, format_pr_or_roc_axes
from health_cpath.utils.naming import PlotOption, ResultsKey, SlideKey
from health_cpath.utils.tiles_selection_utils import SlideNode, TilesSelector
from health_cpath.utils.viz_utils import load_image_dict, save_figure
ResultsType = Dict[ResultsKey, List[Any]]
SlideDictType = Dict[SlideKey, Any]
def validate_class_names_for_plot_options(
@ -37,21 +40,44 @@ def validate_class_names_for_plot_options(
return class_names
def save_scores_histogram(results: ResultsType, figures_dir: Path) -> None:
def save_scores_histogram(results: ResultsType, figures_dir: Path, stage: str = '') -> None:
"""Plots and saves histogram scores figure in its dedicated directory.
:param results: List that contains slide_level dicts
:param results: Dict of lists that contains slide_level results
:param figures_dir: The path to the directory where to save the histogram scores.
:param stage: Test or validation, used to name the figure. Empty string by default.
"""
fig = plot_scores_hist(results)
save_figure(fig=fig, figpath=figures_dir / "hist_scores.png")
save_figure(fig=fig, figpath=figures_dir / f"hist_scores_{stage}.png")
def save_confusion_matrix(results: ResultsType, class_names: Sequence[str], figures_dir: Path) -> None:
def save_pr_curve(results: ResultsType, figures_dir: Path, stage: str = '') -> None:
"""Plots and saves PR curve figure in its dedicated directory. This implementation
only works for binary classification.
''
:param results: Dict of lists that contains slide_level results
:param figures_dir: The path to the directory where to save the histogram scores
:param stage: Test or validation, used to name the figure. Empty string by default.
"""
true_labels = [i.item() if isinstance(i, Tensor) else i for i in results[ResultsKey.TRUE_LABEL]]
if len(set(true_labels)) == 2:
scores = [i.item() if isinstance(i, Tensor) else i for i in results[ResultsKey.PROB]]
fig, ax = plt.subplots()
plot_pr_curve(true_labels, scores, legend_label=stage, ax=ax)
ax.legend()
format_pr_or_roc_axes(plot_type='pr', ax=ax)
save_figure(fig=fig, figpath=figures_dir / f"pr_curve_{stage}.png")
else:
logging.warning("The PR curve plot implementation works only for binary cases, this plot will be skipped.")
def save_confusion_matrix(results: ResultsType, class_names: Sequence[str], figures_dir: Path, stage: str = '') -> None:
"""Plots and saves confusion matrix figure in its dedicated directory.
:param results: Dict of lists that contains slide_level results
:param class_names: List of class names.
:param figures_dir: The path to the directory where to save the confusion matrix.
:param stage: Test or validation, used to name the figure. Empty string by default.
"""
true_labels = [i.item() if isinstance(i, Tensor) else i for i in results[ResultsKey.TRUE_LABEL]]
pred_labels = [i.item() if isinstance(i, Tensor) else i for i in results[ResultsKey.PRED_LABEL]]
@ -68,11 +94,11 @@ def save_confusion_matrix(results: ResultsType, class_names: Sequence[str], figu
true_labels,
pred_labels,
labels=all_potential_labels,
normalize="pred"
normalize="true"
)
fig = plot_normalized_confusion_matrix(cm=cf_matrix_n, class_names=(class_names))
save_figure(fig=fig, figpath=figures_dir / "normalized_confusion_matrix.png")
save_figure(fig=fig, figpath=figures_dir / f"normalized_confusion_matrix_{stage}.png")
def save_top_and_bottom_tiles(
@ -95,12 +121,24 @@ def save_top_and_bottom_tiles(
save_figure(fig=bottom_tiles_fig, figpath=figures_dir / f"{slide_node.slide_id}_bottom.png")
def save_slide_thumbnail_and_heatmap(
def save_slide_thumbnail(case: str, slide_node: SlideNode, slide_dict: SlideDictType, figures_dir: Path) -> None:
"""Plots and saves a slide thumbnail
:param case: The report case (e.g., TP, FN, ...)
:param slide_node: The slide node that encapsulates the slide metadata.
:param slide_dict: The slide dictionary that contains the slide image and other metadata.
:param figures_dir: The path to the directory where to save the plots.
"""
fig = plot_slide(case=case, slide_node=slide_node, slide_image=slide_dict[SlideKey.IMAGE], scale=1.0)
save_figure(fig=fig, figpath=figures_dir / f"{slide_node.slide_id}_thumbnail.png")
def save_attention_heatmap(
case: str,
slide_node: SlideNode,
slide_dict: SlideDictType,
figures_dir: Path,
results: ResultsType,
slides_dataset: SlidesDataset,
tile_size: int = 224,
level: int = 1,
) -> None:
@ -108,29 +146,19 @@ def save_slide_thumbnail_and_heatmap(
:param case: The report case (e.g., TP, FN, ...)
:param slide_node: The slide node that encapsulates the slide metadata.
:param slide_dict: The slide dictionary that contains the slide image and other metadata.
:param figures_dir: The path to the directory where to save the plots.
:param results: Dict containing ResultsKey keys (e.g. slide id) and values as lists of output slides.
:param slides_dataset: The slides dataset from where to pick the slide.
:param tile_size: Size of each tile. Default 224.
:param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original,
1 for 4x downsampled, 2 for 16x downsampled). Default 1.
"""
slide_index = slides_dataset.dataset_df.index.get_loc(slide_node.slide_id)
assert isinstance(slide_index, int), f"Got non-unique slide ID: {slide_node.slide_id}"
slide_dict = slides_dataset[slide_index]
slide_dict = load_image_dict(slide_dict, level=level, margin=0)
slide_image = slide_dict[SlideKey.IMAGE]
location_bbox = slide_dict[SlideKey.LOCATION]
fig = plot_slide(case=case, slide_node=slide_node, slide_image=slide_image, scale=1.0)
save_figure(fig=fig, figpath=figures_dir / f"{slide_node.slide_id}_thumbnail.png")
fig = plot_heatmap_overlay(
case=case,
slide_node=slide_node,
slide_image=slide_image,
slide_image=slide_dict[SlideKey.IMAGE],
results=results,
location_bbox=location_bbox,
location_bbox=slide_dict[SlideKey.ORIGIN],
tile_size=tile_size,
level=level,
)
@ -152,9 +180,11 @@ class DeepMILPlotsHandler:
tile_size: int = 224,
num_columns: int = 4,
figsize: Tuple[int, int] = (10, 10),
stage: str = '',
class_names: Optional[Sequence[str]] = None,
wsi_has_mask: bool = True,
) -> None:
"""_summary_
"""Class that handles the plotting of DeepMIL results.
:param plot_options: A set of plot options to produce the desired plot outputs.
:param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original,
@ -162,6 +192,7 @@ class DeepMILPlotsHandler:
:param tile_size: _description_, defaults to 224
:param num_columns: Number of columns to create the subfigures grid, defaults to 4
:param figsize: The figure size of tiles attention plots, defaults to (10, 10)
:param stage: Test or Validation, used to name the plots
:param class_names: List of class names, defaults to None
:param slides_dataset: The slides dataset from where to load the whole slide images, defaults to None
"""
@ -171,34 +202,38 @@ class DeepMILPlotsHandler:
self.tile_size = tile_size
self.num_columns = num_columns
self.figsize = figsize
self.stage = stage
self.wsi_has_mask = wsi_has_mask
self.slides_dataset: Optional[SlidesDataset] = None
def get_slide_dict(self, slide_node: SlideNode) -> SlideDictType:
"""Returns the slide dictionary for a given slide node"""
assert self.slides_dataset is not None, "Cannot plot attention heatmap or wsi without slides dataset"
slide_index = self.slides_dataset.dataset_df.index.get_loc(slide_node.slide_id)
assert isinstance(slide_index, int), f"Got non-unique slide ID: {slide_node.slide_id}"
slide_dict = self.slides_dataset[slide_index]
slide_dict = load_image_dict(slide_dict, level=self.level, margin=0, wsi_has_mask=self.wsi_has_mask)
return slide_dict
def save_slide_node_figures(
self, case: str, slide_node: SlideNode, outputs_dir: Path, results: ResultsType
) -> None:
"""Plots and saves all slide related figures, e.g., `TOP_BOTTOM_TILES` and `SLIDE_THUMBNAIL_HEATMAP`"""
"""Plots and saves all slide related figures: `TOP_BOTTOM_TILES`, `SLIDE_THUMBNAIL` and `ATTENTION_HEATMAP`."""
case_dir = make_figure_dirs(subfolder=case, parent_dir=outputs_dir)
if PlotOption.TOP_BOTTOM_TILES in self.plot_options:
save_top_and_bottom_tiles(
case=case,
slide_node=slide_node,
figures_dir=case_dir,
num_columns=self.num_columns,
figsize=self.figsize,
)
if PlotOption.SLIDE_THUMBNAIL_HEATMAP in self.plot_options:
assert self.slides_dataset
save_slide_thumbnail_and_heatmap(
case=case,
slide_node=slide_node,
figures_dir=case_dir,
results=results,
slides_dataset=self.slides_dataset,
tile_size=self.tile_size,
level=self.level,
)
save_top_and_bottom_tiles(case, slide_node, case_dir, self.num_columns, self.figsize)
if PlotOption.ATTENTION_HEATMAP in self.plot_options or PlotOption.SLIDE_THUMBNAIL in self.plot_options:
slide_dict = self.get_slide_dict(slide_node=slide_node)
if PlotOption.SLIDE_THUMBNAIL in self.plot_options:
save_slide_thumbnail(case=case, slide_node=slide_node, slide_dict=slide_dict, figures_dir=case_dir)
if PlotOption.ATTENTION_HEATMAP in self.plot_options:
save_attention_heatmap(
case, slide_node, slide_dict, case_dir, results, self.tile_size, level=self.level
)
def save_plots(self, outputs_dir: Path, tiles_selector: Optional[TilesSelector], results: ResultsType) -> None:
"""Plots and saves all selected plot options during inference (validation or test) time.
@ -214,12 +249,15 @@ class DeepMILPlotsHandler:
)
figures_dir = make_figure_dirs(subfolder="fig", parent_dir=outputs_dir)
if PlotOption.PR_CURVE in self.plot_options:
save_pr_curve(results=results, figures_dir=figures_dir, stage=self.stage)
if PlotOption.HISTOGRAM in self.plot_options:
save_scores_histogram(results=results, figures_dir=figures_dir)
save_scores_histogram(results=results, figures_dir=figures_dir, stage=self.stage,)
if PlotOption.CONFUSION_MATRIX in self.plot_options:
assert self.class_names
save_confusion_matrix(results, class_names=self.class_names, figures_dir=figures_dir)
save_confusion_matrix(results, class_names=self.class_names, figures_dir=figures_dir, stage=self.stage)
if tiles_selector:
for class_id in range(tiles_selector.n_classes):

Просмотреть файл

@ -4,7 +4,7 @@
# -------------------------------------------------------------------------------------------
from pathlib import Path
from typing import Dict, List, Sequence, Tuple
from typing import Any, Dict, List, Sequence, Tuple
import dateutil.parser
import numpy as np
@ -39,8 +39,8 @@ def run_has_val_and_test_outputs(run: Run) -> bool:
f"{AML_TEST_OUTPUTS_CSV}): {available_files}")
def crossval_runs_have_val_and_test_outputs(parent_run: Run) -> bool:
"""Checks whether all child cross-validation runs have both validation and test outputs files.
def child_runs_have_val_and_test_outputs(parent_run: Run) -> bool:
"""Checks whether all child hyperdrive runs have both validation and test outputs files.
:param parent_run: The parent Hyperdrive run.
:raises ValueError: If any of the child runs does not have the expected output files, or if
@ -58,30 +58,30 @@ def crossval_runs_have_val_and_test_outputs(parent_run: Run) -> bool:
"test-only outputs and with both validation and test outputs")
def collect_crossval_outputs(parent_run_id: str, download_dir: Path, aml_workspace: Workspace,
crossval_arg_name: str = "crossval_index",
output_filename: str = "test_output.csv",
overwrite: bool = False) -> Dict[int, pd.DataFrame]:
"""Fetch output CSV files from cross-validation runs as dataframes.
def collect_hyperdrive_outputs(parent_run_id: str, download_dir: Path, aml_workspace: Workspace,
hyperdrive_arg_name: str = "crossval_index",
output_filename: str = "test_output.csv",
overwrite: bool = False) -> Dict[int, pd.DataFrame]:
"""Fetch output CSV files from Hyperdrive child runs as dataframes.
Will only download the CSV files if they do not already exist locally.
:param parent_run_id: Azure ML run ID for the parent Hyperdrive run.
:param download_dir: Base directory where to download the CSV files. A new sub-directory will
be created for each child run (e.g. `<download_dir>/<crossval index>/*.csv`).
be created for each child run (e.g. `<download_dir>/<hyperdrive_arg_name>/*.csv`).
:param aml_workspace: Azure ML workspace in which the runs were executed.
:param crossval_arg_name: Name of the Hyperdrive argument used for indexing the child runs.
:param hyperdrive_arg_name: Name of the Hyperdrive argument used for indexing the child runs.
:param output_filename: Filename of the output CSVs to download.
:param overwrite: Whether to force the download even if each file already exists locally.
:return: A dictionary of dataframes with the sorted cross-validation indices as keys.
:return: A dictionary of dataframes with the sorted hyperdrive_arg_name indices as keys.
"""
parent_run = get_aml_run_from_run_id(parent_run_id, aml_workspace)
all_outputs_dfs = {}
for child_run in parent_run.get_children():
child_run_index = get_tags_from_hyperdrive_run(child_run, crossval_arg_name)
child_run_index = get_tags_from_hyperdrive_run(child_run, hyperdrive_arg_name)
if child_run_index is None:
raise ValueError(f"Child run expected to have the tag '{crossval_arg_name}'")
raise ValueError(f"Child run expected to have the tag '{hyperdrive_arg_name}'")
child_dir = download_dir / str(child_run_index)
try:
child_csv = download_file_if_necessary(child_run, output_filename, child_dir / output_filename,
@ -92,10 +92,10 @@ def collect_crossval_outputs(parent_run_id: str, download_dir: Path, aml_workspa
return dict(sorted(all_outputs_dfs.items())) # type: ignore
def collect_crossval_metrics(parent_run_id: str, download_dir: Path, aml_workspace: Workspace,
crossval_arg_name: str = "crossval_index",
overwrite: bool = False) -> pd.DataFrame:
"""Fetch metrics logged to Azure ML from cross-validation runs as a dataframe.
def download_hyperdrive_metrics_if_required(parent_run_id: str, download_dir: Path, aml_workspace: Workspace,
hyperdrive_arg_name: str = "crossval_index",
overwrite: bool = False) -> Path:
"""Fetch metrics logged to Azure ML from hyperdrive runs.
Will only download the metrics if they do not already exist locally, as this can take several
seconds for each child run.
@ -103,89 +103,111 @@ def collect_crossval_metrics(parent_run_id: str, download_dir: Path, aml_workspa
:param parent_run_id: Azure ML run ID for the parent Hyperdrive run.
:param download_dir: Directory where to save the downloaded metrics as `aml_metrics.json`.
:param aml_workspace: Azure ML workspace in which the runs were executed.
:param crossval_arg_name: Name of the Hyperdrive argument used for indexing the child runs.
:param hyperdrive_arg_name: Name of the Hyperdrive argument used for indexing the child runs.
:param overwrite: Whether to force the download even if metrics are already saved locally.
:return: A dataframe in the format returned by :py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
:return: The path of the downloaded json file.
"""
metrics_json = download_dir / "aml_metrics.json"
if not overwrite and metrics_json.is_file():
print(f"AML metrics file already exists at {metrics_json}")
metrics_df = pd.read_json(metrics_json)
else:
metrics_df = aggregate_hyperdrive_metrics(run_id=parent_run_id,
child_run_arg_name=crossval_arg_name,
child_run_arg_name=hyperdrive_arg_name,
aml_workspace=aml_workspace)
metrics_json.parent.mkdir(parents=True, exist_ok=True)
print(f"Writing AML metrics file to {metrics_json}")
df_to_json(metrics_df, metrics_json)
return metrics_df.sort_index(axis='columns')
return metrics_json
def get_crossval_metrics_table(metrics_df: pd.DataFrame, metrics_list: Sequence[str]) -> pd.DataFrame:
"""Format raw cross-validation metrics into a table with a summary "Mean ± Std" column.
def collect_hyperdrive_metrics(metrics_json: Path) -> pd.DataFrame:
"""
Collect the hyperdrive metrics from the downloaded metrics json file in a dataframe.
:param metrics_json: Path of the downloaded metrics file `aml_metrics.json`.
:return: A dataframe in the format returned by :py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
"""
metrics_df = pd.read_json(metrics_json).sort_index(axis='columns')
return metrics_df
def get_hyperdrive_metrics_table(metrics_df: pd.DataFrame, metrics_list: Sequence[str]) -> pd.DataFrame:
"""Format raw hyperdrive metrics into a table with a summary "Mean ± Std" column.
Note that this function only supports scalar metrics. To format metrics that are logged
throughout training, you should call :py:func:`get_best_epoch_metrics()` first.
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_crossval_metrics()` and
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_hyperdrive_metrics()` and
:py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
:param metrics_list: The list of metrics to include in the table.
:return: A dataframe with the values of the selected metrics formatted as strings, including a
header and a summary column.
"""
header = ["Metric"] + [f"Split {k}" for k in metrics_df.columns] + ["Mean ± Std"]
header = ["Metric"] + [f"Child {k}" for k in metrics_df.columns] + ["Mean ± Std"]
metrics_rows = []
for metric in metrics_list:
values: pd.Series = metrics_df.loc[metric]
mean = values.mean()
std = values.std()
row = [metric] + [f"{v:.3f}" for v in values] + [f"{mean:.3f} ± {std:.3f}"]
round_values: List[str] = [f"{v:.3f}" if v is not None else str(np.nan) for v in values]
agg_values: List[str] = [f"{mean:.3f} ± {std:.3f}"]
row = [metric] + round_values + agg_values
metrics_rows.append(row)
table = pd.DataFrame(metrics_rows, columns=header).set_index(header[0])
return table
def get_best_epochs(metrics_df: pd.DataFrame, primary_metric: str, maximise: bool = True) -> Dict[int, int]:
"""Determine the best epoch for each cross-validation run based on a given metric.
def get_best_epochs(metrics_df: pd.DataFrame, primary_metric: str, max_epochs_dict: Dict[int, int],
maximise: bool = True) -> Dict[int, Any]:
"""Determine the best epoch for each hyperdrive child run based on a given metric.
The returned epoch indices are relative to the logging frequency of the chosen metric, i.e.
should not be mixed between pipeline stages that log metrics at different epoch intervals.
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_crossval_metrics()` and
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_hyperdrive_metrics()` and
:py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
:param primary_metric: Name of the reference metric to optimise.
:max_epochs_dict: A dictionary of the maximum number of epochs in each cross-validation round.
:param maximise: Whether the given metric should be maximised (minimised if `False`).
:return: Dictionary mapping each cross-validation index to its best epoch.
:return: Dictionary mapping each hyperdrive child index to its best epoch.
"""
best_fn = np.argmax if maximise else np.argmin
best_epochs = metrics_df.loc[primary_metric].apply(best_fn)
return best_epochs.to_dict()
best_epochs: Dict[int, Any] = {}
for child_index in metrics_df.columns:
primary_metric_list = metrics_df[child_index][primary_metric]
if primary_metric_list is not None:
# If extra validation epoch was logged (N+1), return only the first N elements
primary_metric_list = primary_metric_list[:-1] \
if (len(primary_metric_list) == max_epochs_dict[child_index] + 1) else primary_metric_list
best_epochs[child_index] = int(np.argmax(primary_metric_list)
if maximise else np.argmin(primary_metric_list))
else:
best_epochs[child_index] = None
return best_epochs
def get_best_epoch_metrics(metrics_df: pd.DataFrame, metrics_list: Sequence[str],
best_epochs: Dict[int, int]) -> pd.DataFrame:
"""Extract the values of the selected cross-validation metrics at the given best epochs.
best_epochs: Dict[int, Any]) -> pd.DataFrame:
"""Extract the values of the selected hyperdrive metrics at the given best epochs.
The `best_epoch` indices are relative to the logging frequency of the chosen primary metric,
i.e. the metrics in `metrics_list` must have been logged at the same epoch intervals.
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_crossval_metrics()` and
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_hyperdrive_metrics()` and
:py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
:param metrics_list: Names of the metrics to index by the best epoch indices provided. Their
values in `metrics_df` should be lists.
:param best_epochs: Dictionary of cross-validation indices to best epochs, as returned by
:param best_epochs: Dictionary of hyperdrive child runs indices to best epochs, as returned by
:py:func:`get_best_epochs()`.
:return: Dataframe with the same columns as `metrics_df` and rows specified by `metrics_list`,
containing only scalar values.
"""
best_metrics = [metrics_df.loc[metrics_list, k].apply(lambda values: values[epoch])
for k, epoch in best_epochs.items()]
if epoch is not None else metrics_df.loc[metrics_list, k] for k, epoch in best_epochs.items()]
best_metrics_df = pd.DataFrame(best_metrics).T
return best_metrics_df
def get_formatted_run_info(parent_run: Run) -> str:
"""Format Azure ML cross-validation run information as HTML.
"""Format Azure ML hyperdrive run information as HTML.
Includes details of the parent and child runs, as well as submission information.
@ -216,20 +238,30 @@ def get_formatted_run_info(parent_run: Run) -> str:
return html
def collect_class_info(metrics_df: pd.DataFrame) -> Tuple[int, List[str]]:
def get_child_runs_hyperparams(metrics_df: pd.DataFrame) -> Dict[int, Dict]:
"""
Get the class names from metrics dataframe
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_crossval_metrics()` and
Get the hyperparameters of each child run from the metrics dataframe.
:param: metrics_df: Metrics dataframe, as returned by :py:func:`collect_hyperdrive_metrics()` and
:py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
:return: Number of classes and list of class names
:return: A dictionary of hyperparameter dictionaries for the child runs.
"""
hyperparams = metrics_df[0][AMLMetricsJsonKey.HYPERPARAMS]
hyperparams_name = hyperparams[AMLMetricsJsonKey.NAME]
hyperparams_value = hyperparams[AMLMetricsJsonKey.VALUE]
num_classes_index = hyperparams_name.index(AMLMetricsJsonKey.N_CLASSES)
num_classes = int(hyperparams_value[num_classes_index])
class_names_index = hyperparams_name.index(AMLMetricsJsonKey.CLASS_NAMES)
class_names = hyperparams_value[class_names_index]
hyperparams_children = {}
for child_index in metrics_df.columns:
hyperparams = metrics_df[child_index][AMLMetricsJsonKey.HYPERPARAMS]
hyperparams_dict = dict(zip(hyperparams[AMLMetricsJsonKey.NAME], hyperparams[AMLMetricsJsonKey.VALUE]))
hyperparams_children[child_index] = hyperparams_dict
return hyperparams_children
def collect_class_info(hyperparams_children: Dict[int, Dict]) -> Tuple[int, List[str]]:
"""
Get the class names from the hyperparameters of child runs.
:param hyperparams_children: Dict of hyperparameter dicts, as returned by :py:func:`get_child_runs_hyperparams()`.
:return: Number of classes and list of class names.
"""
hyperparams_single_run = list(hyperparams_children.values())[0]
num_classes = int(hyperparams_single_run[AMLMetricsJsonKey.N_CLASSES])
class_names = hyperparams_single_run[AMLMetricsJsonKey.CLASS_NAMES]
if class_names == "None":
class_names = None
else:
@ -237,3 +269,15 @@ def collect_class_info(metrics_df: pd.DataFrame) -> Tuple[int, List[str]]:
class_names = [name.lstrip() for name in class_names[1:-1].replace("'", "").split(',')]
class_names = validate_class_names(class_names=class_names, n_classes=num_classes)
return (num_classes, list(class_names))
def get_max_epochs(hyperparams_children: Dict[int, Dict]) -> Dict[int, int]:
"""
Get the maximum number of epochs for each round from the metrics dataframe.
:param hyperparams_children: Dict of hyperparameter dicts, as returned by :py:func:`get_child_runs_hyperparams()`.
:return: Dictionary with the number of epochs in each hyperdrive run.
"""
max_epochs_dict = {}
for child_index in hyperparams_children.keys():
max_epochs_dict[child_index] = int(hyperparams_children[child_index][AMLMetricsJsonKey.MAX_EPOCHS])
return max_epochs_dict

Просмотреть файл

@ -42,23 +42,27 @@ class SlideNode:
"""Data structure class for slide nodes used by `TopBottomTilesHandler` to store top and bottom slides by
probability score. """
def __init__(self, slide_id: str, prob_score: float, true_label: int, pred_label: int) -> None:
def __init__(
self, slide_id: str, gt_prob_score: float, pred_prob_score: float, true_label: int, pred_label: int
) -> None:
"""
:param prob_score: The probability score assigned to the slide node. This scalar defines the order of the
slide_node among the nodes in the max/min heap.
:param gt_prob_score: The probability score assigned to ground truth label of slide node. This scalar defines
the order of the slide_node among the nodes in the max/min heap.
:param pred_prob_score: The probability score assigned to predicted label of slide node.
:param slide_id: The slide id in the data cohort.
:param true_label: The ground truth label of the slide node.
:param pred_label: The label predicted by the model.
"""
self.slide_id = slide_id
self.prob_score = prob_score
self.gt_prob_score = gt_prob_score
self.pred_prob_score = pred_prob_score
self.true_label = true_label
self.pred_label = pred_label
self.top_tiles: List[TileNode] = []
self.bottom_tiles: List[TileNode] = []
def __lt__(self, other: "SlideNode") -> bool:
return self.prob_score < other.prob_score
return self.gt_prob_score < other.gt_prob_score
def update_selected_tiles(self, tiles: Tensor, attn_scores: Tensor, num_top_tiles: int) -> None:
"""Update top and bottom k tiles values from a set of tiles and their assigned attention scores.
@ -77,7 +81,7 @@ class SlideNode:
def _shallow_copy(self) -> "SlideNode":
"""Returns a shallow copy of the current slide node contaning only the slide_id and its probability score."""
return SlideNode(self.slide_id, self.prob_score, self.true_label, self.pred_label)
return SlideNode(self.slide_id, self.gt_prob_score, self.pred_prob_score, self.true_label, self.pred_label)
SlideOrTileKey = Union[SlideKey, TileKey]
@ -106,23 +110,10 @@ class TilesSelector:
self.num_tiles = num_tiles
self.top_slides_heaps: SlideDict = self._initialise_slide_heaps()
self.bottom_slides_heaps: SlideDict = self._initialise_slide_heaps()
self.report_cases_slide_ids = self.init_report_cases()
def _initialise_slide_heaps(self) -> SlideDict:
return {class_id: [] for class_id in range(self.n_classes)}
def init_report_cases(self) -> Dict[str, List[str]]:
""" Initializes the report cases dictionary to store slide_ids per case.
Possible cases are set such as class 0 is considered to be the negative class.
:return: A dictionary that maps TN/FP TP_{i}/FN_{i}, i={1,..., self.n_classes+1} cases to an empty list to be
filled with corresponding slide ids.
"""
report_cases: Dict[str, List[str]] = {"TN": [], "FP": []}
report_cases.update({f"TP_{class_id}": [] for class_id in range(1, self.n_classes)})
report_cases.update({f"FN_{class_id}": [] for class_id in range(1, self.n_classes)})
return report_cases
def _clear_cached_slides_heaps(self) -> None:
self.top_slides_heaps = self._initialise_slide_heaps()
self.bottom_slides_heaps = self._initialise_slide_heaps()
@ -136,7 +127,7 @@ class TilesSelector:
) -> None:
"""Update the selected slides of a given class label on the fly by updating the content of class_slides_heap.
First, we push a shallow slide_node into the slides_heaps[gt_label]. The order in slides_heaps[gt_label] is
defined by the slide_node.prob_score that is positive in top_slides_heaps nodes and negative in
defined by the slide_node.gt_prob_score that is positive in top_slides_heaps nodes and negative in
bottom_slides_heaps nodes.
Second, we check if we exceeded self.num_top_slides to be selected.
If so, we update the slides_node top and bottom tiles only if it has been kept in the heap.
@ -147,7 +138,7 @@ class TilesSelector:
:param tiles: Tiles of a given whole slide retrieved from the current validation or test batch.
(n_tiles, channels, height, width)
:param attn_scores: The tiles attention scores to determine top and bottom tiles. (n_tiles, )
:param slide_node: A shallow version of slide_node that contains only slide_id and its assigned prob_score.
:param slide_node: A shallow version of slide_node that contains only slide_id and additional metadata.
"""
heapq.heappush(class_slides_heap, slide_node)
if len(class_slides_heap) == self.num_slides + 1:
@ -170,33 +161,39 @@ class TilesSelector:
slide_ids = [slide_id[0] for slide_id in slide_ids] # to account for repetitions in tiles pipeline
batch_size = len(batch[SlideKey.IMAGE])
for i in range(batch_size):
label = results[ResultsKey.TRUE_LABEL][i].item()
probs_gt_label = results[ResultsKey.CLASS_PROBS][:, label][i].item()
gt_label = results[ResultsKey.TRUE_LABEL][i].item()
pred_label = results[ResultsKey.PRED_LABEL][i].item()
gt_prob_score = results[ResultsKey.CLASS_PROBS][:, gt_label][i].item()
pred_prob_score = results[ResultsKey.CLASS_PROBS][:, pred_label][i].item()
tiles = batch[SlideKey.IMAGE][i]
attn_scores = results[ResultsKey.BAG_ATTN][i].squeeze(0)
pred_label = results[ResultsKey.PRED_LABEL][i].item()
self._update_label_slides(
class_slides_heap=self.top_slides_heaps[label],
tiles=tiles,
attn_scores=attn_scores,
slide_node=SlideNode(
slide_id=slide_ids[i],
prob_score=probs_gt_label,
true_label=label,
pred_label=pred_label,
),
)
self._update_label_slides(
class_slides_heap=self.bottom_slides_heaps[label],
tiles=tiles,
attn_scores=attn_scores,
slide_node=SlideNode(
slide_id=slide_ids[i],
prob_score=-probs_gt_label, # negative score for bottom slides to reverse order in max heap
true_label=label,
pred_label=pred_label,
),
)
if pred_label == gt_label:
self._update_label_slides(
class_slides_heap=self.top_slides_heaps[gt_label],
tiles=tiles,
attn_scores=attn_scores,
slide_node=SlideNode(
slide_id=slide_ids[i],
gt_prob_score=gt_prob_score,
pred_prob_score=pred_prob_score,
true_label=gt_label,
pred_label=pred_label,
),
)
elif pred_label != gt_label:
self._update_label_slides(
class_slides_heap=self.bottom_slides_heaps[gt_label],
tiles=tiles,
attn_scores=attn_scores,
slide_node=SlideNode(
slide_id=slide_ids[i],
# negative score for bottom slides to reverse order in max heap
gt_prob_score=-gt_prob_score,
pred_prob_score=pred_prob_score,
true_label=gt_label,
pred_label=pred_label,
),
)
def _shallow_copy_slides_heaps(self, slides_heaps: SlideDict) -> SlideDict:
"""Returns a shallow copy of slides heaps to be synchronised across devices.

Просмотреть файл

@ -18,6 +18,7 @@ from typing import Sequence, List, Any, Dict, Optional, Union, Tuple
from monai.data.dataset import Dataset
from monai.data.image_reader import WSIReader
from torch.utils.data import DataLoader
from health_cpath.preprocessing.loading import LoadROId
from health_cpath.utils.naming import SlideKey
from health_cpath.utils.naming import ResultsKey
@ -26,7 +27,7 @@ from health_cpath.utils.tiles_selection_utils import SlideNode
from health_cpath.datasets.panda_dataset import PandaDataset, LoadPandaROId
def load_image_dict(sample: dict, level: int, margin: int) -> Dict[SlideKey, Any]:
def load_image_dict(sample: dict, level: int, margin: int, wsi_has_mask: bool = True) -> Dict[SlideKey, Any]:
"""
Load image from metadata dictionary
:param sample: dict describing image metadata. Example:
@ -40,7 +41,8 @@ def load_image_dict(sample: dict, level: int, margin: int) -> Dict[SlideKey, Any
:param margin: margin to be included
:return: a dict containing the image data and metadata
"""
loader = LoadPandaROId(WSIReader("cuCIM"), level=level, margin=margin)
transform = LoadPandaROId if wsi_has_mask else LoadROId
loader = transform(WSIReader("cuCIM"), level=level, margin=margin)
img = loader(sample)
return img
@ -76,7 +78,7 @@ def plot_panda_data_sample(
def plot_scores_hist(
results: Dict, prob_col: str = ResultsKey.CLASS_PROBS, gt_col: str = ResultsKey.TRUE_LABEL
) -> plt.Figure:
"""Plot scores as a historgram.
"""Plot scores as a histogram.
:param results: List that contains slide_level dicts
:param prob_col: column name that contains the scores
@ -95,6 +97,18 @@ def plot_scores_hist(
return fig
def _get_histo_plot_title(case: str, slide_node: SlideNode) -> str:
"""Return the standard title for histopathology plots.
:param case: case id e.g., TP, FN, FP, TN
:param slide_node: SlideNode object that encapsulates the slide information
"""
return (
f"{case}: {slide_node.slide_id} P={slide_node.pred_prob_score:.2f} \n Predicted label: {slide_node.pred_label} "
f"True label: {slide_node.true_label}"
)
def plot_attention_tiles(
case: str, slide_node: SlideNode, top: bool, num_columns: int, figsize: Tuple[int, int]
) -> Optional[plt.Figure]:
@ -116,11 +130,7 @@ def plot_attention_tiles(
return None
fig, axs = plt.subplots(nrows=num_rows, ncols=num_columns, figsize=figsize)
fig.suptitle(
f"{case}: {slide_node.slide_id} P={abs(slide_node.prob_score):.2f} \n Predicted label: {slide_node.pred_label} "
f"True label: {slide_node.true_label}"
)
fig.suptitle(_get_histo_plot_title(case, slide_node))
for ax, tile_node in zip(axs.flat, tile_nodes):
ax.imshow(np.transpose(tile_node.data.numpy(), (1, 2, 0)), clim=(0, 255), cmap="gray")
ax.set_title("%.6f" % tile_node.attn)
@ -141,10 +151,7 @@ def plot_slide(case: str, slide_node: SlideNode, slide_image: np.ndarray, scale:
fig, ax = plt.subplots()
slide_image = slide_image.transpose(1, 2, 0)
ax.imshow(slide_image)
fig.suptitle(
f"{case}: {slide_node.slide_id} P={abs(slide_node.prob_score):.2f} \n Predicted label: {slide_node.pred_label} "
f"True label: {slide_node.true_label}"
)
fig.suptitle(_get_histo_plot_title(case, slide_node))
ax.set_axis_off()
original_size = fig.get_size_inches()
fig.set_size_inches((original_size[0] * scale, original_size[1] * scale))
@ -173,10 +180,8 @@ def plot_heatmap_overlay(
:return: matplotlib figure of the heatmap of the given tiles on slide.
"""
fig, ax = plt.subplots()
fig.suptitle(
f"{case}: {slide_node.slide_id} P={abs(slide_node.prob_score):.2f} \n Predicted label: {slide_node.pred_label} "
f"True label: {slide_node.true_label}"
)
fig.suptitle(_get_histo_plot_title(case, slide_node))
slide_image = slide_image.transpose(1, 2, 0)
ax.imshow(slide_image)
ax.set_xlim(0, slide_image.shape[1])

Просмотреть файл

@ -16,7 +16,10 @@ def image_collate(batch: List) -> Any:
for i, item in enumerate(batch):
data = item[0]
data[SlideKey.IMAGE] = torch.tensor(np.array([ix[SlideKey.IMAGE] for ix in item]))
if isinstance(data[SlideKey.IMAGE], torch.Tensor):
data[SlideKey.IMAGE] = torch.stack([ix[SlideKey.IMAGE] for ix in item], dim=0)
else:
data[SlideKey.IMAGE] = torch.tensor(np.array([ix[SlideKey.IMAGE] for ix in item]))
data[SlideKey.LABEL] = torch.tensor(data[SlideKey.LABEL])
batch[i] = data
return multibag_collate(batch)

Просмотреть файл

@ -36,7 +36,7 @@ from SSL.configs.CXR_SSL_configs import CXRImageClassifier, NIH_RSNA_SimCLR
from health_ml.runner import Runner
from health_ml.utils import AzureMLProgressBar
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME
from health_ml.utils.fixed_paths import repository_root_directory, OutputFolderForTests
from health_ml.utils.lightning_loggers import StoringLogger
@ -247,7 +247,7 @@ def test_ssl_container_rsna() -> None:
_compare_stored_metrics(runner, expected_metrics)
# Check that we are able to load the checkpoint and create classifier model
checkpoint_path = loaded_config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
checkpoint_path = loaded_config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME
model_namespace_cxr = "SSL.configs.CXRImageClassifier"
args = common_test_args + [f"--model={model_namespace_cxr}",
f"--local_datasets={str(path_to_cxr_test_dataset)}",

Просмотреть файл

@ -8,11 +8,10 @@ import logging
import shutil
import sys
import uuid
import pytest
from pathlib import Path
from typing import Generator
import pytest
# temporary workaround until these hi-ml package release
testhisto_root_dir = Path(__file__).parent
print(f"Adding {testhisto_root_dir} to sys path")
@ -29,6 +28,9 @@ for package, subpackages in packages.items():
sys.path.insert(0, str(himl_root / package / subpackage))
from health_ml.utils.fixed_paths import OutputFolderForTests # noqa: E402
from testhisto.mocks.base_data_generator import MockHistoDataType # noqa: E402
from testhisto.mocks.tiles_generator import MockPandaTilesGenerator # noqa: E402
from testhisto.mocks.slides_generator import MockPandaSlidesGenerator, TilesPositioningType # noqa: E402
def remove_and_create_folder(folder: Path) -> None:
@ -73,3 +75,47 @@ def tmp_path_to_pathmnist_dataset(tmp_path_factory: pytest.TempPathFactory) -> G
download_azure_dataset(tmp_dir, dataset_id=MockHistoDataType.PATHMNIST.value)
yield tmp_dir
shutil.rmtree(tmp_dir)
@pytest.fixture(scope="session")
def mock_panda_tiles_root_dir(
tmp_path_factory: pytest.TempPathFactory, tmp_path_to_pathmnist_dataset: Path
) -> Generator:
tmp_root_dir = tmp_path_factory.mktemp("mock_tiles")
tiles_generator = MockPandaTilesGenerator(
dest_data_path=tmp_root_dir,
src_data_path=tmp_path_to_pathmnist_dataset,
mock_type=MockHistoDataType.PATHMNIST,
n_tiles=4,
n_slides=15,
n_channels=3,
tile_size=28,
img_size=224,
)
logging.info("Generating temporary mock tiles that will be deleted at the end of the session.")
tiles_generator.generate_mock_histo_data()
yield tmp_root_dir
shutil.rmtree(tmp_root_dir)
@pytest.fixture(scope="session")
def mock_panda_slides_root_dir(
tmp_path_factory: pytest.TempPathFactory, tmp_path_to_pathmnist_dataset: Path
) -> Generator:
tmp_root_dir = tmp_path_factory.mktemp("mock_slides")
wsi_generator = MockPandaSlidesGenerator(
dest_data_path=tmp_root_dir,
src_data_path=tmp_path_to_pathmnist_dataset,
mock_type=MockHistoDataType.PATHMNIST,
n_tiles=4,
n_slides=15,
n_channels=3,
n_levels=3,
tile_size=28,
background_val=255,
tiles_pos_type=TilesPositioningType.RANDOM
)
logging.info("Generating temporary mock slides that will be deleted at the end of the session.")
wsi_generator.generate_mock_histo_data()
yield tmp_root_dir
shutil.rmtree(tmp_root_dir)

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 310 KiB

После

Ширина:  |  Высота:  |  Размер: 311 KiB

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 474 KiB

Двоичные данные
hi-ml-cpath/testhisto/test_data/histo_heatmaps/slide_0.1_TP.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 474 KiB

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 474 KiB

Двоичные данные
hi-ml-cpath/testhisto/test_data/histo_heatmaps/slide_1.2_TN.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 474 KiB

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 474 KiB

Двоичные данные
hi-ml-cpath/testhisto/test_data/histo_heatmaps/slide_2.4_FP.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 474 KiB

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 474 KiB

Двоичные данные
hi-ml-cpath/testhisto/test_data/histo_heatmaps/slide_3.6_FN.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 474 KiB

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 1.1 MiB

После

Ширина:  |  Высота:  |  Размер: 1.1 MiB

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 1.1 MiB

После

Ширина:  |  Высота:  |  Размер: 1.1 MiB

Просмотреть файл

@ -0,0 +1,179 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import pytest
import torch
from pathlib import Path
from typing import List, Optional
from unittest.mock import MagicMock, patch
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler
from health_cpath.datamodules.base_module import HistoDataModule
from health_cpath.datamodules.panda_module import PandaSlidesDataModule, PandaTilesDataModule
from health_cpath.utils.naming import ModelKey, SlideKey
from health_ml.utils.common_utils import is_gpu_available
from testhisto.utils.utils_testhisto import run_distributed
no_gpu = not is_gpu_available()
def _assert_correct_bag_sizes(datamodule: HistoDataModule, max_bag_size: int, max_bag_size_inf: Optional[int],
true_bag_sizes: List[int]) -> None:
# True bag sizes are the bag sizes that are generated by the mock data generator for a fixed seed as the tiles count
# (and therefore bag sizes) are random to reflect real data with varying number of tiles per slide.
for stage_key, bag_size in zip([m for m in ModelKey], [max_bag_size, max_bag_size_inf, max_bag_size_inf]):
assert datamodule.bag_sizes[stage_key] == bag_size
def _assert_bag_size_matching(dataloader: DataLoader, expected_bag_sizes: List[int]) -> None:
sample = next(iter(dataloader))
for i, slide in enumerate(sample[SlideKey.IMAGE]):
assert slide.shape[0] == expected_bag_sizes[i]
_assert_bag_size_matching(datamodule.train_dataloader(), [max_bag_size, max_bag_size])
expected_bag_sizes = true_bag_sizes if not max_bag_size_inf else [max_bag_size_inf, max_bag_size_inf]
_assert_bag_size_matching(datamodule.val_dataloader(), expected_bag_sizes)
_assert_bag_size_matching(datamodule.test_dataloader(), expected_bag_sizes)
def _assert_correct_batch_sizes(datamodule: HistoDataModule, batch_size: int, batch_size_inf: Optional[int]) -> None:
batch_size_inf = batch_size_inf if batch_size_inf is not None else batch_size
for stage_key, _batch_size in zip([m for m in ModelKey], [batch_size, batch_size_inf, batch_size_inf]):
assert datamodule.batch_sizes[stage_key] == _batch_size
def _assert_batch_size_matching(dataloader: DataLoader, expected_batch_size: int) -> None:
sample = next(iter(dataloader))
assert len(sample[SlideKey.IMAGE]) == expected_batch_size
_assert_batch_size_matching(datamodule.train_dataloader(), batch_size)
_assert_batch_size_matching(datamodule.val_dataloader(), batch_size_inf)
_assert_batch_size_matching(datamodule.test_dataloader(), batch_size_inf)
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
@pytest.mark.gpu
@pytest.mark.parametrize("max_bag_size, max_bag_size_inf", [(2, 0), (2, 3)])
def test_slides_datamodule_different_bag_sizes(
mock_panda_slides_root_dir: Path, max_bag_size: int, max_bag_size_inf: int
) -> None:
datamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir,
batch_size=2,
max_bag_size=max_bag_size,
max_bag_size_inf=max_bag_size_inf,
tile_size=28,
level=0,
)
# To account for the fact that slides datamodule fomats 0 to None so that it's compatible with TileOnGrid transform
max_bag_size_inf = max_bag_size_inf if max_bag_size_inf != 0 else None # type: ignore
# For slides datamodule, the true bag sizes [4, 4] are the same as requested to TileOnGrid transform
_assert_correct_bag_sizes(datamodule, max_bag_size, max_bag_size_inf, true_bag_sizes=[4, 4])
@pytest.mark.parametrize("max_bag_size, max_bag_size_inf", [(2, 0), (2, 3)])
def test_tiles_datamodule_different_bag_sizes(
mock_panda_tiles_root_dir: Path, max_bag_size: int, max_bag_size_inf: int
) -> None:
datamodule = PandaTilesDataModule(
root_path=mock_panda_tiles_root_dir,
batch_size=2,
max_bag_size=max_bag_size,
max_bag_size_inf=max_bag_size_inf,
)
# For tiles datamodule, the true bag sizes [4, 5] were generated by the mock data generator for a fixed seed 42
# If test fails, check if the mock data generator has changed and update the true bag sizes
_assert_correct_bag_sizes(datamodule, max_bag_size, max_bag_size_inf, true_bag_sizes=[4, 5])
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
@pytest.mark.gpu
@pytest.mark.parametrize("batch_size, batch_size_inf", [(2, 2), (2, 1), (2, None)])
def test_slides_datamodule_different_batch_sizes(
mock_panda_slides_root_dir: Path, batch_size: int, batch_size_inf: Optional[int],
) -> None:
datamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir,
batch_size=batch_size,
batch_size_inf=batch_size_inf,
max_bag_size=16,
max_bag_size_inf=16,
tile_size=28,
level=0,
)
_assert_correct_batch_sizes(datamodule, batch_size, batch_size_inf)
@pytest.mark.parametrize("batch_size, batch_size_inf", [(2, 2), (2, 1), (2, None)])
def test_tiles_datamodule_different_batch_sizes(
mock_panda_tiles_root_dir: Path, batch_size: int, batch_size_inf: Optional[int],
) -> None:
datamodule = PandaTilesDataModule(
root_path=mock_panda_tiles_root_dir,
batch_size=batch_size,
batch_size_inf=batch_size_inf,
max_bag_size=16,
max_bag_size_inf=16,
)
_assert_correct_batch_sizes(datamodule, batch_size, batch_size_inf)
def _validate_sampler_type(datamodule: HistoDataModule, stages: List[ModelKey], expected_none: bool) -> None:
expected_sampler_types = {
ModelKey.TRAIN: RandomSampler if expected_none else DistributedSampler,
ModelKey.VAL: SequentialSampler,
ModelKey.TEST: SequentialSampler,
}
for stage in stages:
datamodule_sampler = datamodule._get_ddp_sampler(getattr(datamodule, f'{stage}_dataset'), stage)
assert (datamodule_sampler is None) == expected_none
dataloader = getattr(datamodule, f'{stage.value}_dataloader')()
assert isinstance(dataloader.sampler, expected_sampler_types[stage])
def _test_datamodule_pl_ddp_sampler_true(
datamodule: HistoDataModule, rank: int = 0, world_size: int = 1, device: str = "cpu"
) -> None:
datamodule.setup()
_validate_sampler_type(datamodule, [ModelKey.TRAIN, ModelKey.VAL, ModelKey.TEST], expected_none=True)
def _test_datamodule_pl_ddp_sampler_false(
datamodule: HistoDataModule, rank: int = 0, world_size: int = 1, device: str = "cpu"
) -> None:
datamodule.setup()
_validate_sampler_type(datamodule, [ModelKey.VAL, ModelKey.TEST], expected_none=True)
_validate_sampler_type(datamodule, [ModelKey.TRAIN], expected_none=False)
@pytest.mark.skipif(not torch.distributed.is_available(), reason="PyTorch distributed unavailable")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
@pytest.mark.gpu
def test_slides_datamodule_pl_replace_sampler_ddp(mock_panda_slides_root_dir: Path) -> None:
slides_datamodule = PandaSlidesDataModule(root_path=mock_panda_slides_root_dir,
pl_replace_sampler_ddp=True,
seed=42)
run_distributed(_test_datamodule_pl_ddp_sampler_true, [slides_datamodule], world_size=2)
slides_datamodule.pl_replace_sampler_ddp = False
run_distributed(_test_datamodule_pl_ddp_sampler_false, [slides_datamodule], world_size=2)
@pytest.mark.skip(reason="Test fails with Broken Pipe Error. To be fixed.")
@pytest.mark.skipif(not torch.distributed.is_available(), reason="PyTorch distributed unavailable")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
@pytest.mark.gpu
def test_tiles_datamodule_pl_replace_sampler_ddp(mock_panda_tiles_root_dir: Path) -> None:
tiles_datamodule = PandaTilesDataModule(root_path=mock_panda_tiles_root_dir, seed=42, pl_replace_sampler_ddp=True)
run_distributed(_test_datamodule_pl_ddp_sampler_true, [tiles_datamodule], world_size=2)
tiles_datamodule.pl_replace_sampler_ddp = False
run_distributed(_test_datamodule_pl_ddp_sampler_false, [tiles_datamodule], world_size=2)
def test_assertion_error_missing_seed(mock_panda_slides_root_dir: Path) -> None:
with pytest.raises(AssertionError, match="seed must be set when using distributed training for reproducibility"):
with patch("torch.distributed.is_initialized", return_value=True):
with patch("torch.distributed.get_world_size", return_value=2):
slides_datamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir, pl_replace_sampler_ddp=False
)
slides_datamodule._get_ddp_sampler(MagicMock(), ModelKey.TRAIN)

Просмотреть файл

@ -1,17 +1,17 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import shutil
from typing import Generator, Dict, Callable, Union, Tuple
import pytest
import logging
import numpy as np
import torch
from pathlib import Path
from monai.transforms import RandFlipd
from typing import Generator, Dict, Callable, Union, Tuple
from torch.utils.data import DataLoader
from health_ml.utils.common_utils import is_gpu_available
from health_cpath.datamodules.base_module import SlidesDataModule
@ -29,7 +29,7 @@ no_gpu = not is_gpu_available()
@pytest.fixture(scope="session")
def mock_panda_slides_root_dir(
def mock_panda_slides_root_dir_diagonal(
tmp_path_factory: pytest.TempPathFactory, tmp_path_to_pathmnist_dataset: Path
) -> Generator:
tmp_root_dir = tmp_path_factory.mktemp("mock_wsi")
@ -38,7 +38,7 @@ def mock_panda_slides_root_dir(
src_data_path=tmp_path_to_pathmnist_dataset,
mock_type=MockHistoDataType.PATHMNIST,
n_tiles=1,
n_slides=10,
n_slides=16,
n_repeat_diag=4,
n_repeat_tile=2,
n_channels=3,
@ -83,7 +83,7 @@ def get_original_tile(mock_dir: Path, wsi_id: str) -> np.ndarray:
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
@pytest.mark.gpu
def test_tiling_on_the_fly(mock_panda_slides_root_dir: Path) -> None:
def test_tiling_on_the_fly(mock_panda_slides_root_dir_diagonal: Path) -> None:
batch_size = 1
tile_count = 16
tile_size = 28
@ -91,7 +91,7 @@ def test_tiling_on_the_fly(mock_panda_slides_root_dir: Path) -> None:
channels = 3
assert_batch_index = 0
datamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir,
root_path=mock_panda_slides_root_dir_diagonal,
batch_size=batch_size,
max_bag_size=tile_count,
tile_size=tile_size,
@ -105,14 +105,14 @@ def test_tiling_on_the_fly(mock_panda_slides_root_dir: Path) -> None:
assert tiles[assert_batch_index].shape == (tile_count, channels, tile_size, tile_size)
# check tiling on the fly
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
for i in range(tile_count):
assert (original_tile == tiles[assert_batch_index][i].numpy()).all()
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
@pytest.mark.gpu
def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir: Path) -> None:
def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir_diagonal: Path) -> None:
batch_size = 1
tile_count = None
tile_size = 28
@ -120,7 +120,7 @@ def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir: Path) -> No
assert_batch_index = 0
min_expected_tile_count = 16
datamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir,
root_path=mock_panda_slides_root_dir_diagonal,
batch_size=batch_size,
max_bag_size=tile_count,
tile_size=tile_size,
@ -135,14 +135,14 @@ def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir: Path) -> No
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
@pytest.mark.gpu
@pytest.mark.parametrize("level", [0, 1, 2])
def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir: Path) -> None:
def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir_diagonal: Path) -> None:
batch_size = 1
tile_count = 16
channels = 3
tile_size = 28 // 2 ** level
assert_batch_index = 0
datamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir,
root_path=mock_panda_slides_root_dir_diagonal,
batch_size=batch_size,
max_bag_size=tile_count,
tile_size=tile_size,
@ -155,7 +155,7 @@ def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir: Path) -
assert tiles[assert_batch_index].shape == (tile_count, channels, tile_size, tile_size)
# check tiling on the fly at different resolutions
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
for i in range(tile_count):
# multi resolution mock data has been created via 2 factor downsampling
assert (original_tile[:, :: 2 ** level, :: 2 ** level] == tiles[assert_batch_index][i].numpy()).all()
@ -164,7 +164,7 @@ def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir: Path) -
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
@pytest.mark.gpu
@pytest.mark.parametrize("batch_size", [1, 2])
def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir: Path) -> None:
def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir_diagonal: Path) -> None:
tile_size = 28
level = 0
step = 14
@ -172,7 +172,7 @@ def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir: Path) ->
min_expected_tile_count = 32
assert_batch_index = 0
datamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir,
root_path=mock_panda_slides_root_dir_diagonal,
max_bag_size=None,
batch_size=batch_size,
tile_size=tile_size,
@ -184,7 +184,7 @@ def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir: Path) ->
tiles, wsi_id = sample[SlideKey.IMAGE], sample[SlideKey.SLIDE_ID][assert_batch_index]
assert tiles[assert_batch_index].shape[0] >= min_expected_tile_count
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
tile_matches = 0
for _, tile in enumerate(tiles[assert_batch_index]):
tile_matches += int((tile.numpy() == original_tile).all())
@ -193,12 +193,12 @@ def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir: Path) ->
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
@pytest.mark.gpu
def test_train_test_transforms(mock_panda_slides_root_dir: Path) -> None:
def test_train_test_transforms(mock_panda_slides_root_dir_diagonal: Path) -> None:
def get_transforms_dict() -> Dict[ModelKey, Union[Callable, None]]:
train_transform = RandFlipd(keys=[SlideKey.IMAGE], spatial_axis=0, prob=1.0)
return {ModelKey.TRAIN: train_transform, ModelKey.VAL: None, ModelKey.TEST: None} # type: ignore
def retrieve_tiles(dataloader: torch.utils.data.DataLoader) -> Dict[str, torch.Tensor]:
def retrieve_tiles(dataloader: DataLoader) -> Dict[str, torch.Tensor]:
tiles_dict = {}
assert_batch_index = 0
for sample in dataloader:
@ -211,7 +211,7 @@ def test_train_test_transforms(mock_panda_slides_root_dir: Path) -> None:
tile_size = 28
level = 0
flipdatamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir,
root_path=mock_panda_slides_root_dir_diagonal,
batch_size=batch_size,
max_bag_size=tile_count,
max_bag_size_inf=0,
@ -224,20 +224,20 @@ def test_train_test_transforms(mock_panda_slides_root_dir: Path) -> None:
flip_test_tiles = retrieve_tiles(flipdatamodule.test_dataloader())
for wsi_id in flip_train_tiles.keys():
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
# the first dimension is the channel, flipping happened on the horizontal axis of the image
transformed_original_tile = np.flip(original_tile, axis=1)
for tile in flip_train_tiles[wsi_id]:
assert (tile.numpy() == transformed_original_tile).all()
for wsi_id in flip_val_tiles.keys():
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
for tile in flip_val_tiles[wsi_id]:
# no transformation has been applied to val tiles
assert (tile.numpy() == original_tile).all()
for wsi_id in flip_test_tiles.keys():
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
for tile in flip_test_tiles[wsi_id]:
# no transformation has been applied to test tiles
assert (tile.numpy() == original_tile).all()
@ -251,7 +251,6 @@ class MockPandaSlidesDataModule(SlidesDataModule):
"""
def get_splits(self) -> Tuple[PandaDataset, PandaDataset, PandaDataset]:
return (PandaDataset(self.root_path), PandaDataset(self.root_path), PandaDataset(self.root_path))
@ -278,7 +277,7 @@ def test_whole_slide_inference(batch_size: int, mock_panda_slides_root_with_diff
tiles = sample[SlideKey.IMAGE]
assert tiles[assert_batch_index].shape[0] == tile_count
def assert_whole_slide_inference_with_all_tiles(dataloader: torch.utils.data.DataLoader) -> None:
def assert_whole_slide_inference_with_all_tiles(dataloader: DataLoader) -> None:
for i, sample in enumerate(dataloader):
tiles = sample[SlideKey.IMAGE]
assert tiles[assert_batch_index].shape[0] == n_tiles_list[i * batch_size]

Просмотреть файл

@ -14,6 +14,7 @@ import torch
from pytorch_lightning import seed_everything
from health_cpath.configs.classification.DeepSMILESlidesPandaBenchmark import DeepSMILESlidesPandaBenchmark
from health_cpath.datasets.panda_dataset import PandaDataset
from health_cpath.utils.naming import SlideKey
from testhisto.mocks.base_data_generator import MockHistoDataType
from testhisto.mocks.slides_generator import MockPandaSlidesGenerator, TilesPositioningType
@ -89,3 +90,21 @@ def test_panda_reproducibility(tmp_path: Path) -> None:
# When using a fixed see, all 3 dataloaders should return idential items. Validation and test dataloader
# are at present not randomized, but checking those as well just in case.
test_data_items_are_equal(["train_dataloader", "val_dataloader", "test_dataloader"])
def test_validate_columns(tmp_path: Path) -> None:
_ = MockPandaSlidesGenerator(
dest_data_path=tmp_path,
mock_type=MockHistoDataType.FAKE,
n_tiles=4,
n_slides=10,
n_channels=3,
n_levels=3,
tile_size=28,
background_val=255,
tiles_pos_type=TilesPositioningType.RANDOM,
)
usecols = [PandaDataset.SLIDE_ID_COLUMN, PandaDataset.MASK_COLUMN]
with pytest.raises(ValueError, match=r"Expected columns"):
_ = PandaDataset(root=tmp_path, dataframe_kwargs={"usecols": usecols})
_ = PandaDataset(root=tmp_path, dataframe_kwargs={"usecols": usecols + [PandaDataset.METADATA_COLUMNS[1]]})

Просмотреть файл

@ -0,0 +1,47 @@
from pathlib import Path
import pytest
import pandas as pd
from health_cpath.datasets.panda_tiles_dataset import PandaTilesDataset
from health_cpath.utils.naming import TileKey
from testhisto.mocks.base_data_generator import MockHistoDataType
from testhisto.mocks.tiles_generator import MockPandaTilesGenerator
@pytest.mark.parametrize("tiling_version", [0, 1])
def test_panda_tiles_dataset(tiling_version: int, tmp_path: Path) -> None:
"""Test the PandaTilesDataset class.
:param tiling_version: The version of the tiles dataset, defaults to 0. This is used to support both the old
and new tiling scheme where coordinates are stored as tile_x and tile_y in v0 and as tile_left and tile_top
in v1.
:param tmp_path: The temporary path where to store the mock dataset.
"""
_ = MockPandaTilesGenerator(
dest_data_path=tmp_path,
mock_type=MockHistoDataType.FAKE,
n_tiles=4,
n_slides=10,
n_channels=3,
tile_size=28,
img_size=224,
tiling_version=tiling_version,
)
base_df = pd.read_csv(tmp_path / PandaTilesDataset.DEFAULT_CSV_FILENAME).set_index(PandaTilesDataset.TILE_ID_COLUMN)
dataset = PandaTilesDataset(root=tmp_path)
coordinates_columns_v0 = {PandaTilesDataset.TILE_X_COLUMN, PandaTilesDataset.TILE_Y_COLUMN}
coordinates_columns_v1 = {TileKey.TILE_LEFT, TileKey.TILE_TOP}
dataset_columns = set(dataset.dataset_df.columns)
base_df_columns = set(base_df.columns)
assert coordinates_columns_v0.issubset(dataset_columns) # v0 columns are always present
if tiling_version == 0:
assert coordinates_columns_v0.issubset(dataset_columns)
assert not coordinates_columns_v1.issubset(dataset_columns)
assert base_df_columns == dataset_columns
elif tiling_version == 1:
assert coordinates_columns_v1.issubset(dataset_columns)
assert not coordinates_columns_v0.issubset(base_df_columns)
assert dataset_columns == base_df_columns.union(coordinates_columns_v0)

Просмотреть файл

@ -8,23 +8,22 @@ from typing import Any, Optional, Set
from health_ml.networks.layers.attention_layers import AttentionLayer
from health_cpath.configs.classification.DeepSMILEPanda import DeepSMILESlidesPanda, DeepSMILETilesPanda
from health_cpath.datasets.panda_dataset import PandaDataset
from health_cpath.models.encoders import ImageNetEncoder
from health_cpath.models.encoders import Resnet18
from health_cpath.datamodules.base_module import CacheMode, CacheLocation
from health_cpath.utils.naming import PlotOption
class MockDeepSMILETilesPanda(DeepSMILETilesPanda):
def __init__(self, tmp_path: Path, **kwargs: Any) -> None:
def __init__(self, tmp_path: Path, analyse_loss: bool = False, **kwargs: Any) -> None:
default_kwargs = dict(
# Model parameters:
pool_type=AttentionLayer.__name__,
pool_hidden_dim=16,
num_transformer_pool_layers=1,
num_transformer_pool_heads=1,
is_finetune=False,
class_names=["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"],
# Encoder parameters
encoder_type=ImageNetEncoder.__name__,
encoder_type=Resnet18.__name__,
tile_size=28,
# Data Module parameters
batch_size=2,
@ -38,6 +37,8 @@ class MockDeepSMILETilesPanda(DeepSMILETilesPanda):
# declared in TrainerParams:
max_epochs=2,
crossval_count=1,
ssl_checkpoint=None,
analyse_loss=analyse_loss,
)
default_kwargs.update(kwargs)
super().__init__(**default_kwargs)
@ -62,10 +63,9 @@ class MockDeepSMILESlidesPanda(DeepSMILESlidesPanda):
pool_hidden_dim=16,
num_transformer_pool_layers=1,
num_transformer_pool_heads=1,
is_finetune=True,
class_names=["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"],
# Encoder parameters
encoder_type=ImageNetEncoder.__name__,
encoder_type=Resnet18.__name__,
tile_size=28,
# Data Module parameters
batch_size=2,

Просмотреть файл

@ -79,9 +79,12 @@ class MockPandaSlidesGenerator(MockHistoDataGenerator):
def create_mock_metadata_dataframe(self) -> pd.DataFrame:
"""Create a mock dataframe with random metadata."""
isup_grades = np.tile(list(self.ISUP_GRADE_MAPPING.keys()), self.n_slides // PANDA_N_CLASSES + 1,)
mock_metadata: dict = {col: [] for col in [PandaDataset.SLIDE_ID_COLUMN, *PandaDataset.METADATA_COLUMNS]}
mock_metadata: dict = {
col: [] for col in [PandaDataset.SLIDE_ID_COLUMN, PandaDataset.MASK_COLUMN, *PandaDataset.METADATA_COLUMNS]
}
for slide_id in range(self.n_slides):
mock_metadata[PandaDataset.SLIDE_ID_COLUMN].append(f"_{slide_id}")
mock_metadata[PandaDataset.MASK_COLUMN].append(f"_{slide_id}_mask")
mock_metadata[self.DATA_PROVIDER].append(np.random.choice(self.DATA_PROVIDERS_VALUES))
mock_metadata[self.ISUP_GRADE].append(isup_grades[slide_id])
mock_metadata[self.GLEASON_SCORE].append(np.random.choice(self.ISUP_GRADE_MAPPING[isup_grades[slide_id]]))

Просмотреть файл

@ -10,18 +10,23 @@ import torch
from torchvision.utils import save_image
from health_cpath.datasets.panda_tiles_dataset import PandaTilesDataset
from health_cpath.utils.naming import TileKey
from testhisto.mocks.base_data_generator import MockHistoDataGenerator, MockHistoDataType, PANDA_N_CLASSES
class MockPandaTilesGenerator(MockHistoDataGenerator):
"""Generator class to create mock tiles dataset on the fly. The tiles are positioned randomly in a wsi grid."""
def __init__(self, img_size: int = 224, **kwargs: Any) -> None:
def __init__(self, img_size: int = 224, tiling_version: int = 0, **kwargs: Any) -> None:
"""
:param img_size: The whole slide image resolution, defaults to 224.
:param tiling_version: The version of the tiles dataset, defaults to 0. This is used to support both the old
and new tiling scheme where coordinates are stored as tile_x and tile_y in v0 and as tile_left and tile_top
in v1.
:param kwargs: Same params passed to MockHistoDataGenerator.
"""
self.img_size = img_size
self.tiling_version = tiling_version
super().__init__(**kwargs)
def validate(self) -> None:
@ -32,16 +37,19 @@ class MockPandaTilesGenerator(MockHistoDataGenerator):
f"The image of size {self.img_size} can't contain more than {(self.img_size // self.tile_size)**2} tiles."
f"Choose a number of tiles 0 < n_tiles <= {(self.img_size // self.tile_size)**2} "
)
assert self.tiling_version in [0, 1], f"Tiling version should be 0 or 1, got {self.tiling_version}"
def create_mock_metadata_dataframe(self) -> pd.DataFrame:
"""Create a mock dataframe with random metadata."""
x_column = TileKey.TILE_LEFT if self.tiling_version == 1 else PandaTilesDataset.TILE_X_COLUMN
y_column = TileKey.TILE_TOP if self.tiling_version == 1 else PandaTilesDataset.TILE_Y_COLUMN
csv_columns = [
PandaTilesDataset.SLIDE_ID_COLUMN,
PandaTilesDataset.TILE_ID_COLUMN,
PandaTilesDataset.IMAGE_COLUMN,
self.MASK_COLUMN,
PandaTilesDataset.TILE_X_COLUMN,
PandaTilesDataset.TILE_Y_COLUMN,
x_column,
y_column,
self.OCCUPANCY,
self.DATA_PROVIDER,
self.ISUP_GRADE,
@ -81,8 +89,9 @@ class MockPandaTilesGenerator(MockHistoDataGenerator):
f"_{slide_id}/train_images/{tile_x}x_{tile_y}y.png"
)
mock_metadata[self.MASK_COLUMN].append(f"_{slide_id}/train_label_masks/{tile_x}x_{tile_y}y_mask.png")
mock_metadata[PandaTilesDataset.TILE_X_COLUMN].append(tile_x)
mock_metadata[PandaTilesDataset.TILE_Y_COLUMN].append(tile_y)
mock_metadata[x_column].append(tile_x)
mock_metadata[y_column].append(tile_y)
mock_metadata[self.OCCUPANCY].append(1.0)
mock_metadata[self.DATA_PROVIDER].append(data_provider)
mock_metadata[self.ISUP_GRADE].append(isup_grade)
@ -114,7 +123,7 @@ class MockPandaTilesGenerator(MockHistoDataGenerator):
raise NotImplementedError
tile_filename = self.dest_data_path / row[PandaTilesDataset.IMAGE_COLUMN]
save_image(tile.float(), tile_filename)
save_image(tile.div(255.), tile_filename)
random_mask = torch.randint(0, 256, size=(self.n_channels, self.tile_size, self.tile_size))
mask_filename = self.dest_data_path / row[self.MASK_COLUMN]
save_image(random_mask.float(), mask_filename)
save_image(random_mask.div(255), mask_filename)

Просмотреть файл

@ -2,45 +2,61 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from copy import deepcopy
import os
import shutil
from pytorch_lightning import Trainer
import torch
import pytest
from pathlib import Path
from unittest.mock import MagicMock, patch
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type
from typing import Any, Callable, Dict, Iterable, List, Optional, Type
from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose
from torch.utils.data._utils.collate import default_collate
from health_cpath.configs.classification.DeepSMILESlidesPandaBenchmark import SlidesPandaSSLMILBenchmark
from health_cpath.datamodules.panda_module import PandaTilesDataModule
from health_ml.networks.layers.attention_layers import AttentionLayer
from health_cpath.configs.classification.BaseMIL import BaseMILTiles
from health_ml.networks.layers.attention_layers import AttentionLayer, TransformerPoolingBenchmark
from health_cpath.configs.classification.BaseMIL import BaseMIL, BaseMILTiles
from health_cpath.configs.classification.DeepSMILECrck import DeepSMILECrck
from health_cpath.configs.classification.DeepSMILEPanda import BaseDeepSMILEPanda, DeepSMILETilesPanda
from health_cpath.configs.classification.DeepSMILECrck import DeepSMILECrck, TcgaCrckSSLMIL
from health_cpath.configs.classification.DeepSMILEPanda import (
BaseDeepSMILEPanda, DeepSMILETilesPanda, SlidesPandaSSLMIL, TilesPandaSSLMIL
)
from health_cpath.datamodules.base_module import HistoDataModule, TilesDataModule
from health_cpath.datasets.base_dataset import DEFAULT_LABEL_COLUMN, TilesDataset
from health_cpath.datasets.default_paths import PANDA_5X_TILES_DATASET_ID, TCGA_CRCK_DATASET_DIR
from health_cpath.models.deepmil import BaseDeepMILModule, TilesDeepMILModule
from health_cpath.models.encoders import IdentityEncoder, ImageNetEncoder, TileEncoder
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams
from health_cpath.utils.naming import MetricsKey, ResultsKey
from health_cpath.models.encoders import IdentityEncoder, ImageNetEncoder, Resnet18, TileEncoder
from health_cpath.utils.deepmil_utils import ClassifierParams, EncoderParams, PoolingParams
from health_cpath.utils.naming import DeepMILSubmodules, MetricsKey, ResultsKey
from testhisto.mocks.base_data_generator import MockHistoDataType
from testhisto.mocks.slides_generator import MockPandaSlidesGenerator, TilesPositioningType
from testhisto.mocks.tiles_generator import MockPandaTilesGenerator
from testhisto.mocks.container import MockDeepSMILETilesPanda, MockDeepSMILESlidesPanda
from health_ml.utils.common_utils import is_gpu_available
from health_ml.utils.checkpoint_utils import CheckpointParser
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_crck_4ws, innereye_ssl_checkpoint_binary
from testhisto.models.test_encoders import TEST_SSL_RUN_ID
no_gpu = not is_gpu_available()
def get_supervised_imagenet_encoder_params() -> EncoderParams:
return EncoderParams(encoder_type=ImageNetEncoder.__name__)
def get_supervised_imagenet_encoder_params(tune_encoder: bool = True) -> EncoderParams:
return EncoderParams(encoder_type=Resnet18.__name__, tune_encoder=tune_encoder)
def get_attention_pooling_layer_params(pool_out_dim: int = 1) -> PoolingParams:
return PoolingParams(pool_type=AttentionLayer.__name__, pool_out_dim=pool_out_dim, pool_hidden_dim=5)
def get_attention_pooling_layer_params(pool_out_dim: int = 1, tune_pooling: bool = True) -> PoolingParams:
return PoolingParams(pool_type=AttentionLayer.__name__, pool_out_dim=pool_out_dim, pool_hidden_dim=5,
tune_pooling=tune_pooling)
def get_transformer_pooling_layer_params(num_layers: int, num_heads: int,
hidden_dim: int, transformer_dropout: float) -> PoolingParams:
return PoolingParams(pool_type=TransformerPoolingBenchmark.__name__,
num_transformer_pool_layers=num_layers,
num_transformer_pool_heads=num_heads,
pool_hidden_dim=hidden_dim,
transformer_dropout=transformer_dropout)
def _test_lightningmodule(
@ -56,7 +72,7 @@ def _test_lightningmodule(
module = TilesDeepMILModule(
label_column=DEFAULT_LABEL_COLUMN,
n_classes=n_classes,
dropout_rate=dropout_rate,
classifier_params=ClassifierParams(dropout_rate=dropout_rate),
encoder_params=get_supervised_imagenet_encoder_params(),
pooling_params=get_attention_pooling_layer_params(pool_out_dim)
)
@ -114,58 +130,17 @@ def _test_lightningmodule(
# A NaN value could result due to a division-by-zero error
assert torch.all(score[~score.isnan()] >= -1)
assert torch.all(score[~score.isnan()] <= 1)
elif metric_name == MetricsKey.AVERAGE_PRECISION:
assert torch.all(score[~score.isnan()] >= 0)
assert torch.all(score[~score.isnan()] <= 1)
else:
assert torch.all(score >= 0)
assert torch.all(score <= 1)
@pytest.fixture(scope="session")
def mock_panda_tiles_root_dir(
tmp_path_factory: pytest.TempPathFactory, tmp_path_to_pathmnist_dataset: Path
) -> Generator:
tmp_root_dir = tmp_path_factory.mktemp("mock_tiles")
tiles_generator = MockPandaTilesGenerator(
dest_data_path=tmp_root_dir,
src_data_path=tmp_path_to_pathmnist_dataset,
mock_type=MockHistoDataType.PATHMNIST,
n_tiles=4,
n_slides=10,
n_channels=3,
tile_size=28,
img_size=224,
)
logging.info("Generating temporary mock tiles that will be deleted at the end of the session.")
tiles_generator.generate_mock_histo_data()
yield tmp_root_dir
shutil.rmtree(tmp_root_dir)
@pytest.fixture(scope="session")
def mock_panda_slides_root_dir(
tmp_path_factory: pytest.TempPathFactory, tmp_path_to_pathmnist_dataset: Path
) -> Generator:
tmp_root_dir = tmp_path_factory.mktemp("mock_slides")
wsi_generator = MockPandaSlidesGenerator(
dest_data_path=tmp_root_dir,
src_data_path=tmp_path_to_pathmnist_dataset,
mock_type=MockHistoDataType.PATHMNIST,
n_tiles=4,
n_slides=10,
n_channels=3,
n_levels=3,
tile_size=28,
background_val=255,
tiles_pos_type=TilesPositioningType.RANDOM
)
logging.info("Generating temporary mock slides that will be deleted at the end of the session.")
wsi_generator.generate_mock_histo_data()
yield tmp_root_dir
shutil.rmtree(tmp_root_dir)
@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("max_bag_size", [1, 5])
@pytest.mark.parametrize("pool_out_dim", [1, 6])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule_attention(
@ -203,9 +178,7 @@ def add_callback(fn: Callable, callback: Callable) -> Callable:
def test_metrics(n_classes: int) -> None:
input_dim = (128,)
def _mock_get_encoder( # type: ignore
self, ssl_ckpt_run_id: Optional[str], outputs_folder: Optional[Path]
) -> TileEncoder:
def _mock_get_encoder(self, outputs_folder: Optional[Path]) -> TileEncoder: # type: ignore
return IdentityEncoder(input_dim=input_dim)
with patch("health_cpath.models.deepmil.EncoderParams.get_encoder", new=_mock_get_encoder):
@ -274,7 +247,7 @@ def assert_train_step(module: BaseDeepMILModule, data_module: HistoDataModule, u
train_data_loader = data_module.train_dataloader()
for batch_idx, batch in enumerate(train_data_loader):
batch = move_batch_to_expected_device(batch, use_gpu)
loss = module.training_step(batch, batch_idx)
loss = module.training_step(batch, batch_idx)[ResultsKey.LOSS]
loss.retain_grad()
loss.backward()
assert loss.grad is not None
@ -327,9 +300,10 @@ def test_container(container_type: Type[BaseMILTiles], use_gpu: bool) -> None:
container = container_type()
container.setup()
container.batch_size = 10
container.batch_size_inf = 10
data_module: TilesDataModule = container.get_data_module() # type: ignore
data_module.max_bag_size = 10
module = container.create_model()
module.outputs_handler = MagicMock()
@ -361,7 +335,7 @@ def _test_mock_panda_container(use_gpu: bool, mock_container: BaseDeepSMILEPanda
def test_mock_tiles_panda_container_cpu(mock_panda_tiles_root_dir: Path) -> None:
_test_mock_panda_container(use_gpu=False, mock_container=MockDeepSMILETilesPanda,
_test_mock_panda_container(use_gpu=False, mock_container=MockDeepSMILETilesPanda, # type: ignore
tmp_path=mock_panda_tiles_root_dir)
@ -425,6 +399,268 @@ def test_class_weights_multiclass() -> None:
assert allclose(loss_weighted, loss_unweighted)
def test_wrong_tuning_options() -> None:
with pytest.raises(ValueError,
match=r"At least one of the encoder, pooling or classifier should be fine tuned"):
_ = MockDeepSMILETilesPanda(
tmp_path=Path("foo"),
tune_encoder=False,
tune_pooling=False,
tune_classifier=False
)
def _get_datamodule(tmp_path: Path) -> PandaTilesDataModule:
tiles_generator = MockPandaTilesGenerator(
dest_data_path=tmp_path,
mock_type=MockHistoDataType.FAKE,
n_tiles=4,
n_slides=10,
n_channels=3,
tile_size=28,
img_size=224,
)
tiles_generator.generate_mock_histo_data()
datamodule = PandaTilesDataModule(root_path=tmp_path, batch_size=2, max_bag_size=4)
return datamodule
@pytest.mark.parametrize("tune_classifier", [False, True])
@pytest.mark.parametrize("tune_pooling", [False, True])
@pytest.mark.parametrize("tune_encoder", [False, True])
def test_finetuning_options(
tune_encoder: bool, tune_pooling: bool, tune_classifier: bool, tmp_path: Path
) -> None:
module = TilesDeepMILModule(
n_classes=1,
label_column=DEFAULT_LABEL_COLUMN,
encoder_params=get_supervised_imagenet_encoder_params(tune_encoder=tune_encoder),
pooling_params=get_attention_pooling_layer_params(pool_out_dim=1, tune_pooling=tune_pooling),
classifier_params=ClassifierParams(tune_classifier=tune_classifier),
)
assert module.encoder_params.tune_encoder == tune_encoder
assert module.pooling_params.tune_pooling == tune_pooling
assert module.classifier_params.tune_classifier == tune_classifier
for params in module.encoder.parameters():
assert params.requires_grad == tune_encoder
for params in module.aggregation_fn.parameters():
assert params.requires_grad == tune_pooling
for params in module.classifier_fn.parameters():
assert params.requires_grad == tune_classifier
instances = torch.randn(4, 3, 224, 224)
def _assert_existing_gradients_fn(tensor: Tensor, tuning_flag: bool) -> None:
assert tensor.requires_grad == tuning_flag
if tuning_flag:
assert tensor.grad_fn is not None
else:
assert tensor.grad_fn is None
with torch.enable_grad():
instance_features = module.get_instance_features(instances)
_assert_existing_gradients_fn(instance_features, tuning_flag=tune_encoder)
assert module.encoder.training == tune_encoder
attentions, bag_features = module.get_attentions_and_bag_features(instances)
_assert_existing_gradients_fn(attentions, tuning_flag=tune_pooling)
_assert_existing_gradients_fn(bag_features, tuning_flag=tune_pooling)
assert module.aggregation_fn.training == tune_pooling
bag_logit = module.get_bag_logit(bag_features)
# bag_logit gradients are required for pooling layer gradients computation, hence
# "tuning_flag=tune_classifier or tune_pooling"
_assert_existing_gradients_fn(bag_logit, tuning_flag=tune_classifier or tune_pooling)
assert module.classifier_fn.training == tune_classifier
@pytest.mark.parametrize("tune_classifier", [False, True])
@pytest.mark.parametrize("tune_pooling", [False, True])
@pytest.mark.parametrize("tune_encoder", [False, True])
def test_training_with_different_finetuning_options(
tune_encoder: bool, tune_pooling: bool, tune_classifier: bool, tmp_path: Path
) -> None:
if any([tune_encoder, tune_pooling, tune_classifier]):
module = TilesDeepMILModule(
n_classes=6,
label_column=MockPandaTilesGenerator.ISUP_GRADE,
encoder_params=get_supervised_imagenet_encoder_params(tune_encoder=tune_encoder),
pooling_params=get_attention_pooling_layer_params(pool_out_dim=1, tune_pooling=tune_pooling),
classifier_params=ClassifierParams(tune_classifier=tune_classifier),
)
def _assert_existing_gradients(module: nn.Module, tuning_flag: bool) -> None:
for param in module.parameters():
if tuning_flag:
assert param.grad is not None
else:
assert param.grad is None
with patch.object(module, "validation_step"):
trainer = Trainer(max_epochs=1)
trainer.fit(module, datamodule=_get_datamodule(tmp_path))
_assert_existing_gradients(module.classifier_fn, tuning_flag=tune_classifier)
_assert_existing_gradients(module.aggregation_fn, tuning_flag=tune_pooling)
_assert_existing_gradients(module.encoder, tuning_flag=tune_encoder)
def test_missing_src_checkpoint_with_pretraining_flags() -> None:
with pytest.raises(ValueError, match=r"You need to specify a source checkpoint, to use a pretrained"):
_ = MockDeepSMILETilesPanda(tmp_path=Path("foo"), pretrained_classifier=True, pretrained_encoder=True)
@pytest.mark.parametrize("pretrained_classifier", [False, True])
@pytest.mark.parametrize("pretrained_pooling", [False, True])
@pytest.mark.parametrize("pretrained_encoder", [False, True])
def test_init_weights_options(pretrained_encoder: bool, pretrained_pooling: bool, pretrained_classifier: bool) -> None:
n_classes = 1
module = BaseDeepMILModule(
n_classes=n_classes,
label_column=DEFAULT_LABEL_COLUMN,
encoder_params=get_supervised_imagenet_encoder_params(),
pooling_params=get_attention_pooling_layer_params(pool_out_dim=1),
)
module.encoder_params.pretrained_encoder = pretrained_encoder
module.pooling_params.pretrained_pooling = pretrained_pooling
module.classifier_params.pretrained_classifier = pretrained_classifier
with patch.object(module, "load_from_checkpoint") as mock_load_from_checkpoint:
with patch.object(module, "copy_weights") as mock_copy_weights:
mock_load_from_checkpoint.return_value = MagicMock(n_classes=n_classes)
module.transfer_weights(Path("foo"))
assert mock_copy_weights.call_count == sum(
[int(pretrained_encoder), int(pretrained_pooling), int(pretrained_classifier)]
)
def _get_tiles_deepmil_module(
pretrained_encoder: bool = True,
pretrained_pooling: bool = True,
pretrained_classifier: bool = True,
n_classes: int = 3,
num_layers: int = 2,
num_heads: int = 1,
hidden_dim: int = 8,
transformer_dropout: float = 0.1
) -> TilesDeepMILModule:
module = TilesDeepMILModule(
n_classes=n_classes,
label_column=MockPandaTilesGenerator.ISUP_GRADE,
encoder_params=get_supervised_imagenet_encoder_params(),
pooling_params=get_transformer_pooling_layer_params(num_layers, num_heads, hidden_dim, transformer_dropout),
classifier_params=ClassifierParams(pretrained_classifier=pretrained_classifier),
)
module.encoder_params.pretrained_encoder = pretrained_encoder
module.pooling_params.pretrained_pooling = pretrained_pooling
module.classifier_params.pretrained_classifier = pretrained_classifier
return module
def get_pretrained_module(encoder_val: int = 5, pooling_val: int = 6, classifier_val: int = 7) -> nn.Module:
module = _get_tiles_deepmil_module()
def _fix_sub_module_weights(submodule: nn.Module, constant_val: int) -> None:
for param in submodule.state_dict().values():
param.data.fill_(constant_val)
_fix_sub_module_weights(module.encoder, encoder_val)
_fix_sub_module_weights(module.aggregation_fn, pooling_val)
_fix_sub_module_weights(module.classifier_fn, classifier_val)
return module
@pytest.mark.parametrize("pretrained_classifier", [False, True])
@pytest.mark.parametrize("pretrained_pooling", [False, True])
@pytest.mark.parametrize("pretrained_encoder", [False, True])
def test_transfer_weights_same_config(
pretrained_encoder: bool, pretrained_pooling: bool, pretrained_classifier: bool,
) -> None:
encoder_val = 5
pooling_val = 6
classifier_val = 7
module = _get_tiles_deepmil_module(pretrained_encoder, pretrained_pooling, pretrained_classifier)
pretrained_module = get_pretrained_module(encoder_val, pooling_val, classifier_val)
encoder_random_weights = deepcopy(module.encoder.state_dict())
pooling_random_weights = deepcopy(module.aggregation_fn.state_dict())
classification_random_weights = deepcopy(module.classifier_fn.state_dict())
with patch.object(module, "load_from_checkpoint") as mock_load_from_checkpoint:
mock_load_from_checkpoint.return_value = pretrained_module
module.transfer_weights(Path("foo"))
encoder_transfer_weights = module.encoder.state_dict()
pooling_transfer_weights = module.aggregation_fn.state_dict()
classification_transfer_weights = module.classifier_fn.state_dict()
def _assert_weights_equal(
random_weights: Dict, transfer_weights: Dict, pretrained_flag: bool, expected_val: int
) -> None:
for r_param_name, t_param_name in zip(random_weights, transfer_weights):
assert r_param_name == t_param_name, "Param names do not match"
r_param = random_weights[r_param_name]
t_param = transfer_weights[t_param_name]
if pretrained_flag:
assert torch.equal(t_param.data, torch.full_like(t_param.data, expected_val))
else:
assert torch.equal(t_param.data, r_param.data)
_assert_weights_equal(encoder_random_weights, encoder_transfer_weights, pretrained_encoder, encoder_val)
_assert_weights_equal(pooling_random_weights, pooling_transfer_weights, pretrained_pooling, pooling_val)
_assert_weights_equal(
classification_random_weights, classification_transfer_weights, pretrained_classifier, classifier_val
)
def test_transfer_weights_different_encoder() -> None:
module = _get_tiles_deepmil_module(pretrained_encoder=True)
pretrained_module = _get_tiles_deepmil_module()
pretrained_module.encoder = IdentityEncoder(tile_size=224)
with patch.object(module, "load_from_checkpoint") as mock_load_from_checkpoint:
mock_load_from_checkpoint.return_value = pretrained_module
with pytest.raises(
ValueError, match=rf"Submodule {DeepMILSubmodules.ENCODER} has different number of parameters "
):
module.transfer_weights(Path("foo"))
def test_transfer_weights_different_pooling() -> None:
module = _get_tiles_deepmil_module(num_heads=2, hidden_dim=24, pretrained_pooling=True)
pretrained_module = _get_tiles_deepmil_module(num_heads=1, hidden_dim=8)
with patch.object(module, "load_from_checkpoint") as mock_load_from_checkpoint:
mock_load_from_checkpoint.return_value = pretrained_module
with pytest.raises(
ValueError, match=rf"Submodule {DeepMILSubmodules.POOLING} has different number of parameters "
):
module.transfer_weights(Path("foo"))
def test_transfer_weights_different_classifier() -> None:
module = _get_tiles_deepmil_module(n_classes=4, pretrained_classifier=True)
pretrained_module = _get_tiles_deepmil_module(n_classes=3)
with patch.object(module, "load_from_checkpoint") as mock_load_from_checkpoint:
mock_load_from_checkpoint.return_value = pretrained_module
with pytest.raises(
ValueError,
match=r"Number of classes in pretrained model 3 does not match number of classes in current model 4."
):
module.transfer_weights(Path("foo"))
def test_wrong_encoding_chunk_size() -> None:
with pytest.raises(
ValueError, match=r"The encoding chunk size should be at least as large as the maximum bag size"
):
_ = BaseMIL(encoding_chunk_size=1, max_bag_size=4, tune_encoder=True, max_num_gpus=2, pl_sync_batchnorm=True)
@pytest.mark.parametrize("container_type", [DeepSMILETilesPanda,
DeepSMILECrck])
@pytest.mark.parametrize("primary_val_metric", [m for m in MetricsKey])
@ -444,3 +680,35 @@ def test_checkpoint_name(container_type: Type[BaseMILTiles], primary_val_metric:
metric_optim = "max" if maximise_primary_metric else "min"
assert container.best_checkpoint_filename == f"checkpoint_{metric_optim}_val_{primary_val_metric.value}"
def test_on_run_extra_val_epoch(mock_panda_tiles_root_dir: Path) -> None:
container = MockDeepSMILETilesPanda(tmp_path=mock_panda_tiles_root_dir)
container.setup()
container.data_module = MagicMock()
container.create_lightning_module_and_store()
assert not container.model._on_extra_val_epoch
assert (
container.model.outputs_handler.test_plots_handler.plot_options # type: ignore
!= container.model.outputs_handler.val_plots_handler.plot_options # type: ignore
)
container.on_run_extra_validation_epoch()
assert container.model._on_extra_val_epoch
assert (
container.model.outputs_handler.test_plots_handler.plot_options # type: ignore
== container.model.outputs_handler.val_plots_handler.plot_options # type: ignore
)
@pytest.mark.parametrize(
"container_type", [TcgaCrckSSLMIL, TilesPandaSSLMIL, SlidesPandaSSLMIL, SlidesPandaSSLMILBenchmark]
)
def test_ssl_containers_default_checkpoint(container_type: BaseMIL) -> None:
if container_type == TcgaCrckSSLMIL:
default_checkpoint = innereye_ssl_checkpoint_crck_4ws
else:
default_checkpoint = innereye_ssl_checkpoint_binary
assert container_type().ssl_checkpoint.checkpoint == default_checkpoint
container = container_type(ssl_checkpoint=CheckpointParser(TEST_SSL_RUN_ID))
assert container.ssl_checkpoint.checkpoint != default_checkpoint

Просмотреть файл

@ -12,8 +12,8 @@ from torch import Tensor, float32, nn, rand
from torchvision.models import resnet18
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, CheckpointDownloader
from health_cpath.models.encoders import (TileEncoder, HistoSSLEncoder, ImageNetEncoder,
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME, CheckpointDownloader
from health_cpath.models.encoders import (Resnet18, TileEncoder, HistoSSLEncoder,
ImageNetSimCLREncoder, SSLEncoder)
from health_cpath.utils.layer_utils import setup_feature_extractor
from testhiml.utils_testhiml import DEFAULT_WORKSPACE
@ -25,7 +25,7 @@ TEST_SSL_RUN_ID = "CRCK_SimCLR_1654677598_49a66020"
def get_supervised_imagenet_encoder() -> TileEncoder:
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=TILE_SIZE)
return Resnet18(tile_size=TILE_SIZE)
def get_simclr_imagenet_encoder() -> TileEncoder:
@ -35,8 +35,9 @@ def get_simclr_imagenet_encoder() -> TileEncoder:
def get_ssl_encoder(download_dir: Path) -> TileEncoder:
downloader = CheckpointDownloader(run_id=TEST_SSL_RUN_ID,
download_dir=download_dir,
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME,
remote_checkpoint_dir=Path(DEFAULT_AML_CHECKPOINT_DIR))
downloader.download_checkpoint_if_necessary()
return SSLEncoder(pl_checkpoint_path=downloader.local_checkpoint_path, tile_size=TILE_SIZE)

Просмотреть файл

@ -14,15 +14,15 @@ from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
from monai.transforms import Compose
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import Subset
from torchvision.models import resnet18
from torchvision.transforms import RandomHorizontalFlip
from health_ml.utils.bag_utils import BagDataset
from health_ml.utils.data_augmentations import HEDJitter
from health_cpath.datasets.default_paths import TCGA_CRCK_DATASET_DIR
from health_cpath.datasets.panda_tiles_dataset import PandaTilesDataset
from health_cpath.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
from health_cpath.models.encoders import ImageNetEncoder
from health_cpath.models.encoders import Resnet18
from health_cpath.models.transforms import (EncodeTilesBatchd, LoadTiled, LoadTilesBatchd, Subsampled,
transform_dict_adaptor)
@ -89,6 +89,32 @@ def test_load_tiles_batch() -> None:
assert_dicts_equal(bagged_loaded_batch, loaded_bagged_batch)
@pytest.mark.parametrize("scale_intensity", [True, False])
def test_itensity_scaling_load_tiles_batch(scale_intensity: bool, mock_panda_tiles_root_dir: Path) -> None:
tiles_dataset = PandaTilesDataset(mock_panda_tiles_root_dir)
image_key = tiles_dataset.IMAGE_COLUMN
max_bag_size = 4
bagged_dataset = BagDataset(tiles_dataset, bag_ids=tiles_dataset.slide_ids, # type: ignore
max_bag_size=max_bag_size)
load_batch_transform = LoadTilesBatchd(image_key, scale_intensity=scale_intensity)
index = 0
# Test that the transform returns images in [0, 255] range
bagged_batch = bagged_dataset[index]
manually_loaded_batch = load_batch_transform(bagged_batch)
pixels_dtype = torch.uint8 if not scale_intensity else torch.float32
max_val = 255 if not scale_intensity else 1.
for tile in manually_loaded_batch[image_key]:
assert tile.dtype == pixels_dtype
assert tile.max() <= max_val
assert tile.min() >= 0
if not scale_intensity:
assert manually_loaded_batch[image_key][index].max() > 1
assert tile.unique().shape[0] > 1
def _test_cache_and_persistent_datasets(tmp_path: Path,
base_dataset: TorchDataset,
transform: Union[Sequence[Callable], Callable],
@ -142,7 +168,7 @@ def test_encode_tiles(tmp_path: Path, use_gpu: bool, chunk_size: int) -> None:
bagged_dataset = BagDataset(tiles_dataset, bag_ids=tiles_dataset.slide_ids, # type: ignore
max_bag_size=max_bag_size)
encoder = ImageNetEncoder(resnet18, tile_size=224, n_channels=3)
encoder = Resnet18(tile_size=224, n_channels=3)
if use_gpu:
encoder.cuda()

Просмотреть файл

@ -0,0 +1,281 @@
import time
import torch
import pytest
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Callable, List
from unittest.mock import MagicMock, patch
from health_cpath.configs.classification.BaseMIL import BaseMIL
from health_cpath.utils.naming import ModelKey, ResultsKey
from health_cpath.utils.callbacks import LossAnalysisCallback, LossCacheDictType
from testhisto.mocks.container import MockDeepSMILETilesPanda
from testhisto.utils.utils_testhisto import run_distributed
def _assert_is_sorted(array: np.ndarray) -> None:
assert np.all(np.diff(array) <= 0)
def _assert_loss_cache_contains_n_elements(loss_cache: LossCacheDictType, n: int) -> None:
for key in loss_cache:
assert len(loss_cache[key]) == n
def get_loss_cache(n_slides: int = 4, rank: int = 0) -> LossCacheDictType:
return {
ResultsKey.LOSS: list(range(1, n_slides + 1)),
ResultsKey.ENTROPY: list(range(1, n_slides + 1)),
ResultsKey.SLIDE_ID: [f"id_{i}" for i in range(rank * n_slides, (rank + 1) * n_slides)],
ResultsKey.TILE_ID: [f"a${i * (rank + 1)}$b" for i in range(rank * n_slides, (rank + 1) * n_slides)],
}
def dump_loss_cache_for_epochs(loss_callback: LossAnalysisCallback, epochs: int, stage: ModelKey) -> None:
for epoch in range(epochs):
loss_callback.loss_cache[stage] = get_loss_cache(n_slides=4, rank=0)
loss_callback.save_loss_cache(epoch, stage)
@pytest.mark.parametrize("create_outputs_folders", [True, False])
def test_loss_callback_outputs_folder_exist(create_outputs_folders: bool, tmp_path: Path) -> None:
outputs_folder = tmp_path / "outputs"
callback = LossAnalysisCallback(outputs_folder=outputs_folder, create_outputs_folders=create_outputs_folders)
for stage in [ModelKey.TRAIN, ModelKey.VAL]:
for folder in [
callback.outputs_folder,
callback.get_cache_folder(stage),
callback.get_scatter_folder(stage),
callback.get_heatmap_folder(stage),
callback.get_anomalies_folder(stage),
]:
assert folder.exists() == create_outputs_folders
@pytest.mark.parametrize("analyse_loss", [True, False])
def test_analyse_loss_param(analyse_loss: bool) -> None:
container = BaseMIL(analyse_loss=analyse_loss)
container.data_module = MagicMock()
callbacks = container.get_callbacks()
assert isinstance(callbacks[-1], LossAnalysisCallback) == analyse_loss
@pytest.mark.parametrize("save_tile_ids", [True, False])
def test_save_tile_ids_param(save_tile_ids: bool) -> None:
callback = LossAnalysisCallback(outputs_folder=Path("foo"), save_tile_ids=save_tile_ids)
assert callback.save_tile_ids == save_tile_ids
assert (ResultsKey.TILE_ID in callback.loss_cache[ModelKey.TRAIN]) == save_tile_ids
assert (ResultsKey.TILE_ID in callback.loss_cache[ModelKey.VAL]) == save_tile_ids
@pytest.mark.parametrize("patience", [0, 1, 2])
def test_loss_analysis_patience(patience: int) -> None:
callback = LossAnalysisCallback(outputs_folder=Path("foo"), patience=patience, max_epochs=10)
assert callback.patience == patience
assert callback.epochs_range[0] == patience
current_epoch = 0
if patience > 0:
assert callback.should_cache_loss_values(current_epoch) is False
else:
assert callback.should_cache_loss_values(current_epoch)
current_epoch = 5
assert callback.should_cache_loss_values(current_epoch)
@pytest.mark.parametrize("epochs_interval", [1, 2])
def test_loss_analysis_epochs_interval(epochs_interval: int) -> None:
max_epochs = 10
callback = LossAnalysisCallback(
outputs_folder=Path("foo"), patience=0, max_epochs=max_epochs, epochs_interval=epochs_interval
)
assert callback.epochs_interval == epochs_interval
assert len(callback.epochs_range) == max_epochs // epochs_interval
# First time to cache loss values
current_epoch = 0
assert callback.should_cache_loss_values(current_epoch)
current_epoch = 4 # Note that PL starts counting epochs from 0, 4th epoch is actually the 5th
if epochs_interval == 2:
assert callback.should_cache_loss_values(current_epoch) is False
else:
assert callback.should_cache_loss_values(current_epoch)
current_epoch = 5
assert callback.should_cache_loss_values(current_epoch)
current_epoch = max_epochs # no loss caching for extra validation epoch
assert callback.should_cache_loss_values(current_epoch) is False
def test_on_train_and_val_batch_end(tmp_path: Path, mock_panda_tiles_root_dir: Path) -> None:
batch_size = 2
container = MockDeepSMILETilesPanda(tmp_path=mock_panda_tiles_root_dir, analyse_loss=True, batch_size=batch_size)
container.setup()
container.create_lightning_module_and_store()
current_epoch = 5
trainer = MagicMock(current_epoch=current_epoch)
callback = LossAnalysisCallback(outputs_folder=tmp_path)
_assert_loss_cache_contains_n_elements(callback.loss_cache[ModelKey.TRAIN], 0)
_assert_loss_cache_contains_n_elements(callback.loss_cache[ModelKey.VAL], 0)
dataloader = iter(container.data_module.train_dataloader())
def _call_on_batch_end_hook(on_batch_end_hook: Callable, batch_idx: int) -> None:
batch = next(dataloader)
outputs = container.model.training_step(batch, batch_idx)
on_batch_end_hook(trainer, container.model, outputs, batch, batch_idx, 0) # type: ignore
stages = [ModelKey.TRAIN, ModelKey.VAL]
hooks: List[Callable] = [callback.on_train_batch_end, callback.on_validation_batch_end]
for stage, on_batch_end_hook in zip(stages, hooks):
_call_on_batch_end_hook(on_batch_end_hook, batch_idx=0)
_assert_loss_cache_contains_n_elements(callback.loss_cache[stage], batch_size)
_call_on_batch_end_hook(on_batch_end_hook, batch_idx=1)
_assert_loss_cache_contains_n_elements(callback.loss_cache[stage], 2 * batch_size)
def test_on_train_and_val_epoch_end(
tmp_path: Path, duplicate: bool = False, rank: int = 0, world_size: int = 1, device: str = "cpu"
) -> None:
current_epoch = 2
n_slides_per_process = 4
trainer = MagicMock(current_epoch=current_epoch)
pl_module = MagicMock(global_rank=rank)
loss_callback = LossAnalysisCallback(
outputs_folder=tmp_path, num_slides_heatmap=2, num_slides_scatter=2, max_epochs=10
)
stages = [ModelKey.TRAIN, ModelKey.VAL]
hooks = [loss_callback.on_train_epoch_end, loss_callback.on_validation_epoch_end]
for stage, on_epoch_hook in zip(stages, hooks):
loss_callback.loss_cache[stage] = get_loss_cache(rank=rank, n_slides=n_slides_per_process)
if duplicate:
# Duplicate slide "id_0" to test that the duplicates are removed
loss_callback.loss_cache[stage][ResultsKey.SLIDE_ID][0] = "id_0"
_assert_loss_cache_contains_n_elements(loss_callback.loss_cache[stage], n_slides_per_process)
on_epoch_hook(trainer, pl_module)
# Loss cache is flushed after each epoch
_assert_loss_cache_contains_n_elements(loss_callback.loss_cache[stage], 0)
if rank > 0:
time.sleep(10) # Wait for rank 0 to save the loss cache in a csv file
loss_cache_path = loss_callback.get_loss_cache_file(current_epoch, stage)
assert loss_callback.get_cache_folder(stage).exists()
assert loss_cache_path.exists()
assert loss_cache_path.parent == loss_callback.get_cache_folder(stage)
loss_cache = pd.read_csv(loss_cache_path)
total_slides = n_slides_per_process * world_size if not duplicate else n_slides_per_process * world_size - 1
_assert_loss_cache_contains_n_elements(loss_cache, total_slides)
_assert_is_sorted(loss_cache[ResultsKey.LOSS].values)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
@pytest.mark.gpu
def test_on_train_epoch_end_distributed(tmp_path: Path) -> None:
# Test that the loss cache is saved correctly when using multiple GPUs
# First scenario: no duplicates
run_distributed(test_on_train_and_val_epoch_end, [tmp_path, False], world_size=2)
# Second scenario: introduce duplicates
run_distributed(test_on_train_and_val_epoch_end, [tmp_path, True], world_size=2)
def test_on_train_and_val_end(tmp_path: Path) -> None:
pl_module = MagicMock(global_rank=0, _on_extra_val_epoch=False)
max_epochs = 4
trainer = MagicMock(current_epoch=max_epochs - 1)
loss_callback = LossAnalysisCallback(
outputs_folder=tmp_path, max_epochs=max_epochs, num_slides_heatmap=2, num_slides_scatter=2
)
stages = [ModelKey.TRAIN, ModelKey.VAL]
hooks = [loss_callback.on_train_end, loss_callback.on_validation_end]
for stage, on_end_hook in zip(stages, hooks):
dump_loss_cache_for_epochs(loss_callback, max_epochs, stage)
on_end_hook(trainer, pl_module)
for epoch in range(max_epochs):
assert loss_callback.get_loss_cache_file(epoch, stage).exists()
# check save_loss_ranks outputs
assert loss_callback.get_all_epochs_loss_cache_file(stage).exists()
assert loss_callback.get_loss_stats_file(stage).exists()
assert loss_callback.get_loss_ranks_file(stage).exists()
assert loss_callback.get_loss_ranks_stats_file(stage).exists()
# check plot_slides_loss_scatter outputs
assert loss_callback.get_scatter_plot_file(loss_callback.HIGHEST, stage).exists()
assert loss_callback.get_scatter_plot_file(loss_callback.LOWEST, stage).exists()
# check plot_loss_heatmap_for_slides_of_epoch outputs
for epoch in range(max_epochs):
assert loss_callback.get_heatmap_plot_file(epoch, loss_callback.HIGHEST, stage).exists()
assert loss_callback.get_heatmap_plot_file(epoch, loss_callback.LOWEST, stage).exists()
def test_on_validation_end_not_called_if_extra_val_epoch(tmp_path: Path) -> None:
pl_module = MagicMock(global_rank=0, _on_extra_val_epoch=True)
max_epochs = 4
trainer = MagicMock(current_epoch=0)
loss_callback = LossAnalysisCallback(
outputs_folder=tmp_path, max_epochs=max_epochs, num_slides_heatmap=2, num_slides_scatter=2
)
with patch.object(loss_callback, "save_loss_outliers_analaysis_results") as mock_func:
loss_callback.on_validation_end(trainer, pl_module)
mock_func.assert_not_called()
def test_nans_detection(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None:
max_epochs = 2
n_slides = 4
loss_callback = LossAnalysisCallback(
outputs_folder=tmp_path, max_epochs=max_epochs, num_slides_heatmap=2, num_slides_scatter=2
)
stages = [ModelKey.TRAIN, ModelKey.VAL]
for stage in stages:
for epoch in range(max_epochs):
loss_callback.loss_cache[stage] = get_loss_cache(n_slides)
loss_callback.loss_cache[stage][ResultsKey.LOSS][epoch] = np.nan # Introduce NaNs
loss_callback.save_loss_cache(epoch, stage)
all_slides = loss_callback.select_slides_for_epoch(epoch=0, stage=stage)
all_loss_values_per_slides = loss_callback.select_all_losses_for_selected_slides(all_slides, stage)
loss_callback.sanity_check_loss_values(all_loss_values_per_slides, stage)
assert "NaNs found in loss values for slide id_0" in caplog.records[-1].getMessage()
assert "NaNs found in loss values for slide id_1" in caplog.records[0].getMessage()
assert loss_callback.nan_slides[stage] == ["id_1", "id_0"]
assert loss_callback.get_nan_slides_file(stage).exists()
assert loss_callback.get_nan_slides_file(stage).parent == loss_callback.get_anomalies_folder(stage)
@pytest.mark.parametrize("log_exceptions", [True, False])
def test_log_exceptions_flag(log_exceptions: bool, tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None:
max_epochs = 3
trainer = MagicMock(current_epoch=max_epochs - 1)
pl_module = MagicMock(global_rank=0, _on_extra_val_epoch=False)
loss_callback = LossAnalysisCallback(
outputs_folder=tmp_path, max_epochs=max_epochs,
num_slides_heatmap=2, num_slides_scatter=2, log_exceptions=log_exceptions
)
stages = [ModelKey.TRAIN, ModelKey.VAL]
hooks = [loss_callback.on_train_end, loss_callback.on_validation_end]
for stage, on_end_hook in zip(stages, hooks):
message = "Error while detecting " + stage.value + " loss values outliers"
if log_exceptions:
on_end_hook(trainer, pl_module)
assert message in caplog.records[-1].getMessage()
else:
with pytest.raises(Exception, match=fr"{message}"):
on_end_hook(trainer, pl_module)

Просмотреть файл

@ -0,0 +1,70 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import pytest
from pathlib import Path
from unittest.mock import MagicMock, patch
from health_cpath.models.encoders import SSLEncoder
from health_cpath.scripts.generate_checkpoint_url import get_checkpoint_url_from_aml_run
from health_cpath.utils.deepmil_utils import EncoderParams
from health_ml.utils.checkpoint_utils import CheckpointParser, LAST_CHECKPOINT_FILE_NAME, MODEL_WEIGHTS_DIR_NAME
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
from testhiml.utils.fixed_paths_for_tests import full_test_data_path
from testhisto.models.test_encoders import TEST_SSL_RUN_ID
from testhiml.utils_testhiml import DEFAULT_WORKSPACE
LAST_CHECKPOINT = f"{DEFAULT_AML_CHECKPOINT_DIR}/{LAST_CHECKPOINT_FILE_NAME}"
def test_validate_encoder_params() -> None:
with pytest.raises(ValueError, match=r"SSLEncoder requires an ssl_checkpoint"):
encoder = EncoderParams(encoder_type=SSLEncoder.__name__)
encoder.validate()
def test_load_ssl_checkpoint_from_local_file(tmp_path: Path) -> None:
checkpoint_filename = "hello_world_checkpoint.ckpt"
local_checkpoint_path = full_test_data_path(suffix=checkpoint_filename)
encoder_params = EncoderParams(
encoder_type=SSLEncoder.__name__, ssl_checkpoint=CheckpointParser(str(local_checkpoint_path))
)
assert encoder_params.ssl_checkpoint.is_local_file
ssl_checkpoint_path = encoder_params.ssl_checkpoint.get_path(tmp_path)
assert ssl_checkpoint_path.exists()
assert ssl_checkpoint_path == local_checkpoint_path
with patch("health_cpath.models.encoders.SSLEncoder._get_encoder") as mock_get_encoder:
mock_get_encoder.return_value = (MagicMock(), MagicMock())
encoder = encoder_params.get_encoder(tmp_path)
assert isinstance(encoder, SSLEncoder)
@pytest.mark.skip(reason="This test is failing because of issue #655")
def test_load_ssl_checkpoint_from_url(tmp_path: Path) -> None:
blob_url = get_checkpoint_url_from_aml_run(
run_id=TEST_SSL_RUN_ID,
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME,
expiry_days=1,
aml_workspace=DEFAULT_WORKSPACE.workspace)
encoder_params = EncoderParams(encoder_type=SSLEncoder.__name__, ssl_checkpoint=CheckpointParser(blob_url))
assert encoder_params.ssl_checkpoint.is_url
ssl_checkpoint_path = encoder_params.ssl_checkpoint.get_path(tmp_path)
assert ssl_checkpoint_path.exists()
assert ssl_checkpoint_path == tmp_path / MODEL_WEIGHTS_DIR_NAME / LAST_CHECKPOINT_FILE_NAME
encoder = encoder_params.get_encoder(tmp_path)
assert isinstance(encoder, SSLEncoder)
@pytest.mark.skip(reason="This test is failing because of issue #655")
def test_load_ssl_checkpoint_from_run_id(tmp_path: Path) -> None:
encoder_params = EncoderParams(encoder_type=SSLEncoder.__name__, ssl_checkpoint=CheckpointParser(TEST_SSL_RUN_ID))
assert encoder_params.ssl_checkpoint.is_aml_run_id
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:
mock_get_workspace.return_value = DEFAULT_WORKSPACE.workspace
ssl_checkpoint_path = encoder_params.ssl_checkpoint.get_path(tmp_path)
assert ssl_checkpoint_path.exists()
assert ssl_checkpoint_path == tmp_path / TEST_SSL_RUN_ID / LAST_CHECKPOINT
encoder = encoder_params.get_encoder(tmp_path)
assert isinstance(encoder, SSLEncoder)

Просмотреть файл

@ -1,12 +1,13 @@
from pathlib import Path
from typing import Dict, List
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.distributed
import torch.multiprocessing
from ruamel.yaml import YAML
from health_cpath.utils.tiles_selection_utils import TilesSelector
from testhisto.utils.utils_testhisto import run_distributed
from torch.testing import assert_close
from torchmetrics.metric import Metric
@ -251,3 +252,27 @@ def test_collate_results_multigpu() -> None:
epoch_results = _create_epoch_results(batch_size, num_batches, rank=0, device='cuda:0') \
+ _create_epoch_results(batch_size, num_batches, rank=1, device='cuda:1')
_test_collate_results(epoch_results, total_num_samples=2 * num_batches * batch_size)
@pytest.mark.parametrize('val_set_is_dist', [True, False])
def test_results_gathering_with_val_set_is_dist_flag(val_set_is_dist: bool, tmp_path: Path) -> None:
outputs_handler = _create_outputs_handler(tmp_path)
outputs_handler.tiles_selector = TilesSelector(2, num_slides=2, num_tiles=2)
outputs_handler.val_set_is_dist = val_set_is_dist
outputs_handler._save_outputs = MagicMock() # type: ignore
metric_value = 0.5
with patch("health_cpath.utils.output_utils.gather_results") as mock_gather_results:
with patch.object(outputs_handler.tiles_selector, "gather_selected_tiles_across_devices") as mock_gather_tiles:
with patch.object(outputs_handler.tiles_selector, "_clear_cached_slides_heaps") as mock_clear_cache:
with patch.object(outputs_handler, "should_gather_tiles") as mock_should_gather_tiles:
mock_should_gather_tiles.return_value = True
for rank in range(2):
epoch_results = [{_PRIMARY_METRIC_KEY: [metric_value] * 5, _RANK_KEY: rank}]
outputs_handler.save_validation_outputs(
epoch_results=epoch_results, # type: ignore
metrics_dict=_get_mock_metrics_dict(metric_value),
epoch=0,
is_global_rank_zero=rank == 0)
assert mock_gather_results.called == val_set_is_dist
assert mock_gather_tiles.called == val_set_is_dist
mock_clear_cache.assert_called()

Просмотреть файл

@ -3,33 +3,42 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
import os
from pathlib import Path
from typing import Collection
from typing import Any, Collection, Dict, List
from unittest.mock import MagicMock, patch
import pytest
from health_cpath.utils.naming import PlotOption, ResultsKey
from health_cpath.utils.plots_utils import DeepMILPlotsHandler, save_confusion_matrix
from health_cpath.utils.plots_utils import DeepMILPlotsHandler, save_confusion_matrix, save_pr_curve
from health_cpath.utils.tiles_selection_utils import SlideNode, TilesSelector
from testhisto.mocks.container import MockDeepSMILETilesPanda
def test_plots_handler_wrong_class_names() -> None:
plot_options = {PlotOption.HISTOGRAM, PlotOption.CONFUSION_MATRIX}
with pytest.raises(ValueError) as ex:
with pytest.raises(ValueError, match=r"No class_names were provided while activating confusion matrix plotting."):
_ = DeepMILPlotsHandler(plot_options, class_names=[])
assert "No class_names were provided while activating confusion matrix plotting." in str(ex)
def test_plots_handler_slide_thumbnails_without_slide_dataset() -> None:
with pytest.raises(ValueError) as ex:
@pytest.mark.parametrize(
"slide_plot_options",
[
[PlotOption.SLIDE_THUMBNAIL],
[PlotOption.ATTENTION_HEATMAP],
[PlotOption.SLIDE_THUMBNAIL, PlotOption.ATTENTION_HEATMAP]
],
)
def test_plots_handler_slide_plot_options_without_slide_dataset(slide_plot_options: List[PlotOption]) -> None:
exception_prompt = f"Plot option {slide_plot_options[0]} requires a slides dataset"
with pytest.raises(ValueError, match=rf"{exception_prompt}"):
container = MockDeepSMILETilesPanda(tmp_path=Path("foo"))
container.setup()
container.data_module = MagicMock()
container.data_module.train_dataset.n_classes = 6
outputs_handler = container.get_outputs_handler()
outputs_handler.test_plots_handler.plot_options = {PlotOption.SLIDE_THUMBNAIL_HEATMAP}
outputs_handler.test_plots_handler.plot_options = slide_plot_options
outputs_handler.set_slides_dataset_for_plots_handlers(container.get_slides_dataset())
assert "You can not plot slide thumbnails and heatmaps without setting a slides_dataset." in str(ex)
def assert_plot_func_called_if_among_plot_options(
@ -48,42 +57,46 @@ def assert_plot_func_called_if_among_plot_options(
"plot_options",
[
{},
{PlotOption.HISTOGRAM},
{PlotOption.HISTOGRAM, PlotOption.PR_CURVE},
{PlotOption.HISTOGRAM, PlotOption.CONFUSION_MATRIX},
{PlotOption.HISTOGRAM, PlotOption.TOP_BOTTOM_TILES, PlotOption.SLIDE_THUMBNAIL_HEATMAP},
{PlotOption.HISTOGRAM, PlotOption.TOP_BOTTOM_TILES, PlotOption.ATTENTION_HEATMAP},
{
PlotOption.HISTOGRAM,
PlotOption.PR_CURVE,
PlotOption.CONFUSION_MATRIX,
PlotOption.TOP_BOTTOM_TILES,
PlotOption.SLIDE_THUMBNAIL_HEATMAP,
PlotOption.SLIDE_THUMBNAIL,
PlotOption.ATTENTION_HEATMAP,
},
],
)
def test_plots_handler_plots_only_desired_plot_options(plot_options: Collection[PlotOption]) -> None:
plots_handler = DeepMILPlotsHandler(plot_options, class_names=["foo"])
plots_handler = DeepMILPlotsHandler(plot_options, class_names=["foo1", "foo2"])
plots_handler.slides_dataset = MagicMock()
n_tiles = 4
slide_node = SlideNode(slide_id="1", prob_score=0.5, true_label=1, pred_label=0)
slide_node = SlideNode(slide_id="1", gt_prob_score=0.2, pred_prob_score=0.8, true_label=1, pred_label=0)
tiles_selector = TilesSelector(n_classes=2, num_slides=4, num_tiles=2)
tiles_selector.top_slides_heaps = {0: [slide_node] * n_tiles, 1: [slide_node] * n_tiles}
tiles_selector.bottom_slides_heaps = {0: [slide_node] * n_tiles, 1: [slide_node] * n_tiles}
with patch("health_cpath.utils.plots_utils.save_slide_thumbnail_and_heatmap") as mock_slide:
with patch("health_cpath.utils.plots_utils.save_top_and_bottom_tiles") as mock_tile:
with patch("health_cpath.utils.plots_utils.save_scores_histogram") as mock_histogram:
with patch("health_cpath.utils.plots_utils.save_confusion_matrix") as mock_conf:
plots_handler.save_plots(
outputs_dir=MagicMock(), tiles_selector=tiles_selector, results=MagicMock()
)
patchers: Dict[PlotOption, Any] = {
PlotOption.SLIDE_THUMBNAIL: patch("health_cpath.utils.plots_utils.save_slide_thumbnail"),
PlotOption.ATTENTION_HEATMAP: patch("health_cpath.utils.plots_utils.save_attention_heatmap"),
PlotOption.TOP_BOTTOM_TILES: patch("health_cpath.utils.plots_utils.save_top_and_bottom_tiles"),
PlotOption.CONFUSION_MATRIX: patch("health_cpath.utils.plots_utils.save_confusion_matrix"),
PlotOption.HISTOGRAM: patch("health_cpath.utils.plots_utils.save_scores_histogram"),
PlotOption.PR_CURVE: patch("health_cpath.utils.plots_utils.save_pr_curve"),
}
mock_funcs = {option: patcher.start() for option, patcher in patchers.items()} # type: ignore
with patch.object(plots_handler, "get_slide_dict"):
plots_handler.save_plots(outputs_dir=MagicMock(), tiles_selector=tiles_selector, results=MagicMock())
patch.stopall()
calls_count = 0
calls_count += assert_plot_func_called_if_among_plot_options(
mock_slide, PlotOption.SLIDE_THUMBNAIL_HEATMAP, plot_options
)
calls_count += assert_plot_func_called_if_among_plot_options(mock_tile, PlotOption.TOP_BOTTOM_TILES, plot_options)
calls_count += assert_plot_func_called_if_among_plot_options(mock_histogram, PlotOption.HISTOGRAM, plot_options)
calls_count += assert_plot_func_called_if_among_plot_options(mock_conf, PlotOption.CONFUSION_MATRIX, plot_options)
for option, mock_func in mock_funcs.items():
calls_count += assert_plot_func_called_if_among_plot_options(mock_func, option, plot_options)
assert calls_count == len(plot_options)
@ -97,8 +110,8 @@ def test_save_conf_matrix_integration(tmp_path: Path) -> None:
}
class_names = ["foo", "bar"]
save_confusion_matrix(results, class_names, tmp_path)
file = Path(tmp_path) / "normalized_confusion_matrix.png"
save_confusion_matrix(results, class_names, tmp_path, stage='foo')
file = Path(tmp_path) / "normalized_confusion_matrix_foo.png"
assert file.exists()
# check that an error is raised if true labels include indices greater than the expected number of classes
@ -119,7 +132,7 @@ def test_save_conf_matrix_integration(tmp_path: Path) -> None:
save_confusion_matrix(invalid_results_2, class_names, tmp_path)
assert "More entries were found in predicted labels than are available in class names" in str(e)
# check that if confusion matrix still has correct shape even if results don't cover all expected labels
# check that confusion matrix still has correct shape even if results don't cover all expected labels
class_names_extended = ["foo", "bar", "baz"]
num_classes = len(class_names_extended)
expected_conf_matrix_shape = (num_classes, num_classes)
@ -129,3 +142,25 @@ def test_save_conf_matrix_integration(tmp_path: Path) -> None:
mock_plot_conf_matrix.assert_called_once()
actual_conf_matrix = mock_plot_conf_matrix.call_args[1].get('cm')
assert actual_conf_matrix.shape == expected_conf_matrix_shape
def test_pr_curve_integration(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None:
results = {
ResultsKey.TRUE_LABEL: [0, 1, 0, 1, 0, 1],
ResultsKey.PROB: [0.1, 0.8, 0.6, 0.3, 0.5, 0.4]
}
# check plot is produced and it has right filename
save_pr_curve(results, tmp_path, stage='foo') # type: ignore
file = Path(tmp_path) / "pr_curve_foo.png"
assert file.exists()
os.remove(file)
# check warning is logged and plot is not produced if NOT a binary case
results[ResultsKey.TRUE_LABEL] = [0, 1, 0, 2, 0, 1]
save_pr_curve(results, tmp_path, stage='foo') # type: ignore
warning_message = "The PR curve plot implementation works only for binary cases, this plot will be skipped."
assert warning_message in caplog.records[-1].getMessage()
assert not file.exists()

Просмотреть файл

@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Dict, List, Sequence, Union
from typing import Any, Dict, List, Sequence, Union
from unittest.mock import MagicMock, patch
import numpy as np
@ -11,10 +11,10 @@ import pytest
from health_azure.utils import download_file_if_necessary
from health_cpath.utils.output_utils import (AML_LEGACY_TEST_OUTPUTS_CSV, AML_OUTPUTS_DIR, AML_TEST_OUTPUTS_CSV,
AML_VAL_OUTPUTS_CSV)
from health_cpath.utils.report_utils import (collect_crossval_metrics, collect_crossval_outputs,
crossval_runs_have_val_and_test_outputs, get_best_epoch_metrics,
get_best_epochs, get_crossval_metrics_table,
run_has_val_and_test_outputs)
from health_cpath.utils.report_utils import (collect_hyperdrive_metrics, collect_hyperdrive_outputs,
child_runs_have_val_and_test_outputs, get_best_epoch_metrics,
get_best_epochs, get_hyperdrive_metrics_table,
run_has_val_and_test_outputs, download_hyperdrive_metrics_if_required)
def test_run_has_val_and_test_outputs() -> None:
@ -41,7 +41,7 @@ def test_run_has_val_and_test_outputs() -> None:
run_has_val_and_test_outputs(run)
def test_crossval_runs_have_val_and_test_outputs() -> None:
def test_hyperdrive_runs_have_val_and_test_outputs() -> None:
legacy_run = MagicMock(display_name="legacy run", id="child1")
legacy_run.get_file_names.return_value = [AML_LEGACY_TEST_OUTPUTS_CSV]
@ -55,22 +55,22 @@ def test_crossval_runs_have_val_and_test_outputs() -> None:
parent_run = MagicMock()
with patch.object(parent_run, 'get_children', return_value=[legacy_run, legacy_run]):
assert not crossval_runs_have_val_and_test_outputs(parent_run)
assert not child_runs_have_val_and_test_outputs(parent_run)
with patch.object(parent_run, 'get_children', return_value=[run_with_val_and_test, run_with_val_and_test]):
assert crossval_runs_have_val_and_test_outputs(parent_run)
assert child_runs_have_val_and_test_outputs(parent_run)
with patch.object(parent_run, 'get_children', return_value=[legacy_run, invalid_run]):
with pytest.raises(ValueError, match="does not have the expected files"):
crossval_runs_have_val_and_test_outputs(parent_run)
child_runs_have_val_and_test_outputs(parent_run)
with patch.object(parent_run, 'get_children', return_value=[run_with_val_and_test, invalid_run]):
with pytest.raises(ValueError, match="does not have the expected files"):
crossval_runs_have_val_and_test_outputs(parent_run)
child_runs_have_val_and_test_outputs(parent_run)
with patch.object(parent_run, 'get_children', return_value=[legacy_run, run_with_val_and_test]):
with pytest.raises(ValueError, match="has mixed children"):
crossval_runs_have_val_and_test_outputs(parent_run)
child_runs_have_val_and_test_outputs(parent_run)
@pytest.mark.parametrize('overwrite', [False, True])
@ -102,9 +102,9 @@ def test_download_from_run_if_necessary(tmp_path: Path, overwrite: bool) -> None
class MockChildRun:
def __init__(self, run_id: str, cross_val_index: int):
def __init__(self, run_id: str, child_run_index: int):
self.run_id = run_id
self.tags = {"hyperparameters": json.dumps({"child_run_index": cross_val_index})}
self.tags = {"hyperparameters": json.dumps({"child_run_index": child_run_index})}
def get_metrics(self) -> Dict[str, Union[float, List[Union[int, float]]]]:
num_epochs = 5
@ -127,9 +127,9 @@ class MockHyperDriveRun:
return [MockChildRun(f"run_abc_{i}456", i) for i in self.child_indices]
def test_collect_crossval_outputs(tmp_path: Path) -> None:
def test_collect_hyperdrive_outputs(tmp_path: Path) -> None:
download_dir = tmp_path
crossval_arg_name = "child_run_index"
hyperdrive_arg_name = "child_run_index"
output_filename = "output.csv"
child_indices = [0, 3, 1] # Missing and unsorted children
@ -142,16 +142,16 @@ def test_collect_crossval_outputs(tmp_path: Path) -> None:
with patch('health_cpath.utils.report_utils.get_aml_run_from_run_id',
return_value=MockHyperDriveRun(child_indices)):
crossval_dfs = collect_crossval_outputs(parent_run_id="",
download_dir=download_dir,
aml_workspace=None,
crossval_arg_name=crossval_arg_name,
output_filename=output_filename)
hyperdrive_dfs = collect_hyperdrive_outputs(parent_run_id="",
download_dir=download_dir,
aml_workspace=None,
hyperdrive_arg_name=hyperdrive_arg_name,
output_filename=output_filename)
assert set(crossval_dfs.keys()) == set(child_indices)
assert list(crossval_dfs.keys()) == sorted(crossval_dfs.keys())
assert set(hyperdrive_dfs.keys()) == set(child_indices)
assert list(hyperdrive_dfs.keys()) == sorted(hyperdrive_dfs.keys())
for child_index, child_df in crossval_dfs.items():
for child_index, child_df in hyperdrive_dfs.items():
assert child_df.columns.tolist() == columns
assert child_df.loc[0, 'split'] == child_index
@ -176,13 +176,24 @@ def metrics_df() -> pd.DataFrame:
'val/auroc': [0.8, 0.9, 0.7],
'test/accuracy': 0.9,
'test/auroc': 0.9
}
},
4: {'val/accuracy': None,
'val/auroc': None,
'test/accuracy': None,
'test/auroc': None
}
})
@pytest.fixture
def best_epochs(metrics_df: pd.DataFrame) -> Dict[int, int]:
return get_best_epochs(metrics_df, 'val/accuracy', maximise=True)
def max_epochs_dict() -> Dict[int, int]:
return {0: 3, 1: 10, 3: 3, 4: 3}
@pytest.fixture
def best_epochs(metrics_df: pd.DataFrame, max_epochs_dict: Dict[int, int]) -> Dict[int, Any]:
return get_best_epochs(metrics_df=metrics_df, primary_metric='val/accuracy',
max_epochs_dict=max_epochs_dict, maximise=True)
@pytest.fixture
@ -192,16 +203,18 @@ def best_epoch_metrics(metrics_df: pd.DataFrame, best_epochs: Dict[int, int]) ->
@pytest.mark.parametrize('overwrite', [False, True])
def test_collect_crossval_metrics(metrics_df: pd.DataFrame, tmp_path: Path, overwrite: bool) -> None:
def test_collect_hyperdrive_metrics(metrics_df: pd.DataFrame, tmp_path: Path, overwrite: bool) -> None:
with patch('health_cpath.utils.report_utils.aggregate_hyperdrive_metrics',
return_value=metrics_df) as mock_aggregate:
returned_df = collect_crossval_metrics(parent_run_id="", download_dir=tmp_path,
aml_workspace=None, overwrite=overwrite)
returned_json = download_hyperdrive_metrics_if_required(parent_run_id="", download_dir=tmp_path,
aml_workspace=None, overwrite=overwrite)
returned_df = collect_hyperdrive_metrics(metrics_json=returned_json)
mock_aggregate.assert_called_once()
mock_aggregate.reset_mock()
new_returned_df = collect_crossval_metrics(parent_run_id="", download_dir=tmp_path,
aml_workspace=None, overwrite=overwrite)
new_returned_json = download_hyperdrive_metrics_if_required(parent_run_id="", download_dir=tmp_path,
aml_workspace=None, overwrite=overwrite)
new_returned_df = collect_hyperdrive_metrics(metrics_json=new_returned_json)
if overwrite:
mock_aggregate.assert_called_once()
else:
@ -211,12 +224,13 @@ def test_collect_crossval_metrics(metrics_df: pd.DataFrame, tmp_path: Path, over
@pytest.mark.parametrize('maximise', [True, False])
def test_get_best_epochs(metrics_df: pd.DataFrame, maximise: bool) -> None:
best_epochs = get_best_epochs(metrics_df, 'val/accuracy', maximise=maximise)
def test_get_best_epochs(metrics_df: pd.DataFrame, max_epochs_dict: Dict[int, int], maximise: bool) -> None:
best_epochs = get_best_epochs(metrics_df=metrics_df, primary_metric='val/accuracy',
max_epochs_dict=max_epochs_dict, maximise=maximise)
assert list(best_epochs.keys()) == list(metrics_df.columns)
assert all(isinstance(epoch, int) for epoch in best_epochs.values())
assert all(isinstance(epoch, (int, type(None))) for epoch in best_epochs.values())
expected_best = {0: 0, 1: 1, 3: 2} if maximise else {0: 1, 1: 2, 3: 0}
expected_best = {0: 0, 1: 1, 3: 2, 4: None} if maximise else {0: 1, 1: 2, 3: 0, 4: None}
for split in metrics_df.columns:
assert best_epochs[split] == expected_best[split]
@ -233,13 +247,15 @@ def test_get_best_epoch_metrics(metrics_df: pd.DataFrame, best_epochs: Dict[int,
@pytest.mark.parametrize('fixture_name, metrics_list', [('metrics_df', ['test/accuracy', 'test/auroc']),
('best_epoch_metrics', ['val/accuracy', 'val/auroc'])])
def test_get_crossval_metrics_table(fixture_name: str, metrics_list: List[str], request: pytest.FixtureRequest) -> None:
def test_get_hyperdrive_metrics_table(
fixture_name: str, metrics_list: List[str], request: pytest.FixtureRequest
) -> None:
df = request.getfixturevalue(fixture_name)
metrics_table = get_crossval_metrics_table(df, metrics_list)
metrics_table = get_hyperdrive_metrics_table(df, metrics_list)
assert list(metrics_table.index) == metrics_list
assert len(metrics_table.columns) == len(df.columns) + 1
original_values = df.loc[metrics_list].values
table_values = metrics_table.iloc[:, :-1].applymap(float).values
assert (table_values == original_values).all()
assert (table_values[~pd.isnull(table_values)] == original_values[~pd.isnull(original_values)]).all()

Просмотреть файл

@ -8,7 +8,7 @@ import pytest
import numpy as np
from unittest.mock import patch
from typing import Dict, List, Any
from typing import Dict, List, Any, Set
from testhisto.utils.utils_testhisto import run_distributed
from health_cpath.utils.naming import ResultsKey, SlideKey
from health_cpath.utils.plots_utils import TilesSelector, SlideNode
@ -42,12 +42,14 @@ def _create_mock_results(n_samples: int, n_tiles: int = 3, n_classes: int = 2, d
:return: A dictioanry containing randomly generated mock results.
"""
diff_n_tiles = [n_tiles + i for i in range(n_samples)]
probs = torch.rand((n_samples, n_classes), device=device)
probs = probs / probs.sum(dim=1, keepdim=True)
mock_results = {
ResultsKey.SLIDE_ID: np.array([f"slide_{i}" for i in range(n_samples)]),
ResultsKey.TRUE_LABEL: torch.randint(2, size=(n_samples,), device=device),
ResultsKey.PRED_LABEL: torch.randint(2, size=(n_samples,), device=device),
ResultsKey.TRUE_LABEL: torch.randint(n_classes, size=(n_samples,), device=device),
ResultsKey.PRED_LABEL: torch.argmax(probs, dim=1),
ResultsKey.BAG_ATTN: [torch.rand(size=(1, diff_n_tiles[i]), device=device) for i in range(n_samples)],
ResultsKey.CLASS_PROBS: torch.rand((n_samples, n_classes), device=device),
ResultsKey.CLASS_PROBS: probs,
}
return mock_results
@ -96,14 +98,14 @@ def _create_and_update_top_bottom_tiles_selector(
def _get_expected_slides_by_probability(
results: Dict[ResultsKey, Any], num_top_slides: int = 2, label: int = 1, top: bool = True
) -> List[str]:
) -> Set[str]:
"""Select top or bottom slides according to their probability scores from the entire dataset.
:param results: The results dictionary for the entire dataset.
:param num_top_slides: The number of slides to use to select top and bottom slides, defaults to 5
:param label: The current label to process given that top and bottom are grouped by class label, defaults to 1
:param top: A flag to select top or bottom slides with highest (respetively, lowest) prob scores, defaults to True
:return: A list of selected slide ids.
:return: A set of selected slide ids.
"""
class_indices = (results[ResultsKey.TRUE_LABEL].squeeze() == label).nonzero().squeeze(1)
@ -111,37 +113,29 @@ def _get_expected_slides_by_probability(
assert class_prob.shape == (len(class_indices),)
num_top_slides = min(num_top_slides, len(class_prob))
_, sorting_indices = class_prob.topk(num_top_slides, largest=top, sorted=True)
_, sorting_indices = class_prob.topk(num_top_slides, largest=top)
sorted_class_indices = class_indices[sorting_indices]
return [results[ResultsKey.SLIDE_ID][i] for i in sorted_class_indices][::-1] # the order is inversed in the heaps
def get_expected_top_slides_by_probability(
results: Dict[ResultsKey, Any], num_top_slides: int = 5, label: int = 1
) -> List[str]:
"""Calls `_get_expected_slides_by_probability` with `top=True` to select expected top slides for the entire dataset
in one go. """
return _get_expected_slides_by_probability(results, num_top_slides, label, top=True)
def get_expected_bottom_slides_by_probability(
results: Dict[ResultsKey, Any], num_top_slides: int = 5, label: int = 1
) -> List[str]:
"""Calls `_get_expected_slides_by_probability` with `top=False` to select expected bottom slides for the entire
dataset in one go. """
return _get_expected_slides_by_probability(results, num_top_slides, label, top=False)
def _selection_condition(index: int) -> bool:
if top:
return results[ResultsKey.PRED_LABEL][index] == results[ResultsKey.TRUE_LABEL][index]
else:
return results[ResultsKey.PRED_LABEL][index] != results[ResultsKey.TRUE_LABEL][index]
return {results[ResultsKey.SLIDE_ID][i] for i in sorted_class_indices if _selection_condition(i)}
@pytest.mark.parametrize("num_top_slides", [2, 10])
@pytest.mark.parametrize("n_classes", [2, 3]) # n_classes=2 represents the binary case.
def test_aggregate_shallow_slide_nodes(n_classes: int, rank: int = 0, world_size: int = 1, device: str = "cpu") -> None:
def test_aggregate_shallow_slide_nodes(
n_classes: int, num_top_slides: int, rank: int = 0, world_size: int = 1, device: str = "cpu"
) -> None:
"""This test ensures that shallow copies of slide nodes are gathered properlyy across devices in a ddp context."""
n_tiles = 3
batch_size = 2
n_batches = 10
total_batches = n_batches * world_size
num_top_tiles = 2
num_top_slides = 2
torch.manual_seed(42)
data = _create_mock_data(n_samples=batch_size * total_batches, n_tiles=n_tiles, device=device)
@ -167,13 +161,21 @@ def test_aggregate_shallow_slide_nodes(n_classes: int, rank: int = 0, world_size
if rank == 0:
for label in range(n_classes):
expected_top_slides_ids = _get_expected_slides_by_probability(results, num_top_slides, label, top=True)
assert expected_top_slides_ids == [slide_node.slide_id for slide_node in shallow_top_slides_heaps[label]]
assert all(slide_node.pred_label == slide_node.true_label for slide_node in shallow_top_slides_heaps[label])
selected_top_slides_ids = {slide_node.slide_id for slide_node in shallow_top_slides_heaps[label]}
expected_top_slides_ids = _get_expected_slides_by_probability(results, num_top_slides, label, top=True)
assert expected_top_slides_ids == selected_top_slides_ids
assert all(
slide_node.pred_label != slide_node.true_label for slide_node in shallow_bottom_slides_heaps[label]
)
selected_bottom_slides_ids = {slide_node.slide_id for slide_node in shallow_bottom_slides_heaps[label]}
expected_bottom_slides_ids = _get_expected_slides_by_probability(results, num_top_slides, label, top=False)
assert expected_bottom_slides_ids == [
slide_node.slide_id for slide_node in shallow_bottom_slides_heaps[label]
]
assert expected_bottom_slides_ids == selected_bottom_slides_ids
# Make sure that the top and bottom slides are disjoint.
assert not selected_top_slides_ids.intersection(selected_bottom_slides_ids)
@pytest.mark.skipif(not torch.distributed.is_available(), reason="PyTorch distributed unavailable")
@ -181,27 +183,30 @@ def test_aggregate_shallow_slide_nodes(n_classes: int, rank: int = 0, world_size
@pytest.mark.gpu
def test_aggregate_shallow_slide_nodes_distributed() -> None:
"""These tests need to be called sequentially to prevent them to be run in parallel."""
# test with n_classes = 2
run_distributed(test_aggregate_shallow_slide_nodes, [2], world_size=1)
run_distributed(test_aggregate_shallow_slide_nodes, [2], world_size=2)
# test with n_classes = 3
run_distributed(test_aggregate_shallow_slide_nodes, [3], world_size=1)
run_distributed(test_aggregate_shallow_slide_nodes, [3], world_size=2)
# test with n_classes = 2, n_slides = 2
run_distributed(test_aggregate_shallow_slide_nodes, [2, 2], world_size=1)
run_distributed(test_aggregate_shallow_slide_nodes, [2, 2], world_size=2)
# test with n_classes = 3, n_slides = 2
run_distributed(test_aggregate_shallow_slide_nodes, [3, 2], world_size=1)
run_distributed(test_aggregate_shallow_slide_nodes, [3, 2], world_size=2)
def assert_equal_top_bottom_attention_tiles(
slide_ids: List[str], data: Dict, results: Dict, num_top_tiles: int, slide_nodes: List[SlideNode]
slide_ids: Set[str], data: Dict, results: Dict, num_top_tiles: int, slide_nodes: List[SlideNode]
) -> None:
"""Asserts that top and bottom tiles selected on the fly by the top bottom tiles selector are equal to the expected
top and bottom tiles in the mock dataset.
:param slide_ids: A list of expected slide ids0
:param slide_ids: A set of expected slide ids0
:param data: A dictionary containing the entire dataset.
:param results: A dictionary of data results.
:param num_top_tiles: The number of tiles to select as top and bottom tiles for each top/bottom slide.
:param slide_nodes: The top or bottom slide nodes selected on the fly by the selector.
"""
for i, slide_id in enumerate(slide_ids):
slide_nodes_dict = {slide_node.slide_id: slide_node for slide_node in slide_nodes}
for slide_id in slide_ids:
slide_batch_idx = int(slide_id.split("_")[1])
tiles = data[SlideKey.IMAGE][slide_batch_idx]
@ -215,8 +220,8 @@ def assert_equal_top_bottom_attention_tiles(
expected_top_tiles: List[torch.Tensor] = [tiles[tile_id] for tile_id in top_tiles_ids]
expected_bottom_tiles: List[torch.Tensor] = [tiles[tile_id] for tile_id in bottom_tiles_ids]
top_tiles = slide_nodes[i].top_tiles
bottom_tiles = slide_nodes[i].bottom_tiles
top_tiles = slide_nodes_dict[slide_id].top_tiles
bottom_tiles = slide_nodes_dict[slide_id].bottom_tiles
for j, expected_top_tile in enumerate(expected_top_tiles):
assert torch.equal(expected_top_tile.cpu(), top_tiles[j].data)
@ -227,9 +232,10 @@ def assert_equal_top_bottom_attention_tiles(
assert expected_bottom_attns[j].item() == bottom_tiles[j].attn
@pytest.mark.parametrize("num_top_slides", [2, 10])
@pytest.mark.parametrize("n_classes", [2, 3]) # n_classes=2 represents the binary case.
def test_select_k_top_bottom_tiles_on_the_fly(
n_classes: int, rank: int = 0, world_size: int = 1, device: str = "cpu"
n_classes: int, num_top_slides: int, rank: int = 0, world_size: int = 1, device: str = "cpu"
) -> None:
"""This tests checks that k top and bottom tiles are selected properly `on the fly`:
1- Create a mock dataset and corresponding mock results that are small enough to fit in memory
@ -264,34 +270,46 @@ def test_select_k_top_bottom_tiles_on_the_fly(
if rank == 0:
for label in range(n_classes):
expected_top_slides_ids = get_expected_top_slides_by_probability(results, num_top_slides, label)
assert expected_top_slides_ids == [
slide_node.slide_id for slide_node in tiles_selector.top_slides_heaps[label]
]
assert all(
slide_node.pred_label == slide_node.true_label for slide_node in tiles_selector.top_slides_heaps[label]
)
selected_top_slides_ids = {slide_node.slide_id for slide_node in tiles_selector.top_slides_heaps[label]}
expected_top_slides_ids = _get_expected_slides_by_probability(results, num_top_slides, label, top=True)
assert expected_top_slides_ids == selected_top_slides_ids
assert_equal_top_bottom_attention_tiles(
expected_top_slides_ids, data, results, num_top_tiles, tiles_selector.top_slides_heaps[label]
)
expected_bottom_slides_ids = get_expected_bottom_slides_by_probability(results, num_top_slides, label)
assert expected_bottom_slides_ids == [
assert all(
slide_node.pred_label != slide_node.true_label
for slide_node in tiles_selector.bottom_slides_heaps[label]
)
selected_bottom_slides_ids = {
slide_node.slide_id for slide_node in tiles_selector.bottom_slides_heaps[label]
]
}
expected_bottom_slides_ids = _get_expected_slides_by_probability(results, num_top_slides, label, top=False)
assert expected_bottom_slides_ids == selected_bottom_slides_ids
assert_equal_top_bottom_attention_tiles(
expected_bottom_slides_ids, data, results, num_top_tiles, tiles_selector.bottom_slides_heaps[label]
)
# Make sure that the top and bottom slides are disjoint.
assert not selected_top_slides_ids.intersection(selected_bottom_slides_ids)
@pytest.mark.skipif(not torch.distributed.is_available(), reason="PyTorch distributed unavailable")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
@pytest.mark.gpu
def test_select_k_top_bottom_tiles_on_the_fly_distributed() -> None:
"""These tests need to be called sequentially to prevent them to be run in parallel"""
# test with n_classes = 2
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [2], world_size=1)
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [2], world_size=2)
# test with n_classes = 3
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [3], world_size=1)
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [3], world_size=2)
# test with n_classes = 2, n_slides = 2
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [2, 2], world_size=1)
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [2, 2], world_size=2)
# test with n_classes = 3, n_slides = 2
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [3, 2], world_size=1)
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [3, 2], world_size=2)
def test_disable_top_bottom_tiles_selector() -> None:

Просмотреть файл

@ -8,6 +8,7 @@ import math
import random
from pathlib import Path
from typing import List, Optional
from unittest.mock import MagicMock, patch
import matplotlib
import numpy as np
@ -24,9 +25,8 @@ from health_cpath.utils.viz_utils import plot_attention_tiles, plot_scores_hist,
from health_cpath.utils.naming import ResultsKey
from health_cpath.utils.heatmap_utils import location_selected_tiles
from health_cpath.utils.tiles_selection_utils import SlideNode, TileNode
from health_cpath.utils.viz_utils import save_figure
from health_cpath.utils.viz_utils import save_figure, load_image_dict
from testhisto.utils.utils_testhisto import assert_binary_files_match, full_ml_test_data_path
# import testhisto
def set_random_seed(random_seed: int, caller_name: Optional[str] = None) -> None:
@ -124,7 +124,7 @@ def slide_node() -> SlideNode:
set_random_seed(0)
tile_size = (3, 224, 224)
num_top_tiles = 12
slide_node = SlideNode(slide_id="slide_0", prob_score=0.5, true_label=1, pred_label=1)
slide_node = SlideNode(slide_id="slide_0", gt_prob_score=0.04, pred_prob_score=0.96, true_label=1, pred_label=0)
top_attn_scores = [0.99, 0.98, 0.97, 0.96, 0.95, 0.94, 0.93, 0.92, 0.91, 0.90, 0.89, 0.88]
slide_node.top_tiles = [
TileNode(attn=top_attn_scores[i], data=torch.randint(0, 255, tile_size)) for i in range(num_top_tiles)
@ -150,11 +150,11 @@ def assert_plot_tiles_figure(tiles_fig: plt.Figure, fig_name: str, test_output_d
@pytest.mark.skipif(is_windows(), reason="Rendering is different on Windows")
def test_plot_top_bottom_tiles(slide_node: SlideNode, test_output_dirs: OutputFolderForTests) -> None:
top_tiles_fig = plot_attention_tiles(
case="TP", slide_node=slide_node, top=True, num_columns=4, figsize=(10, 10)
case="FN", slide_node=slide_node, top=True, num_columns=4, figsize=(10, 10)
)
assert top_tiles_fig is not None
bottom_tiles_fig = plot_attention_tiles(
case="TP", slide_node=slide_node, top=False, num_columns=4, figsize=(10, 10)
case="FN", slide_node=slide_node, top=False, num_columns=4, figsize=(10, 10)
)
assert bottom_tiles_fig is not None
assert_plot_tiles_figure(top_tiles_fig, "slide_0_top.png", test_output_dirs)
@ -167,7 +167,7 @@ def test_plot_attention_tiles_below_min_rows(slide_node: SlideNode, caplog: LogC
slide_node.bottom_tiles = []
with caplog.at_level(logging.WARNING):
bottom_tiles_fig = plot_attention_tiles(
case="TP", slide_node=slide_node, top=False, num_columns=4, figsize=(10, 10)
case="FN", slide_node=slide_node, top=False, num_columns=4, figsize=(10, 10)
)
assert bottom_tiles_fig is None
assert expected_warning in caplog.text
@ -175,25 +175,43 @@ def test_plot_attention_tiles_below_min_rows(slide_node: SlideNode, caplog: LogC
slide_node.top_tiles = []
with caplog.at_level(logging.WARNING):
top_tiles_fig = plot_attention_tiles(
case="TP", slide_node=slide_node, top=True, num_columns=4, figsize=(10, 10)
case="FN", slide_node=slide_node, top=True, num_columns=4, figsize=(10, 10)
)
assert top_tiles_fig is None
assert expected_warning in caplog.text
@pytest.mark.parametrize("scale", [0.1, 1.2, 2.4, 3.6])
def test_plot_slide(test_output_dirs: OutputFolderForTests, scale: int) -> None:
@pytest.mark.parametrize(
"scale, gt_prob, pred_prob, gt_label, pred_label, case",
[
(0.1, 0.99, 0.99, 1, 1, "TP"),
(1.2, 0.95, 0.95, 0, 0, "TN"),
(2.4, 0.04, 0.96, 0, 1, "FP"),
(3.6, 0.03, 0.97, 1, 0, "FN"),
],
)
def test_plot_slide(
test_output_dirs: OutputFolderForTests,
scale: int,
gt_prob: float,
pred_prob: float,
case: str,
gt_label: int,
pred_label: int,
) -> None:
set_random_seed(0)
slide_image = np.random.rand(3, 1000, 2000)
slide_node = SlideNode(slide_id="slide_0", prob_score=0.5, true_label=1, pred_label=1)
fig = plot_slide(case="TP", slide_node=slide_node, slide_image=slide_image, scale=scale)
slide_node = SlideNode(
slide_id="slide_0", gt_prob_score=gt_prob, pred_prob_score=pred_prob, true_label=gt_label, pred_label=pred_label
)
fig = plot_slide(case=case, slide_node=slide_node, slide_image=slide_image, scale=scale)
assert isinstance(fig, matplotlib.figure.Figure)
file = Path(test_output_dirs.root_dir) / "plot_slide.png"
resize_and_save(5, 5, file)
assert file.exists()
expected = full_ml_test_data_path("histo_heatmaps") / f"slide_{scale}.png"
expected = full_ml_test_data_path("histo_heatmaps") / f"slide_{scale}_{case}.png"
# To update the stored results, uncomment this line:
# expected.write_bytes(file.read_bytes())
expected.write_bytes(file.read_bytes())
assert_binary_files_match(file, expected)
@ -201,11 +219,13 @@ def test_plot_slide(test_output_dirs: OutputFolderForTests, scale: int) -> None:
def test_plot_heatmap_overlay(test_output_dirs: OutputFolderForTests) -> None:
set_random_seed(0)
slide_image = np.random.rand(3, 1000, 2000)
slide_node = SlideNode(slide_id=1, prob_score=0.5, true_label=1, pred_label=1) # type: ignore
slide_node = SlideNode(
slide_id=1, gt_prob_score=0.04, pred_prob_score=0.96, true_label=1, pred_label=0 # type: ignore
)
location_bbox = [100, 100]
tile_size = 224
level = 0
fig = plot_heatmap_overlay(case="TP",
fig = plot_heatmap_overlay(case="FN",
slide_node=slide_node,
slide_image=slide_image,
results=test_dict, # type: ignore
@ -275,3 +295,12 @@ def test_location_selected_tiles(level: int) -> None:
assert max(tile_xs) <= slide_image.shape[2] // factor
assert min(tile_ys) >= 0
assert max(tile_ys) <= slide_image.shape[1] // factor
@pytest.mark.parametrize("wsi_has_mask", [True, False])
def test_load_image_dict(wsi_has_mask: bool) -> None:
with patch("health_cpath.utils.viz_utils.LoadPandaROId") as mock_load_panda_roi:
with patch("health_cpath.utils.viz_utils.LoadROId") as mock_load_roi:
_ = load_image_dict(sample=MagicMock(), level=0, margin=0, wsi_has_mask=wsi_has_mask) # type: ignore
assert mock_load_panda_roi.called == wsi_has_mask
assert mock_load_roi.called == (not wsi_has_mask)

Просмотреть файл

@ -2,7 +2,7 @@ import torch
import pytest
import numpy as np
from typing import Any, Dict, List
from typing import Any, Dict, List, Union
from typing import Sequence
from health_cpath.utils.naming import SlideKey
from health_cpath.utils.wsi_utils import image_collate
@ -15,7 +15,8 @@ class MockTiledWSIDataset(Dataset):
n_slides: int,
n_classes: int,
tile_size: Sequence[int],
random_n_tiles: bool) -> None:
random_n_tiles: bool,
img_type: str = "np") -> None:
self.n_tiles = n_tiles
self.n_slides = n_slides
@ -23,30 +24,38 @@ class MockTiledWSIDataset(Dataset):
self.n_classes = n_classes
self.random_n_tiles = random_n_tiles
self.slide_ids = torch.arange(self.n_slides)
self.img_type = img_type
def __len__(self) -> int:
return self.n_slides
def __getitem__(self, index: int) -> List[Dict[SlideKey, Any]]:
tile_count = np.random.randint(self.n_tiles) if self.random_n_tiles else self.n_tiles
tile_count = np.random.randint(low=1, high=self.n_tiles) if self.random_n_tiles else self.n_tiles
label = np.random.choice(self.n_classes)
img: Union[np.ndarray, torch.Tensor]
if self.img_type == "np":
img = np.random.randint(0, 255, size=(tile_count, *self.tile_size))
else:
img = torch.randint(0, 255, size=(tile_count, *self.tile_size))
return [{SlideKey.SLIDE_ID: self.slide_ids[index],
SlideKey.IMAGE: np.random.randint(0, 255, size=self.tile_size),
SlideKey.IMAGE: img,
SlideKey.IMAGE_PATH: f"slide_{self.slide_ids[index]}.tiff",
SlideKey.LABEL: label
} for _ in range(tile_count)
]
@pytest.mark.parametrize("img_type", ["np", "torch"])
@pytest.mark.parametrize("random_n_tiles", [False, True])
def test_image_collate(random_n_tiles: bool) -> None:
def test_image_collate(random_n_tiles: bool, img_type: str) -> None:
# random_n_tiles accounts for both train and inference settings where the number of tiles is fixed (during
# training) and None during inference (validation and test)
dataset = MockTiledWSIDataset(n_tiles=20,
n_slides=10,
n_classes=4,
tile_size=(1, 4, 4),
random_n_tiles=random_n_tiles)
random_n_tiles=random_n_tiles,
img_type=img_type)
batch_size = 5
samples_list = [dataset[idx] for idx in range(batch_size)]

Просмотреть файл

@ -122,7 +122,7 @@ class HelloRegression(LightningModule):
self.model = torch.nn.Linear(in_features=1, out_features=1, bias=True) # type: ignore
self.test_mse: List[torch.Tensor] = []
self.test_mae = MeanAbsoluteError()
self.run_extra_val_epoch = False
self._on_extra_val_epoch = False
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
"""
@ -232,6 +232,9 @@ class HelloRegression(LightningModule):
Path("test_mae.txt").write_text(str(self.test_mae.compute().item()))
self.log("test_mse", average_mse, on_epoch=True, on_step=False)
def on_run_extra_validation_epoch(self) -> None:
self._on_extra_val_epoch = True
class HelloWorld(LightningContainer):
"""
@ -272,3 +275,6 @@ class HelloWorld(LightningContainer):
*super().get_callbacks()]
else:
return super().get_callbacks()
def get_additional_aml_run_tags(self) -> Dict[str, str]:
return {"max_epochs": str(self.max_epochs)}

Просмотреть файл

@ -7,12 +7,10 @@ from __future__ import annotations
import logging
import os
import param
import re
from enum import Enum, unique
from param import Parameterized
from pathlib import Path
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from azureml.train.hyperdrive import HyperDriveConfig
@ -24,6 +22,7 @@ from health_azure.amulet import (ENV_AMLT_PROJECT_NAME, ENV_AMLT_INPUT_OUTPUT,
is_amulet_job, get_amulet_aml_working_dir)
from health_azure.utils import (RUN_CONTEXT, PathOrString, is_global_rank_zero, is_running_in_azure_ml)
from health_ml.utils import fixed_paths
from health_ml.utils.checkpoint_utils import CheckpointParser
from health_ml.utils.common_utils import (CHECKPOINT_FOLDER,
create_unique_timestamp_id,
DEFAULT_AML_UPLOAD_DIR,
@ -144,32 +143,13 @@ class ExperimentFolderHandler(Parameterized):
)
SRC_CHECKPOINT_FORMAT_DOC = ("<AzureML_run_id>:<optional/custom/path/to/checkpoints/><filename.ckpt>"
"If no custom path is provided (e.g., <AzureML_run_id>:<filename.ckpt>)"
"the checkpoint will be downloaded from the default checkpoint folder "
"(e.g., 'outputs/checkpoints/'). If no filename is provided, (e.g., "
"`src_checkpoint=<AzureML_run_id>`) the latest checkpoint (last.ckpt) "
"will be used to initialize the model."
)
class WorkflowParams(param.Parameterized):
"""
This class contains all parameters that affect how the whole training and testing workflow is executed.
"""
random_seed: int = param.Integer(42, doc="The seed to use for all random number generators.")
src_checkpoint: str = param.String(default="",
doc="This flag can be used in 3 different scenarios:"
"1- Resume training from a checkpoint to train longer."
"2- Run inference-only using `run_inference_only` flag jointly."
"3- Transfer learning from a pretrained model checkpoint."
"We currently support three types of checkpoints: "
" a. A local checkpoint folder that contains a checkpoint file."
" b. A URL to a remote checkpoint to be downloaded."
" c. A previous azureml run id where the checkpoint is supposed to be "
" saved ('outputs/checkpoints/' folder by default.)"
"For the latter case 'c' : src_checkpoint should be in the format of "
f"{SRC_CHECKPOINT_FORMAT_DOC}")
src_checkpoint: CheckpointParser = param.ClassSelector(class_=CheckpointParser, default=None,
instantiate=False, doc=CheckpointParser.DOC)
crossval_count: int = param.Integer(default=1, bounds=(0, None),
doc="The number of splits to use when doing cross-validation. "
"Use 1 to disable cross-validation")
@ -196,6 +176,7 @@ class WorkflowParams(param.Parameterized):
run_inference_only: bool = param.Boolean(False, doc="If True, run only inference and skip training after loading"
"model weights from the specified checkpoint in "
"`src_checkpoint` flag. If False, run training and inference.")
resume_training: bool = param.Boolean(False, doc="If True, resume training from the src_checkpoint.")
tag: str = param.String(doc="A string that will be used as the display name of the run in AzureML.")
experiment: str = param.String(default="", doc="The name of the AzureML experiment to use for this run. If not "
"provided, the name of the model class will be used.")
@ -213,42 +194,15 @@ class WorkflowParams(param.Parameterized):
CROSSVAL_COUNT_ARG_NAME = "crossval_count"
RANDOM_SEED_ARG_NAME = "random_seed"
@property
def src_checkpoint_is_url(self) -> bool:
try:
result = urlparse(self.src_checkpoint)
return all([result.scheme, result.netloc])
except ValueError:
return False
@property
def src_checkpoint_is_local_file(self) -> bool:
return Path(self.src_checkpoint).is_file()
@property
def src_checkpoint_is_aml_run_id(self) -> bool:
match = re.match(r"[_\w-]*$", self.src_checkpoint.split(":")[0])
return match is not None and not self.src_checkpoint_is_url and not self.src_checkpoint_is_local_file
@property
def is_valid_src_checkpoint(self) -> bool:
if self.src_checkpoint:
return self.src_checkpoint_is_local_file or self.src_checkpoint_is_url or self.src_checkpoint_is_aml_run_id
return True
def validate(self) -> None:
if not self.is_valid_src_checkpoint:
raise ValueError(f"Invalid src_checkpoint: {self.src_checkpoint}. Please provide a valid URL, local file "
"or azureml run id.")
if self.crossval_count > 1:
if not (0 <= self.crossval_index < self.crossval_count):
raise ValueError(f"Attribute crossval_index out of bounds (crossval_count = {self.crossval_count})")
if self.run_inference_only and not self.src_checkpoint:
raise ValueError("Cannot run inference without a src_checkpoint. Please specify a valid src_checkpoint."
"You can either use a URL, a local file or an azureml run id. For custom checkpoint paths "
"within an azureml run, (other than last.ckpt), provide a src_checkpoint in the format."
f"{SRC_CHECKPOINT_FORMAT_DOC}")
raise ValueError(f"Cannot run inference without a src_checkpoint. {CheckpointParser.INFO_MESSAGE}")
if self.resume_training and not self.src_checkpoint:
raise ValueError(f"Cannot resume training without a src_checkpoint. {CheckpointParser.INFO_MESSAGE}")
@property
def is_running_in_aml(self) -> bool:
@ -533,6 +487,10 @@ class TrainerParams(param.Parameterized):
param.String(default=None,
doc="The value to use for the 'profiler' argument for the Lightning trainer. "
"Set to either 'simple', 'advanced', or 'pytorch'")
pl_sync_batchnorm: bool = param.Boolean(default=True,
doc="PyTorch Lightning trainer flag 'sync_batchnorm': If True, "
"synchronize batchnorm across all GPUs when running in ddp mode."
"If False, batchnorm is not synchronized.")
monitor_gpu: bool = param.Boolean(default=False,
doc="If True, add the GPUStatsMonitor callback to the Lightning trainer object. "
"This will write GPU utilization metrics every 50 batches by default.")
@ -545,6 +503,17 @@ class TrainerParams(param.Parameterized):
"any validation overheads during training time and produce "
"additional time or memory consuming outputs only once after "
"training is finished on the validation set.")
pl_accumulate_grad_batches: int = param.Integer(default=1,
doc="The number of batches over which gradients are accumulated, "
"before a parameter update is done.")
pl_log_every_n_steps: int = param.Integer(default=50,
doc="PyTorch Lightning trainer flag 'log_every_n_steps': How often to"
"log within steps. Default to 50.")
pl_replace_sampler_ddp: bool = param.Boolean(default=True,
doc="PyTorch Lightning trainer flag 'replace_sampler_ddp' that "
"sets the sampler for distributed training with shuffle=True during "
"training and shuffle=False during validation. Default to True. Set to"
"False to set your own sampler.")
@property
def use_gpu(self) -> bool:

Просмотреть файл

@ -1,8 +1,18 @@
from enum import Enum
from pathlib import Path
from typing import Optional
import param
class DebugDDPOptions(Enum):
OFF = "OFF"
INFO = "INFO"
DETAIL = "DETAIL"
DEBUG_DDP_ENV_VAR = "TORCH_DISTRIBUTED_DEBUG"
class ExperimentConfig(param.Parameterized):
cluster: str = param.String(default="", allow_None=False,
doc="The name of the GPU or CPU cluster inside the AzureML workspace"
@ -26,5 +36,13 @@ class ExperimentConfig(param.Parameterized):
doc="The Conda environment file that should be used when submitting the present run to "
"AzureML. If not specified, the environment file in the current folder or one of its "
"parents will be used.")
debug_ddp: DebugDDPOptions = param.ClassSelector(default=DebugDDPOptions.OFF, class_=DebugDDPOptions,
doc=f"Flag to override the environment var {DEBUG_DDP_ENV_VAR}"
"that can be used to trigger logging and collective "
"synchronization checks to ensure all ranks are synchronized "
"appropriately. Default is `OFF`. It can be set to either "
"`INFO` or `DETAIL` for different levels of logging. "
"`DETAIL` may impact the application performance and thus "
"should only be used when debugging issues")
strictly_aml_v1: bool = param.Boolean(default=False, doc="If True, use AzureML v1 SDK. If False (default), use "
"the v2 of the SDK")

Просмотреть файл

@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple
from pathlib import Path
import param
import logging
from azureml.core import ScriptRunConfig
from azureml.train.hyperdrive import HyperDriveConfig
from pytorch_lightning import Callback, LightningDataModule, LightningModule
@ -37,6 +38,7 @@ class LightningContainer(WorkflowParams,
self._model: Optional[LightningModule] = None
self._model_name = type(self).__name__
self.num_nodes = 1
self.trained_weights_path: Optional[Path] = None
def validate(self) -> None:
WorkflowParams.validate(self)
@ -246,6 +248,18 @@ class LightningContainer(WorkflowParams,
argument `experiment`, falling back to the model class name if not set."""
return self.experiment or self.model_name
def get_additional_aml_run_tags(self) -> Dict[str, str]:
"""Returns a dictionary of tags that should be added to the AzureML run."""
return {}
def on_run_extra_validation_epoch(self) -> None:
if hasattr(self.model, "on_run_extra_validation_epoch"):
assert self._model, "Model is not initialized."
self.model.on_run_extra_validation_epoch() # type: ignore
else:
logging.warning("Hook `on_run_extra_validation_epoch` is not implemented by lightning module."
"The extra validation epoch won't produce any extra outputs.")
class LightningModuleWithOptimizer(LightningModule):
"""

Просмотреть файл

@ -197,16 +197,19 @@ def create_lightning_trainer(container: LightningContainer,
limit_test_batches=container.pl_limit_test_batches or 1.0,
fast_dev_run=container.pl_fast_dev_run, # type: ignore
num_sanity_val_steps=container.pl_num_sanity_val_steps,
log_every_n_steps=container.pl_log_every_n_steps,
# check_val_every_n_epoch=container.pl_check_val_every_n_epoch,
callbacks=callbacks,
logger=loggers,
num_nodes=num_nodes,
devices=devices,
precision=precision,
sync_batchnorm=True,
sync_batchnorm=container.pl_sync_batchnorm,
detect_anomaly=container.detect_anomaly,
profiler=profiler,
resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None,
multiple_trainloader_mode=multiple_trainloader_mode,
accumulate_grad_batches=container.pl_accumulate_grad_batches,
replace_sampler_ddp=container.pl_replace_sampler_ddp,
**additional_args)
return trainer, storing_logger

Просмотреть файл

@ -187,14 +187,15 @@ class TransformerPooling(Module):
num_layers: Number of Transformer encoder layers.
num_heads: Number of attention heads per layer.
dim_representation: Dimension of input encoding.
transformer_dropout: The dropout value of transformer encoder layers.
"""
def __init__(self, num_layers: int, num_heads: int, dim_representation: int) -> None:
def __init__(self, num_layers: int, num_heads: int, dim_representation: int, transformer_dropout: float) -> None:
super(TransformerPooling, self).__init__()
self.num_layers = num_layers
self.num_heads = num_heads
self.dim_representation = dim_representation
self.transformer_dropout = transformer_dropout
self.cls_token = nn.Parameter(torch.zeros([1, dim_representation]))
self.transformer_encoder_layers = []
@ -203,8 +204,8 @@ class TransformerPooling(Module):
CustomTransformerEncoderLayer(self.dim_representation,
self.num_heads,
dim_feedforward=self.dim_representation,
dropout=0.1,
activation=F.gelu, # type: ignore
dropout=self.transformer_dropout,
activation=F.gelu,
batch_first=True))
self.transformer_encoder_layers = torch.nn.ModuleList(self.transformer_encoder_layers) # type: ignore
@ -239,17 +240,21 @@ class TransformerPoolingBenchmark(Module):
num_layers: Number of Transformer encoder layers.
num_heads: Number of attention heads per layer.
dim_representation: Dimension of input encoding.
transformer_dropout: The dropout value of transformer encoder layers.
"""
def __init__(self, num_layers: int, num_heads: int, dim_representation: int, hidden_dim: int) -> None:
def __init__(self, num_layers: int, num_heads: int,
dim_representation: int, hidden_dim: int,
transformer_dropout: float) -> None:
super().__init__()
self.num_layers = num_layers
self.num_heads = num_heads
self.dim_representation = dim_representation
self.hidden_dim = hidden_dim
self.transformer_dropout = transformer_dropout
transformer_layer = nn.TransformerEncoderLayer(d_model=self.dim_representation,
nhead=self.num_heads,
dropout=0.0,
dropout=self.transformer_dropout,
batch_first=True)
self.transformer = nn.TransformerEncoder(transformer_layer, num_layers=self.num_layers)
self.attention = nn.Sequential(nn.Linear(self.dim_representation, self.hidden_dim),

Просмотреть файл

@ -5,7 +5,7 @@
import os
import sys
import logging
import torch.multiprocessing
import torch
from pathlib import Path
from typing import Dict, List, Optional
@ -20,14 +20,12 @@ from health_azure.utils import (create_run_recovery_id, ENV_OMPI_COMM_WORLD_RANK
aggregate_hyperdrive_metrics, get_metrics_for_childless_run,
ENV_GLOBAL_RANK, ENV_LOCAL_RANK, ENV_NODE_RANK,
is_local_rank_zero, is_global_rank_zero, create_aml_run_object)
from health_ml.experiment_config import ExperimentConfig
from health_ml.lightning_container import LightningContainer
from health_ml.model_trainer import create_lightning_trainer, write_experiment_summary_file
from health_ml.utils import fixed_paths
from health_ml.utils.checkpoint_handler import CheckpointHandler
from health_ml.utils.checkpoint_utils import cleanup_checkpoints
from health_ml.utils.checkpoint_handler import CheckpointHandler
from health_ml.utils.common_utils import (
EFFECTIVE_RANDOM_SEED_KEY_NAME,
change_working_directory,
@ -214,7 +212,7 @@ class MLRunner:
if not self.container.run_inference_only:
checkpoint_path_for_recovery = self.checkpoint_handler.get_recovery_or_checkpoint_path_train()
if not checkpoint_path_for_recovery and self.container.src_checkpoint:
if not checkpoint_path_for_recovery and self.container.resume_training:
# If there is no recovery checkpoint (e.g job hasn't been resubmitted) and a source checkpoint is given,
# use it to resume training.
checkpoint_path_for_recovery = self.checkpoint_handler.trained_weights_path
@ -236,6 +234,8 @@ class MLRunner:
self.checkpoint_handler.additional_training_done()
checkpoint_path_for_inference = self.checkpoint_handler.get_checkpoint_to_test()
self.container.load_model_checkpoint(checkpoint_path_for_inference)
best_epoch = torch.load(checkpoint_path_for_inference).get("epoch", -1)
logging.info(f"Checkpoint saved at epoch: {best_epoch}")
def after_ddp_cleanup(self, old_environ: Dict) -> None:
"""
@ -274,6 +274,27 @@ class MLRunner:
return self.container.crossval_index == 0
return True
def get_trainer_for_inference(self, checkpoint_path: Optional[Path] = None) -> Trainer:
# We run inference on a single device because distributed strategies such as DDP use DistributedSampler
# internally, which replicates some samples to make sure all devices have the same batch size in case of
# uneven inputs.
mlflow_run_id = get_mlflow_run_id_from_previous_loggers(self.trainer)
self.container.max_num_gpus = 1
if self.container.run_inference_only:
assert checkpoint_path is not None
else:
self.validate_model_weights()
trainer, _ = create_lightning_trainer(
container=self.container,
resume_from_checkpoint=checkpoint_path,
num_nodes=1,
azureml_run_for_logging=self.azureml_run_for_logging,
mlflow_run_for_logging=mlflow_run_id
)
return trainer
def run_training(self) -> None:
"""
The main training loop. It creates the Pytorch model based on the configuration options passed in,
@ -294,12 +315,20 @@ class MLRunner:
"""
Run validation on the validation set for all models to save time/memory consuming outputs.
"""
assert hasattr(self.container.model, "run_extra_val_epoch"), "Model does not have run_extra_val_epoch flag."
"This is required for running an additional validation epoch to save plots."
self.container.model.run_extra_val_epoch = True # type: ignore
self.container.on_run_extra_validation_epoch()
trainer = self.get_trainer_for_inference(checkpoint_path=None)
with change_working_directory(self.container.outputs_folder):
assert self.trainer, "Trainer should be initialized before validation. Call self.init_training() first."
self.trainer.validate(self.container.model, datamodule=self.data_module)
trainer.validate(self.container.model, datamodule=self.data_module)
def validate_model_weights(self) -> None:
logging.info("Validating model weights.")
weights = torch.load(self.checkpoint_handler.get_checkpoint_to_test())["state_dict"]
number_mismatch = 0
for name, param in self.container.model.named_parameters():
if not torch.allclose(weights[name].cpu(), param):
logging.warning(f"Parameter {name} does not match between model and checkpoint.")
number_mismatch += 1
logging.info(f"Number of mismatched parameters: {number_mismatch}")
def run_inference(self) -> None:
"""
@ -312,20 +341,12 @@ class MLRunner:
# We run inference on a single device because distributed strategies such as DDP use DistributedSampler
# internally, which replicates some samples to make sure all devices have some batch size in case of
# uneven inputs.
mlflow_run_id = get_mlflow_run_id_from_previous_loggers(self.trainer)
self.container.max_num_gpus = 1
checkpoint_path = (
self.checkpoint_handler.get_checkpoint_to_test() if self.container.src_checkpoint else None
self.checkpoint_handler.get_checkpoint_to_test() if self.container.run_inference_only else None
)
trainer, _ = create_lightning_trainer(
container=self.container,
resume_from_checkpoint=checkpoint_path,
num_nodes=1,
azureml_run_for_logging=self.azureml_run_for_logging,
mlflow_run_for_logging=mlflow_run_id
)
trainer = self.get_trainer_for_inference(checkpoint_path)
# Change to the outputs folder so that the model can write to current working directory, and still
# everything is put into the right place in AzureML (there, only the contents of the "outputs" folder
# retained)

Просмотреть файл

@ -28,13 +28,13 @@ from health_azure import AzureRunInfo, submit_to_azure_if_needed # noqa: E402
from health_azure.datasets import create_dataset_configs # noqa: E402
from health_azure.logging import logging_to_stdout # noqa: E402
from health_azure.paths import is_himl_used_from_git_repo # noqa: E402
from health_azure.amulet import prepare_amulet_job # noqa: E402
from health_azure.amulet import prepare_amulet_job, is_amulet_job # noqa: E402
from health_azure.utils import (get_workspace, get_ml_client, is_local_rank_zero, # noqa: E402
is_running_in_azure_ml, set_environment_variables_for_multi_node,
create_argparser, parse_arguments, ParserResult, apply_overrides,
filter_v2_input_output_args)
from health_ml.experiment_config import ExperimentConfig # noqa: E402
from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, ExperimentConfig # noqa: E402
from health_ml.lightning_container import LightningContainer # noqa: E402
from health_ml.run_ml import MLRunner # noqa: E402
from health_ml.utils import fixed_paths # noqa: E402
@ -164,7 +164,7 @@ class Runner:
if self.lightning_container.hyperdrive:
raise ValueError("HyperDrive for hyperparameters tuning is only supported when submitting the job to "
"AzureML. You need to specify a compute cluster with the argument --cluster.")
if self.lightning_container.is_crossvalidation_enabled:
if self.lightning_container.is_crossvalidation_enabled and not is_amulet_job():
raise ValueError("Cross-validation is only supported when submitting the job to AzureML."
"You need to specify a compute cluster with the argument --cluster.")
@ -176,7 +176,8 @@ class Runner:
"""
return {
"commandline_args": " ".join(script_params),
"tag": self.lightning_container.tag
"tag": self.lightning_container.tag,
**self.lightning_container.get_additional_aml_run_tags()
}
def run(self) -> Tuple[LightningContainer, AzureRunInfo]:
@ -222,6 +223,7 @@ class Runner:
# TODO: Update environment variables
environment_variables: Dict[str, Any] = {}
environment_variables[DEBUG_DDP_ENV_VAR] = self.experiment_config.debug_ddp.value
# Get default datastore from the provided workspace. Authentication can take a few seconds, hence only do
# that if we are really submitting to AzureML.
@ -252,14 +254,13 @@ class Runner:
datastore=datastore,
use_mounting=use_mounting)
if self.experiment_config.strictly_aml_v1:
hyperdrive_config = self.lightning_container.get_hyperdrive_config()
hyperparam_args = None
else:
hyperparam_args = self.lightning_container.get_hyperparam_args()
hyperdrive_config = None
if self.experiment_config.cluster and not is_running_in_azure_ml():
if self.experiment_config.strictly_aml_v1:
hyperdrive_config = self.lightning_container.get_hyperdrive_config()
hyperparam_args = None
else:
hyperparam_args = self.lightning_container.get_hyperparam_args()
hyperdrive_config = None
ml_client = get_ml_client()
env_file = choose_conda_env_file(env_file=self.experiment_config.conda_env)
@ -295,6 +296,7 @@ class Runner:
azure_run_info = submit_to_azure_if_needed(
input_datasets=input_datasets, # type: ignore
submit_to_azureml=False,
environment_variables=environment_variables,
strictly_aml_v1=self.experiment_config.strictly_aml_v1,
)
if azure_run_info.run:
@ -308,7 +310,6 @@ class Runner:
if suffix:
current_name = self.lightning_container.tag or azure_run_info.run.display_name
azure_run_info.run.display_name = f"{current_name} {suffix}"
# submit_to_azure_if_needed calls sys.exit after submitting to AzureML. We only reach this when running
# the script locally or in AzureML.
return azure_run_info
@ -327,6 +328,11 @@ class Runner:
package_setup_and_hacks()
prepare_amulet_job()
# Add tags and arguments to Amulet runs
if is_amulet_job():
assert azure_run_info.run is not None
azure_run_info.run.set_tags(self.additional_run_tags(sys.argv[1:]))
# Set environment variables for multi-node training if needed. This function will terminate early
# if it detects that it is not in a multi-node environment.
if self.experiment_config.num_nodes > 1:

Просмотреть файл

@ -4,22 +4,13 @@
# -------------------------------------------------------------------------------------------
import logging
import os
import uuid
from azureml.core import Run
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
import requests
from azureml.core import Run
from health_azure.utils import is_global_rank_zero
from health_ml.lightning_container import LightningContainer
from health_ml.utils.checkpoint_utils import (
MODEL_WEIGHTS_DIR_NAME,
CheckpointDownloader,
find_recovery_checkpoint_on_disk_or_cloud,
)
from health_ml.utils.checkpoint_utils import find_recovery_checkpoint_on_disk_or_cloud
class CheckpointHandler:
@ -40,9 +31,9 @@ class CheckpointHandler:
the checkpoint_url, local_checkpoint or checkpoint from an azureml run id.
This is called at the start of training.
"""
if self.container.src_checkpoint:
self.trained_weights_path = self.get_local_checkpoints_path_or_download()
self.trained_weights_path = self.container.src_checkpoint.get_path(self.container.checkpoint_folder)
self.container.trained_weights_path = self.trained_weights_path
def additional_training_done(self) -> None:
"""
@ -85,53 +76,3 @@ class CheckpointHandler:
logging.info(f"Using pre-trained weights from {self.trained_weights_path}")
return self.trained_weights_path
raise ValueError("Unable to determine which checkpoint should be used for testing.")
@staticmethod
def download_weights(url: str, download_folder: Path) -> Path:
"""
Download a checkpoint from checkpoint_url to the modelweights directory. The file name is determined from
from the file name in the URL. If that can't be determined, use a random file name.
:param url: The URL from which the weights should be downloaded.
:param download_folder: The target folder for the download.
:return: A path to the downloaded file.
"""
# assign the same filename as in the download url if possible, so that we can check for duplicates
# If that fails, map to a random uuid
file_name = os.path.basename(urlparse(url).path) or str(uuid.uuid4().hex)
checkpoint_path = download_folder / file_name
# only download if hasn't already been downloaded
if checkpoint_path.is_file():
logging.info(f"File already exists, skipping download: {checkpoint_path}")
else:
logging.info(f"Downloading weights from URL {url}")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(checkpoint_path, "wb") as file:
for chunk in response.iter_content(chunk_size=1024):
file.write(chunk)
return checkpoint_path
def get_local_checkpoints_path_or_download(self) -> Path:
"""
Get the path to the local weights to use or download them.
"""
if self.container.src_checkpoint_is_local_file:
checkpoint_path = Path(self.container.src_checkpoint)
elif self.container.src_checkpoint_is_url:
download_folder = self.container.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME
download_folder.mkdir(exist_ok=True, parents=True)
checkpoint_path = self.download_weights(url=self.container.src_checkpoint, download_folder=download_folder)
elif self.container.src_checkpoint_is_aml_run_id:
downloader = CheckpointDownloader(
run_id=self.container.src_checkpoint, download_dir=self.container.outputs_folder
)
checkpoint_path = downloader.local_checkpoint_path
else:
raise ValueError("Unable to determine how to get the checkpoint path.")
if checkpoint_path is None or not checkpoint_path.is_file():
raise FileNotFoundError(f"Could not find the weights file at {checkpoint_path}")
return checkpoint_path

Просмотреть файл

@ -2,27 +2,29 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import re
import os
import uuid
import torch
import logging
import tempfile
import requests
from pathlib import Path
from typing import Optional
import torch
from urllib.parse import urlparse
from azureml.core import Run, Workspace
from health_azure import download_checkpoints_from_run_id, get_workspace
from health_azure.utils import (RUN_CONTEXT, download_files_from_run_id, get_run_file_names, is_running_in_azure_ml)
from health_ml.utils.common_utils import (AUTOSAVE_CHECKPOINT_CANDIDATES, DEFAULT_AML_CHECKPOINT_DIR)
from health_ml.utils.common_utils import (AUTOSAVE_CHECKPOINT_CANDIDATES, DEFAULT_AML_CHECKPOINT_DIR, CHECKPOINT_SUFFIX)
from health_ml.utils.type_annotations import PathOrString
CHECKPOINT_SUFFIX = ".ckpt"
# This is a constant that must match a filename defined in pytorch_lightning.ModelCheckpoint, but we don't want
# to import that here.
LAST_CHECKPOINT_FILE_NAME = "last"
LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX = LAST_CHECKPOINT_FILE_NAME + CHECKPOINT_SUFFIX
LAST_CHECKPOINT_FILE_NAME = f"last{CHECKPOINT_SUFFIX}"
LEGACY_RECOVERY_CHECKPOINT_FILE_NAME = "recovery"
MODEL_INFERENCE_JSON_FILE_NAME = "model_inference_config.json"
MODEL_WEIGHTS_DIR_NAME = "trained_models"
MODEL_WEIGHTS_DIR_NAME = "pretrained_models"
def get_best_checkpoint_path(path: Path) -> Path:
@ -31,7 +33,7 @@ def get_best_checkpoint_path(path: Path) -> Path:
:param path to checkpoint folder
"""
return path / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
return path / LAST_CHECKPOINT_FILE_NAME
def download_folder_from_run_to_temp_folder(folder: str,
@ -120,7 +122,7 @@ def find_recovery_checkpoint(path: Path) -> Optional[Path]:
logging.warning(f"Found these legacy checkpoint files: {legacy_recovery_checkpoints}")
raise ValueError("The legacy recovery checkpoint setup is no longer supported. As a workaround, you can take "
f"one of the legacy checkpoints and upload as '{AUTOSAVE_CHECKPOINT_CANDIDATES[0]}'")
candidates = [*AUTOSAVE_CHECKPOINT_CANDIDATES, LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX]
candidates = [*AUTOSAVE_CHECKPOINT_CANDIDATES, LAST_CHECKPOINT_FILE_NAME]
highest_epoch: Optional[int] = None
file_with_highest_epoch: Optional[Path] = None
for f in candidates:
@ -147,10 +149,10 @@ def cleanup_checkpoints(ckpt_folder: Path) -> None:
if len(files_in_checkpoint_folder) == 0:
return
logging.info(f"Files in checkpoint folder: {' '.join(files_in_checkpoint_folder)}")
last_ckpt = ckpt_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
last_ckpt = ckpt_folder / LAST_CHECKPOINT_FILE_NAME
all_files = f"Existing files: {' '.join(p.name for p in ckpt_folder.glob('*'))}"
if not last_ckpt.is_file():
raise FileNotFoundError(f"Checkpoint file {LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX} not found. {all_files}")
raise FileNotFoundError(f"Checkpoint file {LAST_CHECKPOINT_FILE_NAME} not found. {all_files}")
# Training is finished now. To save storage, remove the autosave checkpoint which is now obsolete.
# Lightning does not overwrite checkpoints in-place. Rather, it writes "autosave.ckpt",
# then "autosave-1.ckpt" and deletes "autosave.ckpt", then "autosave.ckpt" and deletes "autosave-v1.ckpt"
@ -183,7 +185,6 @@ class CheckpointDownloader:
self.remote_checkpoint_dir = (
remote_checkpoint_dir or self.extract_remote_checkpoint_dir_from_checkpoint_filename()
)
self.download_checkpoint_if_necessary()
def extract_checkpoint_filename_from_run_id(self) -> str:
"""
@ -192,7 +193,7 @@ class CheckpointDownloader:
"""
run_id_split = self.run_id.split(":")
self.run_id = run_id_split[0]
return run_id_split[-1] if len(run_id_split) > 1 else LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
return run_id_split[-1] if len(run_id_split) > 1 else LAST_CHECKPOINT_FILE_NAME
def extract_remote_checkpoint_dir_from_checkpoint_filename(self) -> Path:
"""
@ -232,3 +233,111 @@ class CheckpointDownloader:
self.run_id, str(self.remote_checkpoint_path), self.local_checkpoint_dir, aml_workspace=workspace
)
assert self.local_checkpoint_path.exists(), f"Couln't download checkpoint from run {self.run_id}."
class CheckpointParser:
"""Wrapper class for parsing checkpoint arguments. A checkpoint can be specified in one of the following ways:
1. A local checkpoint file path
2. A remote checkpoint file path
3. A run ID from which to download the checkpoint file
"""
AML_RUN_ID_FORMAT = (f"<AzureML_run_id>:<optional/custom/path/to/checkpoints/><filename{CHECKPOINT_SUFFIX}>"
f"If no custom path is provided (e.g., <AzureML_run_id>:<filename{CHECKPOINT_SUFFIX}>)"
"the checkpoint will be downloaded from the default checkpoint folder "
f"(e.g., '{DEFAULT_AML_CHECKPOINT_DIR}') If no filename is provided, "
"(e.g., `src_checkpoint=<AzureML_run_id>`) the latest checkpoint "
f"({LAST_CHECKPOINT_FILE_NAME}) will be downloaded.")
INFO_MESSAGE = ("Please provide a valid checkpoint path, URL or AzureML run ID. For custom checkpoint paths "
f"within an azureml run, provide a checkpoint in the format {AML_RUN_ID_FORMAT}.")
DOC = ("We currently support three types of checkpoints: "
" a. A local checkpoint folder that contains a checkpoint file."
" b. A URL to a remote checkpoint to be downloaded."
" c. A previous azureml run id where the checkpoint is supposed to be "
" saved ('outputs/checkpoints/' folder by default.)"
f"For the latter case 'c' : src_checkpoint should be in the format of {AML_RUN_ID_FORMAT}")
def __init__(self, checkpoint: str = "") -> None:
self.checkpoint = checkpoint
self.validate()
@property
def is_url(self) -> bool:
try:
result = urlparse(self.checkpoint)
return all([result.scheme, result.netloc])
except ValueError:
return False
@property
def is_local_file(self) -> bool:
return Path(self.checkpoint).is_file()
@property
def is_aml_run_id(self) -> bool:
match = re.match(r"[_\w-]*$", self.checkpoint.split(":")[0])
return match is not None and not self.is_url and not self.is_local_file
@property
def is_valid(self) -> bool:
if self.checkpoint:
return self.is_local_file or self.is_url or self.is_aml_run_id
return True
def validate(self) -> None:
if not self.is_valid:
raise ValueError(f"Invalid checkpoint '{self.checkpoint}'. {self.INFO_MESSAGE}")
@staticmethod
def download_from_url(url: str, download_folder: Path) -> Path:
"""
Download a checkpoint from checkpoint_url to the download folder. The file name is determined from
from the file name in the URL. If that can't be determined, use a random file name.
:param url: The URL from which to download.
:param download_folder: The target folder for the download.
:return: A path to the downloaded file.
"""
# assign the same filename as in the download url if possible, so that we can check for duplicates
# If that fails, map to a random uuid
file_name = os.path.basename(urlparse(url).path) or str(uuid.uuid4().hex)
checkpoint_path = download_folder / file_name
# only download if hasn't already been downloaded
if checkpoint_path.is_file():
logging.info(f"File already exists, skipping download: {checkpoint_path}")
else:
logging.info(f"Downloading from URL {url}")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(checkpoint_path, "wb") as file:
for chunk in response.iter_content(chunk_size=1024):
file.write(chunk)
return checkpoint_path
def get_path(self, download_dir: Path) -> Path:
"""Returns the path to the checkpoint file. If the checkpoint is a URL, it will be downloaded to the checkpoints
folder. If the checkpoint is an AzureML run ID, it will be downloaded from the run to the checkpoints folder.
If the checkpoint is a local file, it will be returned as is.
:param download_dir: The checkpoints folder to which the checkpoint should be downloaded if it is a URL or
AzureML run ID.
:raises ValueError: If the checkpoint is not a local file, URL or AzureML run ID.
:raises FileNotFoundError: If the checkpoint is a URL or AzureML run ID and the download fails.
:return: The path to the checkpoint file.
"""
if self.is_local_file:
checkpoint_path = Path(self.checkpoint)
elif self.is_url:
download_folder = download_dir / MODEL_WEIGHTS_DIR_NAME
download_folder.mkdir(exist_ok=True, parents=True)
checkpoint_path = self.download_from_url(url=self.checkpoint, download_folder=download_folder)
elif self.is_aml_run_id:
downloader = CheckpointDownloader(run_id=self.checkpoint, download_dir=download_dir)
downloader.download_checkpoint_if_necessary()
checkpoint_path = downloader.local_checkpoint_path
else:
raise ValueError("Unable to determine how to get the checkpoint path.")
if checkpoint_path is None or not checkpoint_path.is_file():
raise FileNotFoundError(f"Could not find the file at {checkpoint_path}")
return checkpoint_path

Просмотреть файл

@ -12,10 +12,10 @@ class HEDJitter(object):
"""
Randomly perturbe the HED color space value an RGB image.
First, it disentangled the hematoxylin and eosin color channels by color deconvolution method using a fixed matrix,
taken from Ruifrok and Johnston (2001): "Quantification of histochemical staining by color deconvolution."
First, it disentangled the hematoxylin and eosin color channels by color deconvolution method using a fixed matrix.
Second, it perturbed the hematoxylin, eosin stains independently.
Third, it transformed the resulting stains into regular RGB color space.
PyTorch version of: https://github.com/gatsby2016/Augmentation-PyTorch-Transforms/blob/master/myTransforms.py
Usage example:
>>> transform = HEDJitter(0.05)
@ -37,45 +37,45 @@ class HEDJitter(object):
self.hed_from_rgb = torch.tensor([[1.87798274, -1.00767869, -0.55611582],
[-0.06590806, 1.13473037, -0.1355218],
[-0.60190736, -0.48041419, 1.57358807]])
self.log_adjust = torch.log(torch.tensor(1E-6))
@staticmethod
def adjust_hed(img: torch.Tensor,
theta: float,
stain_from_rgb_mat: torch.Tensor,
rgb_from_stain_mat: torch.Tensor
) -> torch.Tensor:
def adjust_hed(self, img: torch.Tensor) -> torch.Tensor:
"""
Applies HED jitter to image.
:param img: Input image.
:param theta: Strength of the jitter. HED_light: theta=0.05; HED_strong: theta=0.2.
:param stain_from_rgb_mat: Transformation matrix from HED to RGB.
:param rgb_from_stain_mat: Transformation matrix from RGB to HED.
"""
alpha = torch.FloatTensor(1, 3).uniform_(1 - theta, 1 + theta)
beta = torch.FloatTensor(1, 3).uniform_(-theta, theta)
# Only perturb the H (=0) and E (=1) channels
alpha[0][-1] = 1.
beta[0][-1] = 0.
alpha = torch.FloatTensor(img.shape[0], 1, 1, 3).uniform_(1 - self.theta, 1 + self.theta)
beta = torch.FloatTensor(img.shape[0], 1, 1, 3).uniform_(-self.theta, self.theta)
# Separate stains
img = img.permute([0, 2, 3, 1])
img = img + 2 # for consistency with skimage
stains = -torch.log10(img) @ stain_from_rgb_mat
stains = alpha * stains + beta # perturbations in HED color space
img = torch.maximum(img, 1E-6 * torch.ones(img.shape))
stains = (torch.log(img) / self.log_adjust) @ self.hed_from_rgb
stains = torch.maximum(stains, torch.zeros(stains.shape))
# perturbations in HED color space
stains = alpha * stains + beta
# Combine stains
img = 10 ** (-stains @ rgb_from_stain_mat) - 2
img = -(stains * (-self.log_adjust)) @ self.rgb_from_hed
img = torch.exp(img)
img = torch.clip(img, 0, 1)
img = img.permute(0, 3, 1, 2)
# Normalize
imin = torch.amin(img, dim=[1, 2, 3], keepdim=True)
imax = torch.amax(img, dim=[1, 2, 3], keepdim=True)
img = (img - imin) / (imax - imin)
return img
def __call__(self, img: torch.Tensor) -> torch.Tensor:
if img.shape[1] != 3:
raise ValueError("HED jitter can only be applied to images with 3 channels (RGB).")
return self.adjust_hed(img, self.theta, self.hed_from_rgb, self.rgb_from_hed)
return self.adjust_hed(img)
class StainNormalization(object):
@ -131,7 +131,20 @@ class StainNormalization(object):
return nimg
def __call__(self, img: torch.Tensor) -> torch.Tensor:
return self.stain_normalize(img, self.reference_mean, self.reference_std)
original_shape = img.shape
if len(original_shape) == 3:
img = img.unsqueeze(0) # add batch dimension if missing
# if the input is a bag of images, stain normalization needs to run on each image separately
if img.shape[0] > 1:
for i in range(img.shape[0]):
img_tile = img[i]
img[i] = self.stain_normalize(img_tile.unsqueeze(0), self.reference_mean, self.reference_std)
return img
else:
img = self.stain_normalize(img, self.reference_mean, self.reference_std)
if len(original_shape) == 3:
return img.squeeze(0)
return img
class GaussianBlur(object):

Просмотреть файл

@ -5,7 +5,7 @@ from torch import nn, rand, sum, allclose, ones_like
from health_ml.networks.layers.attention_layers import (AttentionLayer, GatedAttentionLayer,
MeanPoolingLayer, TransformerPooling,
MaxPoolingLayer)
MaxPoolingLayer, TransformerPoolingBenchmark)
def _test_attention_layer(attentionlayer: nn.Module, dim_in: int, dim_att: int,
@ -19,11 +19,7 @@ def _test_attention_layer(attentionlayer: nn.Module, dim_in: int, dim_att: int,
row_sums = sum(attn_weights, dim=1, keepdim=True)
assert allclose(row_sums, ones_like(row_sums))
if isinstance(attentionlayer, TransformerPooling):
pass
elif isinstance(attentionlayer, MaxPoolingLayer):
pass
else:
if not isinstance(attentionlayer, (MaxPoolingLayer, TransformerPooling, TransformerPoolingBenchmark)):
pooled_features = attn_weights @ features.flatten(start_dim=1)
assert allclose(pooled_features, output_features)
@ -59,8 +55,27 @@ def test_max_pooling(dim_in: int, batch_size: int,) -> None:
@pytest.mark.parametrize("num_heads", [1, 2])
@pytest.mark.parametrize("dim_in", [4, 8]) # dim_in % num_heads must be 0
@pytest.mark.parametrize("batch_size", [1, 7])
def test_transformer_pooling(num_layers: int, num_heads: int, dim_in: int, batch_size: int) -> None:
def test_transformer_pooling(num_layers: int, num_heads: int, dim_in: int,
batch_size: int) -> None:
transformer_dropout = 0.5
transformer_pooling = TransformerPooling(num_layers=num_layers,
num_heads=num_heads,
dim_representation=dim_in).eval()
dim_representation=dim_in,
transformer_dropout=transformer_dropout).eval()
_test_attention_layer(transformer_pooling, dim_in=dim_in, dim_att=1, batch_size=batch_size)
@pytest.mark.parametrize("num_layers", [1, 4])
@pytest.mark.parametrize("num_heads", [1, 2])
@pytest.mark.parametrize("dim_in", [4, 8]) # dim_in % num_heads must be 0
@pytest.mark.parametrize("batch_size", [1, 7])
@pytest.mark.parametrize("dim_hid", [1, 4])
def test_transformer_pooling_benchmark(num_layers: int, num_heads: int, dim_in: int,
batch_size: int, dim_hid: int) -> None:
transformer_dropout = 0.5
transformer_pooling_benchmark = TransformerPoolingBenchmark(num_layers=num_layers,
num_heads=num_heads,
dim_representation=dim_in,
hidden_dim=dim_hid,
transformer_dropout=transformer_dropout).eval()
_test_attention_layer(transformer_pooling_benchmark, dim_in=dim_in, dim_att=1, batch_size=batch_size)

Просмотреть файл

@ -8,52 +8,73 @@ from unittest import mock
import pytest
from health_ml.configs.hello_world import HelloWorld
from health_ml.deep_learning_config import WorkflowParams
from health_ml.lightning_container import LightningContainer
from health_ml.utils.checkpoint_handler import CheckpointHandler
from health_ml.utils.checkpoint_utils import (
LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
LAST_CHECKPOINT_FILE_NAME,
MODEL_WEIGHTS_DIR_NAME,
CheckpointDownloader,
)
CheckpointParser,)
from health_ml.utils.checkpoint_handler import CheckpointHandler
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
from testhiml.utils.fixed_paths_for_tests import full_test_data_path, mock_run_id
from testhiml.utils_testhiml import DEFAULT_WORKSPACE
def test_checkpoint_downloader_run_id() -> None:
with mock.patch("health_ml.utils.checkpoint_utils.CheckpointDownloader.download_checkpoint_if_necessary"):
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id")
assert checkpoint_downloader.run_id == "dummy_run_id"
assert checkpoint_downloader.checkpoint_filename == LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
assert checkpoint_downloader.remote_checkpoint_dir == Path(DEFAULT_AML_CHECKPOINT_DIR)
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id")
assert checkpoint_downloader.run_id == "dummy_run_id"
assert checkpoint_downloader.checkpoint_filename == LAST_CHECKPOINT_FILE_NAME
assert checkpoint_downloader.remote_checkpoint_dir == Path(DEFAULT_AML_CHECKPOINT_DIR)
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id:best.ckpt")
assert checkpoint_downloader.run_id == "dummy_run_id"
assert checkpoint_downloader.checkpoint_filename == "best.ckpt"
assert checkpoint_downloader.remote_checkpoint_dir == Path(DEFAULT_AML_CHECKPOINT_DIR)
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id:best.ckpt")
assert checkpoint_downloader.run_id == "dummy_run_id"
assert checkpoint_downloader.checkpoint_filename == "best.ckpt"
assert checkpoint_downloader.remote_checkpoint_dir == Path(DEFAULT_AML_CHECKPOINT_DIR)
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id:custom/path/best.ckpt")
assert checkpoint_downloader.run_id == "dummy_run_id"
assert checkpoint_downloader.checkpoint_filename == "best.ckpt"
assert checkpoint_downloader.remote_checkpoint_dir == Path("custom/path")
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id:custom/path/best.ckpt")
assert checkpoint_downloader.run_id == "dummy_run_id"
assert checkpoint_downloader.checkpoint_filename == "best.ckpt"
assert checkpoint_downloader.remote_checkpoint_dir == Path("custom/path")
def _test_invalid_checkpoint(checkpoint: str) -> None:
with pytest.raises(ValueError, match=r"Invalid checkpoint "):
CheckpointParser(checkpoint=checkpoint)
WorkflowParams(local_datasets=Path("foo"), src_checkpoint=checkpoint).validate()
def test_validate_checkpoint_parser() -> None:
_test_invalid_checkpoint(checkpoint="dummy/local/path/model.ckpt")
_test_invalid_checkpoint(checkpoint="INV@lid%RUN*id")
_test_invalid_checkpoint(checkpoint="http/dummy_url-com")
# The following should be okay
checkpoint = str(full_test_data_path(suffix="hello_world_checkpoint.ckpt"))
CheckpointParser(checkpoint=checkpoint)
WorkflowParams(local_datasets=Path("foo"), src_checkpoint=CheckpointParser(checkpoint)).validate()
checkpoint = mock_run_id(id=0)
CheckpointParser(checkpoint=checkpoint)
WorkflowParams(local_datasets=Path("foo"), src_checkpoint=CheckpointParser(checkpoint)).validate()
def get_checkpoint_handler(tmp_path: Path, src_checkpoint: str) -> Tuple[LightningContainer, CheckpointHandler]:
container = LightningContainer()
container.set_output_to(tmp_path)
container.checkpoint_folder.mkdir(parents=True)
container.src_checkpoint = src_checkpoint
container.src_checkpoint = CheckpointParser(src_checkpoint)
return container, CheckpointHandler(container=container, project_root=tmp_path)
def test_load_model_chcekpoints_from_url(tmp_path: Path) -> None:
def test_load_model_checkpoints_from_url(tmp_path: Path) -> None:
WEIGHTS_URL = (
"https://pl-bolts-weights.s3.us-east-2.amazonaws.com/" "simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
)
container, checkpoint_handler = get_checkpoint_handler(tmp_path, WEIGHTS_URL)
download_folder = container.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME
assert container.src_checkpoint_is_url
assert container.src_checkpoint.is_url
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert checkpoint_handler.trained_weights_path
assert checkpoint_handler.trained_weights_path.exists()
@ -64,7 +85,7 @@ def test_load_model_checkpoints_from_local_file(tmp_path: Path) -> None:
local_checkpoint_path = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
container, checkpoint_handler = get_checkpoint_handler(tmp_path, str(local_checkpoint_path))
assert container.src_checkpoint_is_local_file
assert container.src_checkpoint.is_local_file
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert checkpoint_handler.trained_weights_path
assert checkpoint_handler.trained_weights_path.exists()
@ -82,10 +103,10 @@ def test_load_model_checkpoints_from_aml_run_id(src_chekpoint_filename: str, tmp
src_checkpoint_filename = (
src_chekpoint_filename.split("/")[-1]
if src_chekpoint_filename
else LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
else LAST_CHECKPOINT_FILE_NAME
)
expected_weights_path = container.outputs_folder / run_id / checkpoint_path / src_checkpoint_filename
assert container.src_checkpoint_is_aml_run_id
expected_weights_path = container.checkpoint_folder / run_id / checkpoint_path / src_checkpoint_filename
assert container.src_checkpoint.is_aml_run_id
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert checkpoint_handler.trained_weights_path
assert checkpoint_handler.trained_weights_path.exists()
@ -99,7 +120,7 @@ def test_custom_checkpoint_for_test(tmp_path: Path) -> None:
container = HelloWorld()
container.set_output_to(tmp_path)
container.checkpoint_folder.mkdir(parents=True)
last_checkpoint = container.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
last_checkpoint = container.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME
last_checkpoint.touch()
checkpoint_handler = CheckpointHandler(container=container, project_root=tmp_path)
checkpoint_handler.additional_training_done()

Просмотреть файл

@ -16,6 +16,7 @@ dummy_img = torch.Tensor(
[0.8717, 0.9098]],
[[0.1592, 0.7216],
[0.8305, 0.1127]]]])
dummy_bag = torch.stack([dummy_img.squeeze(0), dummy_img.squeeze(0)])
def _test_data_augmentation(data_augmentation: Callable[[Tensor], Tensor],
@ -59,21 +60,38 @@ def test_stain_normalization() -> None:
[0.8706, 0.4863]],
[[0.8235, 0.5294],
[0.8275, 0.7725]]]])
expected_output_bag = torch.stack([expected_output_img.squeeze(0), expected_output_img.squeeze(0)])
_test_data_augmentation(data_augmentation, dummy_img, expected_output_img, stochastic=False)
_test_data_augmentation(data_augmentation, dummy_bag, expected_output_bag, stochastic=False)
# Test tiling on the fly (i.e. when the input image does not have a batch dimension)
_test_data_augmentation(data_augmentation, dummy_img.squeeze(0), expected_output_img.squeeze(0), stochastic=False)
def test_hed_jitter() -> None:
data_augmentation = HEDJitter(0.05)
expected_output_img = torch.Tensor(
[[[[0.6241, 0.1635],
[0.9993, 1.0000]],
[[1.0000, 1.0000],
[1.0000, 1.0000]],
[[0.2232, 0.8028],
[0.9117, 0.1742]]]])
expected_output_img1 = torch.Tensor(
[[[[0.9639, 0.4130],
[0.9134, 1.0000]],
[[0.3125, 0.0000],
[0.4474, 0.1820]],
[[0.9195, 0.5265],
[0.9118, 0.8291]]]])
expected_output_img2 = torch.Tensor(
[[[[0.8411, 0.2361],
[0.7857, 0.8766]],
[[0.7075, 0.0000],
[1.0000, 0.4138]],
[[0.9694, 0.4674],
[0.9577, 0.8476]]]])
expected_output_bag = torch.vstack([expected_output_img1,
expected_output_img2])
_test_data_augmentation(data_augmentation, dummy_img, expected_output_img, stochastic=True)
_test_data_augmentation(data_augmentation,
dummy_bag,
expected_output_bag,
stochastic=True)
def test_gaussian_blur() -> None:

Просмотреть файл

@ -13,44 +13,40 @@ from pathlib import Path
from health_ml.deep_learning_config import DatasetParams, WorkflowParams, OutputParams, OptimizerParams, \
ExperimentFolderHandler, TrainerParams
from health_ml.utils.checkpoint_utils import CheckpointParser
from testhiml.utils.fixed_paths_for_tests import full_test_data_path, mock_run_id
def _test_invalid_pre_checkpoint_workflow_params(src_checkpoint: str) -> None:
error_message = "Invalid src_checkpoint:"
with pytest.raises(ValueError) as ex:
WorkflowParams(local_datasets=Path("foo"), src_checkpoint=src_checkpoint).validate()
assert error_message in ex.value.args[0]
def test_validate_workflow_params_src_checkpoint() -> None:
_test_invalid_pre_checkpoint_workflow_params(src_checkpoint="dummy/local/path/model.ckpt")
_test_invalid_pre_checkpoint_workflow_params(src_checkpoint="INV@lid%RUN*id")
_test_invalid_pre_checkpoint_workflow_params(src_checkpoint="http/dummy_url-com")
# The following should be okay
full_file_path = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
WorkflowParams(local_dataset=Path("foo"), src_checkpoint=str(full_file_path)).validate()
run_id = mock_run_id(id=0)
WorkflowParams(local_dataset=Path("foo"), src_checkpoint=run_id).validate()
def test_validate_workflow_params_for_inference_only() -> None:
error_message = "Cannot run inference without a src_checkpoint."
with pytest.raises(ValueError) as ex:
with pytest.raises(ValueError, match=r"Cannot run inference without a src_checkpoint."):
WorkflowParams(local_datasets=Path("foo"), run_inference_only=True).validate()
assert error_message in ex.value.args[0]
full_file_path = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
run_id = mock_run_id(id=0)
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True, src_checkpoint=run_id).validate()
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True,
src_checkpoint=f"{run_id}:best_val_loss.ckpt").validate()
src_checkpoint=CheckpointParser(run_id)).validate()
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True,
src_checkpoint=f"{run_id}:custom/path/model.ckpt").validate()
src_checkpoint=CheckpointParser(f"{run_id}:best_val_loss.ckpt")).validate()
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True,
src_checkpoint=str(full_file_path)).validate()
src_checkpoint=CheckpointParser(f"{run_id}:custom/path/model.ckpt")).validate()
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True,
src_checkpoint=CheckpointParser(str(full_file_path))).validate()
def test_validate_workflow_params_for_resume_training() -> None:
with pytest.raises(ValueError, match=r"Cannot resume training without a src_checkpoint."):
WorkflowParams(local_datasets=Path("foo"), resume_training=True).validate()
full_file_path = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
run_id = mock_run_id(id=0)
WorkflowParams(local_dataset=Path("foo"), resume_training=True,
src_checkpoint=CheckpointParser(run_id)).validate()
WorkflowParams(local_dataset=Path("foo"), resume_training=True,
src_checkpoint=CheckpointParser(f"{run_id}:best_val_loss.ckpt")).validate()
WorkflowParams(local_dataset=Path("foo"), resume_training=True,
src_checkpoint=CheckpointParser(f"{run_id}:custom/path/model.ckpt")).validate()
WorkflowParams(local_dataset=Path("foo"), resume_training=True,
src_checkpoint=CheckpointParser(str(full_file_path))).validate()
@pytest.mark.fast

Просмотреть файл

@ -41,6 +41,7 @@ def test_create_lightning_trainer() -> None:
assert trainer.default_root_dir == str(container.outputs_folder)
assert trainer.limit_train_batches == 1.0
assert trainer._detect_anomaly == container.detect_anomaly
assert trainer.accumulate_grad_batches == 1
assert isinstance(trainer.callbacks[0], TQDMProgressBar)
assert isinstance(trainer.callbacks[1], ModelSummary)
@ -192,3 +193,11 @@ def test_create_lightning_trainer_limit_batches() -> None:
assert trainer3.num_training_batches == int(limit_train_batches_float * original_num_train_batches)
assert trainer3.num_val_batches[0] == int(limit_val_batches_float * original_num_val_batches)
assert trainer3.num_test_batches[0] == int(limit_test_batches_float * original_num_test_batches)
def test_flag_grad_accum() -> None:
num_batches = 4
container = LightningContainer()
container.pl_accumulate_grad_batches = num_batches
trainer, _ = create_lightning_trainer(container)
assert trainer.accumulate_grad_batches == num_batches

Просмотреть файл

@ -8,6 +8,8 @@ import pytest
from pathlib import Path
from typing import Generator
from unittest.mock import DEFAULT, MagicMock, Mock, patch
from _pytest.logging import LogCaptureFixture
from pytorch_lightning import LightningModule
import mlflow
from pytorch_lightning import Trainer
@ -16,6 +18,7 @@ from health_ml.configs.hello_world import HelloWorld # type: ignore
from health_ml.experiment_config import ExperimentConfig
from health_ml.lightning_container import LightningContainer
from health_ml.run_ml import MLRunner, get_mlflow_run_id_from_previous_loggers
from health_ml.utils.checkpoint_utils import CheckpointParser
from health_ml.utils.common_utils import is_gpu_available
from health_ml.utils.lightning_loggers import HimlMLFlowLogger, StoringLogger
from health_azure.utils import is_global_rank_zero
@ -62,7 +65,7 @@ def ml_runner_with_run_id() -> Generator:
experiment_config = ExperimentConfig(model="HelloWorld")
container = HelloWorld()
container.save_checkpoint = True
container.src_checkpoint = mock_run_id(id=0)
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:
mock_get_workspace.return_value = DEFAULT_WORKSPACE.workspace
runner = MLRunner(experiment_config=experiment_config, container=container)
@ -164,26 +167,79 @@ def test_run_validation(run_extra_val_epoch: bool) -> None:
container = HelloWorld()
container.create_lightning_module_and_store()
container.run_extra_val_epoch = run_extra_val_epoch
container.model.run_extra_val_epoch = run_extra_val_epoch # type: ignore
runner = MLRunner(experiment_config=experiment_config, container=container)
with patch.object(container, "get_data_module"):
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
runner.setup()
mock_trainer = MagicMock()
mock_storing_logger = MagicMock()
mock_create_trainer.return_value = mock_trainer, mock_storing_logger
runner.init_training()
with patch.object(container, "on_run_extra_validation_epoch") as mock_on_run_extra_validation_epoch:
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
runner.setup()
mock_trainer = MagicMock()
mock_storing_logger = MagicMock()
mock_create_trainer.return_value = mock_trainer, mock_storing_logger
runner.init_training()
assert runner.trainer == mock_trainer
assert runner.storing_logger == mock_storing_logger
assert runner.trainer == mock_trainer
assert runner.storing_logger == mock_storing_logger
mock_trainer.validate = Mock()
mock_trainer.validate = Mock()
if run_extra_val_epoch:
runner.run_validation()
if run_extra_val_epoch:
with patch.object(runner, "validate_model_weights") as mock_validate_model_weights:
runner.run_validation()
mock_validate_model_weights.assert_called_once()
assert mock_trainer.validate.called == run_extra_val_epoch
assert mock_on_run_extra_validation_epoch.called == run_extra_val_epoch
assert hasattr(container.model, "on_run_extra_validation_epoch")
assert mock_trainer.validate.called == run_extra_val_epoch
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
def test_model_extra_val_epoch(run_extra_val_epoch: bool) -> None:
experiment_config = ExperimentConfig(model="HelloWorld")
with patch(
"health_ml.configs.hello_world.HelloRegression.on_run_extra_validation_epoch"
) as mock_on_run_extra_validation_epoch:
container = HelloWorld()
container.run_extra_val_epoch = run_extra_val_epoch
container.create_lightning_module_and_store()
runner = MLRunner(experiment_config=experiment_config, container=container)
with patch.object(container, "get_data_module"):
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
runner.setup()
mock_trainer = MagicMock()
mock_create_trainer.return_value = mock_trainer, MagicMock()
runner.init_training()
mock_trainer.validate = Mock()
if run_extra_val_epoch:
with patch.object(runner, "validate_model_weights") as mock_validate_model_weights:
runner.run_validation()
mock_validate_model_weights.assert_called_once()
assert mock_on_run_extra_validation_epoch.called == run_extra_val_epoch
assert mock_trainer.validate.called == run_extra_val_epoch
def test_model_extra_val_epoch_missing_hook(caplog: LogCaptureFixture) -> None:
experiment_config = ExperimentConfig(model="HelloWorld")
def _create_model(self) -> LightningModule: # type: ignore
return LightningModule()
with patch("health_ml.configs.hello_world.HelloWorld.create_model", _create_model):
container = HelloWorld()
container.create_lightning_module_and_store()
container.run_extra_val_epoch = True
runner = MLRunner(experiment_config=experiment_config, container=container)
with patch.object(container, "get_data_module"):
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
runner.setup()
mock_create_trainer.return_value = MagicMock(), MagicMock()
runner.init_training()
with patch.object(runner, "validate_model_weights") as mock_validate_model_weights:
runner.run_validation()
mock_validate_model_weights.assert_called_once()
latest_message = caplog.records[-1].getMessage()
assert "Hook `on_run_extra_validation_epoch` is not implemented by lightning module." in latest_message
def test_run_inference(ml_runner_with_container: MLRunner, tmp_path: Path) -> None:
@ -218,7 +274,9 @@ def test_run_inference(ml_runner_with_container: MLRunner, tmp_path: Path) -> No
actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train()
assert actual_train_ckpt_path is None
ml_runner_with_container.run()
with patch.object(ml_runner_with_container, "validate_model_weights") as mock_validate_model_weights:
ml_runner_with_container.run()
mock_validate_model_weights.assert_called_once()
actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train()
assert actual_train_ckpt_path == expected_ckpt_path
@ -268,7 +326,10 @@ def test_run_inference_only(ml_runner_with_run_id: MLRunner) -> None:
assert ml_runner_with_run_id.checkpoint_handler.trained_weights_path
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
with patch.multiple(
ml_runner_with_run_id, run_training=DEFAULT, run_validation=DEFAULT
ml_runner_with_run_id,
run_training=DEFAULT,
run_validation=DEFAULT,
validate_model_weights=DEFAULT
) as mocks:
mock_trainer = MagicMock()
mock_create_trainer.return_value = mock_trainer, MagicMock()
@ -278,6 +339,7 @@ def test_run_inference_only(ml_runner_with_run_id: MLRunner) -> None:
assert recovery_checkpoint == ml_runner_with_run_id.checkpoint_handler.trained_weights_path
mocks["run_training"].assert_not_called()
mocks["run_validation"].assert_not_called()
mocks["validate_model_weights"].assert_not_called()
mock_trainer.test.assert_called_once()
@ -297,7 +359,8 @@ def test_model_weights_when_resume_training() -> None:
experiment_config = ExperimentConfig(model="HelloWorld")
container = HelloWorld()
container.max_num_gpus = 0
container.src_checkpoint = mock_run_id(id=0)
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
container.resume_training = True
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:
mock_get_workspace.return_value = DEFAULT_WORKSPACE.workspace
runner = MLRunner(experiment_config=experiment_config, container=container)
@ -315,7 +378,7 @@ def test_runner_end_to_end() -> None:
experiment_config = ExperimentConfig(model="HelloWorld")
container = HelloWorld()
container.max_num_gpus = 0
container.src_checkpoint = mock_run_id(id=0)
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:
mock_get_workspace.return_value = DEFAULT_WORKSPACE.workspace
runner = MLRunner(experiment_config=experiment_config, container=container)

Просмотреть файл

@ -6,7 +6,7 @@ from contextlib import contextmanager
import shutil
import sys
from pathlib import Path
from typing import Generator, List, Optional
from typing import Any, Dict, Generator, List, Optional
from unittest.mock import patch, MagicMock
import pytest
@ -17,6 +17,7 @@ from health_azure import AzureRunInfo, DatasetConfig
from health_azure.paths import ENVIRONMENT_YAML_FILE_NAME
from health_ml.configs.hello_world import HelloWorld # type: ignore
from health_ml.deep_learning_config import WorkflowParams
from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, DebugDDPOptions
from health_ml.lightning_container import LightningContainer
from health_ml.runner import Runner
from health_ml.utils.common_utils import change_working_directory
@ -83,6 +84,35 @@ def test_parse_and_load_model(mock_runner: Runner, model_name: Optional[str], cl
assert mock_runner.lightning_container.model_name == model_name
@pytest.mark.parametrize("debug_ddp", ["OFF", "INFO", "DETAIL"])
def test_ddp_debug_flag(debug_ddp: DebugDDPOptions, mock_runner: Runner) -> None:
model_name = "HelloWorld"
arguments = ["", f"--debug_ddp={debug_ddp}", f"--model={model_name}"]
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
with patch("health_ml.runner.get_workspace"):
with patch("health_ml.runner.Runner.run_in_situ"):
with patch.object(sys, "argv", arguments):
mock_runner.run()
mock_submit_to_azure_if_needed.assert_called_once()
assert mock_submit_to_azure_if_needed.call_args[1]["environment_variables"][DEBUG_DDP_ENV_VAR] == debug_ddp
def test_additional_aml_run_tags(mock_runner: Runner) -> None:
model_name = "HelloWorld"
arguments = ["", f"--model={model_name}", "--cluster=foo"]
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
with patch("health_ml.runner.check_conda_environment"):
with patch("health_ml.runner.get_workspace"):
with patch("health_ml.runner.get_ml_client"):
with patch("health_ml.runner.Runner.run_in_situ"):
with patch.object(sys, "argv", arguments):
mock_runner.run()
mock_submit_to_azure_if_needed.assert_called_once()
assert "commandline_args" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
assert "tag" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
assert "max_epochs" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
def test_run(mock_runner: Runner) -> None:
model_name = "HelloWorld"
arguments = ["", f"--model={model_name}"]
@ -106,7 +136,8 @@ def test_submit_to_azureml_if_needed(mock_get_workspace: MagicMock,
mock_runner: Runner
) -> None:
def _mock_dont_submit_to_aml(input_datasets: List[DatasetConfig],
submit_to_azureml: bool, strictly_aml_v1: bool, # type: ignore
submit_to_azureml: bool, strictly_aml_v1: bool, # type: ignore
environment_variables: Dict[str, Any], # type: ignore
) -> AzureRunInfo:
datasets_input = [d.target_folder for d in input_datasets] if input_datasets else []
return AzureRunInfo(input_datasets=datasets_input,