Merge branch 'transfer_main' into kenzab/merge_transfer_main_main
|
@ -13,10 +13,10 @@ name: "CodeQL"
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, PR ]
|
||||
branches: [ main, PR, transfer_main ]
|
||||
pull_request:
|
||||
# The branches below must be a subset of the branches above
|
||||
branches: [ main ]
|
||||
branches: [ main, transfer_main ]
|
||||
schedule:
|
||||
- cron: '20 1 * * 6' # https://crontab.guru/#20_1_*_*_6 this is: At 01:20 on Saturday.
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
name: Pathology PR
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches: [ main, transfer_main ]
|
||||
tags:
|
||||
- '*'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches: [ main, transfer_main ]
|
||||
paths:
|
||||
- "hi-ml-cpath/**"
|
||||
- ".github/workflows/cpath-pr.yml"
|
||||
|
@ -128,7 +128,22 @@ jobs:
|
|||
cd ${{ env.folder }}
|
||||
make smoke_test_tilespandaimagenetmil_aml
|
||||
|
||||
smoke_test_tcgacrcksslmil_aml:
|
||||
smoke_test_tcgacrckimagenetmil:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Prepare Conda environment
|
||||
uses: ./.github/actions/prepare_cpath_environment
|
||||
|
||||
- name: smoke test
|
||||
run: |
|
||||
cd ${{ env.folder }}
|
||||
make smoke_test_tcgacrckimagenetmil_aml
|
||||
|
||||
smoke_test_tcgacrcksslmil:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
@ -143,7 +158,7 @@ jobs:
|
|||
cd ${{ env.folder }}
|
||||
make smoke_test_tcgacrcksslmil_aml
|
||||
|
||||
smoke_test_crck_simclr_aml:
|
||||
smoke_test_crck_simclr:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
@ -158,6 +173,81 @@ jobs:
|
|||
cd ${{ env.folder }}
|
||||
make smoke_test_crck_simclr_aml
|
||||
|
||||
smoke_test_crck_flexible_finetuning:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Prepare Conda environment
|
||||
uses: ./.github/actions/prepare_cpath_environment
|
||||
|
||||
- name: smoke test
|
||||
run: |
|
||||
cd ${{ env.folder }}
|
||||
make smoke_test_crck_flexible_finetuning_aml
|
||||
|
||||
smoke_test_crck_loss_analysis:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Prepare Conda environment
|
||||
uses: ./.github/actions/prepare_cpath_environment
|
||||
|
||||
- name: smoke test
|
||||
run: |
|
||||
cd ${{ env.folder }}
|
||||
make smoke_test_crck_loss_analysis_aml
|
||||
|
||||
smoke_test_slides_panda_loss_analysis:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Prepare Conda environment
|
||||
uses: ./.github/actions/prepare_cpath_environment
|
||||
|
||||
- name: smoke test
|
||||
run: |
|
||||
cd ${{ env.folder }}
|
||||
make smoke_test_slides_panda_loss_analysis_aml
|
||||
|
||||
smoke_test_slides_panda_no_ddp_sampler:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Prepare Conda environment
|
||||
uses: ./.github/actions/prepare_cpath_environment
|
||||
|
||||
- name: smoke test
|
||||
run: |
|
||||
cd ${{ env.folder }}
|
||||
make smoke_test_slides_panda_no_ddp_sampler_aml
|
||||
|
||||
smoke_test_tiles_panda_no_ddp_sampler:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Prepare Conda environment
|
||||
uses: ./.github/actions/prepare_cpath_environment
|
||||
|
||||
- name: smoke test
|
||||
run: |
|
||||
cd ${{ env.folder }}
|
||||
make smoke_test_tiles_panda_no_ddp_sampler_aml
|
||||
|
||||
publish:
|
||||
runs-on: ubuntu-20.04
|
||||
needs: [
|
||||
|
@ -166,7 +256,14 @@ jobs:
|
|||
pytest,
|
||||
smoke_test_slidespandaimagenetmil,
|
||||
smoke_test_tilespandaimagenetmil,
|
||||
smoke_test_crck_simclr_aml
|
||||
smoke_test_tcgacrckimagenetmil,
|
||||
smoke_test_tcgacrcksslmil,
|
||||
smoke_test_crck_simclr,
|
||||
smoke_test_crck_flexible_finetuning,
|
||||
smoke_test_crck_loss_analysis,
|
||||
smoke_test_slides_panda_loss_analysis,
|
||||
smoke_test_slides_panda_no_ddp_sampler,
|
||||
smoke_test_tiles_panda_no_ddp_sampler,
|
||||
]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
|
|
@ -2,9 +2,9 @@ name: CredScan
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches: [ main, transfer_main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches: [ main, transfer_main ]
|
||||
|
||||
jobs:
|
||||
CredScan:
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
name: HI-ML HI-ML-Azure PR
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches: [ main, transfer_main ]
|
||||
tags:
|
||||
- '*'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches: [ main, transfer_main ]
|
||||
paths:
|
||||
- "hi-ml-azure/**"
|
||||
- "hi-ml/**"
|
||||
|
|
|
@ -16,7 +16,9 @@ dependencies:
|
|||
- azureml-train-core==1.43.0
|
||||
- conda-merge==0.1.5
|
||||
- msal-extensions==0.3.1
|
||||
- pandas==1.3.4
|
||||
- param==1.12
|
||||
- protobuf==3.20.1
|
||||
- pytorch-lightning>=1.6.0, <1.7
|
||||
- ruamel.yaml==0.16.12
|
||||
- setuptools==59.5.0
|
||||
|
|
|
@ -63,7 +63,7 @@ jobs:
|
|||
- name: HelloWorld
|
||||
sku: G1
|
||||
command:
|
||||
- python hi-ml/src/health_ml/runner.py --model=health_cpath.TilesPandaImageNetMIL --is_finetune --batch_size=2
|
||||
- python hi-ml/src/health_ml/runner.py --model=health_cpath.TilesPandaImageNetMIL --tune_encoder --batch_size=2
|
||||
```
|
||||
|
||||
Note that SKU here refers to the number of GPUs/CPUs to reserve, and its memory. In this case we have specified 1 GPU.
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
# Checkpoint Utils
|
||||
|
||||
Hi-ml toolbox offers different utilities to parse and download pretrained checkpoints that help you abstract checkpoint
|
||||
downloading from different sources. Refer to
|
||||
[CheckpointParser](https://github.com/microsoft/hi-ml/blob/main/hi-ml/src/health_ml/utils/checkpoint_utils.py#L238) for
|
||||
more details on the supported checkpoints format. Here's how you can use the checkpoint parser depending on the source:
|
||||
|
||||
- For a local path, simply pass it as shown below. The parser will further check if the provided path exists:
|
||||
|
||||
```python
|
||||
from health_ml.utils.checpoint_utils import CheckpointParser
|
||||
|
||||
download_dir = 'outputs/checkpoints'
|
||||
checkpoint_parser = CheckpointParser(checkpoint='local/path/to/my_checkpoint/model.ckpt')
|
||||
print('Checkpoint', checkpoint_parser.checkpoint, 'is a local file', checkpoint_parser.is_local_file)
|
||||
local_file = parser.get_path(download_dir)
|
||||
```
|
||||
|
||||
- To download a checkpoint from a URL:
|
||||
|
||||
```python
|
||||
from health_ml.utils.checpoint_utils import CheckpointParser, MODEL_WEIGHTS_DIR_NAME
|
||||
|
||||
download_dir = 'outputs/checkpoints'
|
||||
checkpoint_parser = CheckpointParser('https://my_checkpoint_url.com/model.ckpt')
|
||||
print('Checkpoint', checkpoint_parser.checkpoint, 'is a URL', checkpoint_parser.is_url)
|
||||
# will dowload the checkpoint to download_dir/MODEL_WEIGHTS_DIR_NAME
|
||||
path_to_ckpt = checkpoint_parser.get_path(download_dir)
|
||||
```
|
||||
|
||||
- Finally checkpoints from an Azure ML runs can be reused by providing an id in this format
|
||||
`<AzureML_run_id>:<optional/custom/path/to/checkpoints/><filename.ckpt>`. If no custom path is provided (e.g.,
|
||||
`<AzureML_run_id>:<filename.ckpt>`) the checkpoint will be downloaded from the default checkpoint folder
|
||||
(e.g., `outputs/checkpoints`) If no filename is provided, (e.g., `src_checkpoint=<AzureML_run_id>`) the latest
|
||||
checkpoint will be downloaded (e.g., `last.ckpt`).
|
||||
|
||||
```python
|
||||
from health_ml.utils.checpoint_utils import CheckpointParser
|
||||
|
||||
checkpoint_parser = CheckpointParser('AzureML_run_id:best.ckpt')
|
||||
print('Checkpoint', checkpoint_parser.checkpoint, 'is a AML run', checkpoint_parser.is_aml_run_id)
|
||||
path_azure_ml_ckpt = checkpoint_parser.get_path(download_dir)
|
||||
```
|
||||
|
||||
If the Azure ML run is in a different workspace, a temporary SAS URL to download the checkpoint can be generated as follow:
|
||||
|
||||
```bash
|
||||
cd hi-ml-cpath
|
||||
python src/health_cpath/scripts/generate_checkpoint_url.py --run_id=AzureML_run_id:best_val_loss.ckpt --expiry_days=10
|
||||
```
|
||||
|
||||
N.B: config.json should correspond to the original workspace where the AML run lives.
|
||||
|
||||
## Use cases
|
||||
|
||||
CheckpointParser is used to specify a `src_checkpoint` to [resume training from a given
|
||||
checkpoint](https://github.com/microsoft/hi-ml/blob/main/docs/source/runner.md#L238),
|
||||
or [run inference with a pretrained model](https://github.com/microsoft/hi-ml/blob/main/docs/source/runner.md#L215),
|
||||
as well as
|
||||
[ssl_checkpoint](https://github.com/microsoft/hi-ml/blob/main/hi-ml-cpath/src/health_cpath/utils/deepmil_utils.py#L62)
|
||||
for computation pathology self supervised pretrained encoders.
|
|
@ -42,6 +42,7 @@ The `hi-ml` toolbox provides
|
|||
logging.md
|
||||
diagnostics.md
|
||||
runner.md
|
||||
checkpoints.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
|
|
@ -44,7 +44,7 @@ addition, you can turn on fine-tuning of the encoder, which will improve the res
|
|||
|
||||
```shell
|
||||
conda activate HimlHisto
|
||||
python ../hi-ml/src/health_ml/runner.py --model health_cpath.SlidesPandaImageNetMILBenchmark --is_finetune --cluster=<your_cluster_name>
|
||||
python ../hi-ml/src/health_ml/runner.py --model health_cpath.SlidesPandaImageNetMILBenchmark --tune_encoder --cluster=<your_cluster_name>
|
||||
```
|
||||
|
||||
Then the script will output "Successfully queued run number ..." and a line prefixed "Run URL: ...". Open that
|
||||
|
|
|
@ -226,6 +226,8 @@ the model weights by setting `--src_checkpoint` argument that supports three typ
|
|||
checkpoints folder `outputs/checkpoints`. If no filename is provided (e.g., `--src_checkpoint=AzureML_run_id`),
|
||||
the last epoch checkpoint `outputs/checkpoints/last.ckpt` will be loaded.
|
||||
|
||||
Refer to [Checkpoints Utils](checkpoints.md) for more details on how checkpoints are parsed.
|
||||
|
||||
Running the following command line will run inference using `MyContainer` model with weights from the checkpoint saved
|
||||
in the AzureMl run `MyContainer_XXXX_yyyy` at the best validation loss epoch `/outputs/checkpoints/best_val_loss.ckpt`.
|
||||
|
||||
|
@ -235,12 +237,12 @@ himl-runner --model=Mycontainer --run_inference_only --src_checkpoint=MyContaine
|
|||
|
||||
## Resume training from a given checkpoint
|
||||
|
||||
Analogously, one can resume training by setting `--src_checkpoint` to either continue training or transfer learning.
|
||||
Analogously, one can resume training by setting `--src_checkpoint` and `--resume_training`to train a model longer.
|
||||
The pytorch lightning trainer will initialize the lightning module from the given checkpoint corresponding to the best
|
||||
validation loss epoch as set in the following comandline.
|
||||
|
||||
```bash
|
||||
himl-runner --model=Mycontainer --cluster=my_cluster_name --src_checkpoint=MyContainer_XXXX_yyyy:best_val_loss.ckpt
|
||||
himl-runner --model=Mycontainer --cluster=my_cluster_name --src_checkpoint=MyContainer_XXXX_yyyy:best_val_loss.ckpt --resume_training
|
||||
```
|
||||
|
||||
Warning: When resuming training, one should make sure to set `container.max_epochs` greater than the last epoch of the
|
||||
|
|
|
@ -103,6 +103,7 @@ DEFAULT_ENVIRONMENT_VARIABLES = {
|
|||
"RSLEX_DIRECT_VOLUME_MOUNT": "true",
|
||||
"RSLEX_DIRECT_VOLUME_MOUNT_MAX_CACHE_SIZE": "1",
|
||||
"DATASET_MOUNT_CACHE_SIZE": "1",
|
||||
"AZUREML_COMPUTE_USE_COMMON_RUNTIME": "false",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -29,6 +29,10 @@ pip_from_conda:
|
|||
sed -e '1,/pip:/ d' environment.yml | grep -v "#" | cut -d "-" -f 2- > temp_requirements.txt
|
||||
pip install -r temp_requirements.txt
|
||||
|
||||
# Lock the current Conda environment secondary dependencies versions
|
||||
lock_env:
|
||||
./create_and_lock_environment.sh
|
||||
|
||||
# clean build artifacts
|
||||
clean:
|
||||
rm -rf `find . -type d -name __pycache__`
|
||||
|
@ -78,6 +82,7 @@ pytest_coverage:
|
|||
pytest --cov=health_cpath --cov SSL --cov-branch --cov-report=html --cov-report=xml --cov-report=term-missing --cov-config=.coveragerc
|
||||
|
||||
SSL_CKPT_RUN_ID_CRCK := CRCK_SimCLR_1655731022_85790606
|
||||
SRC_CKPT_RUN_ID_CRCK := TcgaCrckSSLMIL_1667832811_4e855804
|
||||
|
||||
# Run regression tests and compare performance
|
||||
define BASE_CPATH_RUNNER_COMMAND
|
||||
|
@ -86,15 +91,19 @@ python hi-ml/src/health_ml/runner.py --mount_in_azureml --conda_env=hi-ml-cpath/
|
|||
endef
|
||||
|
||||
define DEEPSMILEPANDASLIDES_ARGS
|
||||
--model=health_cpath.SlidesPandaImageNetMILBenchmark --is_finetune
|
||||
--model=health_cpath.SlidesPandaImageNetMILBenchmark --tune_encoder
|
||||
endef
|
||||
|
||||
define DEEPSMILEPANDATILES_ARGS
|
||||
--model=health_cpath.TilesPandaImageNetMIL --is_finetune --batch_size=2
|
||||
--model=health_cpath.TilesPandaImageNetMIL --tune_encoder --batch_size=2
|
||||
endef
|
||||
|
||||
define TCGACRCKSSLMIL_ARGS
|
||||
--model=health_cpath.TcgaCrckSSLMIL --ssl_checkpoint_run_id=${SSL_CKPT_RUN_ID_CRCK}
|
||||
--model=health_cpath.TcgaCrckSSLMIL --ssl_checkpoint=${SSL_CKPT_RUN_ID_CRCK}
|
||||
endef
|
||||
|
||||
define TCGACRCKIMANEGETMIL_ARGS
|
||||
--model=health_cpath.TcgaCrckImageNetMIL --is_caching=False --batch_size=2 --max_num_gpus=1
|
||||
endef
|
||||
|
||||
define CRCKSIMCLR_ARGS
|
||||
|
@ -158,12 +167,7 @@ define AML_MULTIPLE_DEVICE_ARGS
|
|||
--cluster=dedicated-nc24s-v2 --wait_for_completion
|
||||
endef
|
||||
|
||||
define DEEPSMILEPANDASLIDES_SMOKE_TEST_ARGS
|
||||
--crossval_count=0 --num_top_slides=2 --num_top_tiles=2 --max_bag_size=3 \
|
||||
--max_bag_size_inf=3
|
||||
endef
|
||||
|
||||
define DEEPSMILEPANDATILES_SMOKE_TEST_ARGS
|
||||
define DEEPSMILEDEFAULT_SMOKE_TEST_ARGS
|
||||
--crossval_count=0 --num_top_slides=2 --num_top_tiles=2 --max_bag_size=3 \
|
||||
--max_bag_size_inf=3
|
||||
endef
|
||||
|
@ -176,24 +180,38 @@ define CRCKSIMCLR_SMOKE_TEST_ARGS
|
|||
--is_debug_model=True --num_workers=0
|
||||
endef
|
||||
|
||||
define CRCK_TUNING_ARGS
|
||||
--tune_encoder=False --tune_pooling=True --tune_classifier=True --pretrained_encoder=True \
|
||||
--pretrained_pooling=True --pretrained_classifier=True --src_checkpoint=${SRC_CKPT_RUN_ID_CRCK}
|
||||
endef
|
||||
|
||||
define LOSS_ANALYSIS_ARGS
|
||||
--analyse_loss=True --loss_analysis_patience=0 --loss_analysis_epochs_interval=1 \
|
||||
--num_slides_scatter=2 --num_slides_heatmap=2 --save_tile_ids=True
|
||||
endef
|
||||
|
||||
define DDP_SAMPLER_ARGS
|
||||
--pl_replace_sampler_ddp=False
|
||||
endef
|
||||
|
||||
# The following test takes around 5 minutes
|
||||
smoke_test_slidespandaimagenetmil_local:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEPANDASLIDES_SMOKE_TEST_ARGS};}
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS};}
|
||||
|
||||
# Once running in AML the following test takes around 6 minutes
|
||||
smoke_test_slidespandaimagenetmil_aml:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEPANDASLIDES_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
|
||||
|
||||
# The following test takes about 6 minutes
|
||||
smoke_test_tilespandaimagenetmil_local:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEPANDATILES_SMOKE_TEST_ARGS};}
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS};}
|
||||
|
||||
smoke_test_tilespandaimagenetmil_aml:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEPANDATILES_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
|
||||
|
||||
# The following test takes about 30 seconds
|
||||
smoke_test_tcgacrcksslmil_local:
|
||||
|
@ -204,6 +222,14 @@ smoke_test_tcgacrcksslmil_aml:
|
|||
{ ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKSSLMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${TCGACRCKSSLMIL_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
|
||||
|
||||
smoke_test_tcgacrckimagenetmil_local:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKIMANEGETMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS};}
|
||||
|
||||
smoke_test_tcgacrckimagenetmil_aml:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKIMANEGETMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
|
||||
|
||||
# The following test takes about 3 minutes
|
||||
smoke_test_crck_simclr_local:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${CRCKSIMCLR_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
|
@ -213,6 +239,46 @@ smoke_test_crck_simclr_aml:
|
|||
{ ${BASE_CPATH_RUNNER_COMMAND} ${CRCKSIMCLR_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${CRCKSIMCLR_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
|
||||
|
||||
smoke tests local: smoke_test_slidespandaimagenetmil_local smoke_test_tilespandaimagenetmil_local smoke_test_tcgacrcksslmil_local smoke_test_crck_simclr_local
|
||||
smoke_test_crck_flexible_finetuning_local:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKSSLMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${TCGACRCKSSLMIL_SMOKE_TEST_ARGS} ${CRCK_TUNING_ARGS};}
|
||||
|
||||
smoke tests AML: smoke_test_slidespandaimagenetmil_aml smoke_test_tilespandaimagenetmil_aml smoke_test_tcgacrcksslmil_aml smoke_test_crck_simclr_aml
|
||||
smoke_test_crck_flexible_finetuning_aml:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKSSLMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${TCGACRCKSSLMIL_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS} ${CRCK_TUNING_ARGS};}
|
||||
|
||||
smoke_test_crck_loss_analysis_local:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKSSLMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${TCGACRCKSSLMIL_SMOKE_TEST_ARGS} ${LOSS_ANALYSIS_ARGS};}
|
||||
|
||||
smoke_test_crck_loss_analysis_aml:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKSSLMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${TCGACRCKSSLMIL_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS} ${LOSS_ANALYSIS_ARGS};}
|
||||
|
||||
smoke_test_slides_panda_loss_analysis_local:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${LOSS_ANALYSIS_ARGS};}
|
||||
|
||||
smoke_test_slides_panda_loss_analysis_aml:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${LOSS_ANALYSIS_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
|
||||
|
||||
smoke_test_slides_panda_no_ddp_sampler_local:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${DDP_SAMPLER_ARGS};}
|
||||
|
||||
smoke_test_slides_panda_no_ddp_sampler_aml:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${DDP_SAMPLER_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
|
||||
|
||||
smoke_test_tiles_panda_no_ddp_sampler_local:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${DDP_SAMPLER_ARGS};}
|
||||
|
||||
smoke_test_tiles_panda_no_ddp_sampler_aml:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${DDP_SAMPLER_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
|
||||
|
||||
smoke tests local: smoke_test_slidespandaimagenetmil_local smoke_test_tilespandaimagenetmil_local smoke_test_tcgacrcksslmil_local smoke_test_crck_simclr_local smoke_test_crck_flexible_finetuning_local smoke_test_tcgacrckimagenetmil_local smoke_test_crck_loss_analysis_local smoke_test_slides_panda_loss_analysis_local smoke_test_slides_panda_no_ddp_sampler_local smoke_test_tiles_panda_no_ddp_sampler_local
|
||||
|
||||
smoke tests AML: smoke_test_slidespandaimagenetmil_aml smoke_test_tilespandaimagenetmil_aml smoke_test_tcgacrcksslmil_aml smoke_test_crck_simclr_aml smoke_test_crck_flexible_finetuning_aml smoke_test_tcgacrckimagenetmil_aml smoke_test_crck_loss_analysis_aml smoke_test_slides_panda_loss_analysis_aml smoke_test_slides_panda_no_ddp_sampler_aml smoke_test_tiles_panda_no_ddp_sampler_aml
|
||||
|
|
|
@ -21,11 +21,11 @@ jobs:
|
|||
- name: TilesPandaImageNetMIL
|
||||
sku: G1
|
||||
command:
|
||||
- python hi-ml/src/health_ml/runner.py --model=health_cpath.TilesPandaImageNetMIL --is_finetune --batch_size=2
|
||||
- python hi-ml/src/health_ml/runner.py --model=health_cpath.TilesPandaImageNetMIL --tune_encoder --batch_size=2
|
||||
--crossval_count=0 --num_top_slides=2 --num_top_tiles=2 --max_bag_size=3 --max_bag_size_inf=3
|
||||
--mount_in_azureml=True --pl_limit_train_batches=2 --pl_limit_val_batches=2 --pl_limit_test_batches=2
|
||||
|
||||
- name: SlidesPandaImageNetMIL
|
||||
sku: G1
|
||||
command:
|
||||
- python hi-ml/src/health_ml/runner.py --model=health_cpath.SlidesPandaImageNetMIL --is_finetune --crossval_count=0 --num_top_slides=2 --num_top_tiles=2 --max_bag_size=3 --max_bag_size_inf=3
|
||||
- python hi-ml/src/health_ml/runner.py --model=health_cpath.SlidesPandaImageNetMIL --tune_encoder --crossval_count=0 --num_top_slides=2 --num_top_tiles=2 --max_bag_size=3 --max_bag_size_inf=3
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
azure-storage-blob==12.5.0
|
||||
coloredlogs==15.0.1
|
||||
cucim==22.04.00
|
||||
girder-client==3.1.14
|
||||
|
|
|
@ -14,33 +14,34 @@ from pytorch_lightning.callbacks import Callback
|
|||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
|
||||
from health_azure.utils import create_from_matching_params
|
||||
from health_cpath.utils.callbacks import LossAnalysisCallback, LossCallbackParams
|
||||
|
||||
from health_ml.utils import fixed_paths
|
||||
from health_ml.deep_learning_config import OptimizerParams
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.utils.checkpoint_utils import get_best_checkpoint_path
|
||||
from health_ml.utils.checkpoint_utils import get_best_checkpoint_path, CheckpointParser
|
||||
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
|
||||
|
||||
from health_cpath.datamodules.base_module import CacheLocation, CacheMode, HistoDataModule
|
||||
from health_cpath.datasets.base_dataset import SlidesDataset
|
||||
from health_cpath.models.deepmil import TilesDeepMILModule, SlidesDeepMILModule, BaseDeepMILModule
|
||||
from health_cpath.models.transforms import EncodeTilesBatchd, LoadTilesBatchd
|
||||
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams
|
||||
from health_cpath.utils.deepmil_utils import ClassifierParams, EncoderParams, PoolingParams
|
||||
from health_cpath.utils.output_utils import DeepMILOutputsHandler
|
||||
from health_cpath.utils.naming import MetricsKey, PlotOption, SlideKey, ModelKey
|
||||
from health_cpath.utils.tiles_selection_utils import TilesSelector
|
||||
|
||||
|
||||
class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
|
||||
class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams, LossCallbackParams):
|
||||
"""BaseMIL is an abstract container defining basic functionality for running MIL experiments in both slides and
|
||||
tiles settings. It is responsible for instantiating the encoder and pooling layer. Subclasses should define the
|
||||
full DeepMIL model depending on the type of dataset (tiles/slides based).
|
||||
"""
|
||||
dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.")
|
||||
class_names: Optional[Sequence[str]] = param.List(None, item_type=str, doc="List of class names. If `None`, "
|
||||
"defaults to `('0', '1', ...)`.")
|
||||
# Data module parameters:
|
||||
batch_size: int = param.Integer(16, bounds=(1, None), doc="Number of slides to load per batch.")
|
||||
batch_size_inf: int = param.Integer(16, bounds=(1, None), doc="Number of slides per batch during inference.")
|
||||
max_bag_size: int = param.Integer(1000, bounds=(0, None),
|
||||
doc="Upper bound on number of tiles in each loaded bag during training stage. "
|
||||
"If 0 (default), will return all samples in each bag. "
|
||||
|
@ -71,12 +72,13 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
|
|||
"generating outputs.")
|
||||
maximise_primary_metric: bool = param.Boolean(True, doc="Whether the primary validation metric should be "
|
||||
"maximised (otherwise minimised).")
|
||||
ssl_checkpoint_run_id: str = param.String(default="", doc="Optional run id from which to load checkpoint if "
|
||||
"using SSLEncoder")
|
||||
max_num_workers: int = param.Integer(10, bounds=(0, None),
|
||||
doc="The maximum number of worker processes for dataloaders. Dataloaders use"
|
||||
"a heuristic num_cpus/num_gpus to set the number of workers, which can be"
|
||||
"very high for small num_gpus. This parameters sets an upper bound.")
|
||||
wsi_has_mask: bool = param.Boolean(default=True,
|
||||
doc="Whether the WSI has a mask. If True, will use the mask to load a specific"
|
||||
"region of the WSI. If False, will load the whole WSI.")
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
@ -84,6 +86,42 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
|
|||
metric_optim = "max" if self.maximise_primary_metric else "min"
|
||||
self.best_checkpoint_filename = f"checkpoint_{metric_optim}_val_{self.primary_val_metric.value}"
|
||||
self.best_checkpoint_filename_with_suffix = self.best_checkpoint_filename + ".ckpt"
|
||||
self.validate()
|
||||
if not self.pl_replace_sampler_ddp and self.max_num_gpus > 1:
|
||||
logging.info(
|
||||
"Replacing sampler with `DistributedSampler` is disabled. Make sure to set your own DDP sampler"
|
||||
)
|
||||
|
||||
def validate(self) -> None:
|
||||
super().validate()
|
||||
EncoderParams.validate(self)
|
||||
if not any([self.tune_encoder, self.tune_pooling, self.tune_classifier]) and not self.run_inference_only:
|
||||
raise ValueError(
|
||||
"At least one of the encoder, pooling or classifier should be fine tuned. Turn on one of the tune "
|
||||
"arguments `tune_encoder`, `tune_pooling`, `tune_classifier`. Otherwise, activate inference only "
|
||||
"mode via `run_inference_only` flag."
|
||||
)
|
||||
if (
|
||||
any([self.pretrained_encoder, self.pretrained_pooling, self.pretrained_classifier])
|
||||
and not self.src_checkpoint
|
||||
):
|
||||
raise ValueError(
|
||||
"You need to specify a source checkpoint, to use a pretrained encoder, pooling or classifier."
|
||||
f" {CheckpointParser.INFO_MESSAGE}"
|
||||
)
|
||||
if (
|
||||
self.tune_encoder and self.encoding_chunk_size < self.max_bag_size
|
||||
and self.pl_sync_batchnorm and self.max_num_gpus > 1
|
||||
):
|
||||
raise ValueError(
|
||||
"The encoding chunk size should be at least as large as the maximum bag size when fine tuning the "
|
||||
"encoder. You might encounter Batch Norm synchronization issues if the chunk size is smaller than "
|
||||
"the maximum bag size causing the processes to hang silently. This is due to the encoder being called "
|
||||
"different number of times on each device, which cause Batch Norm running statistics to be updated "
|
||||
"inconsistently across processes. In case you can't increase the `encoding_chunk_size` any further, set"
|
||||
" `pl_sync_batchnorm=False` to simply skip Batch Norm synchronization across devices. Note that this "
|
||||
"might come with some performance penalty."
|
||||
)
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> Path:
|
||||
|
@ -95,6 +133,8 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
|
|||
options.add(PlotOption.TOP_BOTTOM_TILES)
|
||||
return options
|
||||
|
||||
# overwrite this method if you want to produce validation plots at each epoch. By default, at the end of the
|
||||
# training an extra validation epoch is run where val_plot_options = test_plot_options
|
||||
def get_val_plot_options(self) -> Set[PlotOption]:
|
||||
return set()
|
||||
|
||||
|
@ -110,6 +150,8 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
|
|||
maximise=self.maximise_primary_metric,
|
||||
val_plot_options=self.get_val_plot_options(),
|
||||
test_plot_options=self.get_test_plot_options(),
|
||||
wsi_has_mask=self.wsi_has_mask,
|
||||
val_set_is_dist=self.pl_replace_sampler_ddp and self.max_num_gpus > 1,
|
||||
)
|
||||
if self.num_top_slides > 0:
|
||||
outputs_handler.tiles_selector = TilesSelector(
|
||||
|
@ -118,12 +160,26 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
|
|||
return outputs_handler
|
||||
|
||||
def get_callbacks(self) -> List[Callback]:
|
||||
return [*super().get_callbacks(),
|
||||
ModelCheckpoint(dirpath=self.checkpoint_folder,
|
||||
monitor=f"{ModelKey.VAL}/{self.primary_val_metric}",
|
||||
filename=self.best_checkpoint_filename,
|
||||
auto_insert_metric_name=False,
|
||||
mode="max" if self.maximise_primary_metric else "min")]
|
||||
callbacks = [*super().get_callbacks(),
|
||||
ModelCheckpoint(dirpath=self.checkpoint_folder,
|
||||
monitor=f"{ModelKey.VAL}/{self.primary_val_metric}",
|
||||
filename=self.best_checkpoint_filename,
|
||||
auto_insert_metric_name=False,
|
||||
mode="max" if self.maximise_primary_metric else "min")]
|
||||
if self.analyse_loss:
|
||||
callbacks.append(LossAnalysisCallback(outputs_folder=self.outputs_folder,
|
||||
max_epochs=self.max_epochs,
|
||||
patience=self.loss_analysis_patience,
|
||||
epochs_interval=self.loss_analysis_epochs_interval,
|
||||
num_slides_scatter=self.num_slides_scatter,
|
||||
num_slides_heatmap=self.num_slides_heatmap,
|
||||
save_tile_ids=self.save_tile_ids,
|
||||
log_exceptions=self.log_exceptions,
|
||||
val_set_is_dist=(
|
||||
self.pl_replace_sampler_ddp and self.max_num_gpus > 1),
|
||||
)
|
||||
)
|
||||
return callbacks
|
||||
|
||||
def get_checkpoint_to_test(self) -> Path:
|
||||
"""
|
||||
|
@ -144,6 +200,10 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
|
|||
if absolute_checkpoint_path_parent.is_file():
|
||||
return absolute_checkpoint_path_parent
|
||||
|
||||
checkpoint_path = Path(self.checkpoint_folder, self.best_checkpoint_filename_with_suffix)
|
||||
if checkpoint_path.is_file():
|
||||
return checkpoint_path
|
||||
|
||||
checkpoint_path = get_best_checkpoint_path(self.checkpoint_folder)
|
||||
if checkpoint_path.is_file():
|
||||
return checkpoint_path
|
||||
|
@ -197,8 +257,8 @@ class BaseMILTiles(BaseMIL):
|
|||
def setup(self) -> None:
|
||||
super().setup()
|
||||
# Fine-tuning requires tiles to be loaded on-the-fly, hence, caching is disabled by default.
|
||||
# When is_finetune and is_caching are both set, below lines should disable caching automatically.
|
||||
if self.is_finetune:
|
||||
# When tune_encoder and is_caching are both set, below lines should disable caching automatically.
|
||||
if self.tune_encoder:
|
||||
self.is_caching = False
|
||||
if not self.is_caching:
|
||||
self.cache_mode = CacheMode.NONE
|
||||
|
@ -213,8 +273,7 @@ class BaseMILTiles(BaseMIL):
|
|||
|
||||
def get_transforms_dict(self, image_key: str) -> Dict[ModelKey, Union[Callable, None]]:
|
||||
if self.is_caching:
|
||||
encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.ssl_checkpoint_run_id,
|
||||
self.outputs_folder)
|
||||
encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.outputs_folder)
|
||||
transform = Compose([
|
||||
LoadTilesBatchd(image_key, progress=True),
|
||||
EncodeTilesBatchd(image_key, encoder, chunk_size=self.encoding_chunk_size) # type: ignore
|
||||
|
@ -231,13 +290,15 @@ class BaseMILTiles(BaseMIL):
|
|||
n_classes=self.data_module.train_dataset.n_classes,
|
||||
class_names=self.class_names,
|
||||
class_weights=self.data_module.class_weights,
|
||||
dropout_rate=self.dropout_rate,
|
||||
outputs_folder=self.outputs_folder,
|
||||
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
|
||||
encoder_params=create_from_matching_params(self, EncoderParams),
|
||||
pooling_params=create_from_matching_params(self, PoolingParams),
|
||||
classifier_params=create_from_matching_params(self, ClassifierParams),
|
||||
optimizer_params=create_from_matching_params(self, OptimizerParams),
|
||||
outputs_handler=outputs_handler)
|
||||
outputs_folder=self.outputs_folder,
|
||||
outputs_handler=outputs_handler,
|
||||
analyse_loss=self.analyse_loss,
|
||||
)
|
||||
deepmil_module.transfer_weights(self.trained_weights_path)
|
||||
outputs_handler.set_slides_dataset_for_plots_handlers(self.get_slides_dataset())
|
||||
return deepmil_module
|
||||
|
||||
|
@ -271,12 +332,14 @@ class BaseMILSlides(BaseMIL):
|
|||
n_classes=self.data_module.train_dataset.n_classes,
|
||||
class_names=self.class_names,
|
||||
class_weights=self.data_module.class_weights,
|
||||
dropout_rate=self.dropout_rate,
|
||||
outputs_folder=self.outputs_folder,
|
||||
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
|
||||
encoder_params=create_from_matching_params(self, EncoderParams),
|
||||
pooling_params=create_from_matching_params(self, PoolingParams),
|
||||
classifier_params=create_from_matching_params(self, ClassifierParams),
|
||||
optimizer_params=create_from_matching_params(self, OptimizerParams),
|
||||
outputs_handler=outputs_handler)
|
||||
outputs_handler=outputs_handler,
|
||||
analyse_loss=self.analyse_loss,
|
||||
)
|
||||
deepmil_module.transfer_weights(self.trained_weights_path)
|
||||
outputs_handler.set_slides_dataset_for_plots_handlers(self.get_slides_dataset())
|
||||
return deepmil_module
|
||||
|
|
|
@ -11,7 +11,7 @@ Reference:
|
|||
- Schirris (2021). DeepSMILE: Self-supervised heterogeneity-aware multiple instance learning for DNA
|
||||
damage response defect classification directly from H&E whole-slide images. arXiv:2107.09405
|
||||
"""
|
||||
from typing import Any
|
||||
from typing import Any, Set
|
||||
|
||||
from health_ml.networks.layers.attention_layers import AttentionLayer
|
||||
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_crck_4ws
|
||||
|
@ -20,12 +20,14 @@ from health_cpath.datamodules.tcga_crck_module import TcgaCrckTilesDataModule
|
|||
from health_cpath.datasets.default_paths import TCGA_CRCK_DATASET_ID
|
||||
from health_cpath.models.encoders import (
|
||||
HistoSSLEncoder,
|
||||
ImageNetEncoder,
|
||||
ImageNetSimCLREncoder,
|
||||
Resnet18,
|
||||
SSLEncoder,
|
||||
)
|
||||
from health_cpath.configs.classification.BaseMIL import BaseMILTiles
|
||||
from health_cpath.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
|
||||
from health_cpath.utils.naming import PlotOption
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
|
||||
|
||||
class DeepSMILECrck(BaseMILTiles):
|
||||
|
@ -37,7 +39,7 @@ class DeepSMILECrck(BaseMILTiles):
|
|||
num_transformer_pool_layers=4,
|
||||
num_transformer_pool_heads=4,
|
||||
encoding_chunk_size=60,
|
||||
is_finetune=False,
|
||||
tune_encoder=False,
|
||||
is_caching=True,
|
||||
num_top_slides=0,
|
||||
azure_datasets=[TCGA_CRCK_DATASET_ID],
|
||||
|
@ -55,14 +57,13 @@ class DeepSMILECrck(BaseMILTiles):
|
|||
|
||||
def setup(self) -> None:
|
||||
super().setup()
|
||||
# If no SSL checkpoint is provided, use the default one
|
||||
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_crck_4ws
|
||||
|
||||
def get_data_module(self) -> TilesDataModule:
|
||||
return TcgaCrckTilesDataModule(
|
||||
root_path=self.local_datasets[0],
|
||||
max_bag_size=self.max_bag_size,
|
||||
batch_size=self.batch_size,
|
||||
batch_size_inf=self.batch_size_inf,
|
||||
max_bag_size_inf=self.max_bag_size_inf,
|
||||
transforms_dict=self.get_transforms_dict(TcgaCrck_TilesDataset.IMAGE_COLUMN),
|
||||
cache_mode=self.cache_mode,
|
||||
|
@ -72,12 +73,18 @@ class DeepSMILECrck(BaseMILTiles):
|
|||
crossval_index=self.crossval_index,
|
||||
dataloader_kwargs=self.get_dataloader_kwargs(),
|
||||
seed=self.get_effective_random_seed(),
|
||||
pl_replace_sampler_ddp=self.pl_replace_sampler_ddp,
|
||||
)
|
||||
|
||||
def get_test_plot_options(self) -> Set[PlotOption]:
|
||||
plot_options = super().get_test_plot_options()
|
||||
plot_options.add(PlotOption.PR_CURVE)
|
||||
return plot_options
|
||||
|
||||
|
||||
class TcgaCrckImageNetMIL(DeepSMILECrck):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(encoder_type=ImageNetEncoder.__name__, **kwargs)
|
||||
super().__init__(encoder_type=Resnet18.__name__, **kwargs)
|
||||
|
||||
|
||||
class TcgaCrckImageNetSimCLRMIL(DeepSMILECrck):
|
||||
|
@ -87,6 +94,8 @@ class TcgaCrckImageNetSimCLRMIL(DeepSMILECrck):
|
|||
|
||||
class TcgaCrckSSLMIL(DeepSMILECrck):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
# If no SSL checkpoint is provided, use the default one
|
||||
self.ssl_checkpoint = self.ssl_checkpoint or CheckpointParser(innereye_ssl_checkpoint_crck_4ws)
|
||||
super().__init__(encoder_type=SSLEncoder.__name__, **kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -13,8 +13,8 @@ from health_cpath.datamodules.panda_module import (
|
|||
from health_cpath.datasets.panda_tiles_dataset import PandaTilesDataset
|
||||
from health_cpath.models.encoders import (
|
||||
HistoSSLEncoder,
|
||||
ImageNetEncoder,
|
||||
ImageNetSimCLREncoder,
|
||||
Resnet18,
|
||||
SSLEncoder)
|
||||
from health_cpath.configs.classification.BaseMIL import BaseMILSlides, BaseMILTiles, BaseMIL
|
||||
from health_cpath.datasets.panda_dataset import PandaDataset
|
||||
|
@ -22,6 +22,7 @@ from health_cpath.datasets.default_paths import (
|
|||
PANDA_DATASET_ID,
|
||||
PANDA_5X_TILES_DATASET_ID)
|
||||
from health_cpath.utils.naming import PlotOption
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
|
||||
|
||||
class BaseDeepSMILEPanda(BaseMIL):
|
||||
|
@ -33,7 +34,7 @@ class BaseDeepSMILEPanda(BaseMIL):
|
|||
pool_type=AttentionLayer.__name__,
|
||||
num_transformer_pool_layers=4,
|
||||
num_transformer_pool_heads=4,
|
||||
is_finetune=False,
|
||||
tune_encoder=False,
|
||||
# average number of tiles is 56 for PANDA
|
||||
encoding_chunk_size=60,
|
||||
max_bag_size=56,
|
||||
|
@ -55,7 +56,7 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
|
|||
""" DeepSMILETilesPanda is derived from BaseMILTiles and BaseDeepSMILEPanda to inherit common behaviors from both
|
||||
tiles basemil and panda specific configuration.
|
||||
|
||||
`is_finetune` sets the fine-tuning mode. `is_finetune` sets the fine-tuning mode. For fine-tuning, batch_size = 2
|
||||
`tune_encoder` sets the fine-tuning mode of the encoder. For fine-tuning the encoder, batch_size = 2
|
||||
runs on multiple GPUs with ~ 6:24 min/epoch (train) and ~ 00:50 min/epoch (validation).
|
||||
"""
|
||||
|
||||
|
@ -64,20 +65,20 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
|
|||
# declared in BaseMILTiles:
|
||||
is_caching=False,
|
||||
batch_size=8,
|
||||
batch_size_inf=8,
|
||||
azure_datasets=[PANDA_5X_TILES_DATASET_ID, PANDA_DATASET_ID])
|
||||
default_kwargs.update(kwargs)
|
||||
super().__init__(**default_kwargs)
|
||||
|
||||
def setup(self) -> None:
|
||||
BaseMILTiles.setup(self)
|
||||
# If no SSL checkpoint is provided, use the default one
|
||||
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_binary
|
||||
|
||||
def get_data_module(self) -> PandaTilesDataModule:
|
||||
return PandaTilesDataModule(
|
||||
root_path=self.local_datasets[0],
|
||||
max_bag_size=self.max_bag_size,
|
||||
batch_size=self.batch_size,
|
||||
batch_size_inf=self.batch_size_inf,
|
||||
max_bag_size=self.max_bag_size,
|
||||
max_bag_size_inf=self.max_bag_size_inf,
|
||||
transforms_dict=self.get_transforms_dict(PandaTilesDataset.IMAGE_COLUMN),
|
||||
cache_mode=self.cache_mode,
|
||||
|
@ -87,6 +88,7 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
|
|||
crossval_index=self.crossval_index,
|
||||
dataloader_kwargs=self.get_dataloader_kwargs(),
|
||||
seed=self.get_effective_random_seed(),
|
||||
pl_replace_sampler_ddp=self.pl_replace_sampler_ddp,
|
||||
)
|
||||
|
||||
def get_slides_dataset(self) -> Optional[PandaDataset]:
|
||||
|
@ -94,13 +96,13 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
|
|||
|
||||
def get_test_plot_options(self) -> Set[PlotOption]:
|
||||
plot_options = super().get_test_plot_options()
|
||||
plot_options.add(PlotOption.SLIDE_THUMBNAIL_HEATMAP)
|
||||
plot_options.update([PlotOption.SLIDE_THUMBNAIL, PlotOption.ATTENTION_HEATMAP])
|
||||
return plot_options
|
||||
|
||||
|
||||
class TilesPandaImageNetMIL(DeepSMILETilesPanda):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(encoder_type=ImageNetEncoder.__name__, **kwargs)
|
||||
super().__init__(encoder_type=Resnet18.__name__, **kwargs)
|
||||
|
||||
|
||||
class TilesPandaImageNetSimCLRMIL(DeepSMILETilesPanda):
|
||||
|
@ -110,6 +112,8 @@ class TilesPandaImageNetSimCLRMIL(DeepSMILETilesPanda):
|
|||
|
||||
class TilesPandaSSLMIL(DeepSMILETilesPanda):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
# If no SSL checkpoint is provided, use the default one
|
||||
self.ssl_checkpoint = self.ssl_checkpoint or CheckpointParser(innereye_ssl_checkpoint_binary)
|
||||
super().__init__(encoder_type=SSLEncoder.__name__, **kwargs)
|
||||
|
||||
|
||||
|
@ -136,8 +140,6 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
|
|||
|
||||
def setup(self) -> None:
|
||||
BaseMILSlides.setup(self)
|
||||
# If no SSL checkpoint is provided, use the default one
|
||||
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_binary
|
||||
|
||||
def get_dataloader_kwargs(self) -> dict:
|
||||
return dict(
|
||||
|
@ -149,6 +151,7 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
|
|||
return PandaSlidesDataModule(
|
||||
root_path=self.local_datasets[0],
|
||||
batch_size=self.batch_size,
|
||||
batch_size_inf=self.batch_size_inf,
|
||||
level=self.level,
|
||||
max_bag_size=self.max_bag_size,
|
||||
max_bag_size_inf=self.max_bag_size_inf,
|
||||
|
@ -163,6 +166,7 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
|
|||
crossval_count=self.crossval_count,
|
||||
crossval_index=self.crossval_index,
|
||||
dataloader_kwargs=self.get_dataloader_kwargs(),
|
||||
pl_replace_sampler_ddp=self.pl_replace_sampler_ddp,
|
||||
)
|
||||
|
||||
def get_slides_dataset(self) -> PandaDataset:
|
||||
|
@ -171,7 +175,7 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
|
|||
|
||||
class SlidesPandaImageNetMIL(DeepSMILESlidesPanda):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(encoder_type=ImageNetEncoder.__name__, **kwargs)
|
||||
super().__init__(encoder_type=Resnet18.__name__, **kwargs)
|
||||
|
||||
|
||||
class SlidesPandaImageNetSimCLRMIL(DeepSMILESlidesPanda):
|
||||
|
@ -181,6 +185,8 @@ class SlidesPandaImageNetSimCLRMIL(DeepSMILESlidesPanda):
|
|||
|
||||
class SlidesPandaSSLMIL(DeepSMILESlidesPanda):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
# If no SSL checkpoint is provided, use the default one
|
||||
self.ssl_checkpoint = self.ssl_checkpoint or CheckpointParser(innereye_ssl_checkpoint_binary)
|
||||
super().__init__(encoder_type=SSLEncoder.__name__, **kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -7,24 +7,25 @@
|
|||
from typing import Any, Dict, Callable, Union
|
||||
from torch import optim
|
||||
from monai.transforms import Compose, ScaleIntensityRanged, RandRotate90d, RandFlipd
|
||||
|
||||
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_binary
|
||||
from health_azure.utils import create_from_matching_params
|
||||
from health_ml.networks.layers.attention_layers import (
|
||||
TransformerPooling,
|
||||
TransformerPoolingBenchmark
|
||||
)
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
from health_ml.deep_learning_config import OptimizerParams
|
||||
from health_cpath.datasets.panda_dataset import PandaDataset
|
||||
from health_cpath.datamodules.panda_module_benchmark import PandaSlidesDataModuleBenchmark
|
||||
from health_cpath.models.encoders import (
|
||||
HistoSSLEncoder,
|
||||
ImageNetEncoder_Resnet50,
|
||||
ImageNetSimCLREncoder,
|
||||
Resnet50_NoPreproc,
|
||||
SSLEncoder,
|
||||
)
|
||||
from health_cpath.configs.classification.DeepSMILEPanda import DeepSMILESlidesPanda
|
||||
from health_cpath.models.deepmil import SlidesDeepMILModule
|
||||
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams
|
||||
from health_cpath.utils.deepmil_utils import ClassifierParams, EncoderParams, PoolingParams
|
||||
from health_cpath.utils.naming import MetricsKey, ModelKey, SlideKey
|
||||
|
||||
|
||||
|
@ -50,9 +51,8 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
|||
"""
|
||||
Configuration for PANDA experiments from Myronenko et al. 2021:
|
||||
(https://link.springer.com/chapter/10.1007/978-3-030-87237-3_32)
|
||||
`is_finetune` sets the fine-tuning mode. For fine-tuning,
|
||||
batch_size = 2 runs on 8 GPUs with
|
||||
~ 6:24 min/epoch (train) and ~ 00:50 min/epoch (validation).
|
||||
`tune_encoder` sets the fine-tuning mode of the encoder. For fine-tuning, batch_size = 2 runs on 8 GPUs
|
||||
with ~ 6:24 min/epoch (train) and ~ 00:50 min/epoch (validation).
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
|
@ -64,6 +64,7 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
|||
encoding_chunk_size=60,
|
||||
max_bag_size=56,
|
||||
batch_size=8, # effective batch size = batch_size * num_GPUs
|
||||
batch_size_inf=8,
|
||||
max_epochs=50,
|
||||
l_rate=3e-4,
|
||||
weight_decay=0,
|
||||
|
@ -77,8 +78,9 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
|||
self.l_rate = 3e-5
|
||||
self.weight_decay = 0.1
|
||||
# Params specific to fine-tuning
|
||||
if self.is_finetune:
|
||||
if self.tune_encoder:
|
||||
self.batch_size = 2
|
||||
self.batch_size_inf = 2
|
||||
super().setup()
|
||||
|
||||
def get_transforms_dict(self, image_key: str) -> Dict[ModelKey, Union[Callable, None]]:
|
||||
|
@ -100,8 +102,9 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
|||
# Hence, inherited `PandaSlidesDataModuleBenchmark` from `SlidesDataModule`
|
||||
return PandaSlidesDataModuleBenchmark(
|
||||
root_path=self.local_datasets[0],
|
||||
max_bag_size=self.max_bag_size,
|
||||
batch_size=self.batch_size,
|
||||
batch_size_inf=self.batch_size_inf,
|
||||
max_bag_size=self.max_bag_size,
|
||||
max_bag_size_inf=self.max_bag_size_inf,
|
||||
level=self.level,
|
||||
tile_size=self.tile_size,
|
||||
|
@ -115,6 +118,7 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
|||
crossval_count=self.crossval_count,
|
||||
crossval_index=self.crossval_index,
|
||||
dataloader_kwargs=self.get_dataloader_kwargs(),
|
||||
pl_replace_sampler_ddp=self.pl_replace_sampler_ddp,
|
||||
)
|
||||
|
||||
def create_model(self) -> SlidesDeepMILModule:
|
||||
|
@ -126,13 +130,13 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
|||
n_classes=self.data_module.train_dataset.n_classes,
|
||||
class_names=self.class_names,
|
||||
class_weights=self.data_module.class_weights,
|
||||
dropout_rate=self.dropout_rate,
|
||||
outputs_folder=self.outputs_folder,
|
||||
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
|
||||
encoder_params=create_from_matching_params(self, EncoderParams),
|
||||
pooling_params=create_from_matching_params(self, PoolingParams),
|
||||
classifier_params=create_from_matching_params(self, ClassifierParams),
|
||||
optimizer_params=create_from_matching_params(self, OptimizerParams),
|
||||
outputs_handler=outputs_handler,
|
||||
analyse_loss=self.analyse_loss,
|
||||
)
|
||||
outputs_handler.set_slides_dataset_for_plots_handlers(self.get_slides_dataset())
|
||||
return deepmil_module
|
||||
|
@ -140,7 +144,7 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
|||
|
||||
class SlidesPandaImageNetMILBenchmark(DeepSMILESlidesPandaBenchmark):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(encoder_type=ImageNetEncoder_Resnet50.__name__, **kwargs)
|
||||
super().__init__(encoder_type=Resnet50_NoPreproc.__name__, **kwargs)
|
||||
|
||||
|
||||
class SlidesPandaImageNetSimCLRMILBenchmark(DeepSMILESlidesPandaBenchmark):
|
||||
|
@ -150,6 +154,8 @@ class SlidesPandaImageNetSimCLRMILBenchmark(DeepSMILESlidesPandaBenchmark):
|
|||
|
||||
class SlidesPandaSSLMILBenchmark(DeepSMILESlidesPandaBenchmark):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
# If no SSL checkpoint is provided, use the default one
|
||||
self.ssl_checkpoint = self.ssl_checkpoint or CheckpointParser(innereye_ssl_checkpoint_binary)
|
||||
super().__init__(encoder_type=SSLEncoder.__name__, **kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, Generic, Optional, Sequence, Tuple, Type
|
|||
|
||||
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
||||
from health_ml.utils.bag_utils import BagDataset, multibag_collate
|
||||
from health_ml.utils.common_utils import _create_generator
|
||||
|
@ -47,17 +47,21 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
|
|||
self,
|
||||
root_path: Path,
|
||||
batch_size: int = 1,
|
||||
batch_size_inf: Optional[int] = None,
|
||||
max_bag_size: int = 0,
|
||||
max_bag_size_inf: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
transforms_dict: Optional[Dict[ModelKey, Union[Callable, None]]] = None,
|
||||
crossval_count: int = 0,
|
||||
crossval_index: int = 0,
|
||||
pl_replace_sampler_ddp: bool = True,
|
||||
dataloader_kwargs: Optional[Dict[str, Any]] = None,
|
||||
dataframe_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param root_path: Root directory of the source dataset.
|
||||
:param batch_size: Number of slides to load per batch.
|
||||
:param batch_size_inf: Number of slides to load per batch during inference. If None, use batch_size.
|
||||
:param max_bag_size: Upper bound on number of tiles in each loaded bag during training stage. If 0 (default),
|
||||
will return all samples in each bag. If > 0 , bags larger than `max_bag_size` will yield
|
||||
random subsets of instances. For SlideDataModule, this parameter is used in TileOnGridd Transform to set the
|
||||
|
@ -73,21 +77,25 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
|
|||
By default (`None`).
|
||||
:param crossval_count: Number of folds to perform.
|
||||
:param crossval_index: Index of the cross validation split to be performed.
|
||||
:param pl_replace_sampler_ddp: If True, replace the sampler with a DistributedSampler when using DDP.
|
||||
:param dataloader_kwargs: Additional keyword arguments for the training, validation, and test dataloaders.
|
||||
:param dataframe_kwargs: Keyword arguments to pass to `pd.read_csv()` when loading the dataset CSV.
|
||||
"""
|
||||
|
||||
batch_size_inf = batch_size_inf or batch_size
|
||||
super().__init__()
|
||||
|
||||
self.root_path = root_path
|
||||
self.transforms_dict = transforms_dict
|
||||
self.batch_size = batch_size
|
||||
self.max_bag_size = max_bag_size
|
||||
self.max_bag_size_inf = max_bag_size_inf
|
||||
self.batch_sizes = {ModelKey.TRAIN: batch_size, ModelKey.VAL: batch_size_inf, ModelKey.TEST: batch_size_inf}
|
||||
self.bag_sizes = {ModelKey.TRAIN: max_bag_size, ModelKey.VAL: max_bag_size_inf, ModelKey.TEST: max_bag_size_inf}
|
||||
self.crossval_count = crossval_count
|
||||
self.crossval_index = crossval_index
|
||||
self.pl_replace_sampler_ddp = pl_replace_sampler_ddp
|
||||
self.train_dataset: _SlidesOrTilesDataset
|
||||
self.val_dataset: _SlidesOrTilesDataset
|
||||
self.test_dataset: _SlidesOrTilesDataset
|
||||
self.dataframe_kwargs = dataframe_kwargs or {}
|
||||
self.train_dataset, self.val_dataset, self.test_dataset = self.get_splits()
|
||||
self.class_weights = self.train_dataset.get_class_weights()
|
||||
self.seed = seed
|
||||
|
@ -97,6 +105,18 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
|
|||
"""Create the training, validation, and test datasets"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_dataloader(
|
||||
self, dataset: _SlidesOrTilesDataset, stage: ModelKey, shuffle: bool, **dataloader_kwargs: Any
|
||||
) -> DataLoader:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_ddp_sampler(self, dataset: Dataset, stage: ModelKey) -> Optional[DistributedSampler]:
|
||||
is_distributed = torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1
|
||||
if is_distributed and not self.pl_replace_sampler_ddp and stage == ModelKey.TRAIN:
|
||||
assert self.seed is not None, "seed must be set when using distributed training for reproducibility"
|
||||
return DistributedSampler(dataset, shuffle=True, seed=self.seed)
|
||||
return None
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
return self._get_dataloader(self.train_dataset, # type: ignore
|
||||
shuffle=True,
|
||||
|
@ -198,16 +218,11 @@ class TilesDataModule(HistoDataModule[TilesDataset]):
|
|||
|
||||
generator = _create_generator(self.seed)
|
||||
|
||||
if stage in [ModelKey.VAL, ModelKey.TEST]:
|
||||
eff_max_bag_size = self.max_bag_size_inf
|
||||
else:
|
||||
eff_max_bag_size = self.max_bag_size
|
||||
|
||||
bag_dataset = BagDataset(
|
||||
tiles_dataset, # type: ignore
|
||||
bag_ids=tiles_dataset.slide_ids,
|
||||
max_bag_size=eff_max_bag_size,
|
||||
shuffle_samples=shuffle,
|
||||
max_bag_size=self.bag_sizes[stage],
|
||||
shuffle_samples=True,
|
||||
generator=generator,
|
||||
)
|
||||
if self.transforms_dict and self.transforms_dict[stage]:
|
||||
|
@ -233,11 +248,14 @@ class TilesDataModule(HistoDataModule[TilesDataset]):
|
|||
transformed_bag_dataset = self._load_dataset(dataset, stage=stage, shuffle=shuffle)
|
||||
bag_dataset: BagDataset = transformed_bag_dataset.data # type: ignore
|
||||
generator = bag_dataset.bag_sampler.generator
|
||||
sampler = self._get_ddp_sampler(transformed_bag_dataset, stage)
|
||||
return DataLoader(
|
||||
transformed_bag_dataset,
|
||||
batch_size=self.batch_size,
|
||||
batch_size=self.batch_sizes[stage],
|
||||
collate_fn=multibag_collate,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
# sampler option is mutually exclusive with shuffle
|
||||
shuffle=shuffle if sampler is None else None, # type: ignore
|
||||
generator=generator,
|
||||
**dataloader_kwargs,
|
||||
)
|
||||
|
@ -252,13 +270,13 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
level: Optional[int] = 1,
|
||||
tile_size: Optional[int] = 224,
|
||||
level: int = 1,
|
||||
tile_size: int = 224,
|
||||
step: Optional[int] = None,
|
||||
random_offset: Optional[bool] = True,
|
||||
pad_full: Optional[bool] = False,
|
||||
background_val: Optional[int] = 255,
|
||||
filter_mode: Optional[str] = "min",
|
||||
random_offset: bool = True,
|
||||
pad_full: bool = False,
|
||||
background_val: int = 255,
|
||||
filter_mode: str = "min",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
|
@ -290,8 +308,9 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
|
|||
self.filter_mode = filter_mode
|
||||
# TileOnGridd transform expects None to select all foreground tile so we hardcode max_bag_size and
|
||||
# max_bag_size_inf to None if set to 0
|
||||
self.max_bag_size = None if self.max_bag_size == 0 else self.max_bag_size # type: ignore
|
||||
self.max_bag_size_inf = None if self.max_bag_size_inf == 0 else self.max_bag_size_inf # type: ignore
|
||||
for stage_key, max_bag_size in self.bag_sizes.items():
|
||||
if max_bag_size == 0:
|
||||
self.bag_sizes[stage_key] = None # type: ignore
|
||||
|
||||
def _load_dataset(self, slides_dataset: SlidesDataset, stage: ModelKey) -> Dataset:
|
||||
base_transform = Compose(
|
||||
|
@ -306,7 +325,7 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
|
|||
),
|
||||
TileOnGridd(
|
||||
keys=slides_dataset.IMAGE_COLUMN,
|
||||
tile_count=self.max_bag_size if stage == ModelKey.TRAIN else self.max_bag_size_inf,
|
||||
tile_count=self.bag_sizes[stage],
|
||||
tile_size=self.tile_size,
|
||||
step=self.step,
|
||||
random_offset=self.random_offset if stage == ModelKey.TRAIN else False,
|
||||
|
@ -331,11 +350,14 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
|
|||
**dataloader_kwargs: Any) -> DataLoader:
|
||||
transformed_slides_dataset = self._load_dataset(dataset, stage)
|
||||
generator = _create_generator(self.seed)
|
||||
sampler = self._get_ddp_sampler(transformed_slides_dataset, stage)
|
||||
return DataLoader(
|
||||
transformed_slides_dataset,
|
||||
batch_size=self.batch_size,
|
||||
batch_size=self.batch_sizes[stage],
|
||||
collate_fn=image_collate,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
# sampler option is mutually exclusive with shuffle
|
||||
shuffle=shuffle if not sampler else None, # type: ignore
|
||||
generator=generator,
|
||||
**dataloader_kwargs,
|
||||
)
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch
|
|||
from sklearn.utils.class_weight import compute_class_weight
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from health_cpath.utils.naming import SlideKey
|
||||
from health_cpath.utils.naming import SlideKey, TileKey
|
||||
|
||||
|
||||
DEFAULT_TRAIN_SPLIT_LABEL = "train" # Value used to indicate the training split in `SPLIT_COLUMN`
|
||||
|
@ -50,7 +50,8 @@ class TilesDataset(Dataset):
|
|||
train: Optional[bool] = None,
|
||||
validate_columns: bool = True,
|
||||
label_column: str = DEFAULT_LABEL_COLUMN,
|
||||
n_classes: int = 1) -> None:
|
||||
n_classes: int = 1,
|
||||
dataframe_kwargs: Dict[str, Any] = {}) -> None:
|
||||
"""
|
||||
:param root: Root directory of the dataset.
|
||||
:param dataset_csv: Full path to a dataset CSV file, containing at least
|
||||
|
@ -65,6 +66,7 @@ class TilesDataset(Dataset):
|
|||
for this class
|
||||
:param label_column: CSV column name for tile label. Defaults to `DEFAULT_LABEL_COLUMN="label"`.
|
||||
:param n_classes: Number of classes indexed in `label_column`. Default is 1 for binary classification.
|
||||
:param dataframe_kwargs: Keyword arguments to pass to `pd.read_csv()` when loading the dataset CSV.
|
||||
"""
|
||||
if self.SPLIT_COLUMN is None and train is not None:
|
||||
raise ValueError("Train/test split was specified but dataset has no split column")
|
||||
|
@ -77,9 +79,10 @@ class TilesDataset(Dataset):
|
|||
self.dataset_csv = None
|
||||
else:
|
||||
self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME
|
||||
dataset_df = pd.read_csv(self.dataset_csv)
|
||||
dataset_df = pd.read_csv(self.dataset_csv, **dataframe_kwargs)
|
||||
|
||||
dataset_df = dataset_df.set_index(self.TILE_ID_COLUMN)
|
||||
if dataset_df.index.name != self.TILE_ID_COLUMN:
|
||||
dataset_df = dataset_df.set_index(self.TILE_ID_COLUMN)
|
||||
if train is None:
|
||||
self.dataset_df = dataset_df
|
||||
else:
|
||||
|
@ -131,6 +134,19 @@ class TilesDataset(Dataset):
|
|||
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=slide_labels)
|
||||
return torch.as_tensor(class_weights)
|
||||
|
||||
def copy_coordinates_columns(self) -> None:
|
||||
"""Copy columns "left" --> "tile_x" and "top" --> "tile_y" to be consistent with TilesDataset `TILE_X_COLUMN`
|
||||
and `TILE_Y_COLUMN`."""
|
||||
|
||||
if TileKey.TILE_LEFT in self.dataset_df.columns:
|
||||
self.dataset_df = self.dataset_df.assign(
|
||||
**{TilesDataset.TILE_X_COLUMN: self.dataset_df[TileKey.TILE_LEFT]}
|
||||
)
|
||||
if TileKey.TILE_TOP in self.dataset_df.columns:
|
||||
self.dataset_df = self.dataset_df.assign(
|
||||
**{TilesDataset.TILE_Y_COLUMN: self.dataset_df[TileKey.TILE_TOP]}
|
||||
)
|
||||
|
||||
|
||||
class SlidesDataset(Dataset):
|
||||
"""Base class for datasets of WSIs, iterating dictionaries of image paths and metadata.
|
||||
|
@ -158,7 +174,8 @@ class SlidesDataset(Dataset):
|
|||
train: Optional[bool] = None,
|
||||
validate_columns: bool = True,
|
||||
label_column: str = DEFAULT_LABEL_COLUMN,
|
||||
n_classes: int = 1) -> None:
|
||||
n_classes: int = 1,
|
||||
dataframe_kwargs: Dict[str, Any] = {}) -> None:
|
||||
"""
|
||||
:param root: Root directory of the dataset.
|
||||
:param dataset_csv: Full path to a dataset CSV file, containing at least
|
||||
|
@ -173,6 +190,7 @@ class SlidesDataset(Dataset):
|
|||
for this class
|
||||
:param label_column: CSV column name for tile label. Default is `DEFAULT_LABEL_COLUMN="label"`.
|
||||
:param n_classes: Number of classes indexed in `label_column`. Default is 1 for binary classification.
|
||||
:param dataframe_kwargs: Keyword arguments to pass to `pd.read_csv()` when loading the dataset CSV.
|
||||
"""
|
||||
if self.SPLIT_COLUMN is None and train is not None:
|
||||
raise ValueError("Train/test split was specified but dataset has no split column")
|
||||
|
@ -180,14 +198,16 @@ class SlidesDataset(Dataset):
|
|||
self.root_dir = Path(root)
|
||||
self.label_column = label_column
|
||||
self.n_classes = n_classes
|
||||
self.dataframe_kwargs = dataframe_kwargs
|
||||
|
||||
if dataset_df is not None:
|
||||
self.dataset_csv = None
|
||||
else:
|
||||
self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME
|
||||
dataset_df = pd.read_csv(self.dataset_csv)
|
||||
dataset_df = pd.read_csv(self.dataset_csv, **self.dataframe_kwargs)
|
||||
|
||||
dataset_df = dataset_df.set_index(self.SLIDE_ID_COLUMN)
|
||||
if dataset_df.index.name != self.SLIDE_ID_COLUMN:
|
||||
dataset_df = dataset_df.set_index(self.SLIDE_ID_COLUMN)
|
||||
if train is None:
|
||||
self.dataset_df = dataset_df
|
||||
else:
|
||||
|
@ -203,12 +223,14 @@ class SlidesDataset(Dataset):
|
|||
If the constructor is overloaded in a subclass, you can pass `validate_columns=False` and
|
||||
call `validate_columns()` after creating derived columns, for example.
|
||||
"""
|
||||
columns = [self.IMAGE_COLUMN, self.label_column, self.MASK_COLUMN,
|
||||
self.SPLIT_COLUMN] + list(self.METADATA_COLUMNS)
|
||||
columns_not_found = []
|
||||
for column in columns:
|
||||
if column is not None and column not in self.dataset_df.columns:
|
||||
columns_not_found.append(column)
|
||||
mandatory_columns = {self.IMAGE_COLUMN, self.label_column, self.MASK_COLUMN, self.SPLIT_COLUMN}
|
||||
optional_columns = (
|
||||
set(self.dataframe_kwargs["usecols"]) if "usecols" in self.dataframe_kwargs else set(self.METADATA_COLUMNS)
|
||||
)
|
||||
columns = mandatory_columns.union(optional_columns)
|
||||
# SLIDE_ID_COLUMN is used for indexing and is not in df.columns anymore
|
||||
# None might be in columns if SPLITS_COLUMN is None
|
||||
columns_not_found = columns - set(self.dataset_df.columns) - {None, self.SLIDE_ID_COLUMN}
|
||||
if len(columns_not_found) > 0:
|
||||
raise ValueError(f"Expected columns '{columns_not_found}' not found in the dataframe")
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import pandas as pd
|
|||
from monai.config import KeysCollection
|
||||
from monai.data.image_reader import ImageReader, WSIReader
|
||||
from monai.transforms import MapTransform
|
||||
from health_cpath.utils.naming import SlideKey
|
||||
|
||||
from health_ml.utils import box_utils
|
||||
|
||||
|
@ -42,9 +43,10 @@ class PandaDataset(SlidesDataset):
|
|||
dataset_csv: Optional[Union[str, Path]] = None,
|
||||
dataset_df: Optional[pd.DataFrame] = None,
|
||||
label_column: str = "isup_grade",
|
||||
n_classes: int = 6) -> None:
|
||||
n_classes: int = 6,
|
||||
dataframe_kwargs: Dict[str, Any] = {}) -> None:
|
||||
super().__init__(root, dataset_csv, dataset_df, validate_columns=False, label_column=label_column,
|
||||
n_classes=n_classes)
|
||||
n_classes=n_classes, dataframe_kwargs=dataframe_kwargs)
|
||||
# PANDA CSV does not come with paths for image and mask files
|
||||
slide_ids = self.dataset_df.index
|
||||
self.dataset_df[self.IMAGE_COLUMN] = "train_images/" + slide_ids + ".tiff"
|
||||
|
@ -120,8 +122,9 @@ class LoadPandaROId(MapTransform):
|
|||
# but relative region size in pixels at the chosen level
|
||||
scale = mask_obj.resolutions['level_downsamples'][self.level]
|
||||
scaled_bbox = level0_bbox / scale
|
||||
origin = (level0_bbox.y, level0_bbox.x)
|
||||
get_data_kwargs = dict(
|
||||
location=(level0_bbox.y, level0_bbox.x),
|
||||
location=origin,
|
||||
size=(scaled_bbox.h, scaled_bbox.w),
|
||||
level=self.level,
|
||||
)
|
||||
|
@ -129,7 +132,8 @@ class LoadPandaROId(MapTransform):
|
|||
data[self.mask_key] = mask[:1] # PANDA segmentation mask is in 'R' channel
|
||||
data[self.image_key], _ = self.reader.get_data(image_obj, **get_data_kwargs) # type: ignore
|
||||
data.update(get_data_kwargs)
|
||||
data['scale'] = scale
|
||||
data[SlideKey.SCALE] = scale
|
||||
data[SlideKey.ORIGIN] = origin
|
||||
|
||||
mask_obj.close()
|
||||
image_obj.close()
|
||||
|
|
|
@ -11,7 +11,6 @@ from torchvision.datasets.vision import VisionDataset
|
|||
|
||||
from health_cpath.datasets.base_dataset import TilesDataset
|
||||
from health_cpath.models.transforms import load_pil_image
|
||||
from health_cpath.utils.naming import TileKey
|
||||
|
||||
from health_cpath.datasets.dataset_return_index import DatasetWithReturnIndex
|
||||
|
||||
|
@ -73,12 +72,7 @@ class PandaTilesDataset(TilesDataset):
|
|||
dataset_df_filtered = self.dataset_df.sample(n=df_length_random_subset_fraction)
|
||||
self.dataset_df = dataset_df_filtered
|
||||
|
||||
# Copy columns "left" --> "tile_x" and "top" --> "tile_y"
|
||||
# to be consistent with TilesDataset `TILE_X_COLUMN` and `TILE_Y_COLUMN`
|
||||
if TileKey.TILE_LEFT in self.dataset_df.columns:
|
||||
self.dataset_df[TilesDataset.TILE_X_COLUMN] = self.dataset_df[TileKey.TILE_LEFT]
|
||||
if TileKey.TILE_TOP in self.dataset_df.columns:
|
||||
self.dataset_df[TilesDataset.TILE_Y_COLUMN] = self.dataset_df[TileKey.TILE_TOP]
|
||||
self.copy_coordinates_columns()
|
||||
self.validate_columns()
|
||||
|
||||
|
||||
|
|
|
@ -6,20 +6,20 @@ import torch
|
|||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pathlib import Path
|
||||
|
||||
from pytorch_lightning import LightningModule
|
||||
from torch import Tensor, argmax, mode, nn, optim, round, set_grad_enabled
|
||||
from torchmetrics import AUROC, F1Score, Accuracy, ConfusionMatrix, Precision, Recall, CohenKappa # type: ignore
|
||||
from torch import Tensor, argmax, mode, nn, optim, round
|
||||
from torchmetrics import (AUROC, F1Score, Accuracy, ConfusionMatrix, Precision,
|
||||
Recall, CohenKappa, AveragePrecision, Specificity)
|
||||
|
||||
from health_ml.utils import log_on_epoch
|
||||
from health_ml.deep_learning_config import OptimizerParams
|
||||
from health_cpath.models.encoders import IdentityEncoder
|
||||
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams
|
||||
|
||||
from health_cpath.utils.deepmil_utils import ClassifierParams, EncoderParams, PoolingParams
|
||||
from health_cpath.datasets.base_dataset import TilesDataset
|
||||
from health_cpath.utils.naming import MetricsKey, ResultsKey, SlideKey, ModelKey, TileKey
|
||||
from health_cpath.utils.naming import DeepMILSubmodules, MetricsKey, ResultsKey, SlideKey, ModelKey, TileKey
|
||||
from health_cpath.utils.output_utils import (BatchResultsType, DeepMILOutputsHandler, EpochResultsType,
|
||||
validate_class_names)
|
||||
validate_class_names, EXTRA_PREFIX)
|
||||
|
||||
|
||||
RESULTS_COLS = [ResultsKey.SLIDE_ID, ResultsKey.TILE_ID, ResultsKey.IMAGE_PATH, ResultsKey.PROB,
|
||||
ResultsKey.CLASS_PROBS, ResultsKey.PRED_LABEL, ResultsKey.TRUE_LABEL, ResultsKey.BAG_ATTN]
|
||||
|
@ -39,31 +39,31 @@ class BaseDeepMILModule(LightningModule):
|
|||
n_classes: int,
|
||||
class_weights: Optional[Tensor] = None,
|
||||
class_names: Optional[Sequence[str]] = None,
|
||||
dropout_rate: Optional[float] = None,
|
||||
verbose: bool = False,
|
||||
ssl_ckpt_run_id: Optional[str] = None,
|
||||
outputs_folder: Optional[Path] = None,
|
||||
encoder_params: EncoderParams = EncoderParams(),
|
||||
pooling_params: PoolingParams = PoolingParams(),
|
||||
classifier_params: ClassifierParams = ClassifierParams(),
|
||||
optimizer_params: OptimizerParams = OptimizerParams(),
|
||||
outputs_handler: Optional[DeepMILOutputsHandler] = None) -> None:
|
||||
outputs_folder: Optional[Path] = None,
|
||||
outputs_handler: Optional[DeepMILOutputsHandler] = None,
|
||||
analyse_loss: Optional[bool] = False,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
:param label_column: Label key for input batch dictionary.
|
||||
:param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes should be
|
||||
set to 1.
|
||||
:param class_weights: Tensor containing class weights (default=None).
|
||||
:param class_names: The names of the classes if available (default=None).
|
||||
:param dropout_rate: Rate of pre-classifier dropout (0-1). `None` for no dropout (default).
|
||||
:param verbose: if True statements about memory usage are output at each step.
|
||||
:param ssl_ckpt_run_id: Optional parameter to provide the AML run id from where to download the checkpoint
|
||||
if using `SSLEncoder`.
|
||||
:param outputs_folder: Path to output folder where encoder checkpoint is downloaded.
|
||||
:param encoder_params: Encoder parameters that specify all encoder specific attributes.
|
||||
:param pooling_params: Pooling layer parameters that specify all encoder specific attributes.
|
||||
:param classifier_params: Classifier parameters that specify all classifier specific attributes.
|
||||
:param optimizer_params: Optimizer parameters that specify all specific attributes to be used for oprimization.
|
||||
:param outputs_folder: Path to output folder where encoder checkpoint is downloaded.
|
||||
:param outputs_handler: A configured :py:class:`DeepMILOutputsHandler` object to save outputs for the best
|
||||
validation epoch and test stage. If omitted (default), no outputs will be saved to disk (aside from usual
|
||||
metrics logging).
|
||||
:param analyse_loss: If True, the loss is analysed per sample and analysed with LossAnalysisCallback.
|
||||
:param verbose: if True statements about memory usage are output at each step.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
@ -73,8 +73,9 @@ class BaseDeepMILModule(LightningModule):
|
|||
self.class_weights = class_weights
|
||||
self.class_names = validate_class_names(class_names, self.n_classes)
|
||||
|
||||
self.dropout_rate = dropout_rate
|
||||
self.encoder_params = encoder_params
|
||||
self.pooling_params = pooling_params
|
||||
self.classifier_params = classifier_params
|
||||
self.optimizer_params = optimizer_params
|
||||
|
||||
self.save_hyperparameters()
|
||||
|
@ -82,44 +83,84 @@ class BaseDeepMILModule(LightningModule):
|
|||
self.outputs_handler = outputs_handler
|
||||
|
||||
# This flag can be switched on before invoking trainer.validate() to enable saving additional time/memory
|
||||
# consuming validation outputs
|
||||
self.run_extra_val_epoch = False
|
||||
# consuming validation outputs via calling self.on_run_extra_validation_epoch()
|
||||
self._on_extra_val_epoch = False
|
||||
|
||||
# Model components
|
||||
self.encoder = encoder_params.get_encoder(ssl_ckpt_run_id, outputs_folder)
|
||||
self.encoder = encoder_params.get_encoder(outputs_folder)
|
||||
self.aggregation_fn, self.num_pooling = pooling_params.get_pooling_layer(self.encoder.num_encoding)
|
||||
self.classifier_fn = self.get_classifier()
|
||||
self.classifier_fn = classifier_params.get_classifier(self.num_pooling, self.n_classes)
|
||||
self.activation_fn = self.get_activation()
|
||||
|
||||
self.loss_fn = self.get_loss()
|
||||
# Loss function
|
||||
self.loss_fn = self.get_loss(reduction="mean")
|
||||
self.loss_fn_no_reduction = self.get_loss(reduction="none")
|
||||
self.analyse_loss = analyse_loss
|
||||
|
||||
# Metrics Objects
|
||||
self.train_metrics = self.get_metrics()
|
||||
self.val_metrics = self.get_metrics()
|
||||
self.test_metrics = self.get_metrics()
|
||||
|
||||
def get_classifier(self) -> Callable:
|
||||
classifier_layer = nn.Linear(in_features=self.num_pooling,
|
||||
out_features=self.n_classes)
|
||||
if self.dropout_rate is None:
|
||||
return classifier_layer
|
||||
elif 0 <= self.dropout_rate < 1:
|
||||
return nn.Sequential(nn.Dropout(self.dropout_rate), classifier_layer)
|
||||
else:
|
||||
raise ValueError(f"Dropout rate should be in [0, 1), got {self.dropout_rate}")
|
||||
@staticmethod
|
||||
def copy_weights(
|
||||
current_submodule: nn.Module, pretrained_submodule: nn.Module, submodule_name: DeepMILSubmodules
|
||||
) -> None:
|
||||
"""Copy weights from pretrained submodule to current submodule.
|
||||
|
||||
def get_loss(self) -> Callable:
|
||||
:param current_submodule: Submodule to copy weights to.
|
||||
:param pretrained_submodule: Submodule to copy weights from.
|
||||
:param submodule_name: Name of the submodule.
|
||||
"""
|
||||
|
||||
def _total_params(submodule: nn.Module) -> int:
|
||||
return sum(p.numel() for p in submodule.parameters())
|
||||
|
||||
pre_total_params = _total_params(pretrained_submodule)
|
||||
cur_total_params = _total_params(current_submodule)
|
||||
|
||||
if pre_total_params != cur_total_params:
|
||||
raise ValueError(f"Submodule {submodule_name} has different number of parameters "
|
||||
f"({cur_total_params} vs {pre_total_params}) from pretrained model.")
|
||||
|
||||
for param, pretrained_param in zip(
|
||||
current_submodule.state_dict().values(), pretrained_submodule.state_dict().values()
|
||||
):
|
||||
try:
|
||||
param.data.copy_(pretrained_param.data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to copy weights for {submodule_name} because of the following exception: {e}")
|
||||
|
||||
def transfer_weights(self, pretrained_checkpoint_path: Optional[Path]) -> None:
|
||||
"""Transfer weights from pretrained checkpoint if provided."""
|
||||
|
||||
if pretrained_checkpoint_path:
|
||||
pretrained_model = self.load_from_checkpoint(checkpoint_path=str(pretrained_checkpoint_path))
|
||||
|
||||
if self.encoder_params.pretrained_encoder:
|
||||
self.copy_weights(self.encoder, pretrained_model.encoder, DeepMILSubmodules.ENCODER)
|
||||
|
||||
if self.pooling_params.pretrained_pooling:
|
||||
self.copy_weights(self.aggregation_fn, pretrained_model.aggregation_fn, DeepMILSubmodules.POOLING)
|
||||
|
||||
if self.classifier_params.pretrained_classifier:
|
||||
if pretrained_model.n_classes != self.n_classes:
|
||||
raise ValueError(f"Number of classes in pretrained model {pretrained_model.n_classes} "
|
||||
f"does not match number of classes in current model {self.n_classes}.")
|
||||
self.copy_weights(self.classifier_fn, pretrained_model.classifier_fn, DeepMILSubmodules.CLASSIFIER)
|
||||
|
||||
def get_loss(self, reduction: str = "mean") -> Callable:
|
||||
if self.n_classes > 1:
|
||||
if self.class_weights is None:
|
||||
return nn.CrossEntropyLoss()
|
||||
return nn.CrossEntropyLoss(reduction=reduction)
|
||||
else:
|
||||
class_weights = self.class_weights.float()
|
||||
return nn.CrossEntropyLoss(weight=class_weights)
|
||||
return nn.CrossEntropyLoss(weight=class_weights, reduction=reduction)
|
||||
else:
|
||||
pos_weight = None
|
||||
if self.class_weights is not None:
|
||||
pos_weight = Tensor([self.class_weights[1] / (self.class_weights[0] + 1e-5)])
|
||||
return nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
||||
return nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction=reduction)
|
||||
|
||||
def get_activation(self) -> Callable:
|
||||
if self.n_classes > 1:
|
||||
|
@ -133,26 +174,38 @@ class BaseDeepMILModule(LightningModule):
|
|||
|
||||
def get_metrics(self) -> nn.ModuleDict:
|
||||
if self.n_classes > 1:
|
||||
return nn.ModuleDict({MetricsKey.ACC: Accuracy(num_classes=self.n_classes),
|
||||
MetricsKey.ACC_MACRO: Accuracy(num_classes=self.n_classes, average='macro'),
|
||||
MetricsKey.ACC_WEIGHTED: Accuracy(num_classes=self.n_classes, average='weighted'),
|
||||
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
|
||||
# Quadratic Weighted Kappa (QWK) used in PANDA challenge
|
||||
# is calculated using Cohen's Kappa with quadratic weights
|
||||
# https://www.kaggle.com/code/reighns/understanding-the-quadratic-weighted-kappa/
|
||||
MetricsKey.COHENKAPPA: CohenKappa(num_classes=self.n_classes, weights='quadratic'),
|
||||
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes)})
|
||||
return nn.ModuleDict({
|
||||
MetricsKey.ACC: Accuracy(num_classes=self.n_classes),
|
||||
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
|
||||
MetricsKey.AVERAGE_PRECISION: AveragePrecision(num_classes=self.n_classes),
|
||||
# Quadratic Weighted Kappa (QWK) used in PANDA challenge
|
||||
# is calculated using Cohen's Kappa with quadratic weights
|
||||
# https://www.kaggle.com/code/reighns/understanding-the-quadratic-weighted-kappa/
|
||||
MetricsKey.COHENKAPPA: CohenKappa(num_classes=self.n_classes, weights='quadratic'),
|
||||
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes),
|
||||
# Metrics below are computed for multi-class case only
|
||||
MetricsKey.ACC_MACRO: Accuracy(num_classes=self.n_classes, average='macro'),
|
||||
MetricsKey.ACC_WEIGHTED: Accuracy(num_classes=self.n_classes, average='weighted')})
|
||||
else:
|
||||
threshold = 0.5
|
||||
return nn.ModuleDict({MetricsKey.ACC: Accuracy(threshold=threshold),
|
||||
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
|
||||
MetricsKey.PRECISION: Precision(threshold=threshold),
|
||||
MetricsKey.RECALL: Recall(threshold=threshold),
|
||||
MetricsKey.F1: F1Score(threshold=threshold),
|
||||
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=2, threshold=threshold)})
|
||||
return nn.ModuleDict({
|
||||
MetricsKey.ACC: Accuracy(),
|
||||
MetricsKey.AUROC: AUROC(num_classes=None),
|
||||
# Average precision is a measure of area under the PR curve
|
||||
MetricsKey.AVERAGE_PRECISION: AveragePrecision(),
|
||||
MetricsKey.COHENKAPPA: CohenKappa(num_classes=2, weights='quadratic'),
|
||||
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=2),
|
||||
# Metrics below are computed for binary case only
|
||||
MetricsKey.F1: F1Score(),
|
||||
MetricsKey.PRECISION: Precision(),
|
||||
MetricsKey.RECALL: Recall(),
|
||||
MetricsKey.SPECIFICITY: Specificity()})
|
||||
|
||||
def log_metrics(self, stage: str) -> None:
|
||||
valid_stages = [stage for stage in ModelKey]
|
||||
def get_extra_prefix(self) -> str:
|
||||
"""Get extra prefix for the metrics name to avoir overriding best validation metrics."""
|
||||
return EXTRA_PREFIX if self._on_extra_val_epoch else ""
|
||||
|
||||
def log_metrics(self, stage: str, prefix: str = '') -> None:
|
||||
valid_stages = set([stage for stage in ModelKey])
|
||||
if stage not in valid_stages:
|
||||
raise Exception(f"Invalid stage. Chose one of {valid_stages}")
|
||||
for metric_name, metric_object in self.get_metrics_dict(stage).items():
|
||||
|
@ -160,29 +213,46 @@ class BaseDeepMILModule(LightningModule):
|
|||
metric_value = metric_object.compute()
|
||||
metric_value_n = metric_value / metric_value.sum(axis=1, keepdims=True)
|
||||
for i in range(metric_value_n.shape[0]):
|
||||
log_on_epoch(self, f'{stage}/{self.class_names[i]}', metric_value_n[i, i])
|
||||
log_on_epoch(self, f'{prefix}{stage}/{self.class_names[i]}', metric_value_n[i, i])
|
||||
else:
|
||||
log_on_epoch(self, f'{stage}/{metric_name}', metric_object)
|
||||
log_on_epoch(self, f'{prefix}{stage}/{metric_name}', metric_object)
|
||||
|
||||
def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
|
||||
should_enable_encoder_grad = torch.is_grad_enabled() and self.encoder_params.is_finetune
|
||||
with set_grad_enabled(should_enable_encoder_grad):
|
||||
if self.encoder_params.encoding_chunk_size > 0:
|
||||
embeddings = []
|
||||
chunks = torch.split(instances, self.encoder_params.encoding_chunk_size)
|
||||
for chunk in chunks:
|
||||
chunk_embeddings = self.encoder(chunk)
|
||||
embeddings.append(chunk_embeddings)
|
||||
instance_features = torch.cat(embeddings)
|
||||
else:
|
||||
instance_features = self.encoder(instances) # N X L x 1 x 1
|
||||
def get_instance_features(self, instances: Tensor) -> Tensor:
|
||||
if not self.encoder_params.tune_encoder:
|
||||
self.encoder.eval()
|
||||
if self.encoder_params.encoding_chunk_size > 0:
|
||||
embeddings = []
|
||||
chunks = torch.split(instances, self.encoder_params.encoding_chunk_size)
|
||||
for chunk in chunks:
|
||||
chunk_embeddings = self.encoder(chunk)
|
||||
embeddings.append(chunk_embeddings)
|
||||
instance_features = torch.cat(embeddings)
|
||||
else:
|
||||
instance_features = self.encoder(instances) # N X L x 1 x 1
|
||||
return instance_features
|
||||
|
||||
def get_attentions_and_bag_features(self, instance_features: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
if not self.pooling_params.tune_pooling:
|
||||
self.aggregation_fn.eval()
|
||||
attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L
|
||||
bag_features = bag_features.view(1, -1)
|
||||
return attentions, bag_features
|
||||
|
||||
def get_bag_logit(self, bag_features: Tensor) -> Tensor:
|
||||
if not self.classifier_params.tune_classifier:
|
||||
self.classifier_fn.eval()
|
||||
bag_logit = self.classifier_fn(bag_features)
|
||||
return bag_logit
|
||||
|
||||
def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
|
||||
instance_features = self.get_instance_features(instances)
|
||||
attentions, bag_features = self.get_attentions_and_bag_features(instance_features)
|
||||
bag_logit = self.get_bag_logit(bag_features)
|
||||
return bag_logit, attentions
|
||||
|
||||
def configure_optimizers(self) -> optim.Optimizer:
|
||||
return optim.Adam(self.parameters(), lr=self.optimizer_params.l_rate,
|
||||
return optim.Adam(filter(lambda p: p.requires_grad, self.parameters()),
|
||||
lr=self.optimizer_params.l_rate,
|
||||
weight_decay=self.optimizer_params.weight_decay,
|
||||
betas=self.optimizer_params.adam_betas)
|
||||
|
||||
|
@ -212,21 +282,32 @@ class BaseDeepMILModule(LightningModule):
|
|||
"""Update training results with data specific info. This can be either tiles or slides related metadata."""
|
||||
raise NotImplementedError
|
||||
|
||||
def update_slides_selection(self, stage: str, batch: Dict, results: Dict) -> None:
|
||||
if (
|
||||
(stage == ModelKey.TEST or (stage == ModelKey.VAL and self._on_extra_val_epoch))
|
||||
and self.outputs_handler
|
||||
and self.outputs_handler.tiles_selector
|
||||
):
|
||||
self.outputs_handler.tiles_selector.update_slides_selection(batch, results)
|
||||
|
||||
def _compute_loss(self, loss_fn: Callable, bag_logits: Tensor, bag_labels: Tensor) -> Tensor:
|
||||
if self.n_classes > 1:
|
||||
return loss_fn(bag_logits, bag_labels.long())
|
||||
else:
|
||||
return loss_fn(bag_logits.squeeze(1), bag_labels.float())
|
||||
|
||||
def _shared_step(self, batch: Dict, batch_idx: int, stage: str) -> BatchResultsType:
|
||||
|
||||
bag_logits, bag_labels, bag_attn_list = self.compute_bag_labels_logits_and_attn_maps(batch)
|
||||
|
||||
if self.n_classes > 1:
|
||||
loss = self.loss_fn(bag_logits, bag_labels.long())
|
||||
else:
|
||||
loss = self.loss_fn(bag_logits.squeeze(1), bag_labels.float())
|
||||
loss = self._compute_loss(self.loss_fn, bag_logits, bag_labels)
|
||||
|
||||
predicted_probs = self.activation_fn(bag_logits)
|
||||
if self.n_classes > 1:
|
||||
predicted_labels = argmax(predicted_probs, dim=1)
|
||||
probs_perclass = predicted_probs
|
||||
else:
|
||||
predicted_labels = round(predicted_probs)
|
||||
predicted_labels = round(predicted_probs).int()
|
||||
probs_perclass = Tensor([[1.0 - predicted_probs[i][0].item(), predicted_probs[i][0].item()]
|
||||
for i in range(len(predicted_probs))])
|
||||
|
||||
|
@ -237,9 +318,14 @@ class BaseDeepMILModule(LightningModule):
|
|||
if self.n_classes == 1:
|
||||
predicted_probs = predicted_probs.squeeze(dim=1)
|
||||
|
||||
results = dict()
|
||||
|
||||
if self.analyse_loss and stage in [ModelKey.TRAIN, ModelKey.VAL]:
|
||||
loss_per_sample = self._compute_loss(self.loss_fn_no_reduction, bag_logits, bag_labels)
|
||||
results[ResultsKey.LOSS_PER_SAMPLE] = loss_per_sample.detach().cpu().numpy()
|
||||
|
||||
bag_labels = bag_labels.view(-1, 1)
|
||||
|
||||
results = dict()
|
||||
for metric_object in self.get_metrics_dict(stage).values():
|
||||
metric_object.update(predicted_probs, bag_labels.view(batch_size,).int())
|
||||
results.update({ResultsKey.LOSS: loss,
|
||||
|
@ -250,56 +336,46 @@ class BaseDeepMILModule(LightningModule):
|
|||
ResultsKey.BAG_ATTN: bag_attn_list
|
||||
})
|
||||
self.update_results_with_data_specific_info(batch=batch, results=results)
|
||||
if (
|
||||
(stage == ModelKey.TEST or (stage == ModelKey.VAL and self.run_extra_val_epoch))
|
||||
and self.outputs_handler
|
||||
and self.outputs_handler.tiles_selector
|
||||
):
|
||||
self.outputs_handler.tiles_selector.update_slides_selection(batch, results)
|
||||
self.update_slides_selection(stage=stage, batch=batch, results=results)
|
||||
return results
|
||||
|
||||
def training_step(self, batch: Dict, batch_idx: int) -> Tensor: # type: ignore
|
||||
def training_step(self, batch: Dict, batch_idx: int) -> BatchResultsType: # type: ignore
|
||||
train_result = self._shared_step(batch, batch_idx, ModelKey.TRAIN)
|
||||
self.log('train/loss', train_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
|
||||
sync_dist=True)
|
||||
self.log('train/loss', train_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True, sync_dist=True)
|
||||
if self.verbose:
|
||||
print(f"After loading images batch {batch_idx} -", _format_cuda_memory_stats())
|
||||
return train_result[ResultsKey.LOSS]
|
||||
results = {ResultsKey.LOSS: train_result[ResultsKey.LOSS]}
|
||||
if self.analyse_loss:
|
||||
results.update({ResultsKey.LOSS_PER_SAMPLE: train_result[ResultsKey.LOSS_PER_SAMPLE],
|
||||
ResultsKey.CLASS_PROBS: train_result[ResultsKey.CLASS_PROBS],
|
||||
ResultsKey.TILE_ID: train_result[ResultsKey.TILE_ID]})
|
||||
return results
|
||||
|
||||
def validation_step(self, batch: Dict, batch_idx: int) -> BatchResultsType: # type: ignore
|
||||
val_result = self._shared_step(batch, batch_idx, ModelKey.VAL)
|
||||
self.log('val/loss', val_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
|
||||
sync_dist=True)
|
||||
name = f'{self.get_extra_prefix()}val/loss'
|
||||
self.log(name, val_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True, sync_dist=True)
|
||||
return val_result
|
||||
|
||||
def test_step(self, batch: Dict, batch_idx: int) -> BatchResultsType: # type: ignore
|
||||
test_result = self._shared_step(batch, batch_idx, ModelKey.TEST)
|
||||
self.log('test/loss', test_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
|
||||
sync_dist=True)
|
||||
self.log('test/loss', test_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True, sync_dist=True)
|
||||
return test_result
|
||||
|
||||
def training_epoch_end(self, outputs: EpochResultsType) -> None: # type: ignore
|
||||
self.log_metrics(ModelKey.TRAIN)
|
||||
|
||||
def validation_epoch_end(self, epoch_results: EpochResultsType) -> None: # type: ignore
|
||||
self.log_metrics(ModelKey.VAL)
|
||||
self.log_metrics(stage=ModelKey.VAL, prefix=self.get_extra_prefix())
|
||||
if self.outputs_handler:
|
||||
if self.run_extra_val_epoch:
|
||||
self.outputs_handler.val_plots_handler.plot_options = (
|
||||
self.outputs_handler.test_plots_handler.plot_options
|
||||
)
|
||||
self.outputs_handler.save_validation_outputs(
|
||||
epoch_results=epoch_results,
|
||||
metrics_dict=self.get_metrics_dict(ModelKey.VAL), # type: ignore
|
||||
epoch=self.current_epoch,
|
||||
is_global_rank_zero=self.global_rank == 0,
|
||||
run_extra_val_epoch=self.run_extra_val_epoch
|
||||
on_extra_val=self._on_extra_val_epoch
|
||||
)
|
||||
|
||||
# Reset the top and bottom slides heaps
|
||||
if self.outputs_handler.tiles_selector is not None:
|
||||
self.outputs_handler.tiles_selector._clear_cached_slides_heaps()
|
||||
|
||||
def test_epoch_end(self, epoch_results: EpochResultsType) -> None: # type: ignore
|
||||
self.log_metrics(ModelKey.TEST)
|
||||
if self.outputs_handler:
|
||||
|
@ -308,6 +384,13 @@ class BaseDeepMILModule(LightningModule):
|
|||
is_global_rank_zero=self.global_rank == 0
|
||||
)
|
||||
|
||||
def on_run_extra_validation_epoch(self) -> None:
|
||||
"""Hook to be called at the beginning of an extra validation epoch to set validation plots options to the same
|
||||
as the test plots options."""
|
||||
self._on_extra_val_epoch = True
|
||||
if self.outputs_handler:
|
||||
self.outputs_handler.val_plots_handler.plot_options = self.outputs_handler.test_plots_handler.plot_options
|
||||
|
||||
|
||||
class TilesDeepMILModule(BaseDeepMILModule):
|
||||
"""Base class for Tiles based deep multiple-instance learning."""
|
||||
|
|
|
@ -10,8 +10,8 @@ import numpy as np
|
|||
import torch
|
||||
from pl_bolts.models.self_supervised import SimCLR
|
||||
from torch import Tensor as T, nn
|
||||
from torchvision.models import resnet18
|
||||
from torchvision.transforms import Compose
|
||||
from torchvision.models import resnet18, resnet50
|
||||
from monai.transforms import Compose
|
||||
|
||||
from health_cpath.utils.layer_utils import (get_imagenet_preprocessing,
|
||||
load_weights_to_model,
|
||||
|
@ -62,24 +62,49 @@ class ImageNetEncoder(TileEncoder):
|
|||
"""Feature extractor pretrained for classification on ImageNet"""
|
||||
|
||||
def __init__(self, feature_extraction_model: Callable[..., nn.Module],
|
||||
tile_size: int, n_channels: int = 3) -> None:
|
||||
tile_size: int, n_channels: int = 3, apply_imagenet_preprocessing: bool = True) -> None:
|
||||
"""
|
||||
:param feature_extraction_model: A function accepting a `pretrained` keyword argument that
|
||||
returns a classifier pretrained on ImageNet, such as the ones from `torchvision.models.*`.
|
||||
:param tile_size: Tile width/height, in pixels.
|
||||
:param n_channels: Number of channels in the tile (default=3).
|
||||
:param apply_imagenet_preprocessing: Whether to apply ImageNet preprocessing to the input.
|
||||
"""
|
||||
self.create_feature_extractor_fn = feature_extraction_model
|
||||
self.apply_imagenet_preprocessing = apply_imagenet_preprocessing
|
||||
super().__init__(tile_size=tile_size, n_channels=n_channels)
|
||||
|
||||
def _get_preprocessing(self) -> Callable:
|
||||
return get_imagenet_preprocessing()
|
||||
base_preprocessing = super()._get_preprocessing()
|
||||
if self.apply_imagenet_preprocessing:
|
||||
return Compose([get_imagenet_preprocessing(), base_preprocessing]).flatten()
|
||||
return base_preprocessing
|
||||
|
||||
def _get_encoder(self) -> Tuple[torch.nn.Module, int]:
|
||||
pretrained_model = self.create_feature_extractor_fn(pretrained=True)
|
||||
return setup_feature_extractor(pretrained_model, self.input_dim)
|
||||
|
||||
|
||||
class Resnet18(ImageNetEncoder):
|
||||
def __init__(self, tile_size: int, n_channels: int = 3) -> None:
|
||||
super().__init__(resnet18, tile_size, n_channels, apply_imagenet_preprocessing=True)
|
||||
|
||||
|
||||
class Resnet18_NoPreproc(ImageNetEncoder):
|
||||
def __init__(self, tile_size: int, n_channels: int = 3) -> None:
|
||||
super().__init__(resnet18, tile_size, n_channels, apply_imagenet_preprocessing=False)
|
||||
|
||||
|
||||
class Resnet50(ImageNetEncoder):
|
||||
def __init__(self, tile_size: int, n_channels: int = 3) -> None:
|
||||
super().__init__(resnet50, tile_size, n_channels, apply_imagenet_preprocessing=True)
|
||||
|
||||
|
||||
class Resnet50_NoPreproc(ImageNetEncoder):
|
||||
def __init__(self, tile_size: int, n_channels: int = 3) -> None:
|
||||
super().__init__(resnet50, tile_size, n_channels, apply_imagenet_preprocessing=False)
|
||||
|
||||
|
||||
class ImageNetSimCLREncoder(TileEncoder):
|
||||
"""SimCLR encoder pretrained on ImageNet"""
|
||||
|
||||
|
@ -143,23 +168,3 @@ class HistoSSLEncoder(TileEncoder):
|
|||
resnet18_model = resnet18(pretrained=False)
|
||||
histossl_encoder = load_weights_to_model(self.WEIGHTS_URL, resnet18_model)
|
||||
return setup_feature_extractor(histossl_encoder, self.input_dim)
|
||||
|
||||
|
||||
class ImageNetEncoder_Resnet50(TileEncoder):
|
||||
# Myronenko et al. 2021 uses intensity scaling (0-255)-->(0-1), and no ImageNet preprocessing is used.
|
||||
# ResNet50 CNN encoder without ImageNet preprocessing is defined below.
|
||||
|
||||
def __init__(self, feature_extraction_model: Callable[..., nn.Module],
|
||||
tile_size: int, n_channels: int = 3) -> None:
|
||||
"""
|
||||
:param feature_extraction_model: A function accepting a `pretrained` keyword argument that
|
||||
returns a classifier pretrained on ImageNet, such as the ones from `torchvision.models.*`.
|
||||
:param tile_size: Tile width/height, in pixels.
|
||||
:param n_channels: Number of channels in the tile (default=3).
|
||||
"""
|
||||
self.create_feature_extractor_fn = feature_extraction_model
|
||||
super().__init__(tile_size=tile_size, n_channels=n_channels)
|
||||
|
||||
def _get_encoder(self) -> Tuple[nn.Module, int]:
|
||||
pretrained_model = self.create_feature_extractor_fn(pretrained=True)
|
||||
return setup_feature_extractor(pretrained_model, self.input_dim)
|
||||
|
|
|
@ -26,16 +26,35 @@ def load_pil_image(image_path: PathOrString) -> PIL.Image.Image:
|
|||
return image
|
||||
|
||||
|
||||
def load_image_as_tensor(image_path: PathOrString) -> torch.Tensor:
|
||||
"""Load an image as a tensor from the given path"""
|
||||
pil_image = load_pil_image(image_path)
|
||||
return to_tensor(pil_image)
|
||||
def load_image_as_tensor(image_path: PathOrString, scale_intensity: bool = True) -> torch.Tensor:
|
||||
"""Load an image as a tensor from the given path
|
||||
|
||||
:param image_path: path to the image
|
||||
:param scale_intensity: if True, use `to_tensor` from torchvision which scales the image pixel intensities to
|
||||
[0, 1] by defaul as [C, H, W] tensors. Otherwise, only transpose the image to [C, H, W] format and return it as a
|
||||
torch tensor.
|
||||
"""
|
||||
|
||||
pil_image = load_pil_image(image_path) # pil_image is in channels last format [H, W, C]
|
||||
if scale_intensity:
|
||||
return to_tensor(pil_image) # to_tensor scales the image pixel intensities to [0, 1] as [C, H, W] tensors
|
||||
else:
|
||||
return torch.from_numpy(pil_image.transpose((2, 0, 1))).contiguous() # only transpose to [C, H, W]
|
||||
|
||||
|
||||
def load_image_stack_as_tensor(image_paths: Sequence[PathOrString],
|
||||
progress: bool = False) -> torch.Tensor:
|
||||
"""Load a batch of images of the same size as a tensor from the given paths"""
|
||||
loading_generator = (load_image_as_tensor(path) for path in image_paths)
|
||||
progress: bool = False,
|
||||
scale_intensity: bool = True) -> torch.Tensor:
|
||||
"""Load a batch of images of the same size as a tensor from the given paths
|
||||
|
||||
:param image_paths: paths to the images
|
||||
:param progress: if True, show a progress bar
|
||||
:param scale_intensity: if True, use `to_tensor` from torchvision which scales the image pixel intensities to
|
||||
[0, 1] by defaul as [C, H, W] tensors. Otherwise, only transpose the image to [C, H, W] format and return it as a
|
||||
torch tensor.
|
||||
"""
|
||||
|
||||
loading_generator = (load_image_as_tensor(path, scale_intensity) for path in image_paths)
|
||||
if progress:
|
||||
from tqdm import tqdm
|
||||
loading_generator = tqdm(loading_generator, desc="Loading image stack",
|
||||
|
@ -98,20 +117,27 @@ class LoadTilesBatchd(MapTransform):
|
|||
|
||||
# Cannot reuse MONAI readers because they support stacking only images with no channels
|
||||
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False,
|
||||
progress: bool = False) -> None:
|
||||
progress: bool = False, scale_intensity: bool = True) -> None:
|
||||
"""
|
||||
:param keys: Key(s) for the image path(s) in the input dictionary.
|
||||
:param allow_missing_keys: If `False` (default), raises an exception when an input
|
||||
dictionary is missing any of the specified keys.
|
||||
:param progress: Whether to display a tqdm progress bar.
|
||||
:param scale_intensity: if True, use `to_tensor` from torchvision which scales the image pixel intensities to
|
||||
[0, 1] by defaul as [C, H, W] tensors. Otherwise, only transpose the image to [C, H, W] format and return it as
|
||||
a torch tensor.
|
||||
"""
|
||||
|
||||
super().__init__(keys, allow_missing_keys)
|
||||
self.progress = progress
|
||||
self.scale_intensity = scale_intensity
|
||||
|
||||
def __call__(self, data: Mapping) -> Mapping:
|
||||
out_data = dict(data) # create shallow copy
|
||||
for key in self.key_iterator(out_data):
|
||||
out_data[key] = load_image_stack_as_tensor(data[key], progress=self.progress)
|
||||
out_data[key] = load_image_stack_as_tensor(
|
||||
data[key], progress=self.progress, scale_intensity=self.scale_intensity
|
||||
)
|
||||
return out_data
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
from azure.storage.blob import generate_blob_sas, BlobSasPermissions
|
||||
from azureml.core import Workspace
|
||||
from datetime import datetime, timedelta
|
||||
from health_azure import get_workspace
|
||||
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_checkpoint_url_from_aml_run(
|
||||
run_id: str,
|
||||
checkpoint_filename: str,
|
||||
expiry_days: int = 1,
|
||||
aml_workspace: Optional[Workspace] = None,
|
||||
sas_token: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Generate a SAS URL for the checkpoint file in the given run.
|
||||
|
||||
:param run_id: The run ID of the checkpoint.
|
||||
:param checkpoint_filename: The filename of the checkpoint.
|
||||
:param expiry_days: The number of days the SAS URL is valid for, defaults to 30.
|
||||
:param aml_workspace: The Azure ML workspace to use, defaults to the default workspace.
|
||||
:param sas_token: The SAS token to use, defaults to None.
|
||||
:return: The SAS URL for the checkpoint.
|
||||
"""
|
||||
datastore = get_workspace(aml_workspace=aml_workspace).get_default_datastore()
|
||||
account_name = datastore.account_name
|
||||
container_name = 'azureml'
|
||||
blob_name = f'ExperimentRun/dcid.{run_id}/{DEFAULT_AML_CHECKPOINT_DIR}/{checkpoint_filename}'
|
||||
|
||||
if not sas_token:
|
||||
sas_token = generate_blob_sas(account_name=datastore.account_name,
|
||||
container_name=container_name,
|
||||
blob_name=blob_name,
|
||||
account_key=datastore.account_key,
|
||||
permission=BlobSasPermissions(read=True),
|
||||
expiry=datetime.utcnow() + timedelta(days=expiry_days))
|
||||
|
||||
return f'https://{account_name}.blob.core.windows.net/{container_name}/{blob_name}?{sas_token}'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--run_id', type=str, help='The run id of the model checkpoint')
|
||||
parser.add_argument('--checkpoint_filename', type=str, default='last.ckpt',
|
||||
help='The filename of the model checkpoint. Default: last.ckpt')
|
||||
parser.add_argument('--expiry_days', type=int, default=30,
|
||||
help='The number of hours for which the SAS token is valid. Default: 30 for 1 month')
|
||||
args = parser.parse_args()
|
||||
url = get_checkpoint_url_from_aml_run(args.run_id, args.checkpoint_filename, args.expiry_days)
|
||||
print(f'Checkpoint URL: {url}')
|
|
@ -12,22 +12,25 @@ from matplotlib import pyplot as plt
|
|||
from health_azure.utils import get_aml_run_from_run_id, get_workspace
|
||||
from health_ml.utils.reports import HTMLReport
|
||||
from health_cpath.utils.analysis_plot_utils import (add_training_curves_legend, plot_confusion_matrices,
|
||||
plot_crossval_roc_and_pr_curves,
|
||||
plot_crossval_training_curves)
|
||||
plot_hyperdrive_roc_and_pr_curves,
|
||||
plot_hyperdrive_training_curves)
|
||||
from health_cpath.utils.output_utils import (AML_LEGACY_TEST_OUTPUTS_CSV, AML_TEST_OUTPUTS_CSV,
|
||||
AML_VAL_OUTPUTS_CSV)
|
||||
from health_cpath.utils.report_utils import (collect_crossval_metrics, collect_crossval_outputs,
|
||||
crossval_runs_have_val_and_test_outputs, get_best_epoch_metrics,
|
||||
get_best_epochs, get_crossval_metrics_table, get_formatted_run_info,
|
||||
collect_class_info)
|
||||
from health_cpath.utils.report_utils import (collect_hyperdrive_metrics, collect_hyperdrive_outputs,
|
||||
child_runs_have_val_and_test_outputs, get_best_epoch_metrics,
|
||||
get_best_epochs, get_child_runs_hyperparams, get_hyperdrive_metrics_table,
|
||||
get_formatted_run_info, collect_class_info, get_max_epochs,
|
||||
download_hyperdrive_metrics_if_required)
|
||||
from health_cpath.utils.naming import MetricsKey, ModelKey
|
||||
|
||||
|
||||
def generate_html_report(parent_run_id: str, output_dir: Path,
|
||||
workspace_config_path: Optional[Path] = None,
|
||||
include_test: bool = False, overwrite: bool = False) -> None:
|
||||
include_test: bool = False, overwrite: bool = False,
|
||||
hyperdrive_arg_name: str = "crossval_index",
|
||||
primary_metric: str = MetricsKey.AUROC) -> None:
|
||||
"""
|
||||
Function to generate an HTML report of a cross validation AML run.
|
||||
Function to generate an HTML report of a Hyperdrive AML run (e.g., cross validation, different random seeds, ...).
|
||||
|
||||
:param run_id: The parent Hyperdrive run ID.
|
||||
:param output_dir: Directory where to download Azure ML data and save the report.
|
||||
|
@ -35,6 +38,9 @@ def generate_html_report(parent_run_id: str, output_dir: Path,
|
|||
If omitted, will try to load default workspace.
|
||||
:param include_test: Include test results in the generated report.
|
||||
:param overwrite: Forces (re)download of metrics and output files, even if they already exist locally.
|
||||
:param hyperdrive_arg_name: Name of the Hyperdrive argument used for indexing the child runs.
|
||||
Default `crossval_index`.
|
||||
:param primary_metric: Name of the reference metric to optimise. Default `MetricsKey.AUROC`.
|
||||
"""
|
||||
aml_workspace = get_workspace(workspace_config_path=workspace_config_path)
|
||||
parent_run = get_aml_run_from_run_id(parent_run_id, aml_workspace=aml_workspace)
|
||||
|
@ -48,26 +54,38 @@ def generate_html_report(parent_run_id: str, output_dir: Path,
|
|||
report.add_heading("Azure ML metrics", level=2)
|
||||
|
||||
# Download metrics from AML. Can take several seconds for each child run
|
||||
metrics_df = collect_crossval_metrics(parent_run_id, report_dir, aml_workspace, overwrite=overwrite)
|
||||
best_epochs = get_best_epochs(metrics_df, f'{ModelKey.VAL}/{MetricsKey.AUROC}', maximise=True)
|
||||
metrics_json = download_hyperdrive_metrics_if_required(parent_run_id, report_dir, aml_workspace,
|
||||
overwrite=overwrite, hyperdrive_arg_name=hyperdrive_arg_name)
|
||||
|
||||
# Get metrics dataframe from the downloaded json file
|
||||
metrics_df = collect_hyperdrive_metrics(metrics_json=metrics_json)
|
||||
|
||||
hyperparameters_children = get_child_runs_hyperparams(metrics_df)
|
||||
max_epochs_dict = get_max_epochs(hyperparams_children=hyperparameters_children)
|
||||
best_epochs = get_best_epochs(metrics_df=metrics_df, primary_metric=f'{ModelKey.VAL}/{primary_metric}',
|
||||
max_epochs_dict=max_epochs_dict, maximise=True)
|
||||
|
||||
# Add training curves for loss and AUROC (train and val.)
|
||||
render_training_curves(report, heading="Training curves", level=3,
|
||||
metrics_df=metrics_df, best_epochs=best_epochs, report_dir=report_dir)
|
||||
metrics_df=metrics_df, best_epochs=best_epochs, report_dir=report_dir,
|
||||
primary_metric=primary_metric)
|
||||
|
||||
# Get metrics list with class names
|
||||
num_classes, class_names = collect_class_info(metrics_df=metrics_df)
|
||||
num_classes, class_names = collect_class_info(hyperparams_children=hyperparameters_children)
|
||||
|
||||
base_metrics_list: List[str]
|
||||
base_metrics_list: List[str] = [MetricsKey.ACC, MetricsKey.AUROC, MetricsKey.AVERAGE_PRECISION,
|
||||
MetricsKey.COHENKAPPA]
|
||||
if num_classes > 1:
|
||||
base_metrics_list = [MetricsKey.ACC, MetricsKey.AUROC]
|
||||
base_metrics_list += [MetricsKey.ACC_MACRO, MetricsKey.ACC_WEIGHTED]
|
||||
else:
|
||||
base_metrics_list = [MetricsKey.ACC, MetricsKey.AUROC, MetricsKey.PRECISION, MetricsKey.RECALL, MetricsKey.F1]
|
||||
base_metrics_list += [MetricsKey.F1, MetricsKey.PRECISION, MetricsKey.RECALL, MetricsKey.SPECIFICITY]
|
||||
|
||||
base_metrics_list += class_names
|
||||
|
||||
# Add tables with relevant metrics (val. and test)
|
||||
render_metrics_table(report, heading="Validation metrics (best epoch based on maximum validation AUROC)", level=3,
|
||||
render_metrics_table(report,
|
||||
heading=f"Validation metrics (best epoch based on maximum validation {primary_metric})",
|
||||
level=3,
|
||||
metrics_df=metrics_df, best_epochs=best_epochs,
|
||||
base_metrics_list=base_metrics_list, metrics_prefix=f'{ModelKey.VAL}/')
|
||||
|
||||
|
@ -76,54 +94,58 @@ def generate_html_report(parent_run_id: str, output_dir: Path,
|
|||
metrics_df=metrics_df, best_epochs=None,
|
||||
base_metrics_list=base_metrics_list, metrics_prefix=f'{ModelKey.TEST}/')
|
||||
|
||||
has_val_and_test_outputs = crossval_runs_have_val_and_test_outputs(parent_run)
|
||||
|
||||
# Get output data frames
|
||||
if has_val_and_test_outputs:
|
||||
output_filename_val = AML_VAL_OUTPUTS_CSV
|
||||
outputs_dfs_val = collect_crossval_outputs(parent_run_id=parent_run_id, download_dir=report_dir,
|
||||
aml_workspace=aml_workspace,
|
||||
output_filename=output_filename_val, overwrite=overwrite)
|
||||
if include_test:
|
||||
output_filename_test = AML_TEST_OUTPUTS_CSV if has_val_and_test_outputs else AML_LEGACY_TEST_OUTPUTS_CSV
|
||||
outputs_dfs_test = collect_crossval_outputs(parent_run_id=parent_run_id, download_dir=report_dir,
|
||||
aml_workspace=aml_workspace,
|
||||
output_filename=output_filename_test, overwrite=overwrite)
|
||||
|
||||
if num_classes == 1:
|
||||
# Currently ROC and PR curves rendered only for binary case
|
||||
# TODO: Enable rendering of multi-class ROC and PR curves
|
||||
report.add_heading("ROC and PR curves", level=2)
|
||||
# Get output data frames if available
|
||||
try:
|
||||
has_val_and_test_outputs = child_runs_have_val_and_test_outputs(parent_run)
|
||||
if has_val_and_test_outputs:
|
||||
# Add val. ROC and PR curves
|
||||
render_roc_and_pr_curves(report=report, heading="Validation ROC and PR curves", level=3,
|
||||
report_dir=report_dir,
|
||||
outputs_dfs=outputs_dfs_val,
|
||||
prefix=f'{ModelKey.VAL}_')
|
||||
output_filename_val = AML_VAL_OUTPUTS_CSV
|
||||
outputs_dfs_val = collect_hyperdrive_outputs(parent_run_id=parent_run_id, download_dir=report_dir,
|
||||
aml_workspace=aml_workspace,
|
||||
output_filename=output_filename_val, overwrite=overwrite,
|
||||
hyperdrive_arg_name=hyperdrive_arg_name)
|
||||
if include_test:
|
||||
output_filename_test = AML_TEST_OUTPUTS_CSV if has_val_and_test_outputs else AML_LEGACY_TEST_OUTPUTS_CSV
|
||||
outputs_dfs_test = collect_hyperdrive_outputs(parent_run_id=parent_run_id, download_dir=report_dir,
|
||||
aml_workspace=aml_workspace,
|
||||
output_filename=output_filename_test, overwrite=overwrite,
|
||||
hyperdrive_arg_name=hyperdrive_arg_name)
|
||||
|
||||
if num_classes == 1:
|
||||
# Currently ROC and PR curves rendered only for binary case
|
||||
# TODO: Enable rendering of multi-class ROC and PR curves
|
||||
report.add_heading("ROC and PR curves", level=2)
|
||||
if has_val_and_test_outputs:
|
||||
# Add val. ROC and PR curves
|
||||
render_roc_and_pr_curves(report=report, heading="Validation ROC and PR curves", level=3,
|
||||
report_dir=report_dir,
|
||||
outputs_dfs=outputs_dfs_val,
|
||||
prefix=f'{ModelKey.VAL}_')
|
||||
if include_test:
|
||||
# Add test ROC and PR curves
|
||||
render_roc_and_pr_curves(report=report, heading="Test ROC and PR curves", level=3,
|
||||
report_dir=report_dir,
|
||||
outputs_dfs=outputs_dfs_test,
|
||||
prefix=f'{ModelKey.TEST}_')
|
||||
|
||||
# Add confusion matrices for each fold
|
||||
report.add_heading("Confusion matrices", level=2)
|
||||
if has_val_and_test_outputs:
|
||||
# Add val. confusion matrices
|
||||
render_confusion_matrices(report=report, heading="Validation confusion matrices", level=3,
|
||||
class_names=class_names,
|
||||
report_dir=report_dir, outputs_dfs=outputs_dfs_val,
|
||||
prefix=f'{ModelKey.VAL}_')
|
||||
|
||||
if include_test:
|
||||
# Add test ROC and PR curves
|
||||
render_roc_and_pr_curves(report=report, heading="Test ROC and PR curves", level=3,
|
||||
report_dir=report_dir,
|
||||
outputs_dfs=outputs_dfs_test,
|
||||
prefix=f'{ModelKey.TEST}_')
|
||||
# Add test confusion matrices
|
||||
render_confusion_matrices(report=report, heading="Test confusion matrices", level=3,
|
||||
class_names=class_names,
|
||||
report_dir=report_dir, outputs_dfs=outputs_dfs_test,
|
||||
prefix=f'{ModelKey.TEST}_')
|
||||
|
||||
# Add confusion matrices for each fold
|
||||
|
||||
report.add_heading("Confusion matrices", level=2)
|
||||
|
||||
if has_val_and_test_outputs:
|
||||
# Add val. confusion matrices
|
||||
render_confusion_matrices(report=report, heading="Validation confusion matrices", level=3,
|
||||
class_names=class_names,
|
||||
report_dir=report_dir, outputs_dfs=outputs_dfs_val,
|
||||
prefix=f'{ModelKey.VAL}_')
|
||||
|
||||
if include_test:
|
||||
# Add test confusion matrices
|
||||
render_confusion_matrices(report=report, heading="Test confusion matrices", level=3,
|
||||
class_names=class_names,
|
||||
report_dir=report_dir, outputs_dfs=outputs_dfs_test,
|
||||
prefix=f'{ModelKey.TEST}_')
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
print("Since all expected output files were not found, skipping ROC-PR curves and confusion matrices.")
|
||||
|
||||
# TODO: Add qualitative model outputs
|
||||
# report.add_heading("Qualitative model outputs", level=2)
|
||||
|
@ -134,26 +156,26 @@ def generate_html_report(parent_run_id: str, output_dir: Path,
|
|||
|
||||
def render_training_curves(report: HTMLReport, heading: str, level: int,
|
||||
metrics_df: pd.DataFrame, best_epochs: Optional[Dict[int, int]],
|
||||
report_dir: Path) -> None:
|
||||
report_dir: Path, primary_metric: str = MetricsKey.AUROC) -> None:
|
||||
"""
|
||||
Function to render training curves for HTML reports.
|
||||
|
||||
:param report: HTML report to perform the rendering.
|
||||
:param heading: Heading of the section.
|
||||
:param level: Level of HTML heading (e.g. sub-section, sub-sub-section) corresponding to HTML heading levels.
|
||||
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_crossval_metrics()` and
|
||||
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_hyperdrive_metrics()` and
|
||||
:py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
|
||||
:param best_epochs: Dictionary mapping each cross-validation index to its best epoch.
|
||||
:param best_epochs: Dictionary mapping each hyperdrive child index to its best epoch.
|
||||
:param report_dir: Directory of the HTML report.
|
||||
:param primary_metric: Primary metric name. Default is AUROC.
|
||||
"""
|
||||
report.add_heading(heading, level=level)
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
|
||||
plot_crossval_training_curves(metrics_df, train_metric=f'{ModelKey.TRAIN}/loss_epoch',
|
||||
val_metric=f'{ModelKey.VAL}/loss_epoch',
|
||||
ylabel="Loss", best_epochs=best_epochs, ax=ax1)
|
||||
plot_crossval_training_curves(metrics_df, train_metric=f'{ModelKey.TRAIN}/{MetricsKey.AUROC}',
|
||||
val_metric=f'{ModelKey.VAL}/{MetricsKey.AUROC}',
|
||||
ylabel="AUROC", best_epochs=best_epochs, ax=ax2)
|
||||
metrics = {"loss_epoch", MetricsKey.AUROC.value, primary_metric}
|
||||
fig, axs = plt.subplots(1, len(metrics), figsize=(5 * len(metrics), 4))
|
||||
for i, metric in enumerate(metrics):
|
||||
plot_hyperdrive_training_curves(metrics_df, train_metric=f'{ModelKey.TRAIN}/{metric}',
|
||||
val_metric=f'{ModelKey.VAL}/{metric}',
|
||||
ylabel=metric, best_epochs=best_epochs, ax=axs[i])
|
||||
add_training_curves_legend(fig, include_best_epoch=True)
|
||||
training_curves_fig_path = report_dir / "training_curves.png"
|
||||
fig.savefig(training_curves_fig_path, bbox_inches='tight')
|
||||
|
@ -169,17 +191,17 @@ def render_metrics_table(report: HTMLReport, heading: str, level: int,
|
|||
:param report: HTML report to perform the rendering.
|
||||
:param heading: Heading of the section.
|
||||
:param level: Level of HTML heading (e.g. sub-section, sub-sub-section) corresponding to HTML heading levels.
|
||||
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_crossval_metrics()` and
|
||||
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_hyperdrive_metrics()` and
|
||||
:py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
|
||||
:param base_metrics_list: List of metric names to include in the table.
|
||||
:param best_epochs: Dictionary mapping each cross-validation index to its best epoch.
|
||||
:param best_epochs: Dictionary mapping each hyperdrive child index to its best epoch.
|
||||
:param metrics_prefix: Prefix to add to the metrics names (e.g. `val`, `test`)
|
||||
"""
|
||||
report.add_heading(heading, level=level)
|
||||
metrics_list = [metrics_prefix + metric for metric in base_metrics_list]
|
||||
if best_epochs:
|
||||
metrics_df = get_best_epoch_metrics(metrics_df, metrics_list, best_epochs)
|
||||
metrics_table = get_crossval_metrics_table(metrics_df, metrics_list)
|
||||
metrics_table = get_hyperdrive_metrics_table(metrics_df, metrics_list)
|
||||
report.add_tables([metrics_table])
|
||||
|
||||
|
||||
|
@ -192,11 +214,11 @@ def render_roc_and_pr_curves(report: HTMLReport, heading: str, level: int, repor
|
|||
:param heading: Heading of the section.
|
||||
:param level: Level of HTML heading (e.g. sub-section, sub-sub-section) corresponding to HTML heading levels.
|
||||
:param report_dir: Local directory where the report is stored.
|
||||
:param outputs_dfs: A dictionary of dataframes with the sorted cross-validation indices as keys.
|
||||
:param outputs_dfs: A dictionary of dataframes with the sorted hyperdrive child runs indices as keys.
|
||||
:param prefix: Prefix to add to the figures saved (e.g. `val`, `test`).
|
||||
"""
|
||||
report.add_heading(heading, level=level)
|
||||
fig = plot_crossval_roc_and_pr_curves(outputs_dfs, scores_column='prob_class1')
|
||||
fig = plot_hyperdrive_roc_and_pr_curves(outputs_dfs, scores_column='prob_class1')
|
||||
roc_pr_curves_fig_path = report_dir / f"{prefix}roc_pr_curves.png"
|
||||
fig.savefig(roc_pr_curves_fig_path, bbox_inches='tight')
|
||||
report.add_images([roc_pr_curves_fig_path], base64_encode=True)
|
||||
|
@ -212,11 +234,11 @@ def render_confusion_matrices(report: HTMLReport, heading: str, level: int, clas
|
|||
:param level: Level of HTML heading (e.g. sub-section, sub-sub-section) corresponding to HTML heading levels.
|
||||
:param class_names: Names of classes.
|
||||
:param report_dir: Local directory where the report is stored.
|
||||
:param outputs_dfs: A dictionary of dataframes with the sorted cross-validation indices as keys.
|
||||
:param outputs_dfs: A dictionary of dataframes with the sorted hyperdrive child runs indices as keys.
|
||||
:param prefix: Prefix to add to the figures saved (e.g. `val`, `test`).
|
||||
"""
|
||||
report.add_heading(heading, level=level)
|
||||
fig = plot_confusion_matrices(crossval_dfs=outputs_dfs, class_names=class_names)
|
||||
fig = plot_confusion_matrices(hyperdrive_dfs=outputs_dfs, class_names=class_names)
|
||||
confusion_matrices_fig_path = report_dir / f"{prefix}confusion_matrices.png"
|
||||
fig.savefig(confusion_matrices_fig_path, bbox_inches='tight')
|
||||
report.add_images([confusion_matrices_fig_path], base64_encode=True)
|
||||
|
@ -225,7 +247,7 @@ def render_confusion_matrices(report: HTMLReport, heading: str, level: int, clas
|
|||
if __name__ == "__main__":
|
||||
"""
|
||||
Usage example from CLI:
|
||||
python generate_crossval_report.py \
|
||||
python generate_hyperdrive_report.py \
|
||||
--run_id <insert AML run ID here> \
|
||||
--output_dir outputs \
|
||||
--include_test
|
||||
|
@ -242,6 +264,9 @@ if __name__ == "__main__":
|
|||
"in the generated report.")
|
||||
parser.add_argument('--overwrite', action='store_true', help="Forces (re)download of metrics and output files, "
|
||||
"even if they already exist locally.")
|
||||
parser.add_argument("--hyper_arg_name", default="crossval_index",
|
||||
help="Name of the Hyperdrive argument used for indexing the child runs.")
|
||||
parser.add_argument("--primary_metric", default=MetricsKey.AUROC, help="Name of the reference metric to optimise.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.output_dir is None:
|
||||
|
@ -258,4 +283,6 @@ if __name__ == "__main__":
|
|||
output_dir=Path(args.output_dir),
|
||||
workspace_config_path=workspace_config,
|
||||
include_test=args.include_test,
|
||||
overwrite=args.overwrite)
|
||||
overwrite=args.overwrite,
|
||||
hyperdrive_arg_name=args.hyper_arg_name,
|
||||
primary_metric=args.primary_metric)
|
||||
|
|
|
@ -120,32 +120,32 @@ def plot_histogram(data: List[Any], title: str = "") -> None:
|
|||
plt.gca().set(title=title, xlabel='Values', ylabel='Frequency')
|
||||
|
||||
|
||||
def plot_roc_curve(labels: Sequence, scores: Sequence, label: str, ax: Axes) -> None:
|
||||
def plot_roc_curve(labels: Sequence, scores: Sequence, legend_label: str, ax: Axes) -> None:
|
||||
"""Plot ROC curve for the given labels and scores, with AUROC in the line legend.
|
||||
|
||||
:param labels: The true binary labels.
|
||||
:param scores: Scores predicted by the model.
|
||||
:param label: An line identifier to be displayed in the legend.
|
||||
:param legend_label: An line identifier to be displayed in the legend.
|
||||
:param ax: `Axes` object onto which to plot.
|
||||
"""
|
||||
fpr, tpr, _ = roc_curve(labels, scores)
|
||||
auroc = auc(fpr, tpr)
|
||||
label = f"{label} (AUROC: {auroc:.3f})"
|
||||
ax.plot(fpr, tpr, label=label)
|
||||
legend_label = f"{legend_label} (AUROC: {auroc:.3f})"
|
||||
ax.plot(fpr, tpr, label=legend_label)
|
||||
|
||||
|
||||
def plot_pr_curve(labels: Sequence, scores: Sequence, label: str, ax: Axes) -> None:
|
||||
"""Plot precision-recall curve for the given labels and scores, with AUROC in the line legend.
|
||||
def plot_pr_curve(labels: Sequence, scores: Sequence, legend_label: str, ax: Axes) -> None:
|
||||
"""Plot precision-recall curve for the given labels and scores, with AUPR in the line legend.
|
||||
|
||||
:param labels: The true binary labels.
|
||||
:param scores: Scores predicted by the model.
|
||||
:param label: An line identifier to be displayed in the legend.
|
||||
:param legend_label: A line identifier to be displayed in the legend.
|
||||
:param ax: `Axes` object onto which to plot.
|
||||
"""
|
||||
precision, recall, _ = precision_recall_curve(labels, scores)
|
||||
aupr = auc(recall, precision)
|
||||
label = f"{label} (AUPR: {aupr:.3f})"
|
||||
ax.plot(recall, precision, label=label)
|
||||
legend_label = f"{legend_label} (AUPR: {aupr:.3f})"
|
||||
ax.plot(recall, precision, label=legend_label)
|
||||
|
||||
|
||||
def format_pr_or_roc_axes(plot_type: str, ax: Axes) -> None:
|
||||
|
@ -168,18 +168,18 @@ def format_pr_or_roc_axes(plot_type: str, ax: Axes) -> None:
|
|||
ax.grid(color='0.9')
|
||||
|
||||
|
||||
def _plot_crossval_roc_and_pr_curves(crossval_dfs: Dict[int, pd.DataFrame], roc_ax: Axes, pr_ax: Axes,
|
||||
scores_column: str = ResultsKey.PROB) -> None:
|
||||
"""Plot ROC and precision-recall curves for multiple cross-validation runs onto provided axes.
|
||||
def _plot_hyperdrive_roc_and_pr_curves(hyperdrive_dfs: Dict[int, pd.DataFrame], roc_ax: Axes, pr_ax: Axes,
|
||||
scores_column: str = ResultsKey.PROB) -> None:
|
||||
"""Plot ROC and precision-recall curves for multiple hyperdrive runs onto provided axes.
|
||||
|
||||
This is called by :py:func:`plot_crossval_roc_and_pr_curves()`, which additionally creates a figure and the axes.
|
||||
This is called by :py:func:`plot_hyperdrive_roc_and_pr_curves()`, which additionally creates a figure and the axes.
|
||||
|
||||
:param crossval_dfs: Dictionary of dataframes with cross-validation indices as keys,
|
||||
as returned by :py:func:`collect_crossval_outputs()`.
|
||||
:param hyperdrive_dfs: Dictionary of dataframes with hyperdrive child runs indices as keys,
|
||||
as returned by :py:func:`collect_hyperdrive_outputs()`.
|
||||
:param roc_ax: `Axes` object onto which to plot ROC curves.
|
||||
:param pr_ax: `Axes` object onto which to plot precision-recall curves.
|
||||
"""
|
||||
for k, tiles_df in crossval_dfs.items():
|
||||
for k, tiles_df in hyperdrive_dfs.items():
|
||||
slides_groupby = tiles_df.groupby(ResultsKey.SLIDE_ID)
|
||||
|
||||
tile_labels = slides_groupby[ResultsKey.TRUE_LABEL]
|
||||
|
@ -197,8 +197,8 @@ def _plot_crossval_roc_and_pr_curves(crossval_dfs: Dict[int, pd.DataFrame], roc_
|
|||
# assert len(non_unique_slides) == 0
|
||||
scores = tile_scores.first()
|
||||
|
||||
plot_roc_curve(labels, scores, label=f"Fold {k}", ax=roc_ax)
|
||||
plot_pr_curve(labels, scores, label=f"Fold {k}", ax=pr_ax)
|
||||
plot_roc_curve(labels, scores, legend_label=f"Child {k}", ax=roc_ax)
|
||||
plot_pr_curve(labels, scores, legend_label=f"Child {k}", ax=pr_ax)
|
||||
legend_kwargs = dict(edgecolor='none', fontsize='small')
|
||||
roc_ax.legend(**legend_kwargs)
|
||||
pr_ax.legend(**legend_kwargs)
|
||||
|
@ -206,26 +206,26 @@ def _plot_crossval_roc_and_pr_curves(crossval_dfs: Dict[int, pd.DataFrame], roc_
|
|||
format_pr_or_roc_axes('pr', pr_ax)
|
||||
|
||||
|
||||
def plot_crossval_roc_and_pr_curves(crossval_dfs: Dict[int, pd.DataFrame],
|
||||
scores_column: str = ResultsKey.PROB) -> Figure:
|
||||
"""Plot ROC and precision-recall curves for multiple cross-validation runs.
|
||||
def plot_hyperdrive_roc_and_pr_curves(hyperdrive_dfs: Dict[int, pd.DataFrame],
|
||||
scores_column: str = ResultsKey.PROB) -> Figure:
|
||||
"""Plot ROC and precision-recall curves for multiple hyperdrive child runs.
|
||||
|
||||
This will create a new figure with two subplots (left: ROC, right: PR).
|
||||
|
||||
:param crossval_dfs: Dictionary of dataframes with cross-validation indices as keys,
|
||||
as returned by :py:func:`collect_crossval_outputs()`.
|
||||
:param hyperdrive_dfs: Dictionary of dataframes with hyperdrive child indices as keys,
|
||||
as returned by :py:func:`collect_hyperdrive_outputs()`.
|
||||
:return: The created `Figure` object.
|
||||
"""
|
||||
fig, axs = plt.subplots(1, 2, figsize=(8, 4))
|
||||
_plot_crossval_roc_and_pr_curves(crossval_dfs, scores_column=scores_column, roc_ax=axs[0], pr_ax=axs[1])
|
||||
_plot_hyperdrive_roc_and_pr_curves(hyperdrive_dfs, scores_column=scores_column, roc_ax=axs[0], pr_ax=axs[1])
|
||||
return fig
|
||||
|
||||
|
||||
def plot_crossval_training_curves(metrics_df: pd.DataFrame, train_metric: str, val_metric: str, ax: Axes,
|
||||
best_epochs: Optional[Dict[int, int]] = None, ylabel: Optional[str] = None) -> None:
|
||||
"""Plot paired training and validation metrics for every training epoch of cross-validation runs.
|
||||
def plot_hyperdrive_training_curves(metrics_df: pd.DataFrame, train_metric: str, val_metric: str, ax: Axes,
|
||||
best_epochs: Optional[Dict[int, int]] = None, ylabel: Optional[str] = None) -> None:
|
||||
"""Plot paired training and validation metrics for every training epoch of hyperdrive child runs.
|
||||
|
||||
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_crossval_metrics()` and
|
||||
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_hyperdrive_metrics()` and
|
||||
:py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
|
||||
:param train_metric: Name of the training metric to plot.
|
||||
:param val_metric: Name of the validation metric to plot.
|
||||
|
@ -236,14 +236,17 @@ def plot_crossval_training_curves(metrics_df: pd.DataFrame, train_metric: str, v
|
|||
for k in sorted(metrics_df.columns):
|
||||
train_values = metrics_df.loc[train_metric, k]
|
||||
val_values = metrics_df.loc[val_metric, k]
|
||||
line, = ax.plot(train_values, **TRAIN_STYLE, label=f"Fold {k}")
|
||||
color = line.get_color()
|
||||
ax.plot(val_values, color=color, **VAL_STYLE)
|
||||
if train_values is not None:
|
||||
line, = ax.plot(train_values, **TRAIN_STYLE, label=f"Child {k}")
|
||||
color = line.get_color()
|
||||
if val_values is not None:
|
||||
ax.plot(val_values, color=color, **VAL_STYLE)
|
||||
if best_epochs is not None:
|
||||
best_epoch = best_epochs[k]
|
||||
ax.plot(best_epoch, train_values[best_epoch], color=color, zorder=1000, **BEST_TRAIN_MARKER_STYLE)
|
||||
ax.plot(best_epoch, val_values[best_epoch], color=color, zorder=1000, **BEST_VAL_MARKER_STYLE)
|
||||
ax.axvline(best_epoch, color=color, **BEST_EPOCH_LINE_STYLE)
|
||||
if best_epoch is not None:
|
||||
ax.plot(best_epoch, train_values[best_epoch], color=color, zorder=1000, **BEST_TRAIN_MARKER_STYLE)
|
||||
ax.plot(best_epoch, val_values[best_epoch], color=color, zorder=1000, **BEST_VAL_MARKER_STYLE)
|
||||
ax.axvline(best_epoch, color=color, **BEST_EPOCH_LINE_STYLE)
|
||||
ax.grid(color='0.9')
|
||||
ax.set_xlabel("Epoch")
|
||||
if ylabel:
|
||||
|
@ -251,15 +254,15 @@ def plot_crossval_training_curves(metrics_df: pd.DataFrame, train_metric: str, v
|
|||
|
||||
|
||||
def add_training_curves_legend(fig: Figure, include_best_epoch: bool = False) -> None:
|
||||
"""Add a legend to a training curves figure, indicating cross-validation indices and train/val.
|
||||
"""Add a legend to a training curves figure, indicating hyperdrive child indices and train/val.
|
||||
|
||||
:param fig: `Figure` object onto which to add the legend.
|
||||
:param include_best_epoch: If `True`, adds legend items for the best epoch indicators from
|
||||
:py:func:`plot_crossval_training_curves()`.
|
||||
:py:func:`plot_hyperdrive_training_curves()`.
|
||||
"""
|
||||
legend_kwargs = dict(edgecolor='none', fontsize='small', borderpad=.2)
|
||||
|
||||
# Add primary legend for main lines (crossval folds)
|
||||
# Add primary legend for main lines (hyperdrive runs)
|
||||
handles, labels = plt.gca().get_legend_handles_labels()
|
||||
by_label = dict(zip(labels, handles))
|
||||
fig.legend(by_label.values(), by_label.keys(), **legend_kwargs, loc='lower center',
|
||||
|
@ -277,17 +280,18 @@ def add_training_curves_legend(fig: Figure, include_best_epoch: bool = False) ->
|
|||
bbox_to_anchor=(0.5, -0.1), ncol=len(legend_handles))
|
||||
|
||||
|
||||
def plot_confusion_matrices(crossval_dfs: Dict[int, pd.DataFrame], class_names: List[str]) -> Figure:
|
||||
def plot_confusion_matrices(hyperdrive_dfs: Dict[int, pd.DataFrame], class_names: List[str]) -> Figure:
|
||||
"""
|
||||
Plot normalized confusion matrices from HyperDrive child runs.
|
||||
:param crossval_dfs: Dictionary of dataframes with cross-validation indices as keys,
|
||||
as returned by :py:func:`collect_crossval_outputs()`.
|
||||
:param hyperdrive_dfs: Dictionary of dataframes with hyperdrive indices as keys,
|
||||
as returned by :py:func:`collect_hyperdrive_outputs()`.
|
||||
:param class_names: Names of classes.
|
||||
:return: The created `Figure` object.
|
||||
"""
|
||||
crossval_count = len(crossval_dfs)
|
||||
fig, axs = plt.subplots(1, crossval_count, figsize=(crossval_count * 6, 5))
|
||||
for k, tiles_df in crossval_dfs.items():
|
||||
hyperdrive_count = len(hyperdrive_dfs)
|
||||
fig, axs = plt.subplots(1, hyperdrive_count, figsize=(hyperdrive_count * 6, 5))
|
||||
ax_index = 0
|
||||
for k, tiles_df in hyperdrive_dfs.items():
|
||||
slides_groupby = tiles_df.groupby(ResultsKey.SLIDE_ID)
|
||||
tile_labels_true = slides_groupby[ResultsKey.TRUE_LABEL]
|
||||
# True slide label is guaranteed unique
|
||||
|
@ -303,9 +307,10 @@ def plot_confusion_matrices(crossval_dfs: Dict[int, pd.DataFrame], class_names:
|
|||
labels_pred = tile_labels_pred.first()
|
||||
|
||||
cf_matrix_n = confusion_matrix(y_true=labels_true, y_pred=labels_pred, normalize='true')
|
||||
sns.heatmap(cf_matrix_n, annot=True, cmap='Blues', fmt=".2%", ax=axs[k],
|
||||
sns.heatmap(cf_matrix_n, annot=True, cmap='Blues', fmt=".2%", ax=axs[ax_index],
|
||||
xticklabels=class_names, yticklabels=class_names)
|
||||
axs[k].set_xlabel('Predicted')
|
||||
axs[k].set_ylabel('True')
|
||||
axs[k].set_title(f'Fold {k}')
|
||||
axs[ax_index].set_xlabel('Predicted')
|
||||
axs[ax_index].set_ylabel('True')
|
||||
axs[ax_index].set_title(f'Child {k}')
|
||||
ax_index += 1
|
||||
return fig
|
||||
|
|
|
@ -0,0 +1,543 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
import torch
|
||||
import param
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.cm as cm
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
|
||||
from health_cpath.models.deepmil import BaseDeepMILModule
|
||||
from health_cpath.utils.naming import ModelKey, ResultsKey
|
||||
from health_cpath.utils.output_utils import BatchResultsType
|
||||
|
||||
LossCacheDictType = Dict[Union[ResultsKey, str], List]
|
||||
LossDictType = Dict[str, List]
|
||||
AnomalyDictType = Dict[ModelKey, List[str]]
|
||||
|
||||
|
||||
class LossCallbackParams(param.Parameterized):
|
||||
"""Parameters class to group all attributes for loss values analysis callback"""
|
||||
|
||||
analyse_loss: bool = param.Boolean(
|
||||
False,
|
||||
doc="If True, will use `LossValuesAnalysisCallback` to cache loss values per slide/epoch for further analysis."
|
||||
"See `LossValuesAnalysisCallback` for more details.",
|
||||
)
|
||||
loss_analysis_patience: int = param.Integer(
|
||||
0,
|
||||
bounds=(0, None),
|
||||
doc="Number of epochs to wait before starting loss values per slide analysis. Default: 0, It will start"
|
||||
"caching loss values per epoch immediately. Use loss_analysis_patience=n>0 to wait for a few epochs "
|
||||
"before starting the analysis.",
|
||||
)
|
||||
loss_analysis_epochs_interval: int = param.Integer(
|
||||
1,
|
||||
bounds=(1, None),
|
||||
doc="Epochs interval to save loss values. Default: 1, It will save loss values every epoch. Use "
|
||||
"loss_analysis_epochs_interval=n>1 to save loss values every n epochs.",
|
||||
)
|
||||
num_slides_scatter: int = param.Integer(
|
||||
20,
|
||||
bounds=(1, None),
|
||||
doc="Number of slides to plot in the scatter plot. Default: 10, It will plot a scatter of the 10 slides "
|
||||
"with highest/lowest loss values across epochs.",
|
||||
)
|
||||
num_slides_heatmap: int = param.Integer(
|
||||
20,
|
||||
bounds=(1, None),
|
||||
doc="Number of slides to plot in the heatmap plot. Default: 20, It will plot the loss values for the 20 slides "
|
||||
"with highest/lowest loss values.",
|
||||
)
|
||||
save_tile_ids: bool = param.Boolean(
|
||||
True,
|
||||
doc="If True, will save the tile ids for each bag in the loss cache. Default: True. If False, will only save "
|
||||
"the slide ids and their loss values.",
|
||||
)
|
||||
log_exceptions: bool = param.Boolean(
|
||||
True,
|
||||
doc="If True, will log exceptions raised during loss values analysis. Default: True. If False, will raise the "
|
||||
"intercepted exceptions.",
|
||||
)
|
||||
|
||||
|
||||
class LossAnalysisCallback(Callback):
|
||||
"""Callback to analyse loss values per slide across epochs. It saves the loss values per slide in a csv file every n
|
||||
epochs and plots the evolution of the loss values per slide in a heatmap as well as the slides with the
|
||||
highest/lowest loss values per epoch in a scatter plot."""
|
||||
|
||||
TILES_JOIN_TOKEN = "$"
|
||||
X_LABEL, Y_LABEL = "Epoch", "Slide ids"
|
||||
TOP, BOTTOM = "top", "bottom"
|
||||
HIGHEST, LOWEST = "highest", "lowest"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outputs_folder: Path,
|
||||
max_epochs: int = 30,
|
||||
patience: int = 0,
|
||||
epochs_interval: int = 1,
|
||||
num_slides_scatter: int = 10,
|
||||
num_slides_heatmap: int = 20,
|
||||
save_tile_ids: bool = False,
|
||||
log_exceptions: bool = True,
|
||||
create_outputs_folders: bool = True,
|
||||
val_set_is_dist: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
:param outputs_folder: Path to the folder where the outputs will be saved.
|
||||
:param patience: Number of epochs to wait before starting to cache loss values, defaults to 0.
|
||||
:param epochs_interval: Epochs interval to save loss values, defaults to 1.
|
||||
:param max_epochs: Maximum number of epochs to train, defaults to 30.
|
||||
:param num_slides_scatter: Number of slides to plot in the scatter plot, defaults to 10.
|
||||
:param num_slides_heatmap: Number of slides to plot in the heatmap, defaults to 20.
|
||||
:param save_tile_ids: If True, will save the tile ids of the tiles in the bag in the loss cache, defaults to
|
||||
False. This is useful to analyse the tiles that are contributing to the loss value of a slide.
|
||||
:param log_exceptions: If True, will log exceptions raised during loss values analysis, defaults to True. If
|
||||
False will raise the intercepted exceptions.
|
||||
:param create_outputs_folders: If True, will create the output folders if they don't exist, defaults to True.
|
||||
:param val_set_is_dist: If True, will assume that the validation set is distributed, defaults to True.
|
||||
"""
|
||||
|
||||
self.patience = patience
|
||||
self.epochs_interval = epochs_interval
|
||||
self.max_epochs = max_epochs
|
||||
self.num_slides_scatter = num_slides_scatter
|
||||
self.num_slides_heatmap = num_slides_heatmap
|
||||
self.save_tile_ids = save_tile_ids
|
||||
self.log_exceptions = log_exceptions
|
||||
self.val_set_is_dist = val_set_is_dist
|
||||
|
||||
self.outputs_folder = outputs_folder / "loss_analysis_callback"
|
||||
if create_outputs_folders:
|
||||
self.create_outputs_folders()
|
||||
|
||||
self.loss_cache: Dict[ModelKey, LossCacheDictType] = {
|
||||
stage: self.get_empty_loss_cache() for stage in [ModelKey.TRAIN, ModelKey.VAL]
|
||||
}
|
||||
self.epochs_range = list(range(self.patience, self.max_epochs, self.epochs_interval))
|
||||
|
||||
self.nan_slides: AnomalyDictType = {stage: [] for stage in [ModelKey.TRAIN, ModelKey.VAL]}
|
||||
self.anomaly_slides: AnomalyDictType = {stage: [] for stage in [ModelKey.TRAIN, ModelKey.VAL]}
|
||||
|
||||
def get_cache_folder(self, stage: ModelKey) -> Path:
|
||||
return self.outputs_folder / f"{stage}/loss_cache"
|
||||
|
||||
def get_scatter_folder(self, stage: ModelKey) -> Path:
|
||||
return self.outputs_folder / f"{stage}/loss_scatter"
|
||||
|
||||
def get_heatmap_folder(self, stage: ModelKey) -> Path:
|
||||
return self.outputs_folder / f"{stage}/loss_heatmap"
|
||||
|
||||
def get_stats_folder(self, stage: ModelKey) -> Path:
|
||||
return self.outputs_folder / f"{stage}/loss_stats"
|
||||
|
||||
def get_anomalies_folder(self, stage: ModelKey) -> Path:
|
||||
return self.outputs_folder / f"{stage}/loss_anomalies"
|
||||
|
||||
def create_outputs_folders(self) -> None:
|
||||
folders = [
|
||||
self.get_cache_folder,
|
||||
self.get_scatter_folder,
|
||||
self.get_heatmap_folder,
|
||||
self.get_stats_folder,
|
||||
self.get_anomalies_folder,
|
||||
]
|
||||
stages = [ModelKey.TRAIN, ModelKey.VAL]
|
||||
for folder in folders:
|
||||
for stage in stages:
|
||||
os.makedirs(folder(stage), exist_ok=True)
|
||||
|
||||
def get_empty_loss_cache(self) -> LossCacheDictType:
|
||||
keys = [ResultsKey.SLIDE_ID, ResultsKey.LOSS, ResultsKey.ENTROPY]
|
||||
if self.save_tile_ids:
|
||||
keys.append(ResultsKey.TILE_ID)
|
||||
return {key: [] for key in keys}
|
||||
|
||||
def _format_epoch(self, epoch: int) -> str:
|
||||
return str(epoch).zfill(len(str(self.max_epochs)))
|
||||
|
||||
def get_loss_cache_file(self, epoch: int, stage: ModelKey) -> Path:
|
||||
return self.get_cache_folder(stage) / f"epoch_{self._format_epoch(epoch)}.csv"
|
||||
|
||||
def get_all_epochs_loss_cache_file(self, stage: ModelKey) -> Path:
|
||||
return self.get_cache_folder(stage) / f"all_epochs_{stage}.csv"
|
||||
|
||||
def get_loss_stats_file(self, stage: ModelKey) -> Path:
|
||||
return self.get_stats_folder(stage) / f"loss_stats_{stage}.csv"
|
||||
|
||||
def get_loss_ranks_file(self, stage: ModelKey) -> Path:
|
||||
return self.get_stats_folder(stage) / f"loss_ranks_{stage}.csv"
|
||||
|
||||
def get_loss_ranks_stats_file(self, stage: ModelKey) -> Path:
|
||||
return self.get_stats_folder(stage) / f"loss_ranks_stats_{stage}.csv"
|
||||
|
||||
def get_nan_slides_file(self, stage: ModelKey) -> Path:
|
||||
return self.get_anomalies_folder(stage) / f"nan_slides_{stage}.txt"
|
||||
|
||||
def get_anomaly_slides_file(self, stage: ModelKey) -> Path:
|
||||
return self.get_anomalies_folder(stage) / f"anomaly_slides_{stage}.txt"
|
||||
|
||||
def get_scatter_plot_file(self, order: str, stage: ModelKey) -> Path:
|
||||
return self.get_scatter_folder(stage) / f"slides_with_{order}_loss_values_{stage}.png"
|
||||
|
||||
def get_heatmap_plot_file(self, epoch: int, order: str, stage: ModelKey) -> Path:
|
||||
return self.get_heatmap_folder(stage) / f"epoch_{self._format_epoch(epoch)}_{order}_slides_{stage}.png"
|
||||
|
||||
def read_loss_cache(self, epoch: int, stage: ModelKey, idx_col: Optional[ResultsKey] = None) -> pd.DataFrame:
|
||||
columns = [ResultsKey.SLIDE_ID, ResultsKey.LOSS, ResultsKey.ENTROPY]
|
||||
return pd.read_csv(self.get_loss_cache_file(epoch, stage), index_col=idx_col, usecols=columns)
|
||||
|
||||
def should_cache_loss_values(self, current_epoch: int) -> bool:
|
||||
if current_epoch >= self.max_epochs:
|
||||
return False # Don't cache loss values for the extra validation epoch
|
||||
current_epoch = current_epoch + 1
|
||||
first_time = (current_epoch - self.patience) == 1
|
||||
return first_time or (
|
||||
current_epoch > self.patience and (current_epoch - self.patience) % self.epochs_interval == 0
|
||||
)
|
||||
|
||||
def merge_loss_caches(self, loss_caches: List[LossCacheDictType]) -> LossCacheDictType:
|
||||
"""Merges the loss caches from all the workers into a single loss cache"""
|
||||
loss_cache = self.get_empty_loss_cache()
|
||||
for loss_cache_per_device in loss_caches:
|
||||
for key in loss_cache.keys():
|
||||
loss_cache[key].extend(loss_cache_per_device[key])
|
||||
return loss_cache
|
||||
|
||||
def gather_loss_cache(self, rank: int, stage: ModelKey) -> None:
|
||||
"""Gathers the loss cache from all the workers"""
|
||||
if self.val_set_is_dist and torch.distributed.is_initialized():
|
||||
world_size = torch.distributed.get_world_size()
|
||||
if world_size > 1:
|
||||
loss_caches = [None] * world_size
|
||||
torch.distributed.all_gather_object(loss_caches, self.loss_cache[stage])
|
||||
if rank == 0:
|
||||
self.loss_cache[stage] = self.merge_loss_caches(loss_caches) # type: ignore
|
||||
|
||||
def save_loss_cache(self, current_epoch: int, stage: ModelKey) -> None:
|
||||
"""Saves the loss cache to a csv file"""
|
||||
loss_cache_df = pd.DataFrame(self.loss_cache[stage])
|
||||
# Some slides may be appear multiple times in the loss cache in DDP mode. The Distributed Sampler duplicate
|
||||
# some slides to even out the number of samples per device, so we only keep the first occurrence.
|
||||
loss_cache_df.drop_duplicates(subset=ResultsKey.SLIDE_ID, inplace=True, keep="first")
|
||||
loss_cache_df = loss_cache_df.sort_values(by=ResultsKey.LOSS, ascending=False)
|
||||
loss_cache_df.to_csv(self.get_loss_cache_file(current_epoch, stage), index=False)
|
||||
|
||||
def _select_values_for_epoch(
|
||||
self,
|
||||
keys: List[ResultsKey],
|
||||
epoch: int,
|
||||
stage: ModelKey,
|
||||
high: Optional[bool] = None,
|
||||
num_values: Optional[int] = None
|
||||
) -> List[np.ndarray]:
|
||||
"""Selects the values corresponding to keys from a dataframe for the given epoch and stage"""
|
||||
loss_cache = self.read_loss_cache(epoch, stage)
|
||||
return_values = []
|
||||
for key in keys:
|
||||
values = loss_cache[key].values
|
||||
if high is not None:
|
||||
assert num_values is not None, "num_values must be specified if high is specified"
|
||||
if high:
|
||||
return_values.append(values[:num_values])
|
||||
elif not high:
|
||||
return_values.append(values[-num_values:])
|
||||
else:
|
||||
return_values.append(values)
|
||||
return return_values
|
||||
|
||||
def select_slides_for_epoch(
|
||||
self, epoch: int, stage: ModelKey, high: Optional[bool] = None, num_slides: Optional[int] = None
|
||||
) -> np.ndarray:
|
||||
"""Selects slides in ascending/descending order of loss values at a given epoch
|
||||
|
||||
:param epoch: The epoch to select the slides from.
|
||||
:param stage: The model's stage (train/val).
|
||||
:param high: If True, selects the slides with the highest loss values, else selects the slides with the lowest
|
||||
loss values. If None, selects all slides.
|
||||
:param num_slides: The number of slides to select. If None, selects all slides.
|
||||
"""
|
||||
return self._select_values_for_epoch([ResultsKey.SLIDE_ID], epoch, stage, high, num_slides)[0]
|
||||
|
||||
def select_slides_entropy_for_epoch(
|
||||
self, epoch: int, stage: ModelKey, high: Optional[bool] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Selects slides and loss values of slides in ascending/descending order of loss at a given epoch
|
||||
|
||||
:param epoch: The epoch to select the slides from.
|
||||
:param stage: The model's stage (train/val).
|
||||
:param high: If True, selects the slides with the highest loss values, else selects the slides with the lowest
|
||||
loss values. If None, selects all slides.
|
||||
"""
|
||||
keys = [ResultsKey.SLIDE_ID, ResultsKey.ENTROPY]
|
||||
return_values = self._select_values_for_epoch(keys, epoch, stage, high, self.num_slides_scatter)
|
||||
return return_values[0], return_values[1]
|
||||
|
||||
def select_all_losses_for_selected_slides(self, slides: np.ndarray, stage: ModelKey) -> LossDictType:
|
||||
"""Selects the loss values for a given set of slides
|
||||
|
||||
:param slides: The slides to select the loss values for.
|
||||
:param stage: The model's stage (train/val).
|
||||
"""
|
||||
|
||||
slides_loss_values: LossDictType = {slide_id: [] for slide_id in slides}
|
||||
for epoch in self.epochs_range:
|
||||
loss_cache = self.read_loss_cache(epoch, stage, idx_col=ResultsKey.SLIDE_ID)
|
||||
for slide_id in slides:
|
||||
slides_loss_values[slide_id].append(loss_cache.loc[slide_id, ResultsKey.LOSS])
|
||||
return slides_loss_values
|
||||
|
||||
def select_slides_and_entropy_across_epochs(
|
||||
self, stage: ModelKey, high: bool = True
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Selects the slides with the highest/lowest loss values across epochs and their entropy values
|
||||
|
||||
:param stage: The model's stage (train/val).
|
||||
:param high: If True, selects the slides with the highest loss values, else selects the slides with the lowest
|
||||
loss values.
|
||||
"""
|
||||
slides = []
|
||||
slides_entropy = []
|
||||
for epoch in self.epochs_range:
|
||||
epoch_slides, epoch_slides_entropy = self.select_slides_entropy_for_epoch(epoch, stage, high)
|
||||
slides.append(epoch_slides)
|
||||
slides_entropy.append(epoch_slides_entropy)
|
||||
|
||||
return np.array(slides).T, np.array(slides_entropy).T
|
||||
|
||||
def save_slide_ids(self, slide_ids: List[str], path: Path) -> None:
|
||||
"""Dumps the slides ids in a txt file."""
|
||||
if slide_ids:
|
||||
with open(path, "w") as f:
|
||||
for slide_id in slide_ids:
|
||||
f.write(f"{slide_id}\n")
|
||||
|
||||
def sanity_check_loss_values(self, loss_values: LossDictType, stage: ModelKey) -> None:
|
||||
"""Checks if there are any NaNs or any other potential annomalies in the loss values.
|
||||
|
||||
:param loss_values: The loss values for all slides.
|
||||
:param stage: The model's stage (train/val).
|
||||
"""
|
||||
# We don't want any of these exceptions to interrupt validation. So we catch them and log them.
|
||||
loss_values_copy = loss_values.copy()
|
||||
for slide_id, loss in loss_values_copy.items():
|
||||
try:
|
||||
if np.isnan(loss).any():
|
||||
logging.warning(f"NaNs found in loss values for slide {slide_id}.")
|
||||
self.nan_slides[stage].append(slide_id)
|
||||
loss_values.pop(slide_id)
|
||||
except Exception as e:
|
||||
logging.warning(f"Error while checking for NaNs in loss values for slide {slide_id} with error {e}.")
|
||||
logging.warning(f"Loss values that caused the issue: {loss}")
|
||||
self.anomaly_slides[stage].append(slide_id)
|
||||
loss_values.pop(slide_id)
|
||||
self.save_slide_ids(self.nan_slides[stage], self.get_nan_slides_file(stage))
|
||||
self.save_slide_ids(self.anomaly_slides[stage], self.get_anomaly_slides_file(stage))
|
||||
|
||||
def save_loss_ranks(self, slides_loss_values: LossDictType, stage: ModelKey) -> None:
|
||||
"""Saves the loss ranks for each slide across all epochs and their respective statistics in csv files.
|
||||
|
||||
:param slides_loss_values: The loss values for all slides.
|
||||
:param stage: The model's stage (train/val).
|
||||
"""
|
||||
|
||||
loss_df = pd.DataFrame(slides_loss_values).T
|
||||
loss_df.index.names = [ResultsKey.SLIDE_ID.value]
|
||||
loss_df.to_csv(self.get_all_epochs_loss_cache_file(stage))
|
||||
|
||||
loss_stats = loss_df.T.describe().T.sort_values(by="mean", ascending=False)
|
||||
loss_stats.to_csv(self.get_loss_stats_file(stage))
|
||||
|
||||
loss_ranks = loss_df.rank(ascending=False)
|
||||
loss_ranks.to_csv(self.get_loss_ranks_file(stage))
|
||||
|
||||
loss_ranks_stats = loss_ranks.T.describe().T.sort_values("mean", ascending=True)
|
||||
loss_ranks_stats.to_csv(self.get_loss_ranks_stats_file(stage))
|
||||
|
||||
def plot_slides_loss_scatter(
|
||||
self,
|
||||
slides: np.ndarray,
|
||||
slides_entropy: np.ndarray,
|
||||
stage: ModelKey,
|
||||
high: bool = True,
|
||||
figsize: Tuple[float, float] = (30, 30),
|
||||
) -> None:
|
||||
"""Plots the slides with the highest/lowest loss values across epochs in a scatter plot
|
||||
|
||||
:param slides: The slides ids.
|
||||
:param slides_entropy: The entropy values for each slide.
|
||||
:param stage: The model's stage (train/val).
|
||||
:param figsize: The figure size, defaults to (30, 30)
|
||||
:param high: If True, plots the slides with the highest loss values, else plots the slides with the lowest loss.
|
||||
"""
|
||||
label = self.TOP if high else self.BOTTOM
|
||||
plt.figure(figsize=figsize)
|
||||
markers_size = [15 * i for i in range(1, self.num_slides_scatter + 1)]
|
||||
markers_size = markers_size[::-1] if high else markers_size
|
||||
colors = cm.rainbow(np.linspace(0, 1, self.num_slides_scatter))
|
||||
for i in range(self.num_slides_scatter - 1, -1, -1):
|
||||
plt.scatter(self.epochs_range, slides[i], label=f"{label}_{i+1}", s=markers_size[i], color=colors[i])
|
||||
for entropy, epoch, slide in zip(slides_entropy[i], self.epochs_range, slides[i]):
|
||||
plt.annotate(f"{entropy:.3f}", (epoch, slide))
|
||||
plt.xlabel(self.X_LABEL)
|
||||
plt.ylabel(self.Y_LABEL)
|
||||
order = self.HIGHEST if high else self.LOWEST
|
||||
plt.title(f"Slides with {order} loss values per epoch and their entropy values")
|
||||
plt.xticks(self.epochs_range)
|
||||
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
|
||||
plt.grid(True, linestyle="--")
|
||||
plt.savefig(self.get_scatter_plot_file(order, stage), bbox_inches="tight")
|
||||
|
||||
def plot_loss_heatmap_for_slides_of_epoch(
|
||||
self,
|
||||
slides_loss_values: LossDictType,
|
||||
epoch: int,
|
||||
stage: ModelKey,
|
||||
high: bool,
|
||||
figsize: Tuple[float, float] = (15, 15)
|
||||
) -> None:
|
||||
"""Plots the loss values for each slide across all epochs in a heatmap.
|
||||
|
||||
:param slides_loss_values: The loss values for each slide across all epochs.
|
||||
:param epoch: The epoch used to select the slides.
|
||||
:param stage: The model's stage (train/val).
|
||||
:param high: If True, plots the slides with the highest loss values, else plots the slides with the lowest loss.
|
||||
:param figsize: The figure size, defaults to (15, 15)
|
||||
"""
|
||||
order = self.HIGHEST if high else self.LOWEST
|
||||
loss_values = np.array(list(slides_loss_values.values()))
|
||||
slides = list(slides_loss_values.keys())
|
||||
plt.figure(figsize=figsize)
|
||||
_ = sns.heatmap(loss_values, linewidth=0.5, annot=True, yticklabels=slides)
|
||||
plt.xlabel(self.X_LABEL)
|
||||
plt.ylabel(self.Y_LABEL)
|
||||
plt.title(f"Loss values evolution for {order} slides of epoch {epoch}")
|
||||
plt.savefig(self.get_heatmap_plot_file(epoch, order, stage), bbox_inches="tight")
|
||||
|
||||
@staticmethod
|
||||
def compute_entropy(class_probs: torch.Tensor) -> List[float]:
|
||||
"""Computes the entropy of the class probabilities.
|
||||
|
||||
:param class_probs: The class probabilities.
|
||||
"""
|
||||
return (-torch.sum(class_probs * torch.log(class_probs), dim=1)).tolist()
|
||||
|
||||
def update_loss_cache(self, trainer: Trainer, outputs: BatchResultsType, batch: Dict, stage: ModelKey) -> None:
|
||||
"""Updates the loss cache with the loss values for the current batch."""
|
||||
if self.should_cache_loss_values(trainer.current_epoch):
|
||||
self.loss_cache[stage][ResultsKey.LOSS].extend(outputs[ResultsKey.LOSS_PER_SAMPLE])
|
||||
self.loss_cache[stage][ResultsKey.SLIDE_ID].extend([slides[0] for slides in batch[ResultsKey.SLIDE_ID]])
|
||||
self.loss_cache[stage][ResultsKey.ENTROPY].extend(self.compute_entropy(outputs[ResultsKey.CLASS_PROBS]))
|
||||
if self.save_tile_ids:
|
||||
self.loss_cache[stage][ResultsKey.TILE_ID].extend(
|
||||
[self.TILES_JOIN_TOKEN.join(tiles) for tiles in outputs[ResultsKey.TILE_ID]]
|
||||
)
|
||||
|
||||
def synchronise_processes_and_reset(self, trainer: Trainer, pl_module: BaseDeepMILModule, stage: ModelKey) -> None:
|
||||
"""Synchronises the processes in DDP mode and resets the loss cache for the next epoch."""
|
||||
if self.should_cache_loss_values(trainer.current_epoch):
|
||||
self.gather_loss_cache(rank=pl_module.global_rank, stage=stage)
|
||||
if pl_module.global_rank == 0:
|
||||
self.save_loss_cache(trainer.current_epoch, stage)
|
||||
self.loss_cache[stage] = self.get_empty_loss_cache() # reset loss cache for all processes
|
||||
|
||||
def save_loss_outliers_analaysis_results(self, stage: ModelKey) -> None:
|
||||
"""Saves the loss outliers analysis results."""
|
||||
all_slides = self.select_slides_for_epoch(epoch=0, stage=stage)
|
||||
all_loss_values_per_slides = self.select_all_losses_for_selected_slides(all_slides, stage=stage)
|
||||
|
||||
self.sanity_check_loss_values(all_loss_values_per_slides, stage=stage)
|
||||
self.save_loss_ranks(all_loss_values_per_slides, stage=stage)
|
||||
|
||||
top_slides, top_slides_entropy = self.select_slides_and_entropy_across_epochs(stage, high=True)
|
||||
self.plot_slides_loss_scatter(top_slides, top_slides_entropy, stage, high=True)
|
||||
|
||||
bottom_slides, bottom_slides_entropy = self.select_slides_and_entropy_across_epochs(stage, high=False)
|
||||
self.plot_slides_loss_scatter(bottom_slides, bottom_slides_entropy, stage, high=False)
|
||||
|
||||
for epoch in self.epochs_range:
|
||||
epoch_slides = self.select_slides_for_epoch(epoch, stage=stage)
|
||||
|
||||
top_slides = epoch_slides[:self.num_slides_heatmap]
|
||||
top_slides_loss_values = self.select_all_losses_for_selected_slides(top_slides, stage=stage)
|
||||
self.plot_loss_heatmap_for_slides_of_epoch(top_slides_loss_values, epoch, stage, high=True)
|
||||
|
||||
bottom_slides = epoch_slides[-self.num_slides_heatmap:]
|
||||
bottom_slides_loss_values = self.select_all_losses_for_selected_slides(bottom_slides, stage=stage)
|
||||
self.plot_loss_heatmap_for_slides_of_epoch(bottom_slides_loss_values, epoch, stage, high=False)
|
||||
self.loss_cache[stage] = self.get_empty_loss_cache() # reset loss cache
|
||||
|
||||
def handle_loss_exceptions(self, stage: ModelKey, exception: Exception) -> None:
|
||||
"""Handles the loss exceptions."""
|
||||
if self.log_exceptions:
|
||||
# If something goes wrong, we don't want to crash the training. We just log the error and carry on
|
||||
# validation.
|
||||
logging.warning(f"Error while detecting {stage} loss values outliers: {exception}")
|
||||
else:
|
||||
# If we want to debug the error, we raise it. This will crash the training. This is useful when
|
||||
# running smoke tests.
|
||||
raise Exception(f"Error while detecting {stage} loss values outliers: {exception}")
|
||||
|
||||
def on_train_batch_end( # type: ignore
|
||||
self,
|
||||
trainer: Trainer,
|
||||
pl_module: BaseDeepMILModule,
|
||||
outputs: BatchResultsType,
|
||||
batch: Dict,
|
||||
batch_idx: int,
|
||||
unused: int = 0
|
||||
) -> None:
|
||||
"""Caches train loss values per slide at each training step in a local variable self.loss_cache."""
|
||||
self.update_loss_cache(trainer, outputs, batch, stage=ModelKey.TRAIN)
|
||||
|
||||
def on_validation_batch_end( # type: ignore
|
||||
self,
|
||||
trainer: Trainer,
|
||||
pl_module: BaseDeepMILModule,
|
||||
outputs: BatchResultsType,
|
||||
batch: Any,
|
||||
batch_idx: int,
|
||||
dataloader_idx: int
|
||||
) -> None:
|
||||
"""Caches validation loss values per slide at each training step in a local variable self.loss_cache."""
|
||||
self.update_loss_cache(trainer, outputs, batch, stage=ModelKey.VAL)
|
||||
|
||||
def on_train_epoch_end(self, trainer: Trainer, pl_module: BaseDeepMILModule) -> None: # type: ignore
|
||||
"""Gathers loss values per slide from all processes at the end of each epoch and saves them to a csv file."""
|
||||
self.synchronise_processes_and_reset(trainer, pl_module, ModelKey.TRAIN)
|
||||
|
||||
def on_validation_epoch_end(self, trainer: Trainer, pl_module: BaseDeepMILModule) -> None: # type: ignore
|
||||
"""Gathers loss values per slide from all processes at the end of each epoch and saves them to a csv file."""
|
||||
self.synchronise_processes_and_reset(trainer, pl_module, ModelKey.VAL)
|
||||
|
||||
def on_train_end(self, trainer: Trainer, pl_module: BaseDeepMILModule) -> None: # type: ignore
|
||||
"""Hook called at the end of training. Plot the loss heatmap and scratter plots after ranking the slides by loss
|
||||
values."""
|
||||
if pl_module.global_rank == 0:
|
||||
try:
|
||||
self.save_loss_outliers_analaysis_results(stage=ModelKey.TRAIN)
|
||||
except Exception as e:
|
||||
self.handle_loss_exceptions(stage=ModelKey.TRAIN, exception=e)
|
||||
|
||||
def on_validation_end(self, trainer: Trainer, pl_module: BaseDeepMILModule) -> None: # type: ignore
|
||||
"""Hook called at the end of validation. Plot the loss heatmap and scratter plots after ranking the slides by
|
||||
loss values."""
|
||||
epoch = trainer.current_epoch
|
||||
if pl_module.global_rank == 0 and not pl_module._on_extra_val_epoch and epoch == (self.max_epochs - 1):
|
||||
try:
|
||||
self.save_loss_outliers_analaysis_results(stage=ModelKey.VAL)
|
||||
except Exception as e:
|
||||
self.handle_loss_exceptions(stage=ModelKey.VAL, exception=e)
|
|
@ -7,16 +7,16 @@ import param
|
|||
from torch import nn
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from torchvision.models.resnet import resnet18, resnet50
|
||||
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, CheckpointDownloader
|
||||
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
from health_cpath.models.encoders import (
|
||||
HistoSSLEncoder,
|
||||
ImageNetEncoder,
|
||||
ImageNetEncoder_Resnet50,
|
||||
ImageNetSimCLREncoder,
|
||||
SSLEncoder,
|
||||
TileEncoder,
|
||||
Resnet18,
|
||||
Resnet50,
|
||||
Resnet18_NoPreproc,
|
||||
Resnet50_NoPreproc,
|
||||
)
|
||||
from health_ml.networks.layers.attention_layers import (
|
||||
AttentionLayer,
|
||||
|
@ -28,14 +28,28 @@ from health_ml.networks.layers.attention_layers import (
|
|||
)
|
||||
|
||||
|
||||
def set_module_gradients_enabled(model: nn.Module, tuning_flag: bool) -> None:
|
||||
"""Given a model, enable or disable gradients for all parameters.
|
||||
|
||||
:param model: A PyTorch model.
|
||||
:param tuning_flag: A boolean indicating whether to enable or disable gradients for the model parameters.
|
||||
"""
|
||||
for params in model.parameters():
|
||||
params.requires_grad = tuning_flag
|
||||
|
||||
|
||||
class EncoderParams(param.Parameterized):
|
||||
"""Parameters class to group all encoder specific attributes for deepmil module. """
|
||||
|
||||
encoder_type: str = param.String(doc="Name of the encoder class to use.")
|
||||
tile_size: int = param.Integer(default=224, bounds=(1, None), doc="Tile width/height, in pixels.")
|
||||
n_channels: int = param.Integer(default=3, bounds=(1, None), doc="Number of channels in the tile.")
|
||||
is_finetune: bool = param.Boolean(
|
||||
False, doc="If True, fine-tune the encoder during training. If False (default), " "keep the encoder frozen."
|
||||
tune_encoder: bool = param.Boolean(
|
||||
False, doc="If True, fine-tune the encoder during training. If False (default), keep the encoder frozen."
|
||||
)
|
||||
pretrained_encoder = param.Boolean(
|
||||
False, doc="If True, transfer weights from the pretrained model (specified in `src_checkpoint`) to the encoder."
|
||||
"Else (False), keep the encoder weights as defined by the `encoder_type`."
|
||||
)
|
||||
is_caching: bool = param.Boolean(
|
||||
default=False,
|
||||
|
@ -45,26 +59,35 @@ class EncoderParams(param.Parameterized):
|
|||
encoding_chunk_size: int = param.Integer(
|
||||
default=0, doc="If > 0 performs encoding in chunks, by enconding_chunk_size tiles " "per chunk"
|
||||
)
|
||||
ssl_checkpoint: CheckpointParser = param.ClassSelector(class_=CheckpointParser, default=None,
|
||||
instantiate=False, doc=CheckpointParser.DOC)
|
||||
|
||||
def get_encoder(self, ssl_ckpt_run_id: Optional[str], outputs_folder: Optional[Path]) -> TileEncoder:
|
||||
def validate(self) -> None:
|
||||
"""Validate the encoder parameters."""
|
||||
if self.encoder_type == SSLEncoder.__name__ and not self.ssl_checkpoint:
|
||||
raise ValueError("SSLEncoder requires an ssl_checkpoint. Please specify a valid checkpoint. "
|
||||
f"{CheckpointParser.INFO_MESSAGE}")
|
||||
|
||||
def get_encoder(self, outputs_folder: Optional[Path]) -> TileEncoder:
|
||||
"""Given the current encoder parameters, returns the encoder object.
|
||||
|
||||
:param ssl_ckpt_run_id: The AML run id for SSL checkpoint download.
|
||||
:param outputs_folder: The output folder where SSL checkpoint should be saved.
|
||||
:param encoder_params: The encoder arguments that define the encoder class object depending on the encoder type.
|
||||
:raises ValueError: If the encoder type is not supported.
|
||||
:return: A TileEncoder instance for deepmil module.
|
||||
"""
|
||||
encoder: TileEncoder
|
||||
if self.encoder_type == ImageNetEncoder.__name__:
|
||||
encoder = ImageNetEncoder(
|
||||
feature_extraction_model=resnet18, tile_size=self.tile_size, n_channels=self.n_channels,
|
||||
)
|
||||
elif self.encoder_type == ImageNetEncoder_Resnet50.__name__:
|
||||
# Myronenko et al. 2021 uses Resnet50 CNN encoder
|
||||
encoder = ImageNetEncoder_Resnet50(
|
||||
feature_extraction_model=resnet50, tile_size=self.tile_size, n_channels=self.n_channels,
|
||||
)
|
||||
if self.encoder_type == Resnet18.__name__:
|
||||
encoder = Resnet18(tile_size=self.tile_size, n_channels=self.n_channels)
|
||||
|
||||
elif self.encoder_type == Resnet18_NoPreproc.__name__:
|
||||
encoder = Resnet18_NoPreproc(tile_size=self.tile_size, n_channels=self.n_channels)
|
||||
|
||||
elif self.encoder_type == Resnet50.__name__:
|
||||
encoder = Resnet50(tile_size=self.tile_size, n_channels=self.n_channels)
|
||||
|
||||
elif self.encoder_type == Resnet50_NoPreproc.__name__:
|
||||
encoder = Resnet50_NoPreproc(tile_size=self.tile_size, n_channels=self.n_channels)
|
||||
|
||||
elif self.encoder_type == ImageNetSimCLREncoder.__name__:
|
||||
encoder = ImageNetSimCLREncoder(tile_size=self.tile_size, n_channels=self.n_channels)
|
||||
|
@ -73,24 +96,15 @@ class EncoderParams(param.Parameterized):
|
|||
encoder = HistoSSLEncoder(tile_size=self.tile_size, n_channels=self.n_channels)
|
||||
|
||||
elif self.encoder_type == SSLEncoder.__name__:
|
||||
assert ssl_ckpt_run_id and outputs_folder, "SSLEncoder requires ssl_ckpt_run_id and outputs_folder"
|
||||
downloader = CheckpointDownloader(run_id=ssl_ckpt_run_id,
|
||||
download_dir=outputs_folder,
|
||||
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
|
||||
remote_checkpoint_dir=Path(DEFAULT_AML_CHECKPOINT_DIR))
|
||||
assert outputs_folder is not None, "outputs_folder cannot be None for SSLEncoder"
|
||||
encoder = SSLEncoder(
|
||||
pl_checkpoint_path=downloader.local_checkpoint_path,
|
||||
pl_checkpoint_path=self.ssl_checkpoint.get_path(outputs_folder),
|
||||
tile_size=self.tile_size,
|
||||
n_channels=self.n_channels,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported encoder type: {self.encoder_type}")
|
||||
|
||||
if self.is_finetune:
|
||||
for params in encoder.parameters():
|
||||
params.requires_grad = True
|
||||
else:
|
||||
encoder.eval()
|
||||
set_module_gradients_enabled(encoder, tuning_flag=self.tune_encoder)
|
||||
return encoder
|
||||
|
||||
|
||||
|
@ -106,9 +120,20 @@ class PoolingParams(param.Parameterized):
|
|||
default=4, doc="If transformer pooling is chosen, this defines the number of encoding layers.",
|
||||
)
|
||||
num_transformer_pool_heads: int = param.Integer(
|
||||
4,
|
||||
doc="If transformer pooling is chosen, this defines the number\
|
||||
of attention heads.",
|
||||
default=4, doc="If transformer pooling is chosen, this defines the number of attention heads.",
|
||||
)
|
||||
tune_pooling: bool = param.Boolean(
|
||||
default=True,
|
||||
doc="If True (default), fine-tune the pooling layer during training. If False, keep the pooling layer frozen.",
|
||||
)
|
||||
pretrained_pooling = param.Boolean(
|
||||
default=False,
|
||||
doc="If True, transfer weights from the pretrained model (specified in `src_checkpoint`) to the pooling"
|
||||
"layer. Else (False), initialize the pooling layer randomly.",
|
||||
)
|
||||
transformer_dropout: float = param.Number(
|
||||
default=0.0,
|
||||
doc="If transformer pooling is chosen, this defines the dropout of the tranformer encoder layers.",
|
||||
)
|
||||
|
||||
def get_pooling_layer(self, num_encoding: int) -> Tuple[nn.Module, int]:
|
||||
|
@ -121,24 +146,53 @@ class PoolingParams(param.Parameterized):
|
|||
"""
|
||||
pooling_layer: nn.Module
|
||||
if self.pool_type == AttentionLayer.__name__:
|
||||
pooling_layer = AttentionLayer(num_encoding, self.pool_hidden_dim, self.pool_out_dim)
|
||||
pooling_layer = AttentionLayer(input_dims=num_encoding,
|
||||
hidden_dims=self.pool_hidden_dim,
|
||||
attention_dims=self.pool_out_dim)
|
||||
elif self.pool_type == GatedAttentionLayer.__name__:
|
||||
pooling_layer = GatedAttentionLayer(num_encoding, self.pool_hidden_dim, self.pool_out_dim)
|
||||
pooling_layer = GatedAttentionLayer(input_dims=num_encoding,
|
||||
hidden_dims=self.pool_hidden_dim,
|
||||
attention_dims=self.pool_out_dim)
|
||||
elif self.pool_type == MeanPoolingLayer.__name__:
|
||||
pooling_layer = MeanPoolingLayer()
|
||||
elif self.pool_type == MaxPoolingLayer.__name__:
|
||||
pooling_layer = MaxPoolingLayer()
|
||||
elif self.pool_type == TransformerPooling.__name__:
|
||||
pooling_layer = TransformerPooling(
|
||||
self.num_transformer_pool_layers, self.num_transformer_pool_heads, num_encoding
|
||||
)
|
||||
num_layers=self.num_transformer_pool_layers,
|
||||
num_heads=self.num_transformer_pool_heads,
|
||||
dim_representation=num_encoding,
|
||||
transformer_dropout=self.transformer_dropout)
|
||||
self.pool_out_dim = 1 # currently this is hardcoded in forward of the TransformerPooling
|
||||
elif self.pool_type == TransformerPoolingBenchmark.__name__:
|
||||
pooling_layer = TransformerPoolingBenchmark(
|
||||
self.num_transformer_pool_layers, self.num_transformer_pool_heads, num_encoding, self.pool_hidden_dim
|
||||
)
|
||||
num_layers=self.num_transformer_pool_layers,
|
||||
num_heads=self.num_transformer_pool_heads,
|
||||
dim_representation=num_encoding,
|
||||
hidden_dim=self.pool_hidden_dim,
|
||||
transformer_dropout=self.transformer_dropout)
|
||||
self.pool_out_dim = 1 # currently this is hardcoded in forward of the TransformerPooling
|
||||
else:
|
||||
raise ValueError(f"Unsupported pooling type: {self.pooling_type}")
|
||||
raise ValueError(f"Unsupported pooling type: {self.pool_type}")
|
||||
num_features = num_encoding * self.pool_out_dim
|
||||
set_module_gradients_enabled(pooling_layer, tuning_flag=self.tune_pooling)
|
||||
return pooling_layer, num_features
|
||||
|
||||
|
||||
class ClassifierParams(param.Parameterized):
|
||||
dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.")
|
||||
tune_classifier: bool = param.Boolean(
|
||||
default=True,
|
||||
doc="If True (default), fine-tune the classifier during training. If False, keep the classifier frozen.")
|
||||
pretrained_classifier: bool = param.Boolean(
|
||||
default=False,
|
||||
doc="If True, will use classifier weights from pretrained model specified in src_checkpoint. If False, will "
|
||||
"initiliaze classifier with random weights.")
|
||||
|
||||
def get_classifier(self, in_features: int, out_features: int) -> nn.Module:
|
||||
classifier_layer = nn.Linear(in_features=in_features,
|
||||
out_features=out_features)
|
||||
set_module_gradients_enabled(classifier_layer, tuning_flag=self.tune_classifier)
|
||||
if self.dropout_rate is None:
|
||||
return classifier_layer
|
||||
return nn.Sequential(nn.Dropout(self.dropout_rate), classifier_layer)
|
||||
|
|
|
@ -50,6 +50,7 @@ class ResultsKey(str, Enum):
|
|||
FEATURES = 'features'
|
||||
IMAGE_PATH = 'image_path'
|
||||
LOSS = 'loss'
|
||||
LOSS_PER_SAMPLE = 'loss_per_sample'
|
||||
PROB = 'prob'
|
||||
CLASS_PROBS = 'prob_class'
|
||||
PRED_LABEL = 'pred_label'
|
||||
|
@ -59,6 +60,7 @@ class ResultsKey(str, Enum):
|
|||
TILE_TOP = 'top'
|
||||
TILE_RIGHT = 'right'
|
||||
TILE_BOTTOM = 'bottom'
|
||||
ENTROPY = 'entropy'
|
||||
|
||||
|
||||
class MetricsKey(str, Enum):
|
||||
|
@ -71,6 +73,8 @@ class MetricsKey(str, Enum):
|
|||
RECALL = 'recall'
|
||||
F1 = 'f1score'
|
||||
COHENKAPPA = 'cohenkappa'
|
||||
AVERAGE_PRECISION = 'average_precision'
|
||||
SPECIFICITY = 'specificity'
|
||||
|
||||
|
||||
class ModelKey(str, Enum):
|
||||
|
@ -85,10 +89,19 @@ class AMLMetricsJsonKey(str, Enum):
|
|||
VALUE = 'value'
|
||||
N_CLASSES = 'n_classes'
|
||||
CLASS_NAMES = 'class_names'
|
||||
MAX_EPOCHS = 'max_epochs'
|
||||
|
||||
|
||||
class PlotOption(Enum):
|
||||
TOP_BOTTOM_TILES = "top_bottom_tiles"
|
||||
SLIDE_THUMBNAIL_HEATMAP = "slide_thumbnail_heatmap"
|
||||
SLIDE_THUMBNAIL = "slide_thumbnail"
|
||||
ATTENTION_HEATMAP = "attention_heatmap"
|
||||
CONFUSION_MATRIX = "confusion_matrix"
|
||||
HISTOGRAM = "histogram"
|
||||
PR_CURVE = "pr_curve"
|
||||
|
||||
|
||||
class DeepMILSubmodules(str, Enum):
|
||||
ENCODER = 'encoder'
|
||||
POOLING = 'aggregation_fn'
|
||||
CLASSIFIER = 'classifier_fn'
|
||||
|
|
|
@ -25,6 +25,8 @@ OUTPUTS_CSV_FILENAME = "test_output.csv"
|
|||
VAL_OUTPUTS_SUBDIR = "val"
|
||||
PREV_VAL_OUTPUTS_SUBDIR = "val_old"
|
||||
TEST_OUTPUTS_SUBDIR = "test"
|
||||
EXTRA_VAL_OUTPUTS_SUBDIR = "extra_val"
|
||||
EXTRA_PREFIX = "extra_"
|
||||
|
||||
AML_OUTPUTS_DIR = "outputs"
|
||||
AML_LEGACY_TEST_OUTPUTS_CSV = "/".join([AML_OUTPUTS_DIR, OUTPUTS_CSV_FILENAME])
|
||||
|
@ -55,10 +57,13 @@ def validate_class_names(class_names: Optional[Sequence[str]], n_classes: int) -
|
|||
def validate_slide_datasets_for_plot_options(
|
||||
plot_options: Collection[PlotOption], slides_dataset: Optional[SlidesDataset]
|
||||
) -> None:
|
||||
if PlotOption.SLIDE_THUMBNAIL_HEATMAP in plot_options and not slides_dataset:
|
||||
raise ValueError("You can not plot slide thumbnails and heatmaps without setting a slides_dataset. "
|
||||
"Please remove `PlotOption.SLIDE_THUMBNAIL_HEATMAP` from your plot options or provide "
|
||||
"a slide dataset.")
|
||||
|
||||
def _validate_slide_plot_option(plot_option: PlotOption) -> None:
|
||||
if plot_option in plot_options and not slides_dataset:
|
||||
raise ValueError(f"Plot option {plot_option} requires a slides dataset")
|
||||
|
||||
_validate_slide_plot_option(PlotOption.SLIDE_THUMBNAIL)
|
||||
_validate_slide_plot_option(PlotOption.ATTENTION_HEATMAP)
|
||||
|
||||
|
||||
def normalize_dict_for_df(dict_old: Dict[ResultsKey, Any]) -> Dict[str, Any]:
|
||||
|
@ -194,7 +199,7 @@ class OutputsPolicy:
|
|||
YAML().dump(contents, self.best_metric_file_path)
|
||||
|
||||
def should_save_validation_outputs(self, metrics_dict: Mapping[MetricsKey, Metric], epoch: int,
|
||||
is_global_rank_zero: bool = True) -> bool:
|
||||
is_global_rank_zero: bool = True, on_extra_val: bool = False) -> bool:
|
||||
"""Determine whether validation outputs should be saved given the current epoch's metrics.
|
||||
|
||||
:param metrics_dict: Current epoch's metrics dictionary from
|
||||
|
@ -202,8 +207,11 @@ class OutputsPolicy:
|
|||
:param epoch: Current epoch number.
|
||||
:param is_global_rank_zero: Whether this is the global rank-0 process in distributed scenarios.
|
||||
Set to `True` (default) if running a single process.
|
||||
:param on_extra_val: Whether this is an extra validation epoch (e.g. after training).
|
||||
:return: Whether this is the best validation epoch so far.
|
||||
"""
|
||||
if on_extra_val:
|
||||
return False
|
||||
metric = metrics_dict[self.primary_val_metric]
|
||||
# If the metric hasn't been updated we don't want to save it
|
||||
if not metric._update_called:
|
||||
|
@ -244,7 +252,8 @@ class DeepMILOutputsHandler:
|
|||
def __init__(self, outputs_root: Path, n_classes: int, tile_size: int, level: int,
|
||||
class_names: Optional[Sequence[str]], primary_val_metric: MetricsKey,
|
||||
maximise: bool, val_plot_options: Collection[PlotOption],
|
||||
test_plot_options: Collection[PlotOption]) -> None:
|
||||
test_plot_options: Collection[PlotOption], wsi_has_mask: bool = True,
|
||||
val_set_is_dist: bool = True) -> None:
|
||||
"""
|
||||
:param outputs_root: Root directory where to save all produced outputs.
|
||||
:param n_classes: Number of MIL classes (set `n_classes=1` for binary).
|
||||
|
@ -256,6 +265,11 @@ class DeepMILOutputsHandler:
|
|||
:param maximise: Whether higher is better for `primary_val_metric`.
|
||||
:param val_plot_options: The desired plot options for validation time.
|
||||
:param test_plot_options: The desired plot options for test time.
|
||||
:param wsi_has_mask: Whether the whole slides have a mask to crop specific ROIs.
|
||||
:param val_set_is_dist: If True, the validation set is distributed across processes. Otherwise, the validation
|
||||
set is replicated on each process. This shouldn't affect the results, as we take the mean of the validation
|
||||
set metrics across processes. This is only relevant for the outputs_handler, which needs to know whether to
|
||||
gather the validation set outputs across processes or not before saving them.
|
||||
"""
|
||||
self.outputs_root = outputs_root
|
||||
self.n_classes = n_classes
|
||||
|
@ -273,19 +287,28 @@ class DeepMILOutputsHandler:
|
|||
plot_options=val_plot_options,
|
||||
level=self.level,
|
||||
tile_size=self.tile_size,
|
||||
class_names=self.class_names
|
||||
class_names=self.class_names,
|
||||
stage=ModelKey.VAL,
|
||||
wsi_has_mask=wsi_has_mask
|
||||
)
|
||||
self.test_plots_handler = DeepMILPlotsHandler(
|
||||
plot_options=test_plot_options,
|
||||
level=self.level,
|
||||
tile_size=self.tile_size,
|
||||
class_names=self.class_names
|
||||
class_names=self.class_names,
|
||||
stage=ModelKey.TEST,
|
||||
wsi_has_mask=wsi_has_mask
|
||||
)
|
||||
self.val_set_is_dist = val_set_is_dist
|
||||
|
||||
@property
|
||||
def validation_outputs_dir(self) -> Path:
|
||||
return self.outputs_root / VAL_OUTPUTS_SUBDIR
|
||||
|
||||
@property
|
||||
def extra_validation_outputs_dir(self) -> Path:
|
||||
return self.outputs_root / EXTRA_VAL_OUTPUTS_SUBDIR
|
||||
|
||||
@property
|
||||
def previous_validation_outputs_dir(self) -> Path:
|
||||
return self.validation_outputs_dir.with_name(PREV_VAL_OUTPUTS_SUBDIR)
|
||||
|
@ -300,11 +323,15 @@ class DeepMILOutputsHandler:
|
|||
self.test_plots_handler.slides_dataset = slides_dataset
|
||||
self.val_plots_handler.slides_dataset = slides_dataset
|
||||
|
||||
def should_gather_tiles(self, plots_handler: DeepMILPlotsHandler) -> bool:
|
||||
return PlotOption.TOP_BOTTOM_TILES in plots_handler.plot_options and self.tiles_selector is not None
|
||||
|
||||
def _save_outputs(self, epoch_results: EpochResultsType, outputs_dir: Path, stage: ModelKey = ModelKey.VAL) -> None:
|
||||
"""Trigger the rendering and saving of DeepMIL outputs and figures.
|
||||
|
||||
:param epoch_results: Aggregated results from all epoch batches.
|
||||
:param outputs_dir: Specific directory into which outputs should be saved (different for validation and test).
|
||||
:param stage: The stage of the model (e.g. `ModelKey.VAL` or `ModelKey.TEST`).
|
||||
"""
|
||||
# outputs object consists of a list of dictionaries (of metadata and results, including encoded features)
|
||||
# It can be indexed as outputs[batch_idx][batch_key][bag_idx][tile_idx]
|
||||
|
@ -321,8 +348,7 @@ class DeepMILOutputsHandler:
|
|||
plots_handler.save_plots(outputs_dir, self.tiles_selector, results)
|
||||
|
||||
def save_validation_outputs(self, epoch_results: EpochResultsType, metrics_dict: Mapping[MetricsKey, Metric],
|
||||
epoch: int, is_global_rank_zero: bool = True, run_extra_val_epoch: bool = False
|
||||
) -> None:
|
||||
epoch: int, is_global_rank_zero: bool = True, on_extra_val: bool = False) -> None:
|
||||
"""Render and save validation epoch outputs, according to the configured :py:class:`OutputsPolicy`.
|
||||
|
||||
:param epoch_results: Aggregated results from all epoch batches, as passed to :py:meth:`validation_epoch_end()`.
|
||||
|
@ -331,29 +357,35 @@ class DeepMILOutputsHandler:
|
|||
:param is_global_rank_zero: Whether this is the global rank-0 process in distributed scenarios.
|
||||
Set to `True` (default) if running a single process.
|
||||
:param epoch: Current epoch number.
|
||||
:param on_extra_val: Whether this is an extra validation epoch (e.g. after training).
|
||||
"""
|
||||
# All DDP processes must reach this point to allow synchronising epoch results
|
||||
gathered_epoch_results = gather_results(epoch_results)
|
||||
if PlotOption.TOP_BOTTOM_TILES in self.val_plots_handler.plot_options and self.tiles_selector:
|
||||
self.tiles_selector.gather_selected_tiles_across_devices()
|
||||
# All DDP processes must reach this point to allow synchronising epoch results if val_set_is_dist is True
|
||||
if self.val_set_is_dist:
|
||||
epoch_results = gather_results(epoch_results)
|
||||
|
||||
if self.should_gather_tiles(self.val_plots_handler):
|
||||
self.tiles_selector.gather_selected_tiles_across_devices() # type: ignore
|
||||
|
||||
# Only global rank-0 process should actually render and save the outputs
|
||||
# We also want to save the plots of the extra validation epoch
|
||||
if (
|
||||
self.outputs_policy.should_save_validation_outputs(metrics_dict, epoch, is_global_rank_zero)
|
||||
or (run_extra_val_epoch and is_global_rank_zero)
|
||||
):
|
||||
if self.outputs_policy.should_save_validation_outputs(metrics_dict, epoch, is_global_rank_zero, on_extra_val):
|
||||
# First move existing outputs to a temporary directory, to avoid mixing
|
||||
# outputs of different epochs in case writing fails halfway through
|
||||
if self.validation_outputs_dir.exists():
|
||||
replace_directory(source=self.validation_outputs_dir,
|
||||
target=self.previous_validation_outputs_dir)
|
||||
|
||||
self._save_outputs(gathered_epoch_results, self.validation_outputs_dir, ModelKey.VAL)
|
||||
self._save_outputs(epoch_results, self.validation_outputs_dir, ModelKey.VAL)
|
||||
|
||||
# Writing completed successfully; delete temporary back-up
|
||||
if self.previous_validation_outputs_dir.exists():
|
||||
shutil.rmtree(self.previous_validation_outputs_dir, ignore_errors=True)
|
||||
elif on_extra_val and is_global_rank_zero:
|
||||
self._save_outputs(epoch_results, self.extra_validation_outputs_dir, ModelKey.VAL)
|
||||
|
||||
# Reset the top and bottom slides heaps
|
||||
if self.should_gather_tiles(self.val_plots_handler):
|
||||
self.tiles_selector._clear_cached_slides_heaps() # type: ignore
|
||||
|
||||
def save_test_outputs(self, epoch_results: EpochResultsType, is_global_rank_zero: bool = True) -> None:
|
||||
"""Render and save test epoch outputs.
|
||||
|
@ -365,8 +397,8 @@ class DeepMILOutputsHandler:
|
|||
"""
|
||||
# All DDP processes must reach this point to allow synchronising epoch results
|
||||
gathered_epoch_results = gather_results(epoch_results)
|
||||
if PlotOption.TOP_BOTTOM_TILES in self.test_plots_handler.plot_options and self.tiles_selector:
|
||||
self.tiles_selector.gather_selected_tiles_across_devices()
|
||||
if self.should_gather_tiles(self.test_plots_handler):
|
||||
self.tiles_selector.gather_selected_tiles_across_devices() # type: ignore
|
||||
|
||||
# Only global rank-0 process should actually render and save the outputs-
|
||||
if self.outputs_policy.should_save_test_outputs(is_global_rank_zero):
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Collection, List, Optional, Sequence, Tuple, Dict
|
|||
|
||||
from sklearn.metrics import confusion_matrix
|
||||
from torch import Tensor
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from health_cpath.datasets.base_dataset import SlidesDataset
|
||||
from health_cpath.utils.viz_utils import (
|
||||
|
@ -17,12 +18,14 @@ from health_cpath.utils.viz_utils import (
|
|||
plot_scores_hist,
|
||||
plot_slide,
|
||||
)
|
||||
from health_cpath.utils.analysis_plot_utils import plot_pr_curve, format_pr_or_roc_axes
|
||||
from health_cpath.utils.naming import PlotOption, ResultsKey, SlideKey
|
||||
from health_cpath.utils.tiles_selection_utils import SlideNode, TilesSelector
|
||||
from health_cpath.utils.viz_utils import load_image_dict, save_figure
|
||||
|
||||
|
||||
ResultsType = Dict[ResultsKey, List[Any]]
|
||||
SlideDictType = Dict[SlideKey, Any]
|
||||
|
||||
|
||||
def validate_class_names_for_plot_options(
|
||||
|
@ -37,21 +40,44 @@ def validate_class_names_for_plot_options(
|
|||
return class_names
|
||||
|
||||
|
||||
def save_scores_histogram(results: ResultsType, figures_dir: Path) -> None:
|
||||
def save_scores_histogram(results: ResultsType, figures_dir: Path, stage: str = '') -> None:
|
||||
"""Plots and saves histogram scores figure in its dedicated directory.
|
||||
|
||||
:param results: List that contains slide_level dicts
|
||||
:param results: Dict of lists that contains slide_level results
|
||||
:param figures_dir: The path to the directory where to save the histogram scores.
|
||||
:param stage: Test or validation, used to name the figure. Empty string by default.
|
||||
"""
|
||||
fig = plot_scores_hist(results)
|
||||
save_figure(fig=fig, figpath=figures_dir / "hist_scores.png")
|
||||
save_figure(fig=fig, figpath=figures_dir / f"hist_scores_{stage}.png")
|
||||
|
||||
|
||||
def save_confusion_matrix(results: ResultsType, class_names: Sequence[str], figures_dir: Path) -> None:
|
||||
def save_pr_curve(results: ResultsType, figures_dir: Path, stage: str = '') -> None:
|
||||
"""Plots and saves PR curve figure in its dedicated directory. This implementation
|
||||
only works for binary classification.
|
||||
''
|
||||
:param results: Dict of lists that contains slide_level results
|
||||
:param figures_dir: The path to the directory where to save the histogram scores
|
||||
:param stage: Test or validation, used to name the figure. Empty string by default.
|
||||
"""
|
||||
true_labels = [i.item() if isinstance(i, Tensor) else i for i in results[ResultsKey.TRUE_LABEL]]
|
||||
if len(set(true_labels)) == 2:
|
||||
scores = [i.item() if isinstance(i, Tensor) else i for i in results[ResultsKey.PROB]]
|
||||
fig, ax = plt.subplots()
|
||||
plot_pr_curve(true_labels, scores, legend_label=stage, ax=ax)
|
||||
ax.legend()
|
||||
format_pr_or_roc_axes(plot_type='pr', ax=ax)
|
||||
save_figure(fig=fig, figpath=figures_dir / f"pr_curve_{stage}.png")
|
||||
else:
|
||||
raise Warning("The PR curve plot implementation works only for binary cases, this plot will be skipped.")
|
||||
|
||||
|
||||
def save_confusion_matrix(results: ResultsType, class_names: Sequence[str], figures_dir: Path, stage: str = '') -> None:
|
||||
"""Plots and saves confusion matrix figure in its dedicated directory.
|
||||
|
||||
:param results: Dict of lists that contains slide_level results
|
||||
:param class_names: List of class names.
|
||||
:param figures_dir: The path to the directory where to save the confusion matrix.
|
||||
:param stage: Test or validation, used to name the figure. Empty string by default.
|
||||
"""
|
||||
true_labels = [i.item() if isinstance(i, Tensor) else i for i in results[ResultsKey.TRUE_LABEL]]
|
||||
pred_labels = [i.item() if isinstance(i, Tensor) else i for i in results[ResultsKey.PRED_LABEL]]
|
||||
|
@ -68,11 +94,11 @@ def save_confusion_matrix(results: ResultsType, class_names: Sequence[str], figu
|
|||
true_labels,
|
||||
pred_labels,
|
||||
labels=all_potential_labels,
|
||||
normalize="pred"
|
||||
normalize="true"
|
||||
)
|
||||
|
||||
fig = plot_normalized_confusion_matrix(cm=cf_matrix_n, class_names=(class_names))
|
||||
save_figure(fig=fig, figpath=figures_dir / "normalized_confusion_matrix.png")
|
||||
save_figure(fig=fig, figpath=figures_dir / f"normalized_confusion_matrix_{stage}.png")
|
||||
|
||||
|
||||
def save_top_and_bottom_tiles(
|
||||
|
@ -95,12 +121,24 @@ def save_top_and_bottom_tiles(
|
|||
save_figure(fig=bottom_tiles_fig, figpath=figures_dir / f"{slide_node.slide_id}_bottom.png")
|
||||
|
||||
|
||||
def save_slide_thumbnail_and_heatmap(
|
||||
def save_slide_thumbnail(case: str, slide_node: SlideNode, slide_dict: SlideDictType, figures_dir: Path) -> None:
|
||||
"""Plots and saves a slide thumbnail
|
||||
|
||||
:param case: The report case (e.g., TP, FN, ...)
|
||||
:param slide_node: The slide node that encapsulates the slide metadata.
|
||||
:param slide_dict: The slide dictionary that contains the slide image and other metadata.
|
||||
:param figures_dir: The path to the directory where to save the plots.
|
||||
"""
|
||||
fig = plot_slide(case=case, slide_node=slide_node, slide_image=slide_dict[SlideKey.IMAGE], scale=1.0)
|
||||
save_figure(fig=fig, figpath=figures_dir / f"{slide_node.slide_id}_thumbnail.png")
|
||||
|
||||
|
||||
def save_attention_heatmap(
|
||||
case: str,
|
||||
slide_node: SlideNode,
|
||||
slide_dict: SlideDictType,
|
||||
figures_dir: Path,
|
||||
results: ResultsType,
|
||||
slides_dataset: SlidesDataset,
|
||||
tile_size: int = 224,
|
||||
level: int = 1,
|
||||
) -> None:
|
||||
|
@ -108,29 +146,19 @@ def save_slide_thumbnail_and_heatmap(
|
|||
|
||||
:param case: The report case (e.g., TP, FN, ...)
|
||||
:param slide_node: The slide node that encapsulates the slide metadata.
|
||||
:param slide_dict: The slide dictionary that contains the slide image and other metadata.
|
||||
:param figures_dir: The path to the directory where to save the plots.
|
||||
:param results: Dict containing ResultsKey keys (e.g. slide id) and values as lists of output slides.
|
||||
:param slides_dataset: The slides dataset from where to pick the slide.
|
||||
:param tile_size: Size of each tile. Default 224.
|
||||
:param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original,
|
||||
1 for 4x downsampled, 2 for 16x downsampled). Default 1.
|
||||
"""
|
||||
slide_index = slides_dataset.dataset_df.index.get_loc(slide_node.slide_id)
|
||||
assert isinstance(slide_index, int), f"Got non-unique slide ID: {slide_node.slide_id}"
|
||||
slide_dict = slides_dataset[slide_index]
|
||||
slide_dict = load_image_dict(slide_dict, level=level, margin=0)
|
||||
slide_image = slide_dict[SlideKey.IMAGE]
|
||||
location_bbox = slide_dict[SlideKey.LOCATION]
|
||||
|
||||
fig = plot_slide(case=case, slide_node=slide_node, slide_image=slide_image, scale=1.0)
|
||||
save_figure(fig=fig, figpath=figures_dir / f"{slide_node.slide_id}_thumbnail.png")
|
||||
|
||||
fig = plot_heatmap_overlay(
|
||||
case=case,
|
||||
slide_node=slide_node,
|
||||
slide_image=slide_image,
|
||||
slide_image=slide_dict[SlideKey.IMAGE],
|
||||
results=results,
|
||||
location_bbox=location_bbox,
|
||||
location_bbox=slide_dict[SlideKey.ORIGIN],
|
||||
tile_size=tile_size,
|
||||
level=level,
|
||||
)
|
||||
|
@ -152,7 +180,9 @@ class DeepMILPlotsHandler:
|
|||
tile_size: int = 224,
|
||||
num_columns: int = 4,
|
||||
figsize: Tuple[int, int] = (10, 10),
|
||||
stage: str = '',
|
||||
class_names: Optional[Sequence[str]] = None,
|
||||
wsi_has_mask: bool = True,
|
||||
) -> None:
|
||||
"""_summary_
|
||||
|
||||
|
@ -162,6 +192,7 @@ class DeepMILPlotsHandler:
|
|||
:param tile_size: _description_, defaults to 224
|
||||
:param num_columns: Number of columns to create the subfigures grid, defaults to 4
|
||||
:param figsize: The figure size of tiles attention plots, defaults to (10, 10)
|
||||
:param stage: Test or Validation, used to name the plots
|
||||
:param class_names: List of class names, defaults to None
|
||||
:param slides_dataset: The slides dataset from where to load the whole slide images, defaults to None
|
||||
"""
|
||||
|
@ -171,34 +202,38 @@ class DeepMILPlotsHandler:
|
|||
self.tile_size = tile_size
|
||||
self.num_columns = num_columns
|
||||
self.figsize = figsize
|
||||
self.stage = stage
|
||||
self.wsi_has_mask = wsi_has_mask
|
||||
self.slides_dataset: Optional[SlidesDataset] = None
|
||||
|
||||
def get_slide_dict(self, slide_node: SlideNode) -> SlideDictType:
|
||||
"""Returns the slide dictionary for a given slide node"""
|
||||
assert self.slides_dataset is not None, "Cannot plot attention heatmap or wsi without slides dataset"
|
||||
slide_index = self.slides_dataset.dataset_df.index.get_loc(slide_node.slide_id)
|
||||
assert isinstance(slide_index, int), f"Got non-unique slide ID: {slide_node.slide_id}"
|
||||
slide_dict = self.slides_dataset[slide_index]
|
||||
slide_dict = load_image_dict(slide_dict, level=self.level, margin=0, wsi_has_mask=self.wsi_has_mask)
|
||||
return slide_dict
|
||||
|
||||
def save_slide_node_figures(
|
||||
self, case: str, slide_node: SlideNode, outputs_dir: Path, results: ResultsType
|
||||
) -> None:
|
||||
"""Plots and saves all slide related figures, e.g., `TOP_BOTTOM_TILES` and `SLIDE_THUMBNAIL_HEATMAP`"""
|
||||
|
||||
"""Plots and saves all slide related figures: `TOP_BOTTOM_TILES`, `SLIDE_THUMBNAIL` and `ATTENTION_HEATMAP`."""
|
||||
case_dir = make_figure_dirs(subfolder=case, parent_dir=outputs_dir)
|
||||
|
||||
if PlotOption.TOP_BOTTOM_TILES in self.plot_options:
|
||||
save_top_and_bottom_tiles(
|
||||
case=case,
|
||||
slide_node=slide_node,
|
||||
figures_dir=case_dir,
|
||||
num_columns=self.num_columns,
|
||||
figsize=self.figsize,
|
||||
)
|
||||
if PlotOption.SLIDE_THUMBNAIL_HEATMAP in self.plot_options:
|
||||
assert self.slides_dataset
|
||||
save_slide_thumbnail_and_heatmap(
|
||||
case=case,
|
||||
slide_node=slide_node,
|
||||
figures_dir=case_dir,
|
||||
results=results,
|
||||
slides_dataset=self.slides_dataset,
|
||||
tile_size=self.tile_size,
|
||||
level=self.level,
|
||||
)
|
||||
save_top_and_bottom_tiles(case, slide_node, case_dir, self.num_columns, self.figsize)
|
||||
|
||||
if PlotOption.ATTENTION_HEATMAP in self.plot_options or PlotOption.SLIDE_THUMBNAIL in self.plot_options:
|
||||
slide_dict = self.get_slide_dict(slide_node=slide_node)
|
||||
|
||||
if PlotOption.SLIDE_THUMBNAIL in self.plot_options:
|
||||
save_slide_thumbnail(case=case, slide_node=slide_node, slide_dict=slide_dict, figures_dir=case_dir)
|
||||
|
||||
if PlotOption.ATTENTION_HEATMAP in self.plot_options:
|
||||
save_attention_heatmap(
|
||||
case, slide_node, slide_dict, case_dir, results, self.tile_size, level=self.level
|
||||
)
|
||||
|
||||
def save_plots(self, outputs_dir: Path, tiles_selector: Optional[TilesSelector], results: ResultsType) -> None:
|
||||
"""Plots and saves all selected plot options during inference (validation or test) time.
|
||||
|
@ -214,12 +249,15 @@ class DeepMILPlotsHandler:
|
|||
)
|
||||
figures_dir = make_figure_dirs(subfolder="fig", parent_dir=outputs_dir)
|
||||
|
||||
if PlotOption.PR_CURVE in self.plot_options:
|
||||
save_pr_curve(results=results, figures_dir=figures_dir, stage=self.stage)
|
||||
|
||||
if PlotOption.HISTOGRAM in self.plot_options:
|
||||
save_scores_histogram(results=results, figures_dir=figures_dir)
|
||||
save_scores_histogram(results=results, figures_dir=figures_dir, stage=self.stage,)
|
||||
|
||||
if PlotOption.CONFUSION_MATRIX in self.plot_options:
|
||||
assert self.class_names
|
||||
save_confusion_matrix(results, class_names=self.class_names, figures_dir=figures_dir)
|
||||
save_confusion_matrix(results, class_names=self.class_names, figures_dir=figures_dir, stage=self.stage)
|
||||
|
||||
if tiles_selector:
|
||||
for class_id in range(tiles_selector.n_classes):
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Sequence, Tuple
|
||||
from typing import Any, Dict, List, Sequence, Tuple
|
||||
|
||||
import dateutil.parser
|
||||
import numpy as np
|
||||
|
@ -39,8 +39,8 @@ def run_has_val_and_test_outputs(run: Run) -> bool:
|
|||
f"{AML_TEST_OUTPUTS_CSV}): {available_files}")
|
||||
|
||||
|
||||
def crossval_runs_have_val_and_test_outputs(parent_run: Run) -> bool:
|
||||
"""Checks whether all child cross-validation runs have both validation and test outputs files.
|
||||
def child_runs_have_val_and_test_outputs(parent_run: Run) -> bool:
|
||||
"""Checks whether all child hyperdrive runs have both validation and test outputs files.
|
||||
|
||||
:param parent_run: The parent Hyperdrive run.
|
||||
:raises ValueError: If any of the child runs does not have the expected output files, or if
|
||||
|
@ -58,30 +58,30 @@ def crossval_runs_have_val_and_test_outputs(parent_run: Run) -> bool:
|
|||
"test-only outputs and with both validation and test outputs")
|
||||
|
||||
|
||||
def collect_crossval_outputs(parent_run_id: str, download_dir: Path, aml_workspace: Workspace,
|
||||
crossval_arg_name: str = "crossval_index",
|
||||
output_filename: str = "test_output.csv",
|
||||
overwrite: bool = False) -> Dict[int, pd.DataFrame]:
|
||||
"""Fetch output CSV files from cross-validation runs as dataframes.
|
||||
def collect_hyperdrive_outputs(parent_run_id: str, download_dir: Path, aml_workspace: Workspace,
|
||||
hyperdrive_arg_name: str = "crossval_index",
|
||||
output_filename: str = "test_output.csv",
|
||||
overwrite: bool = False) -> Dict[int, pd.DataFrame]:
|
||||
"""Fetch output CSV files from Hyperdrive child runs as dataframes.
|
||||
|
||||
Will only download the CSV files if they do not already exist locally.
|
||||
|
||||
:param parent_run_id: Azure ML run ID for the parent Hyperdrive run.
|
||||
:param download_dir: Base directory where to download the CSV files. A new sub-directory will
|
||||
be created for each child run (e.g. `<download_dir>/<crossval index>/*.csv`).
|
||||
be created for each child run (e.g. `<download_dir>/<hyperdrive_arg_name>/*.csv`).
|
||||
:param aml_workspace: Azure ML workspace in which the runs were executed.
|
||||
:param crossval_arg_name: Name of the Hyperdrive argument used for indexing the child runs.
|
||||
:param hyperdrive_arg_name: Name of the Hyperdrive argument used for indexing the child runs.
|
||||
:param output_filename: Filename of the output CSVs to download.
|
||||
:param overwrite: Whether to force the download even if each file already exists locally.
|
||||
:return: A dictionary of dataframes with the sorted cross-validation indices as keys.
|
||||
:return: A dictionary of dataframes with the sorted hyperdrive_arg_name indices as keys.
|
||||
"""
|
||||
parent_run = get_aml_run_from_run_id(parent_run_id, aml_workspace)
|
||||
|
||||
all_outputs_dfs = {}
|
||||
for child_run in parent_run.get_children():
|
||||
child_run_index = get_tags_from_hyperdrive_run(child_run, crossval_arg_name)
|
||||
child_run_index = get_tags_from_hyperdrive_run(child_run, hyperdrive_arg_name)
|
||||
if child_run_index is None:
|
||||
raise ValueError(f"Child run expected to have the tag '{crossval_arg_name}'")
|
||||
raise ValueError(f"Child run expected to have the tag '{hyperdrive_arg_name}'")
|
||||
child_dir = download_dir / str(child_run_index)
|
||||
try:
|
||||
child_csv = download_file_if_necessary(child_run, output_filename, child_dir / output_filename,
|
||||
|
@ -92,10 +92,10 @@ def collect_crossval_outputs(parent_run_id: str, download_dir: Path, aml_workspa
|
|||
return dict(sorted(all_outputs_dfs.items())) # type: ignore
|
||||
|
||||
|
||||
def collect_crossval_metrics(parent_run_id: str, download_dir: Path, aml_workspace: Workspace,
|
||||
crossval_arg_name: str = "crossval_index",
|
||||
overwrite: bool = False) -> pd.DataFrame:
|
||||
"""Fetch metrics logged to Azure ML from cross-validation runs as a dataframe.
|
||||
def download_hyperdrive_metrics_if_required(parent_run_id: str, download_dir: Path, aml_workspace: Workspace,
|
||||
hyperdrive_arg_name: str = "crossval_index",
|
||||
overwrite: bool = False) -> Path:
|
||||
"""Fetch metrics logged to Azure ML from hyperdrive runs.
|
||||
|
||||
Will only download the metrics if they do not already exist locally, as this can take several
|
||||
seconds for each child run.
|
||||
|
@ -103,89 +103,111 @@ def collect_crossval_metrics(parent_run_id: str, download_dir: Path, aml_workspa
|
|||
:param parent_run_id: Azure ML run ID for the parent Hyperdrive run.
|
||||
:param download_dir: Directory where to save the downloaded metrics as `aml_metrics.json`.
|
||||
:param aml_workspace: Azure ML workspace in which the runs were executed.
|
||||
:param crossval_arg_name: Name of the Hyperdrive argument used for indexing the child runs.
|
||||
:param hyperdrive_arg_name: Name of the Hyperdrive argument used for indexing the child runs.
|
||||
:param overwrite: Whether to force the download even if metrics are already saved locally.
|
||||
:return: A dataframe in the format returned by :py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
|
||||
:return: The path of the downloaded json file.
|
||||
"""
|
||||
metrics_json = download_dir / "aml_metrics.json"
|
||||
if not overwrite and metrics_json.is_file():
|
||||
print(f"AML metrics file already exists at {metrics_json}")
|
||||
metrics_df = pd.read_json(metrics_json)
|
||||
else:
|
||||
metrics_df = aggregate_hyperdrive_metrics(run_id=parent_run_id,
|
||||
child_run_arg_name=crossval_arg_name,
|
||||
child_run_arg_name=hyperdrive_arg_name,
|
||||
aml_workspace=aml_workspace)
|
||||
metrics_json.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Writing AML metrics file to {metrics_json}")
|
||||
df_to_json(metrics_df, metrics_json)
|
||||
return metrics_df.sort_index(axis='columns')
|
||||
return metrics_json
|
||||
|
||||
|
||||
def get_crossval_metrics_table(metrics_df: pd.DataFrame, metrics_list: Sequence[str]) -> pd.DataFrame:
|
||||
"""Format raw cross-validation metrics into a table with a summary "Mean ± Std" column.
|
||||
def collect_hyperdrive_metrics(metrics_json: Path) -> pd.DataFrame:
|
||||
"""
|
||||
Collect the hyperdrive metrics from the downloaded metrics json file in a dataframe.
|
||||
:param metrics_json: Path of the downloaded metrics file `aml_metrics.json`.
|
||||
:return: A dataframe in the format returned by :py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
|
||||
"""
|
||||
metrics_df = pd.read_json(metrics_json).sort_index(axis='columns')
|
||||
return metrics_df
|
||||
|
||||
|
||||
def get_hyperdrive_metrics_table(metrics_df: pd.DataFrame, metrics_list: Sequence[str]) -> pd.DataFrame:
|
||||
"""Format raw hyperdrive metrics into a table with a summary "Mean ± Std" column.
|
||||
|
||||
Note that this function only supports scalar metrics. To format metrics that are logged
|
||||
throughout training, you should call :py:func:`get_best_epoch_metrics()` first.
|
||||
|
||||
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_crossval_metrics()` and
|
||||
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_hyperdrive_metrics()` and
|
||||
:py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
|
||||
:param metrics_list: The list of metrics to include in the table.
|
||||
:return: A dataframe with the values of the selected metrics formatted as strings, including a
|
||||
header and a summary column.
|
||||
"""
|
||||
header = ["Metric"] + [f"Split {k}" for k in metrics_df.columns] + ["Mean ± Std"]
|
||||
header = ["Metric"] + [f"Child {k}" for k in metrics_df.columns] + ["Mean ± Std"]
|
||||
metrics_rows = []
|
||||
for metric in metrics_list:
|
||||
values: pd.Series = metrics_df.loc[metric]
|
||||
mean = values.mean()
|
||||
std = values.std()
|
||||
row = [metric] + [f"{v:.3f}" for v in values] + [f"{mean:.3f} ± {std:.3f}"]
|
||||
round_values: List[str] = [f"{v:.3f}" if v is not None else str(np.nan) for v in values]
|
||||
agg_values: List[str] = [f"{mean:.3f} ± {std:.3f}"]
|
||||
row = [metric] + round_values + agg_values
|
||||
metrics_rows.append(row)
|
||||
table = pd.DataFrame(metrics_rows, columns=header).set_index(header[0])
|
||||
return table
|
||||
|
||||
|
||||
def get_best_epochs(metrics_df: pd.DataFrame, primary_metric: str, maximise: bool = True) -> Dict[int, int]:
|
||||
"""Determine the best epoch for each cross-validation run based on a given metric.
|
||||
def get_best_epochs(metrics_df: pd.DataFrame, primary_metric: str, max_epochs_dict: Dict[int, int],
|
||||
maximise: bool = True) -> Dict[int, Any]:
|
||||
"""Determine the best epoch for each hyperdrive child run based on a given metric.
|
||||
|
||||
The returned epoch indices are relative to the logging frequency of the chosen metric, i.e.
|
||||
should not be mixed between pipeline stages that log metrics at different epoch intervals.
|
||||
|
||||
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_crossval_metrics()` and
|
||||
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_hyperdrive_metrics()` and
|
||||
:py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
|
||||
:param primary_metric: Name of the reference metric to optimise.
|
||||
:max_epochs_dict: A dictionary of the maximum number of epochs in each cross-validation round.
|
||||
:param maximise: Whether the given metric should be maximised (minimised if `False`).
|
||||
:return: Dictionary mapping each cross-validation index to its best epoch.
|
||||
:return: Dictionary mapping each hyperdrive child index to its best epoch.
|
||||
"""
|
||||
best_fn = np.argmax if maximise else np.argmin
|
||||
best_epochs = metrics_df.loc[primary_metric].apply(best_fn)
|
||||
return best_epochs.to_dict()
|
||||
best_epochs: Dict[int, Any] = {}
|
||||
for child_index in metrics_df.columns:
|
||||
primary_metric_list = metrics_df[child_index][primary_metric]
|
||||
if primary_metric_list is not None:
|
||||
# If extra validation epoch was logged (N+1), return only the first N elements
|
||||
primary_metric_list = primary_metric_list[:-1] \
|
||||
if (len(primary_metric_list) == max_epochs_dict[child_index] + 1) else primary_metric_list
|
||||
best_epochs[child_index] = int(np.argmax(primary_metric_list)
|
||||
if maximise else np.argmin(primary_metric_list))
|
||||
else:
|
||||
best_epochs[child_index] = None
|
||||
return best_epochs
|
||||
|
||||
|
||||
def get_best_epoch_metrics(metrics_df: pd.DataFrame, metrics_list: Sequence[str],
|
||||
best_epochs: Dict[int, int]) -> pd.DataFrame:
|
||||
"""Extract the values of the selected cross-validation metrics at the given best epochs.
|
||||
best_epochs: Dict[int, Any]) -> pd.DataFrame:
|
||||
"""Extract the values of the selected hyperdrive metrics at the given best epochs.
|
||||
|
||||
The `best_epoch` indices are relative to the logging frequency of the chosen primary metric,
|
||||
i.e. the metrics in `metrics_list` must have been logged at the same epoch intervals.
|
||||
|
||||
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_crossval_metrics()` and
|
||||
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_hyperdrive_metrics()` and
|
||||
:py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
|
||||
:param metrics_list: Names of the metrics to index by the best epoch indices provided. Their
|
||||
values in `metrics_df` should be lists.
|
||||
:param best_epochs: Dictionary of cross-validation indices to best epochs, as returned by
|
||||
:param best_epochs: Dictionary of hyperdrive child runs indices to best epochs, as returned by
|
||||
:py:func:`get_best_epochs()`.
|
||||
:return: Dataframe with the same columns as `metrics_df` and rows specified by `metrics_list`,
|
||||
containing only scalar values.
|
||||
"""
|
||||
best_metrics = [metrics_df.loc[metrics_list, k].apply(lambda values: values[epoch])
|
||||
for k, epoch in best_epochs.items()]
|
||||
if epoch is not None else metrics_df.loc[metrics_list, k] for k, epoch in best_epochs.items()]
|
||||
best_metrics_df = pd.DataFrame(best_metrics).T
|
||||
return best_metrics_df
|
||||
|
||||
|
||||
def get_formatted_run_info(parent_run: Run) -> str:
|
||||
"""Format Azure ML cross-validation run information as HTML.
|
||||
"""Format Azure ML hyperdrive run information as HTML.
|
||||
|
||||
Includes details of the parent and child runs, as well as submission information.
|
||||
|
||||
|
@ -216,20 +238,30 @@ def get_formatted_run_info(parent_run: Run) -> str:
|
|||
return html
|
||||
|
||||
|
||||
def collect_class_info(metrics_df: pd.DataFrame) -> Tuple[int, List[str]]:
|
||||
def get_child_runs_hyperparams(metrics_df: pd.DataFrame) -> Dict[int, Dict]:
|
||||
"""
|
||||
Get the class names from metrics dataframe
|
||||
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_crossval_metrics()` and
|
||||
Get the hyperparameters of each child run from the metrics dataframe.
|
||||
:param: metrics_df: Metrics dataframe, as returned by :py:func:`collect_hyperdrive_metrics()` and
|
||||
:py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
|
||||
:return: Number of classes and list of class names
|
||||
:return: A dictionary of hyperparameter dictionaries for the child runs.
|
||||
"""
|
||||
hyperparams = metrics_df[0][AMLMetricsJsonKey.HYPERPARAMS]
|
||||
hyperparams_name = hyperparams[AMLMetricsJsonKey.NAME]
|
||||
hyperparams_value = hyperparams[AMLMetricsJsonKey.VALUE]
|
||||
num_classes_index = hyperparams_name.index(AMLMetricsJsonKey.N_CLASSES)
|
||||
num_classes = int(hyperparams_value[num_classes_index])
|
||||
class_names_index = hyperparams_name.index(AMLMetricsJsonKey.CLASS_NAMES)
|
||||
class_names = hyperparams_value[class_names_index]
|
||||
hyperparams_children = {}
|
||||
for child_index in metrics_df.columns:
|
||||
hyperparams = metrics_df[child_index][AMLMetricsJsonKey.HYPERPARAMS]
|
||||
hyperparams_dict = dict(zip(hyperparams[AMLMetricsJsonKey.NAME], hyperparams[AMLMetricsJsonKey.VALUE]))
|
||||
hyperparams_children[child_index] = hyperparams_dict
|
||||
return hyperparams_children
|
||||
|
||||
|
||||
def collect_class_info(hyperparams_children: Dict[int, Dict]) -> Tuple[int, List[str]]:
|
||||
"""
|
||||
Get the class names from the hyperparameters of child runs.
|
||||
:param hyperparams_children: Dict of hyperparameter dicts, as returned by :py:func:`get_child_runs_hyperparams()`.
|
||||
:return: Number of classes and list of class names.
|
||||
"""
|
||||
hyperparams_single_run = list(hyperparams_children.values())[0]
|
||||
num_classes = int(hyperparams_single_run[AMLMetricsJsonKey.N_CLASSES])
|
||||
class_names = hyperparams_single_run[AMLMetricsJsonKey.CLASS_NAMES]
|
||||
if class_names == "None":
|
||||
class_names = None
|
||||
else:
|
||||
|
@ -237,3 +269,15 @@ def collect_class_info(metrics_df: pd.DataFrame) -> Tuple[int, List[str]]:
|
|||
class_names = [name.lstrip() for name in class_names[1:-1].replace("'", "").split(',')]
|
||||
class_names = validate_class_names(class_names=class_names, n_classes=num_classes)
|
||||
return (num_classes, list(class_names))
|
||||
|
||||
|
||||
def get_max_epochs(hyperparams_children: Dict[int, Dict]) -> Dict[int, int]:
|
||||
"""
|
||||
Get the maximum number of epochs for each round from the metrics dataframe.
|
||||
:param hyperparams_children: Dict of hyperparameter dicts, as returned by :py:func:`get_child_runs_hyperparams()`.
|
||||
:return: Dictionary with the number of epochs in each hyperdrive run.
|
||||
"""
|
||||
max_epochs_dict = {}
|
||||
for child_index in hyperparams_children.keys():
|
||||
max_epochs_dict[child_index] = int(hyperparams_children[child_index][AMLMetricsJsonKey.MAX_EPOCHS])
|
||||
return max_epochs_dict
|
||||
|
|
|
@ -42,23 +42,27 @@ class SlideNode:
|
|||
"""Data structure class for slide nodes used by `TopBottomTilesHandler` to store top and bottom slides by
|
||||
probability score. """
|
||||
|
||||
def __init__(self, slide_id: str, prob_score: float, true_label: int, pred_label: int) -> None:
|
||||
def __init__(
|
||||
self, slide_id: str, gt_prob_score: float, pred_prob_score: float, true_label: int, pred_label: int
|
||||
) -> None:
|
||||
"""
|
||||
:param prob_score: The probability score assigned to the slide node. This scalar defines the order of the
|
||||
slide_node among the nodes in the max/min heap.
|
||||
:param gt_prob_score: The probability score assigned to ground truth label of slide node. This scalar defines
|
||||
the order of the slide_node among the nodes in the max/min heap.
|
||||
:param pred_prob_score: The probability score assigned to predicted label of slide node.
|
||||
:param slide_id: The slide id in the data cohort.
|
||||
:param true_label: The ground truth label of the slide node.
|
||||
:param pred_label: The label predicted by the model.
|
||||
"""
|
||||
self.slide_id = slide_id
|
||||
self.prob_score = prob_score
|
||||
self.gt_prob_score = gt_prob_score
|
||||
self.pred_prob_score = pred_prob_score
|
||||
self.true_label = true_label
|
||||
self.pred_label = pred_label
|
||||
self.top_tiles: List[TileNode] = []
|
||||
self.bottom_tiles: List[TileNode] = []
|
||||
|
||||
def __lt__(self, other: "SlideNode") -> bool:
|
||||
return self.prob_score < other.prob_score
|
||||
return self.gt_prob_score < other.gt_prob_score
|
||||
|
||||
def update_selected_tiles(self, tiles: Tensor, attn_scores: Tensor, num_top_tiles: int) -> None:
|
||||
"""Update top and bottom k tiles values from a set of tiles and their assigned attention scores.
|
||||
|
@ -77,7 +81,7 @@ class SlideNode:
|
|||
|
||||
def _shallow_copy(self) -> "SlideNode":
|
||||
"""Returns a shallow copy of the current slide node contaning only the slide_id and its probability score."""
|
||||
return SlideNode(self.slide_id, self.prob_score, self.true_label, self.pred_label)
|
||||
return SlideNode(self.slide_id, self.gt_prob_score, self.pred_prob_score, self.true_label, self.pred_label)
|
||||
|
||||
|
||||
SlideOrTileKey = Union[SlideKey, TileKey]
|
||||
|
@ -106,23 +110,10 @@ class TilesSelector:
|
|||
self.num_tiles = num_tiles
|
||||
self.top_slides_heaps: SlideDict = self._initialise_slide_heaps()
|
||||
self.bottom_slides_heaps: SlideDict = self._initialise_slide_heaps()
|
||||
self.report_cases_slide_ids = self.init_report_cases()
|
||||
|
||||
def _initialise_slide_heaps(self) -> SlideDict:
|
||||
return {class_id: [] for class_id in range(self.n_classes)}
|
||||
|
||||
def init_report_cases(self) -> Dict[str, List[str]]:
|
||||
""" Initializes the report cases dictionary to store slide_ids per case.
|
||||
Possible cases are set such as class 0 is considered to be the negative class.
|
||||
|
||||
:return: A dictionary that maps TN/FP TP_{i}/FN_{i}, i={1,..., self.n_classes+1} cases to an empty list to be
|
||||
filled with corresponding slide ids.
|
||||
"""
|
||||
report_cases: Dict[str, List[str]] = {"TN": [], "FP": []}
|
||||
report_cases.update({f"TP_{class_id}": [] for class_id in range(1, self.n_classes)})
|
||||
report_cases.update({f"FN_{class_id}": [] for class_id in range(1, self.n_classes)})
|
||||
return report_cases
|
||||
|
||||
def _clear_cached_slides_heaps(self) -> None:
|
||||
self.top_slides_heaps = self._initialise_slide_heaps()
|
||||
self.bottom_slides_heaps = self._initialise_slide_heaps()
|
||||
|
@ -136,7 +127,7 @@ class TilesSelector:
|
|||
) -> None:
|
||||
"""Update the selected slides of a given class label on the fly by updating the content of class_slides_heap.
|
||||
First, we push a shallow slide_node into the slides_heaps[gt_label]. The order in slides_heaps[gt_label] is
|
||||
defined by the slide_node.prob_score that is positive in top_slides_heaps nodes and negative in
|
||||
defined by the slide_node.gt_prob_score that is positive in top_slides_heaps nodes and negative in
|
||||
bottom_slides_heaps nodes.
|
||||
Second, we check if we exceeded self.num_top_slides to be selected.
|
||||
If so, we update the slides_node top and bottom tiles only if it has been kept in the heap.
|
||||
|
@ -147,7 +138,7 @@ class TilesSelector:
|
|||
:param tiles: Tiles of a given whole slide retrieved from the current validation or test batch.
|
||||
(n_tiles, channels, height, width)
|
||||
:param attn_scores: The tiles attention scores to determine top and bottom tiles. (n_tiles, )
|
||||
:param slide_node: A shallow version of slide_node that contains only slide_id and its assigned prob_score.
|
||||
:param slide_node: A shallow version of slide_node that contains only slide_id and additional metadata.
|
||||
"""
|
||||
heapq.heappush(class_slides_heap, slide_node)
|
||||
if len(class_slides_heap) == self.num_slides + 1:
|
||||
|
@ -170,33 +161,39 @@ class TilesSelector:
|
|||
slide_ids = [slide_id[0] for slide_id in slide_ids] # to account for repetitions in tiles pipeline
|
||||
batch_size = len(batch[SlideKey.IMAGE])
|
||||
for i in range(batch_size):
|
||||
label = results[ResultsKey.TRUE_LABEL][i].item()
|
||||
probs_gt_label = results[ResultsKey.CLASS_PROBS][:, label][i].item()
|
||||
gt_label = results[ResultsKey.TRUE_LABEL][i].item()
|
||||
pred_label = results[ResultsKey.PRED_LABEL][i].item()
|
||||
gt_prob_score = results[ResultsKey.CLASS_PROBS][:, gt_label][i].item()
|
||||
pred_prob_score = results[ResultsKey.CLASS_PROBS][:, pred_label][i].item()
|
||||
tiles = batch[SlideKey.IMAGE][i]
|
||||
attn_scores = results[ResultsKey.BAG_ATTN][i].squeeze(0)
|
||||
pred_label = results[ResultsKey.PRED_LABEL][i].item()
|
||||
self._update_label_slides(
|
||||
class_slides_heap=self.top_slides_heaps[label],
|
||||
tiles=tiles,
|
||||
attn_scores=attn_scores,
|
||||
slide_node=SlideNode(
|
||||
slide_id=slide_ids[i],
|
||||
prob_score=probs_gt_label,
|
||||
true_label=label,
|
||||
pred_label=pred_label,
|
||||
),
|
||||
)
|
||||
self._update_label_slides(
|
||||
class_slides_heap=self.bottom_slides_heaps[label],
|
||||
tiles=tiles,
|
||||
attn_scores=attn_scores,
|
||||
slide_node=SlideNode(
|
||||
slide_id=slide_ids[i],
|
||||
prob_score=-probs_gt_label, # negative score for bottom slides to reverse order in max heap
|
||||
true_label=label,
|
||||
pred_label=pred_label,
|
||||
),
|
||||
)
|
||||
if pred_label == gt_label:
|
||||
self._update_label_slides(
|
||||
class_slides_heap=self.top_slides_heaps[gt_label],
|
||||
tiles=tiles,
|
||||
attn_scores=attn_scores,
|
||||
slide_node=SlideNode(
|
||||
slide_id=slide_ids[i],
|
||||
gt_prob_score=gt_prob_score,
|
||||
pred_prob_score=pred_prob_score,
|
||||
true_label=gt_label,
|
||||
pred_label=pred_label,
|
||||
),
|
||||
)
|
||||
elif pred_label != gt_label:
|
||||
self._update_label_slides(
|
||||
class_slides_heap=self.bottom_slides_heaps[gt_label],
|
||||
tiles=tiles,
|
||||
attn_scores=attn_scores,
|
||||
slide_node=SlideNode(
|
||||
slide_id=slide_ids[i],
|
||||
# negative score for bottom slides to reverse order in max heap
|
||||
gt_prob_score=-gt_prob_score,
|
||||
pred_prob_score=pred_prob_score,
|
||||
true_label=gt_label,
|
||||
pred_label=pred_label,
|
||||
),
|
||||
)
|
||||
|
||||
def _shallow_copy_slides_heaps(self, slides_heaps: SlideDict) -> SlideDict:
|
||||
"""Returns a shallow copy of slides heaps to be synchronised across devices.
|
||||
|
|
|
@ -18,6 +18,7 @@ from typing import Sequence, List, Any, Dict, Optional, Union, Tuple
|
|||
from monai.data.dataset import Dataset
|
||||
from monai.data.image_reader import WSIReader
|
||||
from torch.utils.data import DataLoader
|
||||
from health_cpath.preprocessing.loading import LoadROId
|
||||
|
||||
from health_cpath.utils.naming import SlideKey
|
||||
from health_cpath.utils.naming import ResultsKey
|
||||
|
@ -26,7 +27,7 @@ from health_cpath.utils.tiles_selection_utils import SlideNode
|
|||
from health_cpath.datasets.panda_dataset import PandaDataset, LoadPandaROId
|
||||
|
||||
|
||||
def load_image_dict(sample: dict, level: int, margin: int) -> Dict[SlideKey, Any]:
|
||||
def load_image_dict(sample: dict, level: int, margin: int, wsi_has_mask: bool = True) -> Dict[SlideKey, Any]:
|
||||
"""
|
||||
Load image from metadata dictionary
|
||||
:param sample: dict describing image metadata. Example:
|
||||
|
@ -40,7 +41,8 @@ def load_image_dict(sample: dict, level: int, margin: int) -> Dict[SlideKey, Any
|
|||
:param margin: margin to be included
|
||||
:return: a dict containing the image data and metadata
|
||||
"""
|
||||
loader = LoadPandaROId(WSIReader("cuCIM"), level=level, margin=margin)
|
||||
transform = LoadPandaROId if wsi_has_mask else LoadROId
|
||||
loader = transform(WSIReader("cuCIM"), level=level, margin=margin)
|
||||
img = loader(sample)
|
||||
return img
|
||||
|
||||
|
@ -76,7 +78,7 @@ def plot_panda_data_sample(
|
|||
def plot_scores_hist(
|
||||
results: Dict, prob_col: str = ResultsKey.CLASS_PROBS, gt_col: str = ResultsKey.TRUE_LABEL
|
||||
) -> plt.Figure:
|
||||
"""Plot scores as a historgram.
|
||||
"""Plot scores as a histogram.
|
||||
|
||||
:param results: List that contains slide_level dicts
|
||||
:param prob_col: column name that contains the scores
|
||||
|
@ -95,6 +97,18 @@ def plot_scores_hist(
|
|||
return fig
|
||||
|
||||
|
||||
def _get_histo_plot_title(case: str, slide_node: SlideNode) -> str:
|
||||
"""Return the standard title for histopathology plots.
|
||||
|
||||
:param case: case id e.g., TP, FN, FP, TN
|
||||
:param slide_node: SlideNode object that encapsulates the slide information
|
||||
"""
|
||||
return (
|
||||
f"{case}: {slide_node.slide_id} P={slide_node.pred_prob_score:.2f} \n Predicted label: {slide_node.pred_label} "
|
||||
f"True label: {slide_node.true_label}"
|
||||
)
|
||||
|
||||
|
||||
def plot_attention_tiles(
|
||||
case: str, slide_node: SlideNode, top: bool, num_columns: int, figsize: Tuple[int, int]
|
||||
) -> Optional[plt.Figure]:
|
||||
|
@ -116,11 +130,7 @@ def plot_attention_tiles(
|
|||
return None
|
||||
|
||||
fig, axs = plt.subplots(nrows=num_rows, ncols=num_columns, figsize=figsize)
|
||||
fig.suptitle(
|
||||
f"{case}: {slide_node.slide_id} P={abs(slide_node.prob_score):.2f} \n Predicted label: {slide_node.pred_label} "
|
||||
f"True label: {slide_node.true_label}"
|
||||
)
|
||||
|
||||
fig.suptitle(_get_histo_plot_title(case, slide_node))
|
||||
for ax, tile_node in zip(axs.flat, tile_nodes):
|
||||
ax.imshow(np.transpose(tile_node.data.numpy(), (1, 2, 0)), clim=(0, 255), cmap="gray")
|
||||
ax.set_title("%.6f" % tile_node.attn)
|
||||
|
@ -141,10 +151,7 @@ def plot_slide(case: str, slide_node: SlideNode, slide_image: np.ndarray, scale:
|
|||
fig, ax = plt.subplots()
|
||||
slide_image = slide_image.transpose(1, 2, 0)
|
||||
ax.imshow(slide_image)
|
||||
fig.suptitle(
|
||||
f"{case}: {slide_node.slide_id} P={abs(slide_node.prob_score):.2f} \n Predicted label: {slide_node.pred_label} "
|
||||
f"True label: {slide_node.true_label}"
|
||||
)
|
||||
fig.suptitle(_get_histo_plot_title(case, slide_node))
|
||||
ax.set_axis_off()
|
||||
original_size = fig.get_size_inches()
|
||||
fig.set_size_inches((original_size[0] * scale, original_size[1] * scale))
|
||||
|
@ -173,10 +180,8 @@ def plot_heatmap_overlay(
|
|||
:return: matplotlib figure of the heatmap of the given tiles on slide.
|
||||
"""
|
||||
fig, ax = plt.subplots()
|
||||
fig.suptitle(
|
||||
f"{case}: {slide_node.slide_id} P={abs(slide_node.prob_score):.2f} \n Predicted label: {slide_node.pred_label} "
|
||||
f"True label: {slide_node.true_label}"
|
||||
)
|
||||
fig.suptitle(_get_histo_plot_title(case, slide_node))
|
||||
|
||||
slide_image = slide_image.transpose(1, 2, 0)
|
||||
ax.imshow(slide_image)
|
||||
ax.set_xlim(0, slide_image.shape[1])
|
||||
|
|
|
@ -16,7 +16,10 @@ def image_collate(batch: List) -> Any:
|
|||
|
||||
for i, item in enumerate(batch):
|
||||
data = item[0]
|
||||
data[SlideKey.IMAGE] = torch.tensor(np.array([ix[SlideKey.IMAGE] for ix in item]))
|
||||
if isinstance(data[SlideKey.IMAGE], torch.Tensor):
|
||||
data[SlideKey.IMAGE] = torch.stack([ix[SlideKey.IMAGE] for ix in item], dim=0)
|
||||
else:
|
||||
data[SlideKey.IMAGE] = torch.tensor(np.array([ix[SlideKey.IMAGE] for ix in item]))
|
||||
data[SlideKey.LABEL] = torch.tensor(data[SlideKey.LABEL])
|
||||
batch[i] = data
|
||||
return multibag_collate(batch)
|
||||
|
|
|
@ -36,7 +36,7 @@ from SSL.configs.CXR_SSL_configs import CXRImageClassifier, NIH_RSNA_SimCLR
|
|||
|
||||
from health_ml.runner import Runner
|
||||
from health_ml.utils import AzureMLProgressBar
|
||||
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME
|
||||
from health_ml.utils.fixed_paths import repository_root_directory, OutputFolderForTests
|
||||
from health_ml.utils.lightning_loggers import StoringLogger
|
||||
|
||||
|
@ -247,7 +247,7 @@ def test_ssl_container_rsna() -> None:
|
|||
_compare_stored_metrics(runner, expected_metrics)
|
||||
|
||||
# Check that we are able to load the checkpoint and create classifier model
|
||||
checkpoint_path = loaded_config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
checkpoint_path = loaded_config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME
|
||||
model_namespace_cxr = "SSL.configs.CXRImageClassifier"
|
||||
args = common_test_args + [f"--model={model_namespace_cxr}",
|
||||
f"--local_datasets={str(path_to_cxr_test_dataset)}",
|
||||
|
|
|
@ -8,11 +8,10 @@ import logging
|
|||
import shutil
|
||||
import sys
|
||||
import uuid
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
# temporary workaround until these hi-ml package release
|
||||
testhisto_root_dir = Path(__file__).parent
|
||||
print(f"Adding {testhisto_root_dir} to sys path")
|
||||
|
@ -29,6 +28,9 @@ for package, subpackages in packages.items():
|
|||
sys.path.insert(0, str(himl_root / package / subpackage))
|
||||
|
||||
from health_ml.utils.fixed_paths import OutputFolderForTests # noqa: E402
|
||||
from testhisto.mocks.base_data_generator import MockHistoDataType # noqa: E402
|
||||
from testhisto.mocks.tiles_generator import MockPandaTilesGenerator # noqa: E402
|
||||
from testhisto.mocks.slides_generator import MockPandaSlidesGenerator, TilesPositioningType # noqa: E402
|
||||
|
||||
|
||||
def remove_and_create_folder(folder: Path) -> None:
|
||||
|
@ -73,3 +75,47 @@ def tmp_path_to_pathmnist_dataset(tmp_path_factory: pytest.TempPathFactory) -> G
|
|||
download_azure_dataset(tmp_dir, dataset_id=MockHistoDataType.PATHMNIST.value)
|
||||
yield tmp_dir
|
||||
shutil.rmtree(tmp_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_panda_tiles_root_dir(
|
||||
tmp_path_factory: pytest.TempPathFactory, tmp_path_to_pathmnist_dataset: Path
|
||||
) -> Generator:
|
||||
tmp_root_dir = tmp_path_factory.mktemp("mock_tiles")
|
||||
tiles_generator = MockPandaTilesGenerator(
|
||||
dest_data_path=tmp_root_dir,
|
||||
src_data_path=tmp_path_to_pathmnist_dataset,
|
||||
mock_type=MockHistoDataType.PATHMNIST,
|
||||
n_tiles=4,
|
||||
n_slides=15,
|
||||
n_channels=3,
|
||||
tile_size=28,
|
||||
img_size=224,
|
||||
)
|
||||
logging.info("Generating temporary mock tiles that will be deleted at the end of the session.")
|
||||
tiles_generator.generate_mock_histo_data()
|
||||
yield tmp_root_dir
|
||||
shutil.rmtree(tmp_root_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_panda_slides_root_dir(
|
||||
tmp_path_factory: pytest.TempPathFactory, tmp_path_to_pathmnist_dataset: Path
|
||||
) -> Generator:
|
||||
tmp_root_dir = tmp_path_factory.mktemp("mock_slides")
|
||||
wsi_generator = MockPandaSlidesGenerator(
|
||||
dest_data_path=tmp_root_dir,
|
||||
src_data_path=tmp_path_to_pathmnist_dataset,
|
||||
mock_type=MockHistoDataType.PATHMNIST,
|
||||
n_tiles=4,
|
||||
n_slides=15,
|
||||
n_channels=3,
|
||||
n_levels=3,
|
||||
tile_size=28,
|
||||
background_val=255,
|
||||
tiles_pos_type=TilesPositioningType.RANDOM
|
||||
)
|
||||
logging.info("Generating temporary mock slides that will be deleted at the end of the session.")
|
||||
wsi_generator.generate_mock_histo_data()
|
||||
yield tmp_root_dir
|
||||
shutil.rmtree(tmp_root_dir)
|
||||
|
|
До Ширина: | Высота: | Размер: 310 KiB После Ширина: | Высота: | Размер: 311 KiB |
Двоичные данные
hi-ml-cpath/testhisto/test_data/histo_heatmaps/slide_0.1.png
До Ширина: | Высота: | Размер: 474 KiB |
После Ширина: | Высота: | Размер: 474 KiB |
Двоичные данные
hi-ml-cpath/testhisto/test_data/histo_heatmaps/slide_1.2.png
До Ширина: | Высота: | Размер: 474 KiB |
После Ширина: | Высота: | Размер: 474 KiB |
Двоичные данные
hi-ml-cpath/testhisto/test_data/histo_heatmaps/slide_2.4.png
До Ширина: | Высота: | Размер: 474 KiB |
После Ширина: | Высота: | Размер: 474 KiB |
Двоичные данные
hi-ml-cpath/testhisto/test_data/histo_heatmaps/slide_3.6.png
До Ширина: | Высота: | Размер: 474 KiB |
После Ширина: | Высота: | Размер: 474 KiB |
До Ширина: | Высота: | Размер: 1.1 MiB После Ширина: | Высота: | Размер: 1.1 MiB |
Двоичные данные
hi-ml-cpath/testhisto/test_data/top_bottom_tiles/slide_0_top.png
До Ширина: | Высота: | Размер: 1.1 MiB После Ширина: | Высота: | Размер: 1.1 MiB |
|
@ -0,0 +1,184 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import pytest
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler
|
||||
|
||||
from health_cpath.datamodules.base_module import HistoDataModule
|
||||
from health_cpath.datamodules.panda_module import PandaSlidesDataModule, PandaTilesDataModule
|
||||
from health_cpath.utils.naming import ModelKey, SlideKey
|
||||
from health_ml.utils.common_utils import is_gpu_available
|
||||
from testhisto.utils.utils_testhisto import run_distributed
|
||||
|
||||
|
||||
no_gpu = not is_gpu_available()
|
||||
|
||||
|
||||
def _assert_correct_bag_sizes(datamodule: HistoDataModule, max_bag_size: int, max_bag_size_inf: Optional[int],
|
||||
true_bag_sizes: List[int]) -> None:
|
||||
# True bag sizes are the bag sizes that are generated by the mock data generator for a fixed seed as the tiles count
|
||||
# (and therefore bag sizes) are random to reflect real data with varying number of tiles per slide.
|
||||
for stage_key, bag_size in zip([m for m in ModelKey], [max_bag_size, max_bag_size_inf, max_bag_size_inf]):
|
||||
assert datamodule.bag_sizes[stage_key] == bag_size
|
||||
|
||||
def _assert_bag_size_matching(dataloader: DataLoader, expected_bag_sizes: List[int]) -> None:
|
||||
sample = next(iter(dataloader))
|
||||
for i, slide in enumerate(sample[SlideKey.IMAGE]):
|
||||
assert slide.shape[0] == expected_bag_sizes[i]
|
||||
|
||||
_assert_bag_size_matching(datamodule.train_dataloader(), [max_bag_size, max_bag_size])
|
||||
expected_bag_sizes = true_bag_sizes if not max_bag_size_inf else [max_bag_size_inf, max_bag_size_inf]
|
||||
_assert_bag_size_matching(datamodule.val_dataloader(), expected_bag_sizes)
|
||||
_assert_bag_size_matching(datamodule.test_dataloader(), expected_bag_sizes)
|
||||
|
||||
|
||||
def _assert_correct_batch_sizes(datamodule: HistoDataModule, batch_size: int, batch_size_inf: Optional[int]) -> None:
|
||||
batch_size_inf = batch_size_inf if batch_size_inf is not None else batch_size
|
||||
for stage_key, _batch_size in zip([m for m in ModelKey], [batch_size, batch_size_inf, batch_size_inf]):
|
||||
assert datamodule.batch_sizes[stage_key] == _batch_size
|
||||
|
||||
def _assert_batch_size_matching(dataloader: DataLoader, expected_batch_size: int) -> None:
|
||||
sample = next(iter(dataloader))
|
||||
assert len(sample[SlideKey.IMAGE]) == expected_batch_size
|
||||
|
||||
_assert_batch_size_matching(datamodule.train_dataloader(), batch_size)
|
||||
_assert_batch_size_matching(datamodule.val_dataloader(), batch_size_inf)
|
||||
_assert_batch_size_matching(datamodule.test_dataloader(), batch_size_inf)
|
||||
|
||||
|
||||
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.parametrize("max_bag_size, max_bag_size_inf", [(2, 0), (2, 3)])
|
||||
def test_slides_datamodule_different_bag_sizes(
|
||||
mock_panda_slides_root_dir: Path, max_bag_size: int, max_bag_size_inf: int
|
||||
) -> None:
|
||||
datamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir,
|
||||
batch_size=2,
|
||||
max_bag_size=max_bag_size,
|
||||
max_bag_size_inf=max_bag_size_inf,
|
||||
tile_size=28,
|
||||
level=0,
|
||||
)
|
||||
# To account for the fact that slides datamodule fomats 0 to None so that it's compatible with TileOnGrid transform
|
||||
max_bag_size_inf = max_bag_size_inf if max_bag_size_inf != 0 else None # type: ignore
|
||||
# For slides datamodule, the true bag sizes [4, 4] are the same as requested to TileOnGrid transform
|
||||
_assert_correct_bag_sizes(datamodule, max_bag_size, max_bag_size_inf, true_bag_sizes=[4, 4])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_bag_size, max_bag_size_inf", [(2, 0), (2, 3)])
|
||||
def test_tiles_datamodule_different_bag_sizes(
|
||||
mock_panda_tiles_root_dir: Path, max_bag_size: int, max_bag_size_inf: int
|
||||
) -> None:
|
||||
datamodule = PandaTilesDataModule(
|
||||
root_path=mock_panda_tiles_root_dir,
|
||||
batch_size=2,
|
||||
max_bag_size=max_bag_size,
|
||||
max_bag_size_inf=max_bag_size_inf,
|
||||
)
|
||||
# For tiles datamodule, the true bag sizes [4, 5] were generated by the mock data generator for a fixed seed 42
|
||||
# If test fails, check if the mock data generator has changed and update the true bag sizes
|
||||
_assert_correct_bag_sizes(datamodule, max_bag_size, max_bag_size_inf, true_bag_sizes=[4, 5])
|
||||
|
||||
|
||||
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.parametrize("batch_size, batch_size_inf", [(2, 2), (2, 1), (2, None)])
|
||||
def test_slides_datamodule_different_batch_sizes(
|
||||
mock_panda_slides_root_dir: Path, batch_size: int, batch_size_inf: Optional[int],
|
||||
) -> None:
|
||||
datamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir,
|
||||
batch_size=batch_size,
|
||||
batch_size_inf=batch_size_inf,
|
||||
max_bag_size=16,
|
||||
max_bag_size_inf=16,
|
||||
tile_size=28,
|
||||
level=0,
|
||||
)
|
||||
_assert_correct_batch_sizes(datamodule, batch_size, batch_size_inf)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size, batch_size_inf", [(2, 2), (2, 1), (2, None)])
|
||||
def test_tiles_datamodule_different_batch_sizes(
|
||||
mock_panda_tiles_root_dir: Path, batch_size: int, batch_size_inf: Optional[int],
|
||||
) -> None:
|
||||
datamodule = PandaTilesDataModule(
|
||||
root_path=mock_panda_tiles_root_dir,
|
||||
batch_size=batch_size,
|
||||
batch_size_inf=batch_size_inf,
|
||||
max_bag_size=16,
|
||||
max_bag_size_inf=16,
|
||||
)
|
||||
_assert_correct_batch_sizes(datamodule, batch_size, batch_size_inf)
|
||||
|
||||
|
||||
def _validate_sampler_type(datamodule: HistoDataModule, stages: List[ModelKey], expected_none: bool) -> None:
|
||||
expected_sampler_types = {
|
||||
ModelKey.TRAIN: RandomSampler if expected_none else DistributedSampler,
|
||||
ModelKey.VAL: SequentialSampler,
|
||||
ModelKey.TEST: SequentialSampler,
|
||||
}
|
||||
for stage in stages:
|
||||
datamodule_sampler = datamodule._get_ddp_sampler(getattr(datamodule, f'{stage}_dataset'), stage)
|
||||
assert (datamodule_sampler is None) == expected_none
|
||||
dataloader = getattr(datamodule, f'{stage.value}_dataloader')()
|
||||
assert isinstance(dataloader.sampler, expected_sampler_types[stage])
|
||||
|
||||
|
||||
def _test_datamodule_pl_ddp_sampler_true(
|
||||
datamodule: HistoDataModule, rank: int = 0, world_size: int = 1, device: str = "cpu"
|
||||
) -> None:
|
||||
datamodule.setup()
|
||||
_validate_sampler_type(datamodule, [ModelKey.TRAIN, ModelKey.VAL, ModelKey.TEST], expected_none=True)
|
||||
|
||||
|
||||
def _test_datamodule_pl_ddp_sampler_false(
|
||||
datamodule: HistoDataModule, rank: int = 0, world_size: int = 1, device: str = "cpu"
|
||||
) -> None:
|
||||
datamodule.setup()
|
||||
_validate_sampler_type(datamodule, [ModelKey.VAL, ModelKey.TEST], expected_none=True)
|
||||
_validate_sampler_type(datamodule, [ModelKey.TRAIN], expected_none=False)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.distributed.is_available(), reason="PyTorch distributed unavailable")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
|
||||
@pytest.mark.gpu
|
||||
def test_slides_datamodule_pl_replace_sampler_ddp(mock_panda_slides_root_dir: Path) -> None:
|
||||
slides_datamodule = PandaSlidesDataModule(root_path=mock_panda_slides_root_dir,
|
||||
pl_replace_sampler_ddp=True,
|
||||
seed=42)
|
||||
run_distributed(_test_datamodule_pl_ddp_sampler_true, [slides_datamodule], world_size=2)
|
||||
slides_datamodule.pl_replace_sampler_ddp = False
|
||||
run_distributed(_test_datamodule_pl_ddp_sampler_false, [slides_datamodule], world_size=2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.distributed.is_available(), reason="PyTorch distributed unavailable")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
|
||||
@pytest.mark.gpu
|
||||
def test_tiles_datamodule_pl_replace_sampler_ddp_true(mock_panda_tiles_root_dir: Path) -> None:
|
||||
tiles_datamodule = PandaTilesDataModule(root_path=mock_panda_tiles_root_dir, seed=42, pl_replace_sampler_ddp=True)
|
||||
run_distributed(_test_datamodule_pl_ddp_sampler_true, [tiles_datamodule], world_size=2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.distributed.is_available(), reason="PyTorch distributed unavailable")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
|
||||
@pytest.mark.gpu
|
||||
def test_tiles_datamodule_pl_replace_sampler_ddp_false(mock_panda_tiles_root_dir: Path) -> None:
|
||||
tiles_datamodule = PandaTilesDataModule(root_path=mock_panda_tiles_root_dir, seed=42, pl_replace_sampler_ddp=False)
|
||||
run_distributed(_test_datamodule_pl_ddp_sampler_false, [tiles_datamodule], world_size=2)
|
||||
|
||||
|
||||
def test_assertion_error_missing_seed(mock_panda_slides_root_dir: Path) -> None:
|
||||
with pytest.raises(AssertionError, match="seed must be set when using distributed training for reproducibility"):
|
||||
with patch("torch.distributed.is_initialized", return_value=True):
|
||||
with patch("torch.distributed.get_world_size", return_value=2):
|
||||
slides_datamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir, pl_replace_sampler_ddp=False
|
||||
)
|
||||
slides_datamodule._get_ddp_sampler(MagicMock(), ModelKey.TRAIN)
|
|
@ -1,17 +1,17 @@
|
|||
|
||||
|
||||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import shutil
|
||||
from typing import Generator, Dict, Callable, Union, Tuple
|
||||
import pytest
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from monai.transforms import RandFlipd
|
||||
from typing import Generator, Dict, Callable, Union, Tuple
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from health_ml.utils.common_utils import is_gpu_available
|
||||
from health_cpath.datamodules.base_module import SlidesDataModule
|
||||
|
@ -29,7 +29,7 @@ no_gpu = not is_gpu_available()
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_panda_slides_root_dir(
|
||||
def mock_panda_slides_root_dir_diagonal(
|
||||
tmp_path_factory: pytest.TempPathFactory, tmp_path_to_pathmnist_dataset: Path
|
||||
) -> Generator:
|
||||
tmp_root_dir = tmp_path_factory.mktemp("mock_wsi")
|
||||
|
@ -38,7 +38,7 @@ def mock_panda_slides_root_dir(
|
|||
src_data_path=tmp_path_to_pathmnist_dataset,
|
||||
mock_type=MockHistoDataType.PATHMNIST,
|
||||
n_tiles=1,
|
||||
n_slides=10,
|
||||
n_slides=16,
|
||||
n_repeat_diag=4,
|
||||
n_repeat_tile=2,
|
||||
n_channels=3,
|
||||
|
@ -83,7 +83,7 @@ def get_original_tile(mock_dir: Path, wsi_id: str) -> np.ndarray:
|
|||
|
||||
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
||||
@pytest.mark.gpu
|
||||
def test_tiling_on_the_fly(mock_panda_slides_root_dir: Path) -> None:
|
||||
def test_tiling_on_the_fly(mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||
batch_size = 1
|
||||
tile_count = 16
|
||||
tile_size = 28
|
||||
|
@ -91,7 +91,7 @@ def test_tiling_on_the_fly(mock_panda_slides_root_dir: Path) -> None:
|
|||
channels = 3
|
||||
assert_batch_index = 0
|
||||
datamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir,
|
||||
root_path=mock_panda_slides_root_dir_diagonal,
|
||||
batch_size=batch_size,
|
||||
max_bag_size=tile_count,
|
||||
tile_size=tile_size,
|
||||
|
@ -105,14 +105,14 @@ def test_tiling_on_the_fly(mock_panda_slides_root_dir: Path) -> None:
|
|||
assert tiles[assert_batch_index].shape == (tile_count, channels, tile_size, tile_size)
|
||||
|
||||
# check tiling on the fly
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
|
||||
for i in range(tile_count):
|
||||
assert (original_tile == tiles[assert_batch_index][i].numpy()).all()
|
||||
|
||||
|
||||
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
||||
@pytest.mark.gpu
|
||||
def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir: Path) -> None:
|
||||
def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||
batch_size = 1
|
||||
tile_count = None
|
||||
tile_size = 28
|
||||
|
@ -120,7 +120,7 @@ def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir: Path) -> No
|
|||
assert_batch_index = 0
|
||||
min_expected_tile_count = 16
|
||||
datamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir,
|
||||
root_path=mock_panda_slides_root_dir_diagonal,
|
||||
batch_size=batch_size,
|
||||
max_bag_size=tile_count,
|
||||
tile_size=tile_size,
|
||||
|
@ -135,14 +135,14 @@ def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir: Path) -> No
|
|||
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.parametrize("level", [0, 1, 2])
|
||||
def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir: Path) -> None:
|
||||
def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||
batch_size = 1
|
||||
tile_count = 16
|
||||
channels = 3
|
||||
tile_size = 28 // 2 ** level
|
||||
assert_batch_index = 0
|
||||
datamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir,
|
||||
root_path=mock_panda_slides_root_dir_diagonal,
|
||||
batch_size=batch_size,
|
||||
max_bag_size=tile_count,
|
||||
tile_size=tile_size,
|
||||
|
@ -155,7 +155,7 @@ def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir: Path) -
|
|||
assert tiles[assert_batch_index].shape == (tile_count, channels, tile_size, tile_size)
|
||||
|
||||
# check tiling on the fly at different resolutions
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
|
||||
for i in range(tile_count):
|
||||
# multi resolution mock data has been created via 2 factor downsampling
|
||||
assert (original_tile[:, :: 2 ** level, :: 2 ** level] == tiles[assert_batch_index][i].numpy()).all()
|
||||
|
@ -164,7 +164,7 @@ def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir: Path) -
|
|||
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||
def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir: Path) -> None:
|
||||
def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||
tile_size = 28
|
||||
level = 0
|
||||
step = 14
|
||||
|
@ -172,7 +172,7 @@ def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir: Path) ->
|
|||
min_expected_tile_count = 32
|
||||
assert_batch_index = 0
|
||||
datamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir,
|
||||
root_path=mock_panda_slides_root_dir_diagonal,
|
||||
max_bag_size=None,
|
||||
batch_size=batch_size,
|
||||
tile_size=tile_size,
|
||||
|
@ -184,7 +184,7 @@ def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir: Path) ->
|
|||
tiles, wsi_id = sample[SlideKey.IMAGE], sample[SlideKey.SLIDE_ID][assert_batch_index]
|
||||
assert tiles[assert_batch_index].shape[0] >= min_expected_tile_count
|
||||
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
|
||||
tile_matches = 0
|
||||
for _, tile in enumerate(tiles[assert_batch_index]):
|
||||
tile_matches += int((tile.numpy() == original_tile).all())
|
||||
|
@ -193,12 +193,12 @@ def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir: Path) ->
|
|||
|
||||
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
||||
@pytest.mark.gpu
|
||||
def test_train_test_transforms(mock_panda_slides_root_dir: Path) -> None:
|
||||
def test_train_test_transforms(mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||
def get_transforms_dict() -> Dict[ModelKey, Union[Callable, None]]:
|
||||
train_transform = RandFlipd(keys=[SlideKey.IMAGE], spatial_axis=0, prob=1.0)
|
||||
return {ModelKey.TRAIN: train_transform, ModelKey.VAL: None, ModelKey.TEST: None} # type: ignore
|
||||
|
||||
def retrieve_tiles(dataloader: torch.utils.data.DataLoader) -> Dict[str, torch.Tensor]:
|
||||
def retrieve_tiles(dataloader: DataLoader) -> Dict[str, torch.Tensor]:
|
||||
tiles_dict = {}
|
||||
assert_batch_index = 0
|
||||
for sample in dataloader:
|
||||
|
@ -211,7 +211,7 @@ def test_train_test_transforms(mock_panda_slides_root_dir: Path) -> None:
|
|||
tile_size = 28
|
||||
level = 0
|
||||
flipdatamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir,
|
||||
root_path=mock_panda_slides_root_dir_diagonal,
|
||||
batch_size=batch_size,
|
||||
max_bag_size=tile_count,
|
||||
max_bag_size_inf=0,
|
||||
|
@ -224,20 +224,20 @@ def test_train_test_transforms(mock_panda_slides_root_dir: Path) -> None:
|
|||
flip_test_tiles = retrieve_tiles(flipdatamodule.test_dataloader())
|
||||
|
||||
for wsi_id in flip_train_tiles.keys():
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
|
||||
# the first dimension is the channel, flipping happened on the horizontal axis of the image
|
||||
transformed_original_tile = np.flip(original_tile, axis=1)
|
||||
for tile in flip_train_tiles[wsi_id]:
|
||||
assert (tile.numpy() == transformed_original_tile).all()
|
||||
|
||||
for wsi_id in flip_val_tiles.keys():
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
|
||||
for tile in flip_val_tiles[wsi_id]:
|
||||
# no transformation has been applied to val tiles
|
||||
assert (tile.numpy() == original_tile).all()
|
||||
|
||||
for wsi_id in flip_test_tiles.keys():
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
|
||||
for tile in flip_test_tiles[wsi_id]:
|
||||
# no transformation has been applied to test tiles
|
||||
assert (tile.numpy() == original_tile).all()
|
||||
|
@ -251,7 +251,6 @@ class MockPandaSlidesDataModule(SlidesDataModule):
|
|||
"""
|
||||
|
||||
def get_splits(self) -> Tuple[PandaDataset, PandaDataset, PandaDataset]:
|
||||
|
||||
return (PandaDataset(self.root_path), PandaDataset(self.root_path), PandaDataset(self.root_path))
|
||||
|
||||
|
||||
|
@ -278,7 +277,7 @@ def test_whole_slide_inference(batch_size: int, mock_panda_slides_root_with_diff
|
|||
tiles = sample[SlideKey.IMAGE]
|
||||
assert tiles[assert_batch_index].shape[0] == tile_count
|
||||
|
||||
def assert_whole_slide_inference_with_all_tiles(dataloader: torch.utils.data.DataLoader) -> None:
|
||||
def assert_whole_slide_inference_with_all_tiles(dataloader: DataLoader) -> None:
|
||||
for i, sample in enumerate(dataloader):
|
||||
tiles = sample[SlideKey.IMAGE]
|
||||
assert tiles[assert_batch_index].shape[0] == n_tiles_list[i * batch_size]
|
||||
|
|
|
@ -14,6 +14,7 @@ import torch
|
|||
from pytorch_lightning import seed_everything
|
||||
|
||||
from health_cpath.configs.classification.DeepSMILESlidesPandaBenchmark import DeepSMILESlidesPandaBenchmark
|
||||
from health_cpath.datasets.panda_dataset import PandaDataset
|
||||
from health_cpath.utils.naming import SlideKey
|
||||
from testhisto.mocks.base_data_generator import MockHistoDataType
|
||||
from testhisto.mocks.slides_generator import MockPandaSlidesGenerator, TilesPositioningType
|
||||
|
@ -89,3 +90,21 @@ def test_panda_reproducibility(tmp_path: Path) -> None:
|
|||
# When using a fixed see, all 3 dataloaders should return idential items. Validation and test dataloader
|
||||
# are at present not randomized, but checking those as well just in case.
|
||||
test_data_items_are_equal(["train_dataloader", "val_dataloader", "test_dataloader"])
|
||||
|
||||
|
||||
def test_validate_columns(tmp_path: Path) -> None:
|
||||
_ = MockPandaSlidesGenerator(
|
||||
dest_data_path=tmp_path,
|
||||
mock_type=MockHistoDataType.FAKE,
|
||||
n_tiles=4,
|
||||
n_slides=10,
|
||||
n_channels=3,
|
||||
n_levels=3,
|
||||
tile_size=28,
|
||||
background_val=255,
|
||||
tiles_pos_type=TilesPositioningType.RANDOM,
|
||||
)
|
||||
usecols = [PandaDataset.SLIDE_ID_COLUMN, PandaDataset.MASK_COLUMN]
|
||||
with pytest.raises(ValueError, match=r"Expected columns"):
|
||||
_ = PandaDataset(root=tmp_path, dataframe_kwargs={"usecols": usecols})
|
||||
_ = PandaDataset(root=tmp_path, dataframe_kwargs={"usecols": usecols + [PandaDataset.METADATA_COLUMNS[1]]})
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
from pathlib import Path
|
||||
import pytest
|
||||
import pandas as pd
|
||||
|
||||
from health_cpath.datasets.panda_tiles_dataset import PandaTilesDataset
|
||||
from health_cpath.utils.naming import TileKey
|
||||
from testhisto.mocks.base_data_generator import MockHistoDataType
|
||||
from testhisto.mocks.tiles_generator import MockPandaTilesGenerator
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tiling_version", [0, 1])
|
||||
def test_panda_tiles_dataset(tiling_version: int, tmp_path: Path) -> None:
|
||||
"""Test the PandaTilesDataset class.
|
||||
|
||||
:param tiling_version: The version of the tiles dataset, defaults to 0. This is used to support both the old
|
||||
and new tiling scheme where coordinates are stored as tile_x and tile_y in v0 and as tile_left and tile_top
|
||||
in v1.
|
||||
:param tmp_path: The temporary path where to store the mock dataset.
|
||||
"""
|
||||
_ = MockPandaTilesGenerator(
|
||||
dest_data_path=tmp_path,
|
||||
mock_type=MockHistoDataType.FAKE,
|
||||
n_tiles=4,
|
||||
n_slides=10,
|
||||
n_channels=3,
|
||||
tile_size=28,
|
||||
img_size=224,
|
||||
tiling_version=tiling_version,
|
||||
)
|
||||
base_df = pd.read_csv(tmp_path / PandaTilesDataset.DEFAULT_CSV_FILENAME).set_index(PandaTilesDataset.TILE_ID_COLUMN)
|
||||
dataset = PandaTilesDataset(root=tmp_path)
|
||||
|
||||
coordinates_columns_v0 = {PandaTilesDataset.TILE_X_COLUMN, PandaTilesDataset.TILE_Y_COLUMN}
|
||||
coordinates_columns_v1 = {TileKey.TILE_LEFT, TileKey.TILE_TOP}
|
||||
dataset_columns = set(dataset.dataset_df.columns)
|
||||
base_df_columns = set(base_df.columns)
|
||||
|
||||
assert coordinates_columns_v0.issubset(dataset_columns) # v0 columns are always present
|
||||
|
||||
if tiling_version == 0:
|
||||
assert coordinates_columns_v0.issubset(dataset_columns)
|
||||
assert not coordinates_columns_v1.issubset(dataset_columns)
|
||||
assert base_df_columns == dataset_columns
|
||||
elif tiling_version == 1:
|
||||
assert coordinates_columns_v1.issubset(dataset_columns)
|
||||
assert not coordinates_columns_v0.issubset(base_df_columns)
|
||||
assert dataset_columns == base_df_columns.union(coordinates_columns_v0)
|
|
@ -8,23 +8,22 @@ from typing import Any, Optional, Set
|
|||
from health_ml.networks.layers.attention_layers import AttentionLayer
|
||||
from health_cpath.configs.classification.DeepSMILEPanda import DeepSMILESlidesPanda, DeepSMILETilesPanda
|
||||
from health_cpath.datasets.panda_dataset import PandaDataset
|
||||
from health_cpath.models.encoders import ImageNetEncoder
|
||||
from health_cpath.models.encoders import Resnet18
|
||||
from health_cpath.datamodules.base_module import CacheMode, CacheLocation
|
||||
from health_cpath.utils.naming import PlotOption
|
||||
|
||||
|
||||
class MockDeepSMILETilesPanda(DeepSMILETilesPanda):
|
||||
def __init__(self, tmp_path: Path, **kwargs: Any) -> None:
|
||||
def __init__(self, tmp_path: Path, analyse_loss: bool = False, **kwargs: Any) -> None:
|
||||
default_kwargs = dict(
|
||||
# Model parameters:
|
||||
pool_type=AttentionLayer.__name__,
|
||||
pool_hidden_dim=16,
|
||||
num_transformer_pool_layers=1,
|
||||
num_transformer_pool_heads=1,
|
||||
is_finetune=False,
|
||||
class_names=["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"],
|
||||
# Encoder parameters
|
||||
encoder_type=ImageNetEncoder.__name__,
|
||||
encoder_type=Resnet18.__name__,
|
||||
tile_size=28,
|
||||
# Data Module parameters
|
||||
batch_size=2,
|
||||
|
@ -38,6 +37,8 @@ class MockDeepSMILETilesPanda(DeepSMILETilesPanda):
|
|||
# declared in TrainerParams:
|
||||
max_epochs=2,
|
||||
crossval_count=1,
|
||||
ssl_checkpoint=None,
|
||||
analyse_loss=analyse_loss,
|
||||
)
|
||||
default_kwargs.update(kwargs)
|
||||
super().__init__(**default_kwargs)
|
||||
|
@ -62,10 +63,9 @@ class MockDeepSMILESlidesPanda(DeepSMILESlidesPanda):
|
|||
pool_hidden_dim=16,
|
||||
num_transformer_pool_layers=1,
|
||||
num_transformer_pool_heads=1,
|
||||
is_finetune=True,
|
||||
class_names=["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"],
|
||||
# Encoder parameters
|
||||
encoder_type=ImageNetEncoder.__name__,
|
||||
encoder_type=Resnet18.__name__,
|
||||
tile_size=28,
|
||||
# Data Module parameters
|
||||
batch_size=2,
|
||||
|
|
|
@ -79,9 +79,12 @@ class MockPandaSlidesGenerator(MockHistoDataGenerator):
|
|||
def create_mock_metadata_dataframe(self) -> pd.DataFrame:
|
||||
"""Create a mock dataframe with random metadata."""
|
||||
isup_grades = np.tile(list(self.ISUP_GRADE_MAPPING.keys()), self.n_slides // PANDA_N_CLASSES + 1,)
|
||||
mock_metadata: dict = {col: [] for col in [PandaDataset.SLIDE_ID_COLUMN, *PandaDataset.METADATA_COLUMNS]}
|
||||
mock_metadata: dict = {
|
||||
col: [] for col in [PandaDataset.SLIDE_ID_COLUMN, PandaDataset.MASK_COLUMN, *PandaDataset.METADATA_COLUMNS]
|
||||
}
|
||||
for slide_id in range(self.n_slides):
|
||||
mock_metadata[PandaDataset.SLIDE_ID_COLUMN].append(f"_{slide_id}")
|
||||
mock_metadata[PandaDataset.MASK_COLUMN].append(f"_{slide_id}_mask")
|
||||
mock_metadata[self.DATA_PROVIDER].append(np.random.choice(self.DATA_PROVIDERS_VALUES))
|
||||
mock_metadata[self.ISUP_GRADE].append(isup_grades[slide_id])
|
||||
mock_metadata[self.GLEASON_SCORE].append(np.random.choice(self.ISUP_GRADE_MAPPING[isup_grades[slide_id]]))
|
||||
|
|
|
@ -10,18 +10,23 @@ import torch
|
|||
from torchvision.utils import save_image
|
||||
|
||||
from health_cpath.datasets.panda_tiles_dataset import PandaTilesDataset
|
||||
from health_cpath.utils.naming import TileKey
|
||||
from testhisto.mocks.base_data_generator import MockHistoDataGenerator, MockHistoDataType, PANDA_N_CLASSES
|
||||
|
||||
|
||||
class MockPandaTilesGenerator(MockHistoDataGenerator):
|
||||
"""Generator class to create mock tiles dataset on the fly. The tiles are positioned randomly in a wsi grid."""
|
||||
|
||||
def __init__(self, img_size: int = 224, **kwargs: Any) -> None:
|
||||
def __init__(self, img_size: int = 224, tiling_version: int = 0, **kwargs: Any) -> None:
|
||||
"""
|
||||
:param img_size: The whole slide image resolution, defaults to 224.
|
||||
:param tiling_version: The version of the tiles dataset, defaults to 0. This is used to support both the old
|
||||
and new tiling scheme where coordinates are stored as tile_x and tile_y in v0 and as tile_left and tile_top
|
||||
in v1.
|
||||
:param kwargs: Same params passed to MockHistoDataGenerator.
|
||||
"""
|
||||
self.img_size = img_size
|
||||
self.tiling_version = tiling_version
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def validate(self) -> None:
|
||||
|
@ -32,16 +37,19 @@ class MockPandaTilesGenerator(MockHistoDataGenerator):
|
|||
f"The image of size {self.img_size} can't contain more than {(self.img_size // self.tile_size)**2} tiles."
|
||||
f"Choose a number of tiles 0 < n_tiles <= {(self.img_size // self.tile_size)**2} "
|
||||
)
|
||||
assert self.tiling_version in [0, 1], f"Tiling version should be 0 or 1, got {self.tiling_version}"
|
||||
|
||||
def create_mock_metadata_dataframe(self) -> pd.DataFrame:
|
||||
"""Create a mock dataframe with random metadata."""
|
||||
x_column = TileKey.TILE_LEFT if self.tiling_version == 1 else PandaTilesDataset.TILE_X_COLUMN
|
||||
y_column = TileKey.TILE_TOP if self.tiling_version == 1 else PandaTilesDataset.TILE_Y_COLUMN
|
||||
csv_columns = [
|
||||
PandaTilesDataset.SLIDE_ID_COLUMN,
|
||||
PandaTilesDataset.TILE_ID_COLUMN,
|
||||
PandaTilesDataset.IMAGE_COLUMN,
|
||||
self.MASK_COLUMN,
|
||||
PandaTilesDataset.TILE_X_COLUMN,
|
||||
PandaTilesDataset.TILE_Y_COLUMN,
|
||||
x_column,
|
||||
y_column,
|
||||
self.OCCUPANCY,
|
||||
self.DATA_PROVIDER,
|
||||
self.ISUP_GRADE,
|
||||
|
@ -81,8 +89,9 @@ class MockPandaTilesGenerator(MockHistoDataGenerator):
|
|||
f"_{slide_id}/train_images/{tile_x}x_{tile_y}y.png"
|
||||
)
|
||||
mock_metadata[self.MASK_COLUMN].append(f"_{slide_id}/train_label_masks/{tile_x}x_{tile_y}y_mask.png")
|
||||
mock_metadata[PandaTilesDataset.TILE_X_COLUMN].append(tile_x)
|
||||
mock_metadata[PandaTilesDataset.TILE_Y_COLUMN].append(tile_y)
|
||||
mock_metadata[x_column].append(tile_x)
|
||||
mock_metadata[y_column].append(tile_y)
|
||||
|
||||
mock_metadata[self.OCCUPANCY].append(1.0)
|
||||
mock_metadata[self.DATA_PROVIDER].append(data_provider)
|
||||
mock_metadata[self.ISUP_GRADE].append(isup_grade)
|
||||
|
@ -114,7 +123,7 @@ class MockPandaTilesGenerator(MockHistoDataGenerator):
|
|||
raise NotImplementedError
|
||||
|
||||
tile_filename = self.dest_data_path / row[PandaTilesDataset.IMAGE_COLUMN]
|
||||
save_image(tile.float(), tile_filename)
|
||||
save_image(tile.div(255.), tile_filename)
|
||||
random_mask = torch.randint(0, 256, size=(self.n_channels, self.tile_size, self.tile_size))
|
||||
mask_filename = self.dest_data_path / row[self.MASK_COLUMN]
|
||||
save_image(random_mask.float(), mask_filename)
|
||||
save_image(random_mask.div(255), mask_filename)
|
||||
|
|
|
@ -2,45 +2,61 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
import os
|
||||
import shutil
|
||||
from pytorch_lightning import Trainer
|
||||
import torch
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type
|
||||
|
||||
from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose
|
||||
from torch.utils.data._utils.collate import default_collate
|
||||
from health_cpath.configs.classification.DeepSMILESlidesPandaBenchmark import SlidesPandaSSLMILBenchmark
|
||||
from health_cpath.datamodules.panda_module import PandaTilesDataModule
|
||||
|
||||
from health_ml.networks.layers.attention_layers import AttentionLayer
|
||||
from health_cpath.configs.classification.BaseMIL import BaseMILTiles
|
||||
from health_ml.networks.layers.attention_layers import AttentionLayer, TransformerPoolingBenchmark
|
||||
from health_cpath.configs.classification.BaseMIL import BaseMIL, BaseMILTiles
|
||||
|
||||
from health_cpath.configs.classification.DeepSMILECrck import DeepSMILECrck
|
||||
from health_cpath.configs.classification.DeepSMILEPanda import BaseDeepSMILEPanda, DeepSMILETilesPanda
|
||||
from health_cpath.configs.classification.DeepSMILECrck import DeepSMILECrck, TcgaCrckSSLMIL
|
||||
from health_cpath.configs.classification.DeepSMILEPanda import (
|
||||
BaseDeepSMILEPanda, DeepSMILETilesPanda, SlidesPandaSSLMIL, TilesPandaSSLMIL
|
||||
)
|
||||
from health_cpath.datamodules.base_module import HistoDataModule, TilesDataModule
|
||||
from health_cpath.datasets.base_dataset import DEFAULT_LABEL_COLUMN, TilesDataset
|
||||
from health_cpath.datasets.default_paths import PANDA_5X_TILES_DATASET_ID, TCGA_CRCK_DATASET_DIR
|
||||
from health_cpath.models.deepmil import BaseDeepMILModule, TilesDeepMILModule
|
||||
from health_cpath.models.encoders import IdentityEncoder, ImageNetEncoder, TileEncoder
|
||||
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams
|
||||
from health_cpath.utils.naming import MetricsKey, ResultsKey
|
||||
from health_cpath.models.encoders import IdentityEncoder, ImageNetEncoder, Resnet18, TileEncoder
|
||||
from health_cpath.utils.deepmil_utils import ClassifierParams, EncoderParams, PoolingParams
|
||||
from health_cpath.utils.naming import DeepMILSubmodules, MetricsKey, ResultsKey
|
||||
from testhisto.mocks.base_data_generator import MockHistoDataType
|
||||
from testhisto.mocks.slides_generator import MockPandaSlidesGenerator, TilesPositioningType
|
||||
from testhisto.mocks.tiles_generator import MockPandaTilesGenerator
|
||||
from testhisto.mocks.container import MockDeepSMILETilesPanda, MockDeepSMILESlidesPanda
|
||||
from health_ml.utils.common_utils import is_gpu_available
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_crck_4ws, innereye_ssl_checkpoint_binary
|
||||
from testhisto.models.test_encoders import TEST_SSL_RUN_ID
|
||||
|
||||
no_gpu = not is_gpu_available()
|
||||
|
||||
|
||||
def get_supervised_imagenet_encoder_params() -> EncoderParams:
|
||||
return EncoderParams(encoder_type=ImageNetEncoder.__name__)
|
||||
def get_supervised_imagenet_encoder_params(tune_encoder: bool = True) -> EncoderParams:
|
||||
return EncoderParams(encoder_type=Resnet18.__name__, tune_encoder=tune_encoder)
|
||||
|
||||
|
||||
def get_attention_pooling_layer_params(pool_out_dim: int = 1) -> PoolingParams:
|
||||
return PoolingParams(pool_type=AttentionLayer.__name__, pool_out_dim=pool_out_dim, pool_hidden_dim=5)
|
||||
def get_attention_pooling_layer_params(pool_out_dim: int = 1, tune_pooling: bool = True) -> PoolingParams:
|
||||
return PoolingParams(pool_type=AttentionLayer.__name__, pool_out_dim=pool_out_dim, pool_hidden_dim=5,
|
||||
tune_pooling=tune_pooling)
|
||||
|
||||
|
||||
def get_transformer_pooling_layer_params(num_layers: int, num_heads: int,
|
||||
hidden_dim: int, transformer_dropout: float) -> PoolingParams:
|
||||
return PoolingParams(pool_type=TransformerPoolingBenchmark.__name__,
|
||||
num_transformer_pool_layers=num_layers,
|
||||
num_transformer_pool_heads=num_heads,
|
||||
pool_hidden_dim=hidden_dim,
|
||||
transformer_dropout=transformer_dropout)
|
||||
|
||||
|
||||
def _test_lightningmodule(
|
||||
|
@ -56,7 +72,7 @@ def _test_lightningmodule(
|
|||
module = TilesDeepMILModule(
|
||||
label_column=DEFAULT_LABEL_COLUMN,
|
||||
n_classes=n_classes,
|
||||
dropout_rate=dropout_rate,
|
||||
classifier_params=ClassifierParams(dropout_rate=dropout_rate),
|
||||
encoder_params=get_supervised_imagenet_encoder_params(),
|
||||
pooling_params=get_attention_pooling_layer_params(pool_out_dim)
|
||||
)
|
||||
|
@ -114,58 +130,17 @@ def _test_lightningmodule(
|
|||
# A NaN value could result due to a division-by-zero error
|
||||
assert torch.all(score[~score.isnan()] >= -1)
|
||||
assert torch.all(score[~score.isnan()] <= 1)
|
||||
elif metric_name == MetricsKey.AVERAGE_PRECISION:
|
||||
assert torch.all(score[~score.isnan()] >= 0)
|
||||
assert torch.all(score[~score.isnan()] <= 1)
|
||||
else:
|
||||
assert torch.all(score >= 0)
|
||||
assert torch.all(score <= 1)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_panda_tiles_root_dir(
|
||||
tmp_path_factory: pytest.TempPathFactory, tmp_path_to_pathmnist_dataset: Path
|
||||
) -> Generator:
|
||||
tmp_root_dir = tmp_path_factory.mktemp("mock_tiles")
|
||||
tiles_generator = MockPandaTilesGenerator(
|
||||
dest_data_path=tmp_root_dir,
|
||||
src_data_path=tmp_path_to_pathmnist_dataset,
|
||||
mock_type=MockHistoDataType.PATHMNIST,
|
||||
n_tiles=4,
|
||||
n_slides=10,
|
||||
n_channels=3,
|
||||
tile_size=28,
|
||||
img_size=224,
|
||||
)
|
||||
logging.info("Generating temporary mock tiles that will be deleted at the end of the session.")
|
||||
tiles_generator.generate_mock_histo_data()
|
||||
yield tmp_root_dir
|
||||
shutil.rmtree(tmp_root_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_panda_slides_root_dir(
|
||||
tmp_path_factory: pytest.TempPathFactory, tmp_path_to_pathmnist_dataset: Path
|
||||
) -> Generator:
|
||||
tmp_root_dir = tmp_path_factory.mktemp("mock_slides")
|
||||
wsi_generator = MockPandaSlidesGenerator(
|
||||
dest_data_path=tmp_root_dir,
|
||||
src_data_path=tmp_path_to_pathmnist_dataset,
|
||||
mock_type=MockHistoDataType.PATHMNIST,
|
||||
n_tiles=4,
|
||||
n_slides=10,
|
||||
n_channels=3,
|
||||
n_levels=3,
|
||||
tile_size=28,
|
||||
background_val=255,
|
||||
tiles_pos_type=TilesPositioningType.RANDOM
|
||||
)
|
||||
logging.info("Generating temporary mock slides that will be deleted at the end of the session.")
|
||||
wsi_generator.generate_mock_histo_data()
|
||||
yield tmp_root_dir
|
||||
shutil.rmtree(tmp_root_dir)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_classes", [1, 3])
|
||||
@pytest.mark.parametrize("batch_size", [1, 15])
|
||||
@pytest.mark.parametrize("max_bag_size", [1, 7])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize("max_bag_size", [1, 5])
|
||||
@pytest.mark.parametrize("pool_out_dim", [1, 6])
|
||||
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
|
||||
def test_lightningmodule_attention(
|
||||
|
@ -203,9 +178,7 @@ def add_callback(fn: Callable, callback: Callable) -> Callable:
|
|||
def test_metrics(n_classes: int) -> None:
|
||||
input_dim = (128,)
|
||||
|
||||
def _mock_get_encoder( # type: ignore
|
||||
self, ssl_ckpt_run_id: Optional[str], outputs_folder: Optional[Path]
|
||||
) -> TileEncoder:
|
||||
def _mock_get_encoder(self, outputs_folder: Optional[Path]) -> TileEncoder: # type: ignore
|
||||
return IdentityEncoder(input_dim=input_dim)
|
||||
|
||||
with patch("health_cpath.models.deepmil.EncoderParams.get_encoder", new=_mock_get_encoder):
|
||||
|
@ -274,7 +247,7 @@ def assert_train_step(module: BaseDeepMILModule, data_module: HistoDataModule, u
|
|||
train_data_loader = data_module.train_dataloader()
|
||||
for batch_idx, batch in enumerate(train_data_loader):
|
||||
batch = move_batch_to_expected_device(batch, use_gpu)
|
||||
loss = module.training_step(batch, batch_idx)
|
||||
loss = module.training_step(batch, batch_idx)[ResultsKey.LOSS]
|
||||
loss.retain_grad()
|
||||
loss.backward()
|
||||
assert loss.grad is not None
|
||||
|
@ -327,9 +300,10 @@ def test_container(container_type: Type[BaseMILTiles], use_gpu: bool) -> None:
|
|||
container = container_type()
|
||||
|
||||
container.setup()
|
||||
container.batch_size = 10
|
||||
container.batch_size_inf = 10
|
||||
|
||||
data_module: TilesDataModule = container.get_data_module() # type: ignore
|
||||
data_module.max_bag_size = 10
|
||||
|
||||
module = container.create_model()
|
||||
module.outputs_handler = MagicMock()
|
||||
|
@ -361,7 +335,7 @@ def _test_mock_panda_container(use_gpu: bool, mock_container: BaseDeepSMILEPanda
|
|||
|
||||
|
||||
def test_mock_tiles_panda_container_cpu(mock_panda_tiles_root_dir: Path) -> None:
|
||||
_test_mock_panda_container(use_gpu=False, mock_container=MockDeepSMILETilesPanda,
|
||||
_test_mock_panda_container(use_gpu=False, mock_container=MockDeepSMILETilesPanda, # type: ignore
|
||||
tmp_path=mock_panda_tiles_root_dir)
|
||||
|
||||
|
||||
|
@ -425,6 +399,268 @@ def test_class_weights_multiclass() -> None:
|
|||
assert allclose(loss_weighted, loss_unweighted)
|
||||
|
||||
|
||||
def test_wrong_tuning_options() -> None:
|
||||
with pytest.raises(ValueError,
|
||||
match=r"At least one of the encoder, pooling or classifier should be fine tuned"):
|
||||
_ = MockDeepSMILETilesPanda(
|
||||
tmp_path=Path("foo"),
|
||||
tune_encoder=False,
|
||||
tune_pooling=False,
|
||||
tune_classifier=False
|
||||
)
|
||||
|
||||
|
||||
def _get_datamodule(tmp_path: Path) -> PandaTilesDataModule:
|
||||
tiles_generator = MockPandaTilesGenerator(
|
||||
dest_data_path=tmp_path,
|
||||
mock_type=MockHistoDataType.FAKE,
|
||||
n_tiles=4,
|
||||
n_slides=10,
|
||||
n_channels=3,
|
||||
tile_size=28,
|
||||
img_size=224,
|
||||
)
|
||||
tiles_generator.generate_mock_histo_data()
|
||||
datamodule = PandaTilesDataModule(root_path=tmp_path, batch_size=2, max_bag_size=4)
|
||||
return datamodule
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tune_classifier", [False, True])
|
||||
@pytest.mark.parametrize("tune_pooling", [False, True])
|
||||
@pytest.mark.parametrize("tune_encoder", [False, True])
|
||||
def test_finetuning_options(
|
||||
tune_encoder: bool, tune_pooling: bool, tune_classifier: bool, tmp_path: Path
|
||||
) -> None:
|
||||
module = TilesDeepMILModule(
|
||||
n_classes=1,
|
||||
label_column=DEFAULT_LABEL_COLUMN,
|
||||
encoder_params=get_supervised_imagenet_encoder_params(tune_encoder=tune_encoder),
|
||||
pooling_params=get_attention_pooling_layer_params(pool_out_dim=1, tune_pooling=tune_pooling),
|
||||
classifier_params=ClassifierParams(tune_classifier=tune_classifier),
|
||||
)
|
||||
|
||||
assert module.encoder_params.tune_encoder == tune_encoder
|
||||
assert module.pooling_params.tune_pooling == tune_pooling
|
||||
assert module.classifier_params.tune_classifier == tune_classifier
|
||||
|
||||
for params in module.encoder.parameters():
|
||||
assert params.requires_grad == tune_encoder
|
||||
for params in module.aggregation_fn.parameters():
|
||||
assert params.requires_grad == tune_pooling
|
||||
for params in module.classifier_fn.parameters():
|
||||
assert params.requires_grad == tune_classifier
|
||||
|
||||
instances = torch.randn(4, 3, 224, 224)
|
||||
|
||||
def _assert_existing_gradients_fn(tensor: Tensor, tuning_flag: bool) -> None:
|
||||
assert tensor.requires_grad == tuning_flag
|
||||
if tuning_flag:
|
||||
assert tensor.grad_fn is not None
|
||||
else:
|
||||
assert tensor.grad_fn is None
|
||||
|
||||
with torch.enable_grad():
|
||||
instance_features = module.get_instance_features(instances)
|
||||
_assert_existing_gradients_fn(instance_features, tuning_flag=tune_encoder)
|
||||
assert module.encoder.training == tune_encoder
|
||||
|
||||
attentions, bag_features = module.get_attentions_and_bag_features(instances)
|
||||
_assert_existing_gradients_fn(attentions, tuning_flag=tune_pooling)
|
||||
_assert_existing_gradients_fn(bag_features, tuning_flag=tune_pooling)
|
||||
assert module.aggregation_fn.training == tune_pooling
|
||||
|
||||
bag_logit = module.get_bag_logit(bag_features)
|
||||
# bag_logit gradients are required for pooling layer gradients computation, hence
|
||||
# "tuning_flag=tune_classifier or tune_pooling"
|
||||
_assert_existing_gradients_fn(bag_logit, tuning_flag=tune_classifier or tune_pooling)
|
||||
assert module.classifier_fn.training == tune_classifier
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tune_classifier", [False, True])
|
||||
@pytest.mark.parametrize("tune_pooling", [False, True])
|
||||
@pytest.mark.parametrize("tune_encoder", [False, True])
|
||||
def test_training_with_different_finetuning_options(
|
||||
tune_encoder: bool, tune_pooling: bool, tune_classifier: bool, tmp_path: Path
|
||||
) -> None:
|
||||
if any([tune_encoder, tune_pooling, tune_classifier]):
|
||||
module = TilesDeepMILModule(
|
||||
n_classes=6,
|
||||
label_column=MockPandaTilesGenerator.ISUP_GRADE,
|
||||
encoder_params=get_supervised_imagenet_encoder_params(tune_encoder=tune_encoder),
|
||||
pooling_params=get_attention_pooling_layer_params(pool_out_dim=1, tune_pooling=tune_pooling),
|
||||
classifier_params=ClassifierParams(tune_classifier=tune_classifier),
|
||||
)
|
||||
|
||||
def _assert_existing_gradients(module: nn.Module, tuning_flag: bool) -> None:
|
||||
for param in module.parameters():
|
||||
if tuning_flag:
|
||||
assert param.grad is not None
|
||||
else:
|
||||
assert param.grad is None
|
||||
|
||||
with patch.object(module, "validation_step"):
|
||||
trainer = Trainer(max_epochs=1)
|
||||
trainer.fit(module, datamodule=_get_datamodule(tmp_path))
|
||||
|
||||
_assert_existing_gradients(module.classifier_fn, tuning_flag=tune_classifier)
|
||||
_assert_existing_gradients(module.aggregation_fn, tuning_flag=tune_pooling)
|
||||
_assert_existing_gradients(module.encoder, tuning_flag=tune_encoder)
|
||||
|
||||
|
||||
def test_missing_src_checkpoint_with_pretraining_flags() -> None:
|
||||
with pytest.raises(ValueError, match=r"You need to specify a source checkpoint, to use a pretrained"):
|
||||
_ = MockDeepSMILETilesPanda(tmp_path=Path("foo"), pretrained_classifier=True, pretrained_encoder=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pretrained_classifier", [False, True])
|
||||
@pytest.mark.parametrize("pretrained_pooling", [False, True])
|
||||
@pytest.mark.parametrize("pretrained_encoder", [False, True])
|
||||
def test_init_weights_options(pretrained_encoder: bool, pretrained_pooling: bool, pretrained_classifier: bool) -> None:
|
||||
n_classes = 1
|
||||
module = BaseDeepMILModule(
|
||||
n_classes=n_classes,
|
||||
label_column=DEFAULT_LABEL_COLUMN,
|
||||
encoder_params=get_supervised_imagenet_encoder_params(),
|
||||
pooling_params=get_attention_pooling_layer_params(pool_out_dim=1),
|
||||
)
|
||||
module.encoder_params.pretrained_encoder = pretrained_encoder
|
||||
module.pooling_params.pretrained_pooling = pretrained_pooling
|
||||
module.classifier_params.pretrained_classifier = pretrained_classifier
|
||||
|
||||
with patch.object(module, "load_from_checkpoint") as mock_load_from_checkpoint:
|
||||
with patch.object(module, "copy_weights") as mock_copy_weights:
|
||||
mock_load_from_checkpoint.return_value = MagicMock(n_classes=n_classes)
|
||||
module.transfer_weights(Path("foo"))
|
||||
assert mock_copy_weights.call_count == sum(
|
||||
[int(pretrained_encoder), int(pretrained_pooling), int(pretrained_classifier)]
|
||||
)
|
||||
|
||||
|
||||
def _get_tiles_deepmil_module(
|
||||
pretrained_encoder: bool = True,
|
||||
pretrained_pooling: bool = True,
|
||||
pretrained_classifier: bool = True,
|
||||
n_classes: int = 3,
|
||||
num_layers: int = 2,
|
||||
num_heads: int = 1,
|
||||
hidden_dim: int = 8,
|
||||
transformer_dropout: float = 0.1
|
||||
) -> TilesDeepMILModule:
|
||||
module = TilesDeepMILModule(
|
||||
n_classes=n_classes,
|
||||
label_column=MockPandaTilesGenerator.ISUP_GRADE,
|
||||
encoder_params=get_supervised_imagenet_encoder_params(),
|
||||
pooling_params=get_transformer_pooling_layer_params(num_layers, num_heads, hidden_dim, transformer_dropout),
|
||||
classifier_params=ClassifierParams(pretrained_classifier=pretrained_classifier),
|
||||
)
|
||||
module.encoder_params.pretrained_encoder = pretrained_encoder
|
||||
module.pooling_params.pretrained_pooling = pretrained_pooling
|
||||
module.classifier_params.pretrained_classifier = pretrained_classifier
|
||||
return module
|
||||
|
||||
|
||||
def get_pretrained_module(encoder_val: int = 5, pooling_val: int = 6, classifier_val: int = 7) -> nn.Module:
|
||||
module = _get_tiles_deepmil_module()
|
||||
|
||||
def _fix_sub_module_weights(submodule: nn.Module, constant_val: int) -> None:
|
||||
for param in submodule.state_dict().values():
|
||||
param.data.fill_(constant_val)
|
||||
|
||||
_fix_sub_module_weights(module.encoder, encoder_val)
|
||||
_fix_sub_module_weights(module.aggregation_fn, pooling_val)
|
||||
_fix_sub_module_weights(module.classifier_fn, classifier_val)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pretrained_classifier", [False, True])
|
||||
@pytest.mark.parametrize("pretrained_pooling", [False, True])
|
||||
@pytest.mark.parametrize("pretrained_encoder", [False, True])
|
||||
def test_transfer_weights_same_config(
|
||||
pretrained_encoder: bool, pretrained_pooling: bool, pretrained_classifier: bool,
|
||||
) -> None:
|
||||
encoder_val = 5
|
||||
pooling_val = 6
|
||||
classifier_val = 7
|
||||
module = _get_tiles_deepmil_module(pretrained_encoder, pretrained_pooling, pretrained_classifier)
|
||||
pretrained_module = get_pretrained_module(encoder_val, pooling_val, classifier_val)
|
||||
|
||||
encoder_random_weights = deepcopy(module.encoder.state_dict())
|
||||
pooling_random_weights = deepcopy(module.aggregation_fn.state_dict())
|
||||
classification_random_weights = deepcopy(module.classifier_fn.state_dict())
|
||||
|
||||
with patch.object(module, "load_from_checkpoint") as mock_load_from_checkpoint:
|
||||
mock_load_from_checkpoint.return_value = pretrained_module
|
||||
module.transfer_weights(Path("foo"))
|
||||
|
||||
encoder_transfer_weights = module.encoder.state_dict()
|
||||
pooling_transfer_weights = module.aggregation_fn.state_dict()
|
||||
classification_transfer_weights = module.classifier_fn.state_dict()
|
||||
|
||||
def _assert_weights_equal(
|
||||
random_weights: Dict, transfer_weights: Dict, pretrained_flag: bool, expected_val: int
|
||||
) -> None:
|
||||
for r_param_name, t_param_name in zip(random_weights, transfer_weights):
|
||||
assert r_param_name == t_param_name, "Param names do not match"
|
||||
r_param = random_weights[r_param_name]
|
||||
t_param = transfer_weights[t_param_name]
|
||||
if pretrained_flag:
|
||||
assert torch.equal(t_param.data, torch.full_like(t_param.data, expected_val))
|
||||
else:
|
||||
assert torch.equal(t_param.data, r_param.data)
|
||||
|
||||
_assert_weights_equal(encoder_random_weights, encoder_transfer_weights, pretrained_encoder, encoder_val)
|
||||
_assert_weights_equal(pooling_random_weights, pooling_transfer_weights, pretrained_pooling, pooling_val)
|
||||
_assert_weights_equal(
|
||||
classification_random_weights, classification_transfer_weights, pretrained_classifier, classifier_val
|
||||
)
|
||||
|
||||
|
||||
def test_transfer_weights_different_encoder() -> None:
|
||||
module = _get_tiles_deepmil_module(pretrained_encoder=True)
|
||||
pretrained_module = _get_tiles_deepmil_module()
|
||||
pretrained_module.encoder = IdentityEncoder(tile_size=224)
|
||||
|
||||
with patch.object(module, "load_from_checkpoint") as mock_load_from_checkpoint:
|
||||
mock_load_from_checkpoint.return_value = pretrained_module
|
||||
with pytest.raises(
|
||||
ValueError, match=rf"Submodule {DeepMILSubmodules.ENCODER} has different number of parameters "
|
||||
):
|
||||
module.transfer_weights(Path("foo"))
|
||||
|
||||
|
||||
def test_transfer_weights_different_pooling() -> None:
|
||||
module = _get_tiles_deepmil_module(num_heads=2, hidden_dim=24, pretrained_pooling=True)
|
||||
pretrained_module = _get_tiles_deepmil_module(num_heads=1, hidden_dim=8)
|
||||
|
||||
with patch.object(module, "load_from_checkpoint") as mock_load_from_checkpoint:
|
||||
mock_load_from_checkpoint.return_value = pretrained_module
|
||||
with pytest.raises(
|
||||
ValueError, match=rf"Submodule {DeepMILSubmodules.POOLING} has different number of parameters "
|
||||
):
|
||||
module.transfer_weights(Path("foo"))
|
||||
|
||||
|
||||
def test_transfer_weights_different_classifier() -> None:
|
||||
module = _get_tiles_deepmil_module(n_classes=4, pretrained_classifier=True)
|
||||
pretrained_module = _get_tiles_deepmil_module(n_classes=3)
|
||||
|
||||
with patch.object(module, "load_from_checkpoint") as mock_load_from_checkpoint:
|
||||
mock_load_from_checkpoint.return_value = pretrained_module
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Number of classes in pretrained model 3 does not match number of classes in current model 4."
|
||||
):
|
||||
module.transfer_weights(Path("foo"))
|
||||
|
||||
|
||||
def test_wrong_encoding_chunk_size() -> None:
|
||||
with pytest.raises(
|
||||
ValueError, match=r"The encoding chunk size should be at least as large as the maximum bag size"
|
||||
):
|
||||
_ = BaseMIL(encoding_chunk_size=1, max_bag_size=4, tune_encoder=True, max_num_gpus=2, pl_sync_batchnorm=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("container_type", [DeepSMILETilesPanda,
|
||||
DeepSMILECrck])
|
||||
@pytest.mark.parametrize("primary_val_metric", [m for m in MetricsKey])
|
||||
|
@ -444,3 +680,35 @@ def test_checkpoint_name(container_type: Type[BaseMILTiles], primary_val_metric:
|
|||
|
||||
metric_optim = "max" if maximise_primary_metric else "min"
|
||||
assert container.best_checkpoint_filename == f"checkpoint_{metric_optim}_val_{primary_val_metric.value}"
|
||||
|
||||
|
||||
def test_on_run_extra_val_epoch(mock_panda_tiles_root_dir: Path) -> None:
|
||||
container = MockDeepSMILETilesPanda(tmp_path=mock_panda_tiles_root_dir)
|
||||
container.setup()
|
||||
container.data_module = MagicMock()
|
||||
container.create_lightning_module_and_store()
|
||||
assert not container.model._on_extra_val_epoch
|
||||
assert (
|
||||
container.model.outputs_handler.test_plots_handler.plot_options # type: ignore
|
||||
!= container.model.outputs_handler.val_plots_handler.plot_options # type: ignore
|
||||
)
|
||||
container.on_run_extra_validation_epoch()
|
||||
assert container.model._on_extra_val_epoch
|
||||
assert (
|
||||
container.model.outputs_handler.test_plots_handler.plot_options # type: ignore
|
||||
== container.model.outputs_handler.val_plots_handler.plot_options # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"container_type", [TcgaCrckSSLMIL, TilesPandaSSLMIL, SlidesPandaSSLMIL, SlidesPandaSSLMILBenchmark]
|
||||
)
|
||||
def test_ssl_containers_default_checkpoint(container_type: BaseMIL) -> None:
|
||||
if container_type == TcgaCrckSSLMIL:
|
||||
default_checkpoint = innereye_ssl_checkpoint_crck_4ws
|
||||
else:
|
||||
default_checkpoint = innereye_ssl_checkpoint_binary
|
||||
assert container_type().ssl_checkpoint.checkpoint == default_checkpoint
|
||||
|
||||
container = container_type(ssl_checkpoint=CheckpointParser(TEST_SSL_RUN_ID))
|
||||
assert container.ssl_checkpoint.checkpoint != default_checkpoint
|
||||
|
|
|
@ -12,8 +12,8 @@ from torch import Tensor, float32, nn, rand
|
|||
from torchvision.models import resnet18
|
||||
|
||||
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
|
||||
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, CheckpointDownloader
|
||||
from health_cpath.models.encoders import (TileEncoder, HistoSSLEncoder, ImageNetEncoder,
|
||||
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME, CheckpointDownloader
|
||||
from health_cpath.models.encoders import (Resnet18, TileEncoder, HistoSSLEncoder,
|
||||
ImageNetSimCLREncoder, SSLEncoder)
|
||||
from health_cpath.utils.layer_utils import setup_feature_extractor
|
||||
from testhiml.utils_testhiml import DEFAULT_WORKSPACE
|
||||
|
@ -25,7 +25,7 @@ TEST_SSL_RUN_ID = "CRCK_SimCLR_1654677598_49a66020"
|
|||
|
||||
|
||||
def get_supervised_imagenet_encoder() -> TileEncoder:
|
||||
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=TILE_SIZE)
|
||||
return Resnet18(tile_size=TILE_SIZE)
|
||||
|
||||
|
||||
def get_simclr_imagenet_encoder() -> TileEncoder:
|
||||
|
@ -35,8 +35,9 @@ def get_simclr_imagenet_encoder() -> TileEncoder:
|
|||
def get_ssl_encoder(download_dir: Path) -> TileEncoder:
|
||||
downloader = CheckpointDownloader(run_id=TEST_SSL_RUN_ID,
|
||||
download_dir=download_dir,
|
||||
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
|
||||
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME,
|
||||
remote_checkpoint_dir=Path(DEFAULT_AML_CHECKPOINT_DIR))
|
||||
downloader.download_checkpoint_if_necessary()
|
||||
return SSLEncoder(pl_checkpoint_path=downloader.local_checkpoint_path, tile_size=TILE_SIZE)
|
||||
|
||||
|
||||
|
|
|
@ -14,15 +14,15 @@ from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
|
|||
from monai.transforms import Compose
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
from torch.utils.data import Subset
|
||||
from torchvision.models import resnet18
|
||||
from torchvision.transforms import RandomHorizontalFlip
|
||||
|
||||
from health_ml.utils.bag_utils import BagDataset
|
||||
from health_ml.utils.data_augmentations import HEDJitter
|
||||
|
||||
from health_cpath.datasets.default_paths import TCGA_CRCK_DATASET_DIR
|
||||
from health_cpath.datasets.panda_tiles_dataset import PandaTilesDataset
|
||||
from health_cpath.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
|
||||
from health_cpath.models.encoders import ImageNetEncoder
|
||||
from health_cpath.models.encoders import Resnet18
|
||||
from health_cpath.models.transforms import (EncodeTilesBatchd, LoadTiled, LoadTilesBatchd, Subsampled,
|
||||
transform_dict_adaptor)
|
||||
|
||||
|
@ -89,6 +89,32 @@ def test_load_tiles_batch() -> None:
|
|||
assert_dicts_equal(bagged_loaded_batch, loaded_bagged_batch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("scale_intensity", [True, False])
|
||||
def test_itensity_scaling_load_tiles_batch(scale_intensity: bool, mock_panda_tiles_root_dir: Path) -> None:
|
||||
tiles_dataset = PandaTilesDataset(mock_panda_tiles_root_dir)
|
||||
image_key = tiles_dataset.IMAGE_COLUMN
|
||||
max_bag_size = 4
|
||||
bagged_dataset = BagDataset(tiles_dataset, bag_ids=tiles_dataset.slide_ids, # type: ignore
|
||||
max_bag_size=max_bag_size)
|
||||
load_batch_transform = LoadTilesBatchd(image_key, scale_intensity=scale_intensity)
|
||||
index = 0
|
||||
|
||||
# Test that the transform returns images in [0, 255] range
|
||||
bagged_batch = bagged_dataset[index]
|
||||
manually_loaded_batch = load_batch_transform(bagged_batch)
|
||||
|
||||
pixels_dtype = torch.uint8 if not scale_intensity else torch.float32
|
||||
max_val = 255 if not scale_intensity else 1.
|
||||
|
||||
for tile in manually_loaded_batch[image_key]:
|
||||
assert tile.dtype == pixels_dtype
|
||||
assert tile.max() <= max_val
|
||||
assert tile.min() >= 0
|
||||
if not scale_intensity:
|
||||
assert manually_loaded_batch[image_key][index].max() > 1
|
||||
assert tile.unique().shape[0] > 1
|
||||
|
||||
|
||||
def _test_cache_and_persistent_datasets(tmp_path: Path,
|
||||
base_dataset: TorchDataset,
|
||||
transform: Union[Sequence[Callable], Callable],
|
||||
|
@ -142,7 +168,7 @@ def test_encode_tiles(tmp_path: Path, use_gpu: bool, chunk_size: int) -> None:
|
|||
bagged_dataset = BagDataset(tiles_dataset, bag_ids=tiles_dataset.slide_ids, # type: ignore
|
||||
max_bag_size=max_bag_size)
|
||||
|
||||
encoder = ImageNetEncoder(resnet18, tile_size=224, n_channels=3)
|
||||
encoder = Resnet18(tile_size=224, n_channels=3)
|
||||
if use_gpu:
|
||||
encoder.cuda()
|
||||
|
||||
|
|
|
@ -0,0 +1,281 @@
|
|||
import time
|
||||
import torch
|
||||
import pytest
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Callable, List
|
||||
from unittest.mock import MagicMock, patch
|
||||
from health_cpath.configs.classification.BaseMIL import BaseMIL
|
||||
|
||||
from health_cpath.utils.naming import ModelKey, ResultsKey
|
||||
from health_cpath.utils.callbacks import LossAnalysisCallback, LossCacheDictType
|
||||
from testhisto.mocks.container import MockDeepSMILETilesPanda
|
||||
from testhisto.utils.utils_testhisto import run_distributed
|
||||
|
||||
|
||||
def _assert_is_sorted(array: np.ndarray) -> None:
|
||||
assert np.all(np.diff(array) <= 0)
|
||||
|
||||
|
||||
def _assert_loss_cache_contains_n_elements(loss_cache: LossCacheDictType, n: int) -> None:
|
||||
for key in loss_cache:
|
||||
assert len(loss_cache[key]) == n
|
||||
|
||||
|
||||
def get_loss_cache(n_slides: int = 4, rank: int = 0) -> LossCacheDictType:
|
||||
return {
|
||||
ResultsKey.LOSS: list(range(1, n_slides + 1)),
|
||||
ResultsKey.ENTROPY: list(range(1, n_slides + 1)),
|
||||
ResultsKey.SLIDE_ID: [f"id_{i}" for i in range(rank * n_slides, (rank + 1) * n_slides)],
|
||||
ResultsKey.TILE_ID: [f"a${i * (rank + 1)}$b" for i in range(rank * n_slides, (rank + 1) * n_slides)],
|
||||
}
|
||||
|
||||
|
||||
def dump_loss_cache_for_epochs(loss_callback: LossAnalysisCallback, epochs: int, stage: ModelKey) -> None:
|
||||
for epoch in range(epochs):
|
||||
loss_callback.loss_cache[stage] = get_loss_cache(n_slides=4, rank=0)
|
||||
loss_callback.save_loss_cache(epoch, stage)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("create_outputs_folders", [True, False])
|
||||
def test_loss_callback_outputs_folder_exist(create_outputs_folders: bool, tmp_path: Path) -> None:
|
||||
outputs_folder = tmp_path / "outputs"
|
||||
callback = LossAnalysisCallback(outputs_folder=outputs_folder, create_outputs_folders=create_outputs_folders)
|
||||
for stage in [ModelKey.TRAIN, ModelKey.VAL]:
|
||||
for folder in [
|
||||
callback.outputs_folder,
|
||||
callback.get_cache_folder(stage),
|
||||
callback.get_scatter_folder(stage),
|
||||
callback.get_heatmap_folder(stage),
|
||||
callback.get_anomalies_folder(stage),
|
||||
]:
|
||||
assert folder.exists() == create_outputs_folders
|
||||
|
||||
|
||||
@pytest.mark.parametrize("analyse_loss", [True, False])
|
||||
def test_analyse_loss_param(analyse_loss: bool) -> None:
|
||||
container = BaseMIL(analyse_loss=analyse_loss)
|
||||
container.data_module = MagicMock()
|
||||
callbacks = container.get_callbacks()
|
||||
assert isinstance(callbacks[-1], LossAnalysisCallback) == analyse_loss
|
||||
|
||||
|
||||
@pytest.mark.parametrize("save_tile_ids", [True, False])
|
||||
def test_save_tile_ids_param(save_tile_ids: bool) -> None:
|
||||
callback = LossAnalysisCallback(outputs_folder=Path("foo"), save_tile_ids=save_tile_ids)
|
||||
assert callback.save_tile_ids == save_tile_ids
|
||||
assert (ResultsKey.TILE_ID in callback.loss_cache[ModelKey.TRAIN]) == save_tile_ids
|
||||
assert (ResultsKey.TILE_ID in callback.loss_cache[ModelKey.VAL]) == save_tile_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("patience", [0, 1, 2])
|
||||
def test_loss_analysis_patience(patience: int) -> None:
|
||||
callback = LossAnalysisCallback(outputs_folder=Path("foo"), patience=patience, max_epochs=10)
|
||||
assert callback.patience == patience
|
||||
assert callback.epochs_range[0] == patience
|
||||
|
||||
current_epoch = 0
|
||||
if patience > 0:
|
||||
assert callback.should_cache_loss_values(current_epoch) is False
|
||||
else:
|
||||
assert callback.should_cache_loss_values(current_epoch)
|
||||
current_epoch = 5
|
||||
assert callback.should_cache_loss_values(current_epoch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("epochs_interval", [1, 2])
|
||||
def test_loss_analysis_epochs_interval(epochs_interval: int) -> None:
|
||||
max_epochs = 10
|
||||
callback = LossAnalysisCallback(
|
||||
outputs_folder=Path("foo"), patience=0, max_epochs=max_epochs, epochs_interval=epochs_interval
|
||||
)
|
||||
assert callback.epochs_interval == epochs_interval
|
||||
assert len(callback.epochs_range) == max_epochs // epochs_interval
|
||||
|
||||
# First time to cache loss values
|
||||
current_epoch = 0
|
||||
assert callback.should_cache_loss_values(current_epoch)
|
||||
|
||||
current_epoch = 4 # Note that PL starts counting epochs from 0, 4th epoch is actually the 5th
|
||||
if epochs_interval == 2:
|
||||
assert callback.should_cache_loss_values(current_epoch) is False
|
||||
else:
|
||||
assert callback.should_cache_loss_values(current_epoch)
|
||||
|
||||
current_epoch = 5
|
||||
assert callback.should_cache_loss_values(current_epoch)
|
||||
|
||||
current_epoch = max_epochs # no loss caching for extra validation epoch
|
||||
assert callback.should_cache_loss_values(current_epoch) is False
|
||||
|
||||
|
||||
def test_on_train_and_val_batch_end(tmp_path: Path, mock_panda_tiles_root_dir: Path) -> None:
|
||||
batch_size = 2
|
||||
container = MockDeepSMILETilesPanda(tmp_path=mock_panda_tiles_root_dir, analyse_loss=True, batch_size=batch_size)
|
||||
container.setup()
|
||||
container.create_lightning_module_and_store()
|
||||
|
||||
current_epoch = 5
|
||||
trainer = MagicMock(current_epoch=current_epoch)
|
||||
|
||||
callback = LossAnalysisCallback(outputs_folder=tmp_path)
|
||||
_assert_loss_cache_contains_n_elements(callback.loss_cache[ModelKey.TRAIN], 0)
|
||||
_assert_loss_cache_contains_n_elements(callback.loss_cache[ModelKey.VAL], 0)
|
||||
dataloader = iter(container.data_module.train_dataloader())
|
||||
|
||||
def _call_on_batch_end_hook(on_batch_end_hook: Callable, batch_idx: int) -> None:
|
||||
batch = next(dataloader)
|
||||
outputs = container.model.training_step(batch, batch_idx)
|
||||
on_batch_end_hook(trainer, container.model, outputs, batch, batch_idx, 0) # type: ignore
|
||||
|
||||
stages = [ModelKey.TRAIN, ModelKey.VAL]
|
||||
hooks: List[Callable] = [callback.on_train_batch_end, callback.on_validation_batch_end]
|
||||
for stage, on_batch_end_hook in zip(stages, hooks):
|
||||
_call_on_batch_end_hook(on_batch_end_hook, batch_idx=0)
|
||||
_assert_loss_cache_contains_n_elements(callback.loss_cache[stage], batch_size)
|
||||
|
||||
_call_on_batch_end_hook(on_batch_end_hook, batch_idx=1)
|
||||
_assert_loss_cache_contains_n_elements(callback.loss_cache[stage], 2 * batch_size)
|
||||
|
||||
|
||||
def test_on_train_and_val_epoch_end(
|
||||
tmp_path: Path, duplicate: bool = False, rank: int = 0, world_size: int = 1, device: str = "cpu"
|
||||
) -> None:
|
||||
current_epoch = 2
|
||||
n_slides_per_process = 4
|
||||
trainer = MagicMock(current_epoch=current_epoch)
|
||||
pl_module = MagicMock(global_rank=rank)
|
||||
|
||||
loss_callback = LossAnalysisCallback(
|
||||
outputs_folder=tmp_path, num_slides_heatmap=2, num_slides_scatter=2, max_epochs=10
|
||||
)
|
||||
stages = [ModelKey.TRAIN, ModelKey.VAL]
|
||||
hooks = [loss_callback.on_train_epoch_end, loss_callback.on_validation_epoch_end]
|
||||
for stage, on_epoch_hook in zip(stages, hooks):
|
||||
loss_callback.loss_cache[stage] = get_loss_cache(rank=rank, n_slides=n_slides_per_process)
|
||||
|
||||
if duplicate:
|
||||
# Duplicate slide "id_0" to test that the duplicates are removed
|
||||
loss_callback.loss_cache[stage][ResultsKey.SLIDE_ID][0] = "id_0"
|
||||
|
||||
_assert_loss_cache_contains_n_elements(loss_callback.loss_cache[stage], n_slides_per_process)
|
||||
on_epoch_hook(trainer, pl_module)
|
||||
# Loss cache is flushed after each epoch
|
||||
_assert_loss_cache_contains_n_elements(loss_callback.loss_cache[stage], 0)
|
||||
|
||||
if rank > 0:
|
||||
time.sleep(10) # Wait for rank 0 to save the loss cache in a csv file
|
||||
|
||||
loss_cache_path = loss_callback.get_loss_cache_file(current_epoch, stage)
|
||||
assert loss_callback.get_cache_folder(stage).exists()
|
||||
assert loss_cache_path.exists()
|
||||
assert loss_cache_path.parent == loss_callback.get_cache_folder(stage)
|
||||
|
||||
loss_cache = pd.read_csv(loss_cache_path)
|
||||
total_slides = n_slides_per_process * world_size if not duplicate else n_slides_per_process * world_size - 1
|
||||
_assert_loss_cache_contains_n_elements(loss_cache, total_slides)
|
||||
_assert_is_sorted(loss_cache[ResultsKey.LOSS].values)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
|
||||
@pytest.mark.gpu
|
||||
def test_on_train_epoch_end_distributed(tmp_path: Path) -> None:
|
||||
# Test that the loss cache is saved correctly when using multiple GPUs
|
||||
# First scenario: no duplicates
|
||||
run_distributed(test_on_train_and_val_epoch_end, [tmp_path, False], world_size=2)
|
||||
# Second scenario: introduce duplicates
|
||||
run_distributed(test_on_train_and_val_epoch_end, [tmp_path, True], world_size=2)
|
||||
|
||||
|
||||
def test_on_train_and_val_end(tmp_path: Path) -> None:
|
||||
pl_module = MagicMock(global_rank=0, _on_extra_val_epoch=False)
|
||||
max_epochs = 4
|
||||
trainer = MagicMock(current_epoch=max_epochs - 1)
|
||||
|
||||
loss_callback = LossAnalysisCallback(
|
||||
outputs_folder=tmp_path, max_epochs=max_epochs, num_slides_heatmap=2, num_slides_scatter=2
|
||||
)
|
||||
stages = [ModelKey.TRAIN, ModelKey.VAL]
|
||||
hooks = [loss_callback.on_train_end, loss_callback.on_validation_end]
|
||||
for stage, on_end_hook in zip(stages, hooks):
|
||||
dump_loss_cache_for_epochs(loss_callback, max_epochs, stage)
|
||||
on_end_hook(trainer, pl_module)
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
assert loss_callback.get_loss_cache_file(epoch, stage).exists()
|
||||
|
||||
# check save_loss_ranks outputs
|
||||
assert loss_callback.get_all_epochs_loss_cache_file(stage).exists()
|
||||
assert loss_callback.get_loss_stats_file(stage).exists()
|
||||
assert loss_callback.get_loss_ranks_file(stage).exists()
|
||||
assert loss_callback.get_loss_ranks_stats_file(stage).exists()
|
||||
|
||||
# check plot_slides_loss_scatter outputs
|
||||
assert loss_callback.get_scatter_plot_file(loss_callback.HIGHEST, stage).exists()
|
||||
assert loss_callback.get_scatter_plot_file(loss_callback.LOWEST, stage).exists()
|
||||
|
||||
# check plot_loss_heatmap_for_slides_of_epoch outputs
|
||||
for epoch in range(max_epochs):
|
||||
assert loss_callback.get_heatmap_plot_file(epoch, loss_callback.HIGHEST, stage).exists()
|
||||
assert loss_callback.get_heatmap_plot_file(epoch, loss_callback.LOWEST, stage).exists()
|
||||
|
||||
|
||||
def test_on_validation_end_not_called_if_extra_val_epoch(tmp_path: Path) -> None:
|
||||
pl_module = MagicMock(global_rank=0, _on_extra_val_epoch=True)
|
||||
max_epochs = 4
|
||||
trainer = MagicMock(current_epoch=0)
|
||||
loss_callback = LossAnalysisCallback(
|
||||
outputs_folder=tmp_path, max_epochs=max_epochs, num_slides_heatmap=2, num_slides_scatter=2
|
||||
)
|
||||
with patch.object(loss_callback, "save_loss_outliers_analaysis_results") as mock_func:
|
||||
loss_callback.on_validation_end(trainer, pl_module)
|
||||
mock_func.assert_not_called()
|
||||
|
||||
|
||||
def test_nans_detection(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None:
|
||||
max_epochs = 2
|
||||
n_slides = 4
|
||||
loss_callback = LossAnalysisCallback(
|
||||
outputs_folder=tmp_path, max_epochs=max_epochs, num_slides_heatmap=2, num_slides_scatter=2
|
||||
)
|
||||
stages = [ModelKey.TRAIN, ModelKey.VAL]
|
||||
for stage in stages:
|
||||
for epoch in range(max_epochs):
|
||||
loss_callback.loss_cache[stage] = get_loss_cache(n_slides)
|
||||
loss_callback.loss_cache[stage][ResultsKey.LOSS][epoch] = np.nan # Introduce NaNs
|
||||
loss_callback.save_loss_cache(epoch, stage)
|
||||
|
||||
all_slides = loss_callback.select_slides_for_epoch(epoch=0, stage=stage)
|
||||
all_loss_values_per_slides = loss_callback.select_all_losses_for_selected_slides(all_slides, stage)
|
||||
loss_callback.sanity_check_loss_values(all_loss_values_per_slides, stage)
|
||||
|
||||
assert "NaNs found in loss values for slide id_0" in caplog.records[-1].getMessage()
|
||||
assert "NaNs found in loss values for slide id_1" in caplog.records[0].getMessage()
|
||||
|
||||
assert loss_callback.nan_slides[stage] == ["id_1", "id_0"]
|
||||
assert loss_callback.get_nan_slides_file(stage).exists()
|
||||
assert loss_callback.get_nan_slides_file(stage).parent == loss_callback.get_anomalies_folder(stage)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("log_exceptions", [True, False])
|
||||
def test_log_exceptions_flag(log_exceptions: bool, tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None:
|
||||
max_epochs = 3
|
||||
trainer = MagicMock(current_epoch=max_epochs - 1)
|
||||
pl_module = MagicMock(global_rank=0, _on_extra_val_epoch=False)
|
||||
loss_callback = LossAnalysisCallback(
|
||||
outputs_folder=tmp_path, max_epochs=max_epochs,
|
||||
num_slides_heatmap=2, num_slides_scatter=2, log_exceptions=log_exceptions
|
||||
)
|
||||
stages = [ModelKey.TRAIN, ModelKey.VAL]
|
||||
hooks = [loss_callback.on_train_end, loss_callback.on_validation_end]
|
||||
for stage, on_end_hook in zip(stages, hooks):
|
||||
message = "Error while detecting " + stage.value + " loss values outliers"
|
||||
if log_exceptions:
|
||||
on_end_hook(trainer, pl_module)
|
||||
assert message in caplog.records[-1].getMessage()
|
||||
else:
|
||||
with pytest.raises(Exception, match=fr"{message}"):
|
||||
on_end_hook(trainer, pl_module)
|
|
@ -0,0 +1,70 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from health_cpath.models.encoders import SSLEncoder
|
||||
from health_cpath.scripts.generate_checkpoint_url import get_checkpoint_url_from_aml_run
|
||||
from health_cpath.utils.deepmil_utils import EncoderParams
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser, LAST_CHECKPOINT_FILE_NAME, MODEL_WEIGHTS_DIR_NAME
|
||||
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
|
||||
from testhiml.utils.fixed_paths_for_tests import full_test_data_path
|
||||
from testhisto.models.test_encoders import TEST_SSL_RUN_ID
|
||||
from testhiml.utils_testhiml import DEFAULT_WORKSPACE
|
||||
|
||||
|
||||
LAST_CHECKPOINT = f"{DEFAULT_AML_CHECKPOINT_DIR}/{LAST_CHECKPOINT_FILE_NAME}"
|
||||
|
||||
|
||||
def test_validate_encoder_params() -> None:
|
||||
with pytest.raises(ValueError, match=r"SSLEncoder requires an ssl_checkpoint"):
|
||||
encoder = EncoderParams(encoder_type=SSLEncoder.__name__)
|
||||
encoder.validate()
|
||||
|
||||
|
||||
def test_load_ssl_checkpoint_from_local_file(tmp_path: Path) -> None:
|
||||
checkpoint_filename = "hello_world_checkpoint.ckpt"
|
||||
local_checkpoint_path = full_test_data_path(suffix=checkpoint_filename)
|
||||
encoder_params = EncoderParams(
|
||||
encoder_type=SSLEncoder.__name__, ssl_checkpoint=CheckpointParser(str(local_checkpoint_path))
|
||||
)
|
||||
assert encoder_params.ssl_checkpoint.is_local_file
|
||||
ssl_checkpoint_path = encoder_params.ssl_checkpoint.get_path(tmp_path)
|
||||
assert ssl_checkpoint_path.exists()
|
||||
assert ssl_checkpoint_path == local_checkpoint_path
|
||||
with patch("health_cpath.models.encoders.SSLEncoder._get_encoder") as mock_get_encoder:
|
||||
mock_get_encoder.return_value = (MagicMock(), MagicMock())
|
||||
encoder = encoder_params.get_encoder(tmp_path)
|
||||
assert isinstance(encoder, SSLEncoder)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test is failing because of issue #655")
|
||||
def test_load_ssl_checkpoint_from_url(tmp_path: Path) -> None:
|
||||
blob_url = get_checkpoint_url_from_aml_run(
|
||||
run_id=TEST_SSL_RUN_ID,
|
||||
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME,
|
||||
expiry_days=1,
|
||||
aml_workspace=DEFAULT_WORKSPACE.workspace)
|
||||
encoder_params = EncoderParams(encoder_type=SSLEncoder.__name__, ssl_checkpoint=CheckpointParser(blob_url))
|
||||
assert encoder_params.ssl_checkpoint.is_url
|
||||
ssl_checkpoint_path = encoder_params.ssl_checkpoint.get_path(tmp_path)
|
||||
assert ssl_checkpoint_path.exists()
|
||||
assert ssl_checkpoint_path == tmp_path / MODEL_WEIGHTS_DIR_NAME / LAST_CHECKPOINT_FILE_NAME
|
||||
encoder = encoder_params.get_encoder(tmp_path)
|
||||
assert isinstance(encoder, SSLEncoder)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test is failing because of issue #655")
|
||||
def test_load_ssl_checkpoint_from_run_id(tmp_path: Path) -> None:
|
||||
encoder_params = EncoderParams(encoder_type=SSLEncoder.__name__, ssl_checkpoint=CheckpointParser(TEST_SSL_RUN_ID))
|
||||
assert encoder_params.ssl_checkpoint.is_aml_run_id
|
||||
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:
|
||||
mock_get_workspace.return_value = DEFAULT_WORKSPACE.workspace
|
||||
ssl_checkpoint_path = encoder_params.ssl_checkpoint.get_path(tmp_path)
|
||||
assert ssl_checkpoint_path.exists()
|
||||
assert ssl_checkpoint_path == tmp_path / TEST_SSL_RUN_ID / LAST_CHECKPOINT
|
||||
encoder = encoder_params.get_encoder(tmp_path)
|
||||
assert isinstance(encoder, SSLEncoder)
|
|
@ -1,12 +1,13 @@
|
|||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.multiprocessing
|
||||
from ruamel.yaml import YAML
|
||||
from health_cpath.utils.tiles_selection_utils import TilesSelector
|
||||
from testhisto.utils.utils_testhisto import run_distributed
|
||||
from torch.testing import assert_close
|
||||
from torchmetrics.metric import Metric
|
||||
|
@ -251,3 +252,27 @@ def test_collate_results_multigpu() -> None:
|
|||
epoch_results = _create_epoch_results(batch_size, num_batches, rank=0, device='cuda:0') \
|
||||
+ _create_epoch_results(batch_size, num_batches, rank=1, device='cuda:1')
|
||||
_test_collate_results(epoch_results, total_num_samples=2 * num_batches * batch_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('val_set_is_dist', [True, False])
|
||||
def test_results_gathering_with_val_set_is_dist_flag(val_set_is_dist: bool, tmp_path: Path) -> None:
|
||||
outputs_handler = _create_outputs_handler(tmp_path)
|
||||
outputs_handler.tiles_selector = TilesSelector(2, num_slides=2, num_tiles=2)
|
||||
outputs_handler.val_set_is_dist = val_set_is_dist
|
||||
outputs_handler._save_outputs = MagicMock() # type: ignore
|
||||
metric_value = 0.5
|
||||
with patch("health_cpath.utils.output_utils.gather_results") as mock_gather_results:
|
||||
with patch.object(outputs_handler.tiles_selector, "gather_selected_tiles_across_devices") as mock_gather_tiles:
|
||||
with patch.object(outputs_handler.tiles_selector, "_clear_cached_slides_heaps") as mock_clear_cache:
|
||||
with patch.object(outputs_handler, "should_gather_tiles") as mock_should_gather_tiles:
|
||||
mock_should_gather_tiles.return_value = True
|
||||
for rank in range(2):
|
||||
epoch_results = [{_PRIMARY_METRIC_KEY: [metric_value] * 5, _RANK_KEY: rank}]
|
||||
outputs_handler.save_validation_outputs(
|
||||
epoch_results=epoch_results, # type: ignore
|
||||
metrics_dict=_get_mock_metrics_dict(metric_value),
|
||||
epoch=0,
|
||||
is_global_rank_zero=rank == 0)
|
||||
assert mock_gather_results.called == val_set_is_dist
|
||||
assert mock_gather_tiles.called == val_set_is_dist
|
||||
mock_clear_cache.assert_called()
|
||||
|
|
|
@ -3,33 +3,42 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Collection
|
||||
from typing import Any, Collection, Dict, List
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from health_cpath.utils.naming import PlotOption, ResultsKey
|
||||
from health_cpath.utils.plots_utils import DeepMILPlotsHandler, save_confusion_matrix
|
||||
from health_cpath.utils.plots_utils import DeepMILPlotsHandler, save_confusion_matrix, save_pr_curve
|
||||
from health_cpath.utils.tiles_selection_utils import SlideNode, TilesSelector
|
||||
from testhisto.mocks.container import MockDeepSMILETilesPanda
|
||||
|
||||
|
||||
def test_plots_handler_wrong_class_names() -> None:
|
||||
plot_options = {PlotOption.HISTOGRAM, PlotOption.CONFUSION_MATRIX}
|
||||
with pytest.raises(ValueError) as ex:
|
||||
with pytest.raises(ValueError, match=r"No class_names were provided while activating confusion matrix plotting."):
|
||||
_ = DeepMILPlotsHandler(plot_options, class_names=[])
|
||||
assert "No class_names were provided while activating confusion matrix plotting." in str(ex)
|
||||
|
||||
|
||||
def test_plots_handler_slide_thumbnails_without_slide_dataset() -> None:
|
||||
with pytest.raises(ValueError) as ex:
|
||||
@pytest.mark.parametrize(
|
||||
"slide_plot_options",
|
||||
[
|
||||
[PlotOption.SLIDE_THUMBNAIL],
|
||||
[PlotOption.ATTENTION_HEATMAP],
|
||||
[PlotOption.SLIDE_THUMBNAIL, PlotOption.ATTENTION_HEATMAP]
|
||||
],
|
||||
)
|
||||
def test_plots_handler_slide_plot_options_without_slide_dataset(slide_plot_options: List[PlotOption]) -> None:
|
||||
exception_prompt = f"Plot option {slide_plot_options[0]} requires a slides dataset"
|
||||
with pytest.raises(ValueError, match=rf"{exception_prompt}"):
|
||||
container = MockDeepSMILETilesPanda(tmp_path=Path("foo"))
|
||||
container.setup()
|
||||
container.data_module = MagicMock()
|
||||
container.data_module.train_dataset.n_classes = 6
|
||||
outputs_handler = container.get_outputs_handler()
|
||||
outputs_handler.test_plots_handler.plot_options = {PlotOption.SLIDE_THUMBNAIL_HEATMAP}
|
||||
outputs_handler.test_plots_handler.plot_options = slide_plot_options
|
||||
outputs_handler.set_slides_dataset_for_plots_handlers(container.get_slides_dataset())
|
||||
assert "You can not plot slide thumbnails and heatmaps without setting a slides_dataset." in str(ex)
|
||||
|
||||
|
||||
def assert_plot_func_called_if_among_plot_options(
|
||||
|
@ -48,42 +57,46 @@ def assert_plot_func_called_if_among_plot_options(
|
|||
"plot_options",
|
||||
[
|
||||
{},
|
||||
{PlotOption.HISTOGRAM},
|
||||
{PlotOption.HISTOGRAM, PlotOption.PR_CURVE},
|
||||
{PlotOption.HISTOGRAM, PlotOption.CONFUSION_MATRIX},
|
||||
{PlotOption.HISTOGRAM, PlotOption.TOP_BOTTOM_TILES, PlotOption.SLIDE_THUMBNAIL_HEATMAP},
|
||||
{PlotOption.HISTOGRAM, PlotOption.TOP_BOTTOM_TILES, PlotOption.ATTENTION_HEATMAP},
|
||||
{
|
||||
PlotOption.HISTOGRAM,
|
||||
PlotOption.PR_CURVE,
|
||||
PlotOption.CONFUSION_MATRIX,
|
||||
PlotOption.TOP_BOTTOM_TILES,
|
||||
PlotOption.SLIDE_THUMBNAIL_HEATMAP,
|
||||
PlotOption.SLIDE_THUMBNAIL,
|
||||
PlotOption.ATTENTION_HEATMAP,
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_plots_handler_plots_only_desired_plot_options(plot_options: Collection[PlotOption]) -> None:
|
||||
plots_handler = DeepMILPlotsHandler(plot_options, class_names=["foo"])
|
||||
plots_handler = DeepMILPlotsHandler(plot_options, class_names=["foo1", "foo2"])
|
||||
plots_handler.slides_dataset = MagicMock()
|
||||
|
||||
n_tiles = 4
|
||||
slide_node = SlideNode(slide_id="1", prob_score=0.5, true_label=1, pred_label=0)
|
||||
slide_node = SlideNode(slide_id="1", gt_prob_score=0.2, pred_prob_score=0.8, true_label=1, pred_label=0)
|
||||
tiles_selector = TilesSelector(n_classes=2, num_slides=4, num_tiles=2)
|
||||
tiles_selector.top_slides_heaps = {0: [slide_node] * n_tiles, 1: [slide_node] * n_tiles}
|
||||
tiles_selector.bottom_slides_heaps = {0: [slide_node] * n_tiles, 1: [slide_node] * n_tiles}
|
||||
|
||||
with patch("health_cpath.utils.plots_utils.save_slide_thumbnail_and_heatmap") as mock_slide:
|
||||
with patch("health_cpath.utils.plots_utils.save_top_and_bottom_tiles") as mock_tile:
|
||||
with patch("health_cpath.utils.plots_utils.save_scores_histogram") as mock_histogram:
|
||||
with patch("health_cpath.utils.plots_utils.save_confusion_matrix") as mock_conf:
|
||||
plots_handler.save_plots(
|
||||
outputs_dir=MagicMock(), tiles_selector=tiles_selector, results=MagicMock()
|
||||
)
|
||||
patchers: Dict[PlotOption, Any] = {
|
||||
PlotOption.SLIDE_THUMBNAIL: patch("health_cpath.utils.plots_utils.save_slide_thumbnail"),
|
||||
PlotOption.ATTENTION_HEATMAP: patch("health_cpath.utils.plots_utils.save_attention_heatmap"),
|
||||
PlotOption.TOP_BOTTOM_TILES: patch("health_cpath.utils.plots_utils.save_top_and_bottom_tiles"),
|
||||
PlotOption.CONFUSION_MATRIX: patch("health_cpath.utils.plots_utils.save_confusion_matrix"),
|
||||
PlotOption.HISTOGRAM: patch("health_cpath.utils.plots_utils.save_scores_histogram"),
|
||||
PlotOption.PR_CURVE: patch("health_cpath.utils.plots_utils.save_pr_curve"),
|
||||
}
|
||||
|
||||
mock_funcs = {option: patcher.start() for option, patcher in patchers.items()} # type: ignore
|
||||
with patch.object(plots_handler, "get_slide_dict"):
|
||||
plots_handler.save_plots(outputs_dir=MagicMock(), tiles_selector=tiles_selector, results=MagicMock())
|
||||
patch.stopall()
|
||||
|
||||
calls_count = 0
|
||||
calls_count += assert_plot_func_called_if_among_plot_options(
|
||||
mock_slide, PlotOption.SLIDE_THUMBNAIL_HEATMAP, plot_options
|
||||
)
|
||||
calls_count += assert_plot_func_called_if_among_plot_options(mock_tile, PlotOption.TOP_BOTTOM_TILES, plot_options)
|
||||
calls_count += assert_plot_func_called_if_among_plot_options(mock_histogram, PlotOption.HISTOGRAM, plot_options)
|
||||
calls_count += assert_plot_func_called_if_among_plot_options(mock_conf, PlotOption.CONFUSION_MATRIX, plot_options)
|
||||
for option, mock_func in mock_funcs.items():
|
||||
calls_count += assert_plot_func_called_if_among_plot_options(mock_func, option, plot_options)
|
||||
|
||||
assert calls_count == len(plot_options)
|
||||
|
||||
|
@ -97,8 +110,8 @@ def test_save_conf_matrix_integration(tmp_path: Path) -> None:
|
|||
}
|
||||
class_names = ["foo", "bar"]
|
||||
|
||||
save_confusion_matrix(results, class_names, tmp_path)
|
||||
file = Path(tmp_path) / "normalized_confusion_matrix.png"
|
||||
save_confusion_matrix(results, class_names, tmp_path, stage='foo')
|
||||
file = Path(tmp_path) / "normalized_confusion_matrix_foo.png"
|
||||
assert file.exists()
|
||||
|
||||
# check that an error is raised if true labels include indices greater than the expected number of classes
|
||||
|
@ -119,7 +132,7 @@ def test_save_conf_matrix_integration(tmp_path: Path) -> None:
|
|||
save_confusion_matrix(invalid_results_2, class_names, tmp_path)
|
||||
assert "More entries were found in predicted labels than are available in class names" in str(e)
|
||||
|
||||
# check that if confusion matrix still has correct shape even if results don't cover all expected labels
|
||||
# check that confusion matrix still has correct shape even if results don't cover all expected labels
|
||||
class_names_extended = ["foo", "bar", "baz"]
|
||||
num_classes = len(class_names_extended)
|
||||
expected_conf_matrix_shape = (num_classes, num_classes)
|
||||
|
@ -129,3 +142,23 @@ def test_save_conf_matrix_integration(tmp_path: Path) -> None:
|
|||
mock_plot_conf_matrix.assert_called_once()
|
||||
actual_conf_matrix = mock_plot_conf_matrix.call_args[1].get('cm')
|
||||
assert actual_conf_matrix.shape == expected_conf_matrix_shape
|
||||
|
||||
|
||||
def test_pr_curve_integration(tmp_path: Path) -> None:
|
||||
results = {
|
||||
ResultsKey.TRUE_LABEL: [0, 1, 0, 1, 0, 1],
|
||||
ResultsKey.PROB: [0.1, 0.8, 0.6, 0.3, 0.5, 0.4]
|
||||
}
|
||||
|
||||
# check plot is produced and it has right filename
|
||||
save_pr_curve(results, tmp_path, stage='foo') # type: ignore
|
||||
file = Path(tmp_path) / "pr_curve_foo.png"
|
||||
assert file.exists()
|
||||
os.remove(file)
|
||||
|
||||
# check warning is raised and plot is not produced if NOT a binary case
|
||||
results[ResultsKey.TRUE_LABEL] = [0, 1, 0, 2, 0, 1]
|
||||
with pytest.raises(Warning) as w:
|
||||
save_pr_curve(results, tmp_path, stage='foo') # type: ignore
|
||||
assert "The PR curve plot implementation works only for binary cases, this plot will be skipped." in str(w)
|
||||
assert not file.exists()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Sequence, Union
|
||||
from typing import Any, Dict, List, Sequence, Union
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
|
@ -11,10 +11,10 @@ import pytest
|
|||
from health_azure.utils import download_file_if_necessary
|
||||
from health_cpath.utils.output_utils import (AML_LEGACY_TEST_OUTPUTS_CSV, AML_OUTPUTS_DIR, AML_TEST_OUTPUTS_CSV,
|
||||
AML_VAL_OUTPUTS_CSV)
|
||||
from health_cpath.utils.report_utils import (collect_crossval_metrics, collect_crossval_outputs,
|
||||
crossval_runs_have_val_and_test_outputs, get_best_epoch_metrics,
|
||||
get_best_epochs, get_crossval_metrics_table,
|
||||
run_has_val_and_test_outputs)
|
||||
from health_cpath.utils.report_utils import (collect_hyperdrive_metrics, collect_hyperdrive_outputs,
|
||||
child_runs_have_val_and_test_outputs, get_best_epoch_metrics,
|
||||
get_best_epochs, get_hyperdrive_metrics_table,
|
||||
run_has_val_and_test_outputs, download_hyperdrive_metrics_if_required)
|
||||
|
||||
|
||||
def test_run_has_val_and_test_outputs() -> None:
|
||||
|
@ -41,7 +41,7 @@ def test_run_has_val_and_test_outputs() -> None:
|
|||
run_has_val_and_test_outputs(run)
|
||||
|
||||
|
||||
def test_crossval_runs_have_val_and_test_outputs() -> None:
|
||||
def test_hyperdrive_runs_have_val_and_test_outputs() -> None:
|
||||
legacy_run = MagicMock(display_name="legacy run", id="child1")
|
||||
legacy_run.get_file_names.return_value = [AML_LEGACY_TEST_OUTPUTS_CSV]
|
||||
|
||||
|
@ -55,22 +55,22 @@ def test_crossval_runs_have_val_and_test_outputs() -> None:
|
|||
parent_run = MagicMock()
|
||||
|
||||
with patch.object(parent_run, 'get_children', return_value=[legacy_run, legacy_run]):
|
||||
assert not crossval_runs_have_val_and_test_outputs(parent_run)
|
||||
assert not child_runs_have_val_and_test_outputs(parent_run)
|
||||
|
||||
with patch.object(parent_run, 'get_children', return_value=[run_with_val_and_test, run_with_val_and_test]):
|
||||
assert crossval_runs_have_val_and_test_outputs(parent_run)
|
||||
assert child_runs_have_val_and_test_outputs(parent_run)
|
||||
|
||||
with patch.object(parent_run, 'get_children', return_value=[legacy_run, invalid_run]):
|
||||
with pytest.raises(ValueError, match="does not have the expected files"):
|
||||
crossval_runs_have_val_and_test_outputs(parent_run)
|
||||
child_runs_have_val_and_test_outputs(parent_run)
|
||||
|
||||
with patch.object(parent_run, 'get_children', return_value=[run_with_val_and_test, invalid_run]):
|
||||
with pytest.raises(ValueError, match="does not have the expected files"):
|
||||
crossval_runs_have_val_and_test_outputs(parent_run)
|
||||
child_runs_have_val_and_test_outputs(parent_run)
|
||||
|
||||
with patch.object(parent_run, 'get_children', return_value=[legacy_run, run_with_val_and_test]):
|
||||
with pytest.raises(ValueError, match="has mixed children"):
|
||||
crossval_runs_have_val_and_test_outputs(parent_run)
|
||||
child_runs_have_val_and_test_outputs(parent_run)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('overwrite', [False, True])
|
||||
|
@ -102,9 +102,9 @@ def test_download_from_run_if_necessary(tmp_path: Path, overwrite: bool) -> None
|
|||
|
||||
|
||||
class MockChildRun:
|
||||
def __init__(self, run_id: str, cross_val_index: int):
|
||||
def __init__(self, run_id: str, child_run_index: int):
|
||||
self.run_id = run_id
|
||||
self.tags = {"hyperparameters": json.dumps({"child_run_index": cross_val_index})}
|
||||
self.tags = {"hyperparameters": json.dumps({"child_run_index": child_run_index})}
|
||||
|
||||
def get_metrics(self) -> Dict[str, Union[float, List[Union[int, float]]]]:
|
||||
num_epochs = 5
|
||||
|
@ -127,9 +127,9 @@ class MockHyperDriveRun:
|
|||
return [MockChildRun(f"run_abc_{i}456", i) for i in self.child_indices]
|
||||
|
||||
|
||||
def test_collect_crossval_outputs(tmp_path: Path) -> None:
|
||||
def test_collect_hyperdrive_outputs(tmp_path: Path) -> None:
|
||||
download_dir = tmp_path
|
||||
crossval_arg_name = "child_run_index"
|
||||
hyperdrive_arg_name = "child_run_index"
|
||||
output_filename = "output.csv"
|
||||
child_indices = [0, 3, 1] # Missing and unsorted children
|
||||
|
||||
|
@ -142,16 +142,16 @@ def test_collect_crossval_outputs(tmp_path: Path) -> None:
|
|||
|
||||
with patch('health_cpath.utils.report_utils.get_aml_run_from_run_id',
|
||||
return_value=MockHyperDriveRun(child_indices)):
|
||||
crossval_dfs = collect_crossval_outputs(parent_run_id="",
|
||||
download_dir=download_dir,
|
||||
aml_workspace=None,
|
||||
crossval_arg_name=crossval_arg_name,
|
||||
output_filename=output_filename)
|
||||
hyperdrive_dfs = collect_hyperdrive_outputs(parent_run_id="",
|
||||
download_dir=download_dir,
|
||||
aml_workspace=None,
|
||||
hyperdrive_arg_name=hyperdrive_arg_name,
|
||||
output_filename=output_filename)
|
||||
|
||||
assert set(crossval_dfs.keys()) == set(child_indices)
|
||||
assert list(crossval_dfs.keys()) == sorted(crossval_dfs.keys())
|
||||
assert set(hyperdrive_dfs.keys()) == set(child_indices)
|
||||
assert list(hyperdrive_dfs.keys()) == sorted(hyperdrive_dfs.keys())
|
||||
|
||||
for child_index, child_df in crossval_dfs.items():
|
||||
for child_index, child_df in hyperdrive_dfs.items():
|
||||
assert child_df.columns.tolist() == columns
|
||||
assert child_df.loc[0, 'split'] == child_index
|
||||
|
||||
|
@ -176,13 +176,24 @@ def metrics_df() -> pd.DataFrame:
|
|||
'val/auroc': [0.8, 0.9, 0.7],
|
||||
'test/accuracy': 0.9,
|
||||
'test/auroc': 0.9
|
||||
}
|
||||
},
|
||||
4: {'val/accuracy': None,
|
||||
'val/auroc': None,
|
||||
'test/accuracy': None,
|
||||
'test/auroc': None
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def best_epochs(metrics_df: pd.DataFrame) -> Dict[int, int]:
|
||||
return get_best_epochs(metrics_df, 'val/accuracy', maximise=True)
|
||||
def max_epochs_dict() -> Dict[int, int]:
|
||||
return {0: 3, 1: 10, 3: 3, 4: 3}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def best_epochs(metrics_df: pd.DataFrame, max_epochs_dict: Dict[int, int]) -> Dict[int, Any]:
|
||||
return get_best_epochs(metrics_df=metrics_df, primary_metric='val/accuracy',
|
||||
max_epochs_dict=max_epochs_dict, maximise=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -192,16 +203,18 @@ def best_epoch_metrics(metrics_df: pd.DataFrame, best_epochs: Dict[int, int]) ->
|
|||
|
||||
|
||||
@pytest.mark.parametrize('overwrite', [False, True])
|
||||
def test_collect_crossval_metrics(metrics_df: pd.DataFrame, tmp_path: Path, overwrite: bool) -> None:
|
||||
def test_collect_hyperdrive_metrics(metrics_df: pd.DataFrame, tmp_path: Path, overwrite: bool) -> None:
|
||||
with patch('health_cpath.utils.report_utils.aggregate_hyperdrive_metrics',
|
||||
return_value=metrics_df) as mock_aggregate:
|
||||
returned_df = collect_crossval_metrics(parent_run_id="", download_dir=tmp_path,
|
||||
aml_workspace=None, overwrite=overwrite)
|
||||
returned_json = download_hyperdrive_metrics_if_required(parent_run_id="", download_dir=tmp_path,
|
||||
aml_workspace=None, overwrite=overwrite)
|
||||
returned_df = collect_hyperdrive_metrics(metrics_json=returned_json)
|
||||
mock_aggregate.assert_called_once()
|
||||
mock_aggregate.reset_mock()
|
||||
|
||||
new_returned_df = collect_crossval_metrics(parent_run_id="", download_dir=tmp_path,
|
||||
aml_workspace=None, overwrite=overwrite)
|
||||
new_returned_json = download_hyperdrive_metrics_if_required(parent_run_id="", download_dir=tmp_path,
|
||||
aml_workspace=None, overwrite=overwrite)
|
||||
new_returned_df = collect_hyperdrive_metrics(metrics_json=new_returned_json)
|
||||
if overwrite:
|
||||
mock_aggregate.assert_called_once()
|
||||
else:
|
||||
|
@ -211,12 +224,13 @@ def test_collect_crossval_metrics(metrics_df: pd.DataFrame, tmp_path: Path, over
|
|||
|
||||
|
||||
@pytest.mark.parametrize('maximise', [True, False])
|
||||
def test_get_best_epochs(metrics_df: pd.DataFrame, maximise: bool) -> None:
|
||||
best_epochs = get_best_epochs(metrics_df, 'val/accuracy', maximise=maximise)
|
||||
def test_get_best_epochs(metrics_df: pd.DataFrame, max_epochs_dict: Dict[int, int], maximise: bool) -> None:
|
||||
best_epochs = get_best_epochs(metrics_df=metrics_df, primary_metric='val/accuracy',
|
||||
max_epochs_dict=max_epochs_dict, maximise=maximise)
|
||||
assert list(best_epochs.keys()) == list(metrics_df.columns)
|
||||
assert all(isinstance(epoch, int) for epoch in best_epochs.values())
|
||||
assert all(isinstance(epoch, (int, type(None))) for epoch in best_epochs.values())
|
||||
|
||||
expected_best = {0: 0, 1: 1, 3: 2} if maximise else {0: 1, 1: 2, 3: 0}
|
||||
expected_best = {0: 0, 1: 1, 3: 2, 4: None} if maximise else {0: 1, 1: 2, 3: 0, 4: None}
|
||||
for split in metrics_df.columns:
|
||||
assert best_epochs[split] == expected_best[split]
|
||||
|
||||
|
@ -233,13 +247,15 @@ def test_get_best_epoch_metrics(metrics_df: pd.DataFrame, best_epochs: Dict[int,
|
|||
|
||||
@pytest.mark.parametrize('fixture_name, metrics_list', [('metrics_df', ['test/accuracy', 'test/auroc']),
|
||||
('best_epoch_metrics', ['val/accuracy', 'val/auroc'])])
|
||||
def test_get_crossval_metrics_table(fixture_name: str, metrics_list: List[str], request: pytest.FixtureRequest) -> None:
|
||||
def test_get_hyperdrive_metrics_table(
|
||||
fixture_name: str, metrics_list: List[str], request: pytest.FixtureRequest
|
||||
) -> None:
|
||||
df = request.getfixturevalue(fixture_name)
|
||||
|
||||
metrics_table = get_crossval_metrics_table(df, metrics_list)
|
||||
metrics_table = get_hyperdrive_metrics_table(df, metrics_list)
|
||||
assert list(metrics_table.index) == metrics_list
|
||||
assert len(metrics_table.columns) == len(df.columns) + 1
|
||||
|
||||
original_values = df.loc[metrics_list].values
|
||||
table_values = metrics_table.iloc[:, :-1].applymap(float).values
|
||||
assert (table_values == original_values).all()
|
||||
assert (table_values[~pd.isnull(table_values)] == original_values[~pd.isnull(original_values)]).all()
|
||||
|
|
|
@ -8,7 +8,7 @@ import pytest
|
|||
import numpy as np
|
||||
|
||||
from unittest.mock import patch
|
||||
from typing import Dict, List, Any
|
||||
from typing import Dict, List, Any, Set
|
||||
from testhisto.utils.utils_testhisto import run_distributed
|
||||
from health_cpath.utils.naming import ResultsKey, SlideKey
|
||||
from health_cpath.utils.plots_utils import TilesSelector, SlideNode
|
||||
|
@ -42,12 +42,14 @@ def _create_mock_results(n_samples: int, n_tiles: int = 3, n_classes: int = 2, d
|
|||
:return: A dictioanry containing randomly generated mock results.
|
||||
"""
|
||||
diff_n_tiles = [n_tiles + i for i in range(n_samples)]
|
||||
probs = torch.rand((n_samples, n_classes), device=device)
|
||||
probs = probs / probs.sum(dim=1, keepdim=True)
|
||||
mock_results = {
|
||||
ResultsKey.SLIDE_ID: np.array([f"slide_{i}" for i in range(n_samples)]),
|
||||
ResultsKey.TRUE_LABEL: torch.randint(2, size=(n_samples,), device=device),
|
||||
ResultsKey.PRED_LABEL: torch.randint(2, size=(n_samples,), device=device),
|
||||
ResultsKey.TRUE_LABEL: torch.randint(n_classes, size=(n_samples,), device=device),
|
||||
ResultsKey.PRED_LABEL: torch.argmax(probs, dim=1),
|
||||
ResultsKey.BAG_ATTN: [torch.rand(size=(1, diff_n_tiles[i]), device=device) for i in range(n_samples)],
|
||||
ResultsKey.CLASS_PROBS: torch.rand((n_samples, n_classes), device=device),
|
||||
ResultsKey.CLASS_PROBS: probs,
|
||||
}
|
||||
return mock_results
|
||||
|
||||
|
@ -96,14 +98,14 @@ def _create_and_update_top_bottom_tiles_selector(
|
|||
|
||||
def _get_expected_slides_by_probability(
|
||||
results: Dict[ResultsKey, Any], num_top_slides: int = 2, label: int = 1, top: bool = True
|
||||
) -> List[str]:
|
||||
) -> Set[str]:
|
||||
"""Select top or bottom slides according to their probability scores from the entire dataset.
|
||||
|
||||
:param results: The results dictionary for the entire dataset.
|
||||
:param num_top_slides: The number of slides to use to select top and bottom slides, defaults to 5
|
||||
:param label: The current label to process given that top and bottom are grouped by class label, defaults to 1
|
||||
:param top: A flag to select top or bottom slides with highest (respetively, lowest) prob scores, defaults to True
|
||||
:return: A list of selected slide ids.
|
||||
:return: A set of selected slide ids.
|
||||
"""
|
||||
|
||||
class_indices = (results[ResultsKey.TRUE_LABEL].squeeze() == label).nonzero().squeeze(1)
|
||||
|
@ -111,37 +113,29 @@ def _get_expected_slides_by_probability(
|
|||
assert class_prob.shape == (len(class_indices),)
|
||||
num_top_slides = min(num_top_slides, len(class_prob))
|
||||
|
||||
_, sorting_indices = class_prob.topk(num_top_slides, largest=top, sorted=True)
|
||||
_, sorting_indices = class_prob.topk(num_top_slides, largest=top)
|
||||
sorted_class_indices = class_indices[sorting_indices]
|
||||
|
||||
return [results[ResultsKey.SLIDE_ID][i] for i in sorted_class_indices][::-1] # the order is inversed in the heaps
|
||||
|
||||
|
||||
def get_expected_top_slides_by_probability(
|
||||
results: Dict[ResultsKey, Any], num_top_slides: int = 5, label: int = 1
|
||||
) -> List[str]:
|
||||
"""Calls `_get_expected_slides_by_probability` with `top=True` to select expected top slides for the entire dataset
|
||||
in one go. """
|
||||
return _get_expected_slides_by_probability(results, num_top_slides, label, top=True)
|
||||
|
||||
|
||||
def get_expected_bottom_slides_by_probability(
|
||||
results: Dict[ResultsKey, Any], num_top_slides: int = 5, label: int = 1
|
||||
) -> List[str]:
|
||||
"""Calls `_get_expected_slides_by_probability` with `top=False` to select expected bottom slides for the entire
|
||||
dataset in one go. """
|
||||
return _get_expected_slides_by_probability(results, num_top_slides, label, top=False)
|
||||
def _selection_condition(index: int) -> bool:
|
||||
if top:
|
||||
return results[ResultsKey.PRED_LABEL][index] == results[ResultsKey.TRUE_LABEL][index]
|
||||
else:
|
||||
return results[ResultsKey.PRED_LABEL][index] != results[ResultsKey.TRUE_LABEL][index]
|
||||
|
||||
return {results[ResultsKey.SLIDE_ID][i] for i in sorted_class_indices if _selection_condition(i)}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_top_slides", [2, 10])
|
||||
@pytest.mark.parametrize("n_classes", [2, 3]) # n_classes=2 represents the binary case.
|
||||
def test_aggregate_shallow_slide_nodes(n_classes: int, rank: int = 0, world_size: int = 1, device: str = "cpu") -> None:
|
||||
def test_aggregate_shallow_slide_nodes(
|
||||
n_classes: int, num_top_slides: int, rank: int = 0, world_size: int = 1, device: str = "cpu"
|
||||
) -> None:
|
||||
"""This test ensures that shallow copies of slide nodes are gathered properlyy across devices in a ddp context."""
|
||||
n_tiles = 3
|
||||
batch_size = 2
|
||||
n_batches = 10
|
||||
total_batches = n_batches * world_size
|
||||
num_top_tiles = 2
|
||||
num_top_slides = 2
|
||||
|
||||
torch.manual_seed(42)
|
||||
data = _create_mock_data(n_samples=batch_size * total_batches, n_tiles=n_tiles, device=device)
|
||||
|
@ -167,13 +161,21 @@ def test_aggregate_shallow_slide_nodes(n_classes: int, rank: int = 0, world_size
|
|||
|
||||
if rank == 0:
|
||||
for label in range(n_classes):
|
||||
expected_top_slides_ids = _get_expected_slides_by_probability(results, num_top_slides, label, top=True)
|
||||
assert expected_top_slides_ids == [slide_node.slide_id for slide_node in shallow_top_slides_heaps[label]]
|
||||
|
||||
assert all(slide_node.pred_label == slide_node.true_label for slide_node in shallow_top_slides_heaps[label])
|
||||
selected_top_slides_ids = {slide_node.slide_id for slide_node in shallow_top_slides_heaps[label]}
|
||||
expected_top_slides_ids = _get_expected_slides_by_probability(results, num_top_slides, label, top=True)
|
||||
assert expected_top_slides_ids == selected_top_slides_ids
|
||||
|
||||
assert all(
|
||||
slide_node.pred_label != slide_node.true_label for slide_node in shallow_bottom_slides_heaps[label]
|
||||
)
|
||||
selected_bottom_slides_ids = {slide_node.slide_id for slide_node in shallow_bottom_slides_heaps[label]}
|
||||
expected_bottom_slides_ids = _get_expected_slides_by_probability(results, num_top_slides, label, top=False)
|
||||
assert expected_bottom_slides_ids == [
|
||||
slide_node.slide_id for slide_node in shallow_bottom_slides_heaps[label]
|
||||
]
|
||||
assert expected_bottom_slides_ids == selected_bottom_slides_ids
|
||||
|
||||
# Make sure that the top and bottom slides are disjoint.
|
||||
assert not selected_top_slides_ids.intersection(selected_bottom_slides_ids)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.distributed.is_available(), reason="PyTorch distributed unavailable")
|
||||
|
@ -181,27 +183,30 @@ def test_aggregate_shallow_slide_nodes(n_classes: int, rank: int = 0, world_size
|
|||
@pytest.mark.gpu
|
||||
def test_aggregate_shallow_slide_nodes_distributed() -> None:
|
||||
"""These tests need to be called sequentially to prevent them to be run in parallel."""
|
||||
# test with n_classes = 2
|
||||
run_distributed(test_aggregate_shallow_slide_nodes, [2], world_size=1)
|
||||
run_distributed(test_aggregate_shallow_slide_nodes, [2], world_size=2)
|
||||
# test with n_classes = 3
|
||||
run_distributed(test_aggregate_shallow_slide_nodes, [3], world_size=1)
|
||||
run_distributed(test_aggregate_shallow_slide_nodes, [3], world_size=2)
|
||||
# test with n_classes = 2, n_slides = 2
|
||||
run_distributed(test_aggregate_shallow_slide_nodes, [2, 2], world_size=1)
|
||||
run_distributed(test_aggregate_shallow_slide_nodes, [2, 2], world_size=2)
|
||||
# test with n_classes = 3, n_slides = 2
|
||||
run_distributed(test_aggregate_shallow_slide_nodes, [3, 2], world_size=1)
|
||||
run_distributed(test_aggregate_shallow_slide_nodes, [3, 2], world_size=2)
|
||||
|
||||
|
||||
def assert_equal_top_bottom_attention_tiles(
|
||||
slide_ids: List[str], data: Dict, results: Dict, num_top_tiles: int, slide_nodes: List[SlideNode]
|
||||
slide_ids: Set[str], data: Dict, results: Dict, num_top_tiles: int, slide_nodes: List[SlideNode]
|
||||
) -> None:
|
||||
"""Asserts that top and bottom tiles selected on the fly by the top bottom tiles selector are equal to the expected
|
||||
top and bottom tiles in the mock dataset.
|
||||
|
||||
:param slide_ids: A list of expected slide ids0
|
||||
:param slide_ids: A set of expected slide ids0
|
||||
:param data: A dictionary containing the entire dataset.
|
||||
:param results: A dictionary of data results.
|
||||
:param num_top_tiles: The number of tiles to select as top and bottom tiles for each top/bottom slide.
|
||||
:param slide_nodes: The top or bottom slide nodes selected on the fly by the selector.
|
||||
"""
|
||||
for i, slide_id in enumerate(slide_ids):
|
||||
|
||||
slide_nodes_dict = {slide_node.slide_id: slide_node for slide_node in slide_nodes}
|
||||
|
||||
for slide_id in slide_ids:
|
||||
slide_batch_idx = int(slide_id.split("_")[1])
|
||||
tiles = data[SlideKey.IMAGE][slide_batch_idx]
|
||||
|
||||
|
@ -215,8 +220,8 @@ def assert_equal_top_bottom_attention_tiles(
|
|||
expected_top_tiles: List[torch.Tensor] = [tiles[tile_id] for tile_id in top_tiles_ids]
|
||||
expected_bottom_tiles: List[torch.Tensor] = [tiles[tile_id] for tile_id in bottom_tiles_ids]
|
||||
|
||||
top_tiles = slide_nodes[i].top_tiles
|
||||
bottom_tiles = slide_nodes[i].bottom_tiles
|
||||
top_tiles = slide_nodes_dict[slide_id].top_tiles
|
||||
bottom_tiles = slide_nodes_dict[slide_id].bottom_tiles
|
||||
|
||||
for j, expected_top_tile in enumerate(expected_top_tiles):
|
||||
assert torch.equal(expected_top_tile.cpu(), top_tiles[j].data)
|
||||
|
@ -227,9 +232,10 @@ def assert_equal_top_bottom_attention_tiles(
|
|||
assert expected_bottom_attns[j].item() == bottom_tiles[j].attn
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_top_slides", [2, 10])
|
||||
@pytest.mark.parametrize("n_classes", [2, 3]) # n_classes=2 represents the binary case.
|
||||
def test_select_k_top_bottom_tiles_on_the_fly(
|
||||
n_classes: int, rank: int = 0, world_size: int = 1, device: str = "cpu"
|
||||
n_classes: int, num_top_slides: int, rank: int = 0, world_size: int = 1, device: str = "cpu"
|
||||
) -> None:
|
||||
"""This tests checks that k top and bottom tiles are selected properly `on the fly`:
|
||||
1- Create a mock dataset and corresponding mock results that are small enough to fit in memory
|
||||
|
@ -264,34 +270,46 @@ def test_select_k_top_bottom_tiles_on_the_fly(
|
|||
|
||||
if rank == 0:
|
||||
for label in range(n_classes):
|
||||
expected_top_slides_ids = get_expected_top_slides_by_probability(results, num_top_slides, label)
|
||||
assert expected_top_slides_ids == [
|
||||
slide_node.slide_id for slide_node in tiles_selector.top_slides_heaps[label]
|
||||
]
|
||||
|
||||
assert all(
|
||||
slide_node.pred_label == slide_node.true_label for slide_node in tiles_selector.top_slides_heaps[label]
|
||||
)
|
||||
selected_top_slides_ids = {slide_node.slide_id for slide_node in tiles_selector.top_slides_heaps[label]}
|
||||
expected_top_slides_ids = _get_expected_slides_by_probability(results, num_top_slides, label, top=True)
|
||||
assert expected_top_slides_ids == selected_top_slides_ids
|
||||
assert_equal_top_bottom_attention_tiles(
|
||||
expected_top_slides_ids, data, results, num_top_tiles, tiles_selector.top_slides_heaps[label]
|
||||
)
|
||||
|
||||
expected_bottom_slides_ids = get_expected_bottom_slides_by_probability(results, num_top_slides, label)
|
||||
assert expected_bottom_slides_ids == [
|
||||
assert all(
|
||||
slide_node.pred_label != slide_node.true_label
|
||||
for slide_node in tiles_selector.bottom_slides_heaps[label]
|
||||
)
|
||||
selected_bottom_slides_ids = {
|
||||
slide_node.slide_id for slide_node in tiles_selector.bottom_slides_heaps[label]
|
||||
]
|
||||
}
|
||||
expected_bottom_slides_ids = _get_expected_slides_by_probability(results, num_top_slides, label, top=False)
|
||||
assert expected_bottom_slides_ids == selected_bottom_slides_ids
|
||||
|
||||
assert_equal_top_bottom_attention_tiles(
|
||||
expected_bottom_slides_ids, data, results, num_top_tiles, tiles_selector.bottom_slides_heaps[label]
|
||||
)
|
||||
|
||||
# Make sure that the top and bottom slides are disjoint.
|
||||
assert not selected_top_slides_ids.intersection(selected_bottom_slides_ids)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.distributed.is_available(), reason="PyTorch distributed unavailable")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
|
||||
@pytest.mark.gpu
|
||||
def test_select_k_top_bottom_tiles_on_the_fly_distributed() -> None:
|
||||
"""These tests need to be called sequentially to prevent them to be run in parallel"""
|
||||
# test with n_classes = 2
|
||||
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [2], world_size=1)
|
||||
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [2], world_size=2)
|
||||
# test with n_classes = 3
|
||||
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [3], world_size=1)
|
||||
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [3], world_size=2)
|
||||
# test with n_classes = 2, n_slides = 2
|
||||
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [2, 2], world_size=1)
|
||||
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [2, 2], world_size=2)
|
||||
# test with n_classes = 3, n_slides = 2
|
||||
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [3, 2], world_size=1)
|
||||
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [3, 2], world_size=2)
|
||||
|
||||
|
||||
def test_disable_top_bottom_tiles_selector() -> None:
|
||||
|
|
|
@ -8,6 +8,7 @@ import math
|
|||
import random
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
|
@ -24,9 +25,8 @@ from health_cpath.utils.viz_utils import plot_attention_tiles, plot_scores_hist,
|
|||
from health_cpath.utils.naming import ResultsKey
|
||||
from health_cpath.utils.heatmap_utils import location_selected_tiles
|
||||
from health_cpath.utils.tiles_selection_utils import SlideNode, TileNode
|
||||
from health_cpath.utils.viz_utils import save_figure
|
||||
from health_cpath.utils.viz_utils import save_figure, load_image_dict
|
||||
from testhisto.utils.utils_testhisto import assert_binary_files_match, full_ml_test_data_path
|
||||
# import testhisto
|
||||
|
||||
|
||||
def set_random_seed(random_seed: int, caller_name: Optional[str] = None) -> None:
|
||||
|
@ -124,7 +124,7 @@ def slide_node() -> SlideNode:
|
|||
set_random_seed(0)
|
||||
tile_size = (3, 224, 224)
|
||||
num_top_tiles = 12
|
||||
slide_node = SlideNode(slide_id="slide_0", prob_score=0.5, true_label=1, pred_label=1)
|
||||
slide_node = SlideNode(slide_id="slide_0", gt_prob_score=0.04, pred_prob_score=0.96, true_label=1, pred_label=0)
|
||||
top_attn_scores = [0.99, 0.98, 0.97, 0.96, 0.95, 0.94, 0.93, 0.92, 0.91, 0.90, 0.89, 0.88]
|
||||
slide_node.top_tiles = [
|
||||
TileNode(attn=top_attn_scores[i], data=torch.randint(0, 255, tile_size)) for i in range(num_top_tiles)
|
||||
|
@ -150,11 +150,11 @@ def assert_plot_tiles_figure(tiles_fig: plt.Figure, fig_name: str, test_output_d
|
|||
@pytest.mark.skipif(is_windows(), reason="Rendering is different on Windows")
|
||||
def test_plot_top_bottom_tiles(slide_node: SlideNode, test_output_dirs: OutputFolderForTests) -> None:
|
||||
top_tiles_fig = plot_attention_tiles(
|
||||
case="TP", slide_node=slide_node, top=True, num_columns=4, figsize=(10, 10)
|
||||
case="FN", slide_node=slide_node, top=True, num_columns=4, figsize=(10, 10)
|
||||
)
|
||||
assert top_tiles_fig is not None
|
||||
bottom_tiles_fig = plot_attention_tiles(
|
||||
case="TP", slide_node=slide_node, top=False, num_columns=4, figsize=(10, 10)
|
||||
case="FN", slide_node=slide_node, top=False, num_columns=4, figsize=(10, 10)
|
||||
)
|
||||
assert bottom_tiles_fig is not None
|
||||
assert_plot_tiles_figure(top_tiles_fig, "slide_0_top.png", test_output_dirs)
|
||||
|
@ -167,7 +167,7 @@ def test_plot_attention_tiles_below_min_rows(slide_node: SlideNode, caplog: LogC
|
|||
slide_node.bottom_tiles = []
|
||||
with caplog.at_level(logging.WARNING):
|
||||
bottom_tiles_fig = plot_attention_tiles(
|
||||
case="TP", slide_node=slide_node, top=False, num_columns=4, figsize=(10, 10)
|
||||
case="FN", slide_node=slide_node, top=False, num_columns=4, figsize=(10, 10)
|
||||
)
|
||||
assert bottom_tiles_fig is None
|
||||
assert expected_warning in caplog.text
|
||||
|
@ -175,25 +175,43 @@ def test_plot_attention_tiles_below_min_rows(slide_node: SlideNode, caplog: LogC
|
|||
slide_node.top_tiles = []
|
||||
with caplog.at_level(logging.WARNING):
|
||||
top_tiles_fig = plot_attention_tiles(
|
||||
case="TP", slide_node=slide_node, top=True, num_columns=4, figsize=(10, 10)
|
||||
case="FN", slide_node=slide_node, top=True, num_columns=4, figsize=(10, 10)
|
||||
)
|
||||
assert top_tiles_fig is None
|
||||
assert expected_warning in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("scale", [0.1, 1.2, 2.4, 3.6])
|
||||
def test_plot_slide(test_output_dirs: OutputFolderForTests, scale: int) -> None:
|
||||
@pytest.mark.parametrize(
|
||||
"scale, gt_prob, pred_prob, gt_label, pred_label, case",
|
||||
[
|
||||
(0.1, 0.99, 0.99, 1, 1, "TP"),
|
||||
(1.2, 0.95, 0.95, 0, 0, "TN"),
|
||||
(2.4, 0.04, 0.96, 0, 1, "FP"),
|
||||
(3.6, 0.03, 0.97, 1, 0, "FN"),
|
||||
],
|
||||
)
|
||||
def test_plot_slide(
|
||||
test_output_dirs: OutputFolderForTests,
|
||||
scale: int,
|
||||
gt_prob: float,
|
||||
pred_prob: float,
|
||||
case: str,
|
||||
gt_label: int,
|
||||
pred_label: int,
|
||||
) -> None:
|
||||
set_random_seed(0)
|
||||
slide_image = np.random.rand(3, 1000, 2000)
|
||||
slide_node = SlideNode(slide_id="slide_0", prob_score=0.5, true_label=1, pred_label=1)
|
||||
fig = plot_slide(case="TP", slide_node=slide_node, slide_image=slide_image, scale=scale)
|
||||
slide_node = SlideNode(
|
||||
slide_id="slide_0", gt_prob_score=gt_prob, pred_prob_score=pred_prob, true_label=gt_label, pred_label=pred_label
|
||||
)
|
||||
fig = plot_slide(case=case, slide_node=slide_node, slide_image=slide_image, scale=scale)
|
||||
assert isinstance(fig, matplotlib.figure.Figure)
|
||||
file = Path(test_output_dirs.root_dir) / "plot_slide.png"
|
||||
resize_and_save(5, 5, file)
|
||||
assert file.exists()
|
||||
expected = full_ml_test_data_path("histo_heatmaps") / f"slide_{scale}.png"
|
||||
expected = full_ml_test_data_path("histo_heatmaps") / f"slide_{scale}_{case}.png"
|
||||
# To update the stored results, uncomment this line:
|
||||
# expected.write_bytes(file.read_bytes())
|
||||
expected.write_bytes(file.read_bytes())
|
||||
assert_binary_files_match(file, expected)
|
||||
|
||||
|
||||
|
@ -201,11 +219,13 @@ def test_plot_slide(test_output_dirs: OutputFolderForTests, scale: int) -> None:
|
|||
def test_plot_heatmap_overlay(test_output_dirs: OutputFolderForTests) -> None:
|
||||
set_random_seed(0)
|
||||
slide_image = np.random.rand(3, 1000, 2000)
|
||||
slide_node = SlideNode(slide_id=1, prob_score=0.5, true_label=1, pred_label=1) # type: ignore
|
||||
slide_node = SlideNode(
|
||||
slide_id=1, gt_prob_score=0.04, pred_prob_score=0.96, true_label=1, pred_label=0 # type: ignore
|
||||
)
|
||||
location_bbox = [100, 100]
|
||||
tile_size = 224
|
||||
level = 0
|
||||
fig = plot_heatmap_overlay(case="TP",
|
||||
fig = plot_heatmap_overlay(case="FN",
|
||||
slide_node=slide_node,
|
||||
slide_image=slide_image,
|
||||
results=test_dict, # type: ignore
|
||||
|
@ -275,3 +295,12 @@ def test_location_selected_tiles(level: int) -> None:
|
|||
assert max(tile_xs) <= slide_image.shape[2] // factor
|
||||
assert min(tile_ys) >= 0
|
||||
assert max(tile_ys) <= slide_image.shape[1] // factor
|
||||
|
||||
|
||||
@pytest.mark.parametrize("wsi_has_mask", [True, False])
|
||||
def test_load_image_dict(wsi_has_mask: bool) -> None:
|
||||
with patch("health_cpath.utils.viz_utils.LoadPandaROId") as mock_load_panda_roi:
|
||||
with patch("health_cpath.utils.viz_utils.LoadROId") as mock_load_roi:
|
||||
_ = load_image_dict(sample=MagicMock(), level=0, margin=0, wsi_has_mask=wsi_has_mask) # type: ignore
|
||||
assert mock_load_panda_roi.called == wsi_has_mask
|
||||
assert mock_load_roi.called == (not wsi_has_mask)
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Sequence
|
||||
from health_cpath.utils.naming import SlideKey
|
||||
from health_cpath.utils.wsi_utils import image_collate
|
||||
|
@ -15,7 +15,8 @@ class MockTiledWSIDataset(Dataset):
|
|||
n_slides: int,
|
||||
n_classes: int,
|
||||
tile_size: Sequence[int],
|
||||
random_n_tiles: bool) -> None:
|
||||
random_n_tiles: bool,
|
||||
img_type: str = "np") -> None:
|
||||
|
||||
self.n_tiles = n_tiles
|
||||
self.n_slides = n_slides
|
||||
|
@ -23,30 +24,38 @@ class MockTiledWSIDataset(Dataset):
|
|||
self.n_classes = n_classes
|
||||
self.random_n_tiles = random_n_tiles
|
||||
self.slide_ids = torch.arange(self.n_slides)
|
||||
self.img_type = img_type
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.n_slides
|
||||
|
||||
def __getitem__(self, index: int) -> List[Dict[SlideKey, Any]]:
|
||||
tile_count = np.random.randint(self.n_tiles) if self.random_n_tiles else self.n_tiles
|
||||
tile_count = np.random.randint(low=1, high=self.n_tiles) if self.random_n_tiles else self.n_tiles
|
||||
label = np.random.choice(self.n_classes)
|
||||
img: Union[np.ndarray, torch.Tensor]
|
||||
if self.img_type == "np":
|
||||
img = np.random.randint(0, 255, size=(tile_count, *self.tile_size))
|
||||
else:
|
||||
img = torch.randint(0, 255, size=(tile_count, *self.tile_size))
|
||||
return [{SlideKey.SLIDE_ID: self.slide_ids[index],
|
||||
SlideKey.IMAGE: np.random.randint(0, 255, size=self.tile_size),
|
||||
SlideKey.IMAGE: img,
|
||||
SlideKey.IMAGE_PATH: f"slide_{self.slide_ids[index]}.tiff",
|
||||
SlideKey.LABEL: label
|
||||
} for _ in range(tile_count)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("img_type", ["np", "torch"])
|
||||
@pytest.mark.parametrize("random_n_tiles", [False, True])
|
||||
def test_image_collate(random_n_tiles: bool) -> None:
|
||||
def test_image_collate(random_n_tiles: bool, img_type: str) -> None:
|
||||
# random_n_tiles accounts for both train and inference settings where the number of tiles is fixed (during
|
||||
# training) and None during inference (validation and test)
|
||||
dataset = MockTiledWSIDataset(n_tiles=20,
|
||||
n_slides=10,
|
||||
n_classes=4,
|
||||
tile_size=(1, 4, 4),
|
||||
random_n_tiles=random_n_tiles)
|
||||
random_n_tiles=random_n_tiles,
|
||||
img_type=img_type)
|
||||
|
||||
batch_size = 5
|
||||
samples_list = [dataset[idx] for idx in range(batch_size)]
|
||||
|
|
|
@ -122,7 +122,7 @@ class HelloRegression(LightningModule):
|
|||
self.model = torch.nn.Linear(in_features=1, out_features=1, bias=True) # type: ignore
|
||||
self.test_mse: List[torch.Tensor] = []
|
||||
self.test_mae = MeanAbsoluteError()
|
||||
self.run_extra_val_epoch = False
|
||||
self._on_extra_val_epoch = False
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
|
||||
"""
|
||||
|
@ -232,6 +232,9 @@ class HelloRegression(LightningModule):
|
|||
Path("test_mae.txt").write_text(str(self.test_mae.compute().item()))
|
||||
self.log("test_mse", average_mse, on_epoch=True, on_step=False)
|
||||
|
||||
def on_run_extra_validation_epoch(self) -> None:
|
||||
self._on_extra_val_epoch = True
|
||||
|
||||
|
||||
class HelloWorld(LightningContainer):
|
||||
"""
|
||||
|
@ -272,3 +275,6 @@ class HelloWorld(LightningContainer):
|
|||
*super().get_callbacks()]
|
||||
else:
|
||||
return super().get_callbacks()
|
||||
|
||||
def get_additional_aml_run_tags(self) -> Dict[str, str]:
|
||||
return {"max_epochs": str(self.max_epochs)}
|
||||
|
|
|
@ -7,12 +7,10 @@ from __future__ import annotations
|
|||
import logging
|
||||
import os
|
||||
import param
|
||||
import re
|
||||
from enum import Enum, unique
|
||||
from param import Parameterized
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from azureml.train.hyperdrive import HyperDriveConfig
|
||||
|
||||
|
@ -24,6 +22,7 @@ from health_azure.amulet import (ENV_AMLT_PROJECT_NAME, ENV_AMLT_INPUT_OUTPUT,
|
|||
is_amulet_job, get_amulet_aml_working_dir)
|
||||
from health_azure.utils import (RUN_CONTEXT, PathOrString, is_global_rank_zero, is_running_in_azure_ml)
|
||||
from health_ml.utils import fixed_paths
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
from health_ml.utils.common_utils import (CHECKPOINT_FOLDER,
|
||||
create_unique_timestamp_id,
|
||||
DEFAULT_AML_UPLOAD_DIR,
|
||||
|
@ -144,32 +143,13 @@ class ExperimentFolderHandler(Parameterized):
|
|||
)
|
||||
|
||||
|
||||
SRC_CHECKPOINT_FORMAT_DOC = ("<AzureML_run_id>:<optional/custom/path/to/checkpoints/><filename.ckpt>"
|
||||
"If no custom path is provided (e.g., <AzureML_run_id>:<filename.ckpt>)"
|
||||
"the checkpoint will be downloaded from the default checkpoint folder "
|
||||
"(e.g., 'outputs/checkpoints/'). If no filename is provided, (e.g., "
|
||||
"`src_checkpoint=<AzureML_run_id>`) the latest checkpoint (last.ckpt) "
|
||||
"will be used to initialize the model."
|
||||
)
|
||||
|
||||
|
||||
class WorkflowParams(param.Parameterized):
|
||||
"""
|
||||
This class contains all parameters that affect how the whole training and testing workflow is executed.
|
||||
"""
|
||||
random_seed: int = param.Integer(42, doc="The seed to use for all random number generators.")
|
||||
src_checkpoint: str = param.String(default="",
|
||||
doc="This flag can be used in 3 different scenarios:"
|
||||
"1- Resume training from a checkpoint to train longer."
|
||||
"2- Run inference-only using `run_inference_only` flag jointly."
|
||||
"3- Transfer learning from a pretrained model checkpoint."
|
||||
"We currently support three types of checkpoints: "
|
||||
" a. A local checkpoint folder that contains a checkpoint file."
|
||||
" b. A URL to a remote checkpoint to be downloaded."
|
||||
" c. A previous azureml run id where the checkpoint is supposed to be "
|
||||
" saved ('outputs/checkpoints/' folder by default.)"
|
||||
"For the latter case 'c' : src_checkpoint should be in the format of "
|
||||
f"{SRC_CHECKPOINT_FORMAT_DOC}")
|
||||
src_checkpoint: CheckpointParser = param.ClassSelector(class_=CheckpointParser, default=None,
|
||||
instantiate=False, doc=CheckpointParser.DOC)
|
||||
crossval_count: int = param.Integer(default=1, bounds=(0, None),
|
||||
doc="The number of splits to use when doing cross-validation. "
|
||||
"Use 1 to disable cross-validation")
|
||||
|
@ -196,6 +176,7 @@ class WorkflowParams(param.Parameterized):
|
|||
run_inference_only: bool = param.Boolean(False, doc="If True, run only inference and skip training after loading"
|
||||
"model weights from the specified checkpoint in "
|
||||
"`src_checkpoint` flag. If False, run training and inference.")
|
||||
resume_training: bool = param.Boolean(False, doc="If True, resume training from the src_checkpoint.")
|
||||
tag: str = param.String(doc="A string that will be used as the display name of the run in AzureML.")
|
||||
experiment: str = param.String(default="", doc="The name of the AzureML experiment to use for this run. If not "
|
||||
"provided, the name of the model class will be used.")
|
||||
|
@ -213,42 +194,15 @@ class WorkflowParams(param.Parameterized):
|
|||
CROSSVAL_COUNT_ARG_NAME = "crossval_count"
|
||||
RANDOM_SEED_ARG_NAME = "random_seed"
|
||||
|
||||
@property
|
||||
def src_checkpoint_is_url(self) -> bool:
|
||||
try:
|
||||
result = urlparse(self.src_checkpoint)
|
||||
return all([result.scheme, result.netloc])
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def src_checkpoint_is_local_file(self) -> bool:
|
||||
return Path(self.src_checkpoint).is_file()
|
||||
|
||||
@property
|
||||
def src_checkpoint_is_aml_run_id(self) -> bool:
|
||||
match = re.match(r"[_\w-]*$", self.src_checkpoint.split(":")[0])
|
||||
return match is not None and not self.src_checkpoint_is_url and not self.src_checkpoint_is_local_file
|
||||
|
||||
@property
|
||||
def is_valid_src_checkpoint(self) -> bool:
|
||||
if self.src_checkpoint:
|
||||
return self.src_checkpoint_is_local_file or self.src_checkpoint_is_url or self.src_checkpoint_is_aml_run_id
|
||||
return True
|
||||
|
||||
def validate(self) -> None:
|
||||
if not self.is_valid_src_checkpoint:
|
||||
raise ValueError(f"Invalid src_checkpoint: {self.src_checkpoint}. Please provide a valid URL, local file "
|
||||
"or azureml run id.")
|
||||
if self.crossval_count > 1:
|
||||
if not (0 <= self.crossval_index < self.crossval_count):
|
||||
raise ValueError(f"Attribute crossval_index out of bounds (crossval_count = {self.crossval_count})")
|
||||
|
||||
if self.run_inference_only and not self.src_checkpoint:
|
||||
raise ValueError("Cannot run inference without a src_checkpoint. Please specify a valid src_checkpoint."
|
||||
"You can either use a URL, a local file or an azureml run id. For custom checkpoint paths "
|
||||
"within an azureml run, (other than last.ckpt), provide a src_checkpoint in the format."
|
||||
f"{SRC_CHECKPOINT_FORMAT_DOC}")
|
||||
raise ValueError(f"Cannot run inference without a src_checkpoint. {CheckpointParser.INFO_MESSAGE}")
|
||||
if self.resume_training and not self.src_checkpoint:
|
||||
raise ValueError(f"Cannot resume training without a src_checkpoint. {CheckpointParser.INFO_MESSAGE}")
|
||||
|
||||
@property
|
||||
def is_running_in_aml(self) -> bool:
|
||||
|
@ -533,6 +487,10 @@ class TrainerParams(param.Parameterized):
|
|||
param.String(default=None,
|
||||
doc="The value to use for the 'profiler' argument for the Lightning trainer. "
|
||||
"Set to either 'simple', 'advanced', or 'pytorch'")
|
||||
pl_sync_batchnorm: bool = param.Boolean(default=True,
|
||||
doc="PyTorch Lightning trainer flag 'sync_batchnorm': If True, "
|
||||
"synchronize batchnorm across all GPUs when running in ddp mode."
|
||||
"If False, batchnorm is not synchronized.")
|
||||
monitor_gpu: bool = param.Boolean(default=False,
|
||||
doc="If True, add the GPUStatsMonitor callback to the Lightning trainer object. "
|
||||
"This will write GPU utilization metrics every 50 batches by default.")
|
||||
|
@ -545,6 +503,17 @@ class TrainerParams(param.Parameterized):
|
|||
"any validation overheads during training time and produce "
|
||||
"additional time or memory consuming outputs only once after "
|
||||
"training is finished on the validation set.")
|
||||
pl_accumulate_grad_batches: int = param.Integer(default=1,
|
||||
doc="The number of batches over which gradients are accumulated, "
|
||||
"before a parameter update is done.")
|
||||
pl_log_every_n_steps: int = param.Integer(default=50,
|
||||
doc="PyTorch Lightning trainer flag 'log_every_n_steps': How often to"
|
||||
"log within steps. Default to 50.")
|
||||
pl_replace_sampler_ddp: bool = param.Boolean(default=True,
|
||||
doc="PyTorch Lightning trainer flag 'replace_sampler_ddp' that "
|
||||
"sets the sampler for distributed training with shuffle=True during "
|
||||
"training and shuffle=False during validation. Default to True. Set to"
|
||||
"False to set your own sampler.")
|
||||
|
||||
@property
|
||||
def use_gpu(self) -> bool:
|
||||
|
|
|
@ -1,8 +1,18 @@
|
|||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import param
|
||||
|
||||
|
||||
class DebugDDPOptions(Enum):
|
||||
OFF = "OFF"
|
||||
INFO = "INFO"
|
||||
DETAIL = "DETAIL"
|
||||
|
||||
|
||||
DEBUG_DDP_ENV_VAR = "TORCH_DISTRIBUTED_DEBUG"
|
||||
|
||||
|
||||
class ExperimentConfig(param.Parameterized):
|
||||
cluster: str = param.String(default="", allow_None=False,
|
||||
doc="The name of the GPU or CPU cluster inside the AzureML workspace"
|
||||
|
@ -26,5 +36,13 @@ class ExperimentConfig(param.Parameterized):
|
|||
doc="The Conda environment file that should be used when submitting the present run to "
|
||||
"AzureML. If not specified, the environment file in the current folder or one of its "
|
||||
"parents will be used.")
|
||||
debug_ddp: DebugDDPOptions = param.ClassSelector(default=DebugDDPOptions.OFF, class_=DebugDDPOptions,
|
||||
doc=f"Flag to override the environment var {DEBUG_DDP_ENV_VAR}"
|
||||
"that can be used to trigger logging and collective "
|
||||
"synchronization checks to ensure all ranks are synchronized "
|
||||
"appropriately. Default is `OFF`. It can be set to either "
|
||||
"`INFO` or `DETAIL` for different levels of logging. "
|
||||
"`DETAIL` may impact the application performance and thus "
|
||||
"should only be used when debugging issues")
|
||||
strictly_aml_v1: bool = param.Boolean(default=False, doc="If True, use AzureML v1 SDK. If False (default), use "
|
||||
"the v2 of the SDK")
|
||||
|
|
|
@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
from pathlib import Path
|
||||
|
||||
import param
|
||||
import logging
|
||||
from azureml.core import ScriptRunConfig
|
||||
from azureml.train.hyperdrive import HyperDriveConfig
|
||||
from pytorch_lightning import Callback, LightningDataModule, LightningModule
|
||||
|
@ -37,6 +38,7 @@ class LightningContainer(WorkflowParams,
|
|||
self._model: Optional[LightningModule] = None
|
||||
self._model_name = type(self).__name__
|
||||
self.num_nodes = 1
|
||||
self.trained_weights_path: Optional[Path] = None
|
||||
|
||||
def validate(self) -> None:
|
||||
WorkflowParams.validate(self)
|
||||
|
@ -246,6 +248,18 @@ class LightningContainer(WorkflowParams,
|
|||
argument `experiment`, falling back to the model class name if not set."""
|
||||
return self.experiment or self.model_name
|
||||
|
||||
def get_additional_aml_run_tags(self) -> Dict[str, str]:
|
||||
"""Returns a dictionary of tags that should be added to the AzureML run."""
|
||||
return {}
|
||||
|
||||
def on_run_extra_validation_epoch(self) -> None:
|
||||
if hasattr(self.model, "on_run_extra_validation_epoch"):
|
||||
assert self._model, "Model is not initialized."
|
||||
self.model.on_run_extra_validation_epoch() # type: ignore
|
||||
else:
|
||||
logging.warning("Hook `on_run_extra_validation_epoch` is not implemented by lightning module."
|
||||
"The extra validation epoch won't produce any extra outputs.")
|
||||
|
||||
|
||||
class LightningModuleWithOptimizer(LightningModule):
|
||||
"""
|
||||
|
|
|
@ -197,16 +197,19 @@ def create_lightning_trainer(container: LightningContainer,
|
|||
limit_test_batches=container.pl_limit_test_batches or 1.0,
|
||||
fast_dev_run=container.pl_fast_dev_run, # type: ignore
|
||||
num_sanity_val_steps=container.pl_num_sanity_val_steps,
|
||||
log_every_n_steps=container.pl_log_every_n_steps,
|
||||
# check_val_every_n_epoch=container.pl_check_val_every_n_epoch,
|
||||
callbacks=callbacks,
|
||||
logger=loggers,
|
||||
num_nodes=num_nodes,
|
||||
devices=devices,
|
||||
precision=precision,
|
||||
sync_batchnorm=True,
|
||||
sync_batchnorm=container.pl_sync_batchnorm,
|
||||
detect_anomaly=container.detect_anomaly,
|
||||
profiler=profiler,
|
||||
resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None,
|
||||
multiple_trainloader_mode=multiple_trainloader_mode,
|
||||
accumulate_grad_batches=container.pl_accumulate_grad_batches,
|
||||
replace_sampler_ddp=container.pl_replace_sampler_ddp,
|
||||
**additional_args)
|
||||
return trainer, storing_logger
|
||||
|
|
|
@ -187,14 +187,15 @@ class TransformerPooling(Module):
|
|||
num_layers: Number of Transformer encoder layers.
|
||||
num_heads: Number of attention heads per layer.
|
||||
dim_representation: Dimension of input encoding.
|
||||
transformer_dropout: The dropout value of transformer encoder layers.
|
||||
"""
|
||||
|
||||
def __init__(self, num_layers: int, num_heads: int, dim_representation: int) -> None:
|
||||
def __init__(self, num_layers: int, num_heads: int, dim_representation: int, transformer_dropout: float) -> None:
|
||||
super(TransformerPooling, self).__init__()
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.dim_representation = dim_representation
|
||||
|
||||
self.transformer_dropout = transformer_dropout
|
||||
self.cls_token = nn.Parameter(torch.zeros([1, dim_representation]))
|
||||
|
||||
self.transformer_encoder_layers = []
|
||||
|
@ -203,8 +204,8 @@ class TransformerPooling(Module):
|
|||
CustomTransformerEncoderLayer(self.dim_representation,
|
||||
self.num_heads,
|
||||
dim_feedforward=self.dim_representation,
|
||||
dropout=0.1,
|
||||
activation=F.gelu, # type: ignore
|
||||
dropout=self.transformer_dropout,
|
||||
activation=F.gelu,
|
||||
batch_first=True))
|
||||
self.transformer_encoder_layers = torch.nn.ModuleList(self.transformer_encoder_layers) # type: ignore
|
||||
|
||||
|
@ -239,17 +240,21 @@ class TransformerPoolingBenchmark(Module):
|
|||
num_layers: Number of Transformer encoder layers.
|
||||
num_heads: Number of attention heads per layer.
|
||||
dim_representation: Dimension of input encoding.
|
||||
transformer_dropout: The dropout value of transformer encoder layers.
|
||||
"""
|
||||
|
||||
def __init__(self, num_layers: int, num_heads: int, dim_representation: int, hidden_dim: int) -> None:
|
||||
def __init__(self, num_layers: int, num_heads: int,
|
||||
dim_representation: int, hidden_dim: int,
|
||||
transformer_dropout: float) -> None:
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.dim_representation = dim_representation
|
||||
self.hidden_dim = hidden_dim
|
||||
self.transformer_dropout = transformer_dropout
|
||||
transformer_layer = nn.TransformerEncoderLayer(d_model=self.dim_representation,
|
||||
nhead=self.num_heads,
|
||||
dropout=0.0,
|
||||
dropout=self.transformer_dropout,
|
||||
batch_first=True)
|
||||
self.transformer = nn.TransformerEncoder(transformer_layer, num_layers=self.num_layers)
|
||||
self.attention = nn.Sequential(nn.Linear(self.dim_representation, self.hidden_dim),
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
import os
|
||||
import sys
|
||||
import logging
|
||||
import torch.multiprocessing
|
||||
import torch
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
@ -20,14 +20,12 @@ from health_azure.utils import (create_run_recovery_id, ENV_OMPI_COMM_WORLD_RANK
|
|||
aggregate_hyperdrive_metrics, get_metrics_for_childless_run,
|
||||
ENV_GLOBAL_RANK, ENV_LOCAL_RANK, ENV_NODE_RANK,
|
||||
is_local_rank_zero, is_global_rank_zero, create_aml_run_object)
|
||||
|
||||
|
||||
from health_ml.experiment_config import ExperimentConfig
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.model_trainer import create_lightning_trainer, write_experiment_summary_file
|
||||
from health_ml.utils import fixed_paths
|
||||
from health_ml.utils.checkpoint_handler import CheckpointHandler
|
||||
from health_ml.utils.checkpoint_utils import cleanup_checkpoints
|
||||
from health_ml.utils.checkpoint_handler import CheckpointHandler
|
||||
from health_ml.utils.common_utils import (
|
||||
EFFECTIVE_RANDOM_SEED_KEY_NAME,
|
||||
change_working_directory,
|
||||
|
@ -214,7 +212,7 @@ class MLRunner:
|
|||
if not self.container.run_inference_only:
|
||||
|
||||
checkpoint_path_for_recovery = self.checkpoint_handler.get_recovery_or_checkpoint_path_train()
|
||||
if not checkpoint_path_for_recovery and self.container.src_checkpoint:
|
||||
if not checkpoint_path_for_recovery and self.container.resume_training:
|
||||
# If there is no recovery checkpoint (e.g job hasn't been resubmitted) and a source checkpoint is given,
|
||||
# use it to resume training.
|
||||
checkpoint_path_for_recovery = self.checkpoint_handler.trained_weights_path
|
||||
|
@ -236,6 +234,8 @@ class MLRunner:
|
|||
self.checkpoint_handler.additional_training_done()
|
||||
checkpoint_path_for_inference = self.checkpoint_handler.get_checkpoint_to_test()
|
||||
self.container.load_model_checkpoint(checkpoint_path_for_inference)
|
||||
best_epoch = torch.load(checkpoint_path_for_inference).get("epoch", -1)
|
||||
logging.info(f"Checkpoint saved at epoch: {best_epoch}")
|
||||
|
||||
def after_ddp_cleanup(self, old_environ: Dict) -> None:
|
||||
"""
|
||||
|
@ -274,6 +274,27 @@ class MLRunner:
|
|||
return self.container.crossval_index == 0
|
||||
return True
|
||||
|
||||
def get_trainer_for_inference(self, checkpoint_path: Optional[Path] = None) -> Trainer:
|
||||
# We run inference on a single device because distributed strategies such as DDP use DistributedSampler
|
||||
# internally, which replicates some samples to make sure all devices have the same batch size in case of
|
||||
# uneven inputs.
|
||||
mlflow_run_id = get_mlflow_run_id_from_previous_loggers(self.trainer)
|
||||
self.container.max_num_gpus = 1
|
||||
|
||||
if self.container.run_inference_only:
|
||||
assert checkpoint_path is not None
|
||||
else:
|
||||
self.validate_model_weights()
|
||||
|
||||
trainer, _ = create_lightning_trainer(
|
||||
container=self.container,
|
||||
resume_from_checkpoint=checkpoint_path,
|
||||
num_nodes=1,
|
||||
azureml_run_for_logging=self.azureml_run_for_logging,
|
||||
mlflow_run_for_logging=mlflow_run_id
|
||||
)
|
||||
return trainer
|
||||
|
||||
def run_training(self) -> None:
|
||||
"""
|
||||
The main training loop. It creates the Pytorch model based on the configuration options passed in,
|
||||
|
@ -294,12 +315,20 @@ class MLRunner:
|
|||
"""
|
||||
Run validation on the validation set for all models to save time/memory consuming outputs.
|
||||
"""
|
||||
assert hasattr(self.container.model, "run_extra_val_epoch"), "Model does not have run_extra_val_epoch flag."
|
||||
"This is required for running an additional validation epoch to save plots."
|
||||
self.container.model.run_extra_val_epoch = True # type: ignore
|
||||
self.container.on_run_extra_validation_epoch()
|
||||
trainer = self.get_trainer_for_inference(checkpoint_path=None)
|
||||
with change_working_directory(self.container.outputs_folder):
|
||||
assert self.trainer, "Trainer should be initialized before validation. Call self.init_training() first."
|
||||
self.trainer.validate(self.container.model, datamodule=self.data_module)
|
||||
trainer.validate(self.container.model, datamodule=self.data_module)
|
||||
|
||||
def validate_model_weights(self) -> None:
|
||||
logging.info("Validating model weights.")
|
||||
weights = torch.load(self.checkpoint_handler.get_checkpoint_to_test())["state_dict"]
|
||||
number_mismatch = 0
|
||||
for name, param in self.container.model.named_parameters():
|
||||
if not torch.allclose(weights[name].cpu(), param):
|
||||
logging.warning(f"Parameter {name} does not match between model and checkpoint.")
|
||||
number_mismatch += 1
|
||||
logging.info(f"Number of mismatched parameters: {number_mismatch}")
|
||||
|
||||
def run_inference(self) -> None:
|
||||
"""
|
||||
|
@ -312,20 +341,12 @@ class MLRunner:
|
|||
# We run inference on a single device because distributed strategies such as DDP use DistributedSampler
|
||||
# internally, which replicates some samples to make sure all devices have some batch size in case of
|
||||
# uneven inputs.
|
||||
mlflow_run_id = get_mlflow_run_id_from_previous_loggers(self.trainer)
|
||||
self.container.max_num_gpus = 1
|
||||
|
||||
checkpoint_path = (
|
||||
self.checkpoint_handler.get_checkpoint_to_test() if self.container.src_checkpoint else None
|
||||
self.checkpoint_handler.get_checkpoint_to_test() if self.container.run_inference_only else None
|
||||
)
|
||||
trainer, _ = create_lightning_trainer(
|
||||
container=self.container,
|
||||
resume_from_checkpoint=checkpoint_path,
|
||||
num_nodes=1,
|
||||
azureml_run_for_logging=self.azureml_run_for_logging,
|
||||
mlflow_run_for_logging=mlflow_run_id
|
||||
)
|
||||
|
||||
trainer = self.get_trainer_for_inference(checkpoint_path)
|
||||
# Change to the outputs folder so that the model can write to current working directory, and still
|
||||
# everything is put into the right place in AzureML (there, only the contents of the "outputs" folder
|
||||
# retained)
|
||||
|
|
|
@ -28,13 +28,13 @@ from health_azure import AzureRunInfo, submit_to_azure_if_needed # noqa: E402
|
|||
from health_azure.datasets import create_dataset_configs # noqa: E402
|
||||
from health_azure.logging import logging_to_stdout # noqa: E402
|
||||
from health_azure.paths import is_himl_used_from_git_repo # noqa: E402
|
||||
from health_azure.amulet import prepare_amulet_job # noqa: E402
|
||||
from health_azure.amulet import prepare_amulet_job, is_amulet_job # noqa: E402
|
||||
from health_azure.utils import (get_workspace, get_ml_client, is_local_rank_zero, # noqa: E402
|
||||
is_running_in_azure_ml, set_environment_variables_for_multi_node,
|
||||
create_argparser, parse_arguments, ParserResult, apply_overrides,
|
||||
filter_v2_input_output_args)
|
||||
|
||||
from health_ml.experiment_config import ExperimentConfig # noqa: E402
|
||||
from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, ExperimentConfig # noqa: E402
|
||||
from health_ml.lightning_container import LightningContainer # noqa: E402
|
||||
from health_ml.run_ml import MLRunner # noqa: E402
|
||||
from health_ml.utils import fixed_paths # noqa: E402
|
||||
|
@ -164,7 +164,7 @@ class Runner:
|
|||
if self.lightning_container.hyperdrive:
|
||||
raise ValueError("HyperDrive for hyperparameters tuning is only supported when submitting the job to "
|
||||
"AzureML. You need to specify a compute cluster with the argument --cluster.")
|
||||
if self.lightning_container.is_crossvalidation_enabled:
|
||||
if self.lightning_container.is_crossvalidation_enabled and not is_amulet_job():
|
||||
raise ValueError("Cross-validation is only supported when submitting the job to AzureML."
|
||||
"You need to specify a compute cluster with the argument --cluster.")
|
||||
|
||||
|
@ -176,7 +176,8 @@ class Runner:
|
|||
"""
|
||||
return {
|
||||
"commandline_args": " ".join(script_params),
|
||||
"tag": self.lightning_container.tag
|
||||
"tag": self.lightning_container.tag,
|
||||
**self.lightning_container.get_additional_aml_run_tags()
|
||||
}
|
||||
|
||||
def run(self) -> Tuple[LightningContainer, AzureRunInfo]:
|
||||
|
@ -222,6 +223,7 @@ class Runner:
|
|||
|
||||
# TODO: Update environment variables
|
||||
environment_variables: Dict[str, Any] = {}
|
||||
environment_variables[DEBUG_DDP_ENV_VAR] = self.experiment_config.debug_ddp.value
|
||||
|
||||
# Get default datastore from the provided workspace. Authentication can take a few seconds, hence only do
|
||||
# that if we are really submitting to AzureML.
|
||||
|
@ -252,14 +254,13 @@ class Runner:
|
|||
datastore=datastore,
|
||||
use_mounting=use_mounting)
|
||||
|
||||
if self.experiment_config.strictly_aml_v1:
|
||||
hyperdrive_config = self.lightning_container.get_hyperdrive_config()
|
||||
hyperparam_args = None
|
||||
else:
|
||||
hyperparam_args = self.lightning_container.get_hyperparam_args()
|
||||
hyperdrive_config = None
|
||||
|
||||
if self.experiment_config.cluster and not is_running_in_azure_ml():
|
||||
if self.experiment_config.strictly_aml_v1:
|
||||
hyperdrive_config = self.lightning_container.get_hyperdrive_config()
|
||||
hyperparam_args = None
|
||||
else:
|
||||
hyperparam_args = self.lightning_container.get_hyperparam_args()
|
||||
hyperdrive_config = None
|
||||
ml_client = get_ml_client()
|
||||
|
||||
env_file = choose_conda_env_file(env_file=self.experiment_config.conda_env)
|
||||
|
@ -295,6 +296,7 @@ class Runner:
|
|||
azure_run_info = submit_to_azure_if_needed(
|
||||
input_datasets=input_datasets, # type: ignore
|
||||
submit_to_azureml=False,
|
||||
environment_variables=environment_variables,
|
||||
strictly_aml_v1=self.experiment_config.strictly_aml_v1,
|
||||
)
|
||||
if azure_run_info.run:
|
||||
|
@ -308,7 +310,6 @@ class Runner:
|
|||
if suffix:
|
||||
current_name = self.lightning_container.tag or azure_run_info.run.display_name
|
||||
azure_run_info.run.display_name = f"{current_name} {suffix}"
|
||||
|
||||
# submit_to_azure_if_needed calls sys.exit after submitting to AzureML. We only reach this when running
|
||||
# the script locally or in AzureML.
|
||||
return azure_run_info
|
||||
|
@ -327,6 +328,11 @@ class Runner:
|
|||
package_setup_and_hacks()
|
||||
prepare_amulet_job()
|
||||
|
||||
# Add tags and arguments to Amulet runs
|
||||
if is_amulet_job():
|
||||
assert azure_run_info.run is not None
|
||||
azure_run_info.run.set_tags(self.additional_run_tags(sys.argv[1:]))
|
||||
|
||||
# Set environment variables for multi-node training if needed. This function will terminate early
|
||||
# if it detects that it is not in a multi-node environment.
|
||||
if self.experiment_config.num_nodes > 1:
|
||||
|
|
|
@ -4,22 +4,13 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from azureml.core import Run
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from azureml.core import Run
|
||||
|
||||
from health_azure.utils import is_global_rank_zero
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.utils.checkpoint_utils import (
|
||||
MODEL_WEIGHTS_DIR_NAME,
|
||||
CheckpointDownloader,
|
||||
find_recovery_checkpoint_on_disk_or_cloud,
|
||||
)
|
||||
from health_ml.utils.checkpoint_utils import find_recovery_checkpoint_on_disk_or_cloud
|
||||
|
||||
|
||||
class CheckpointHandler:
|
||||
|
@ -40,9 +31,9 @@ class CheckpointHandler:
|
|||
the checkpoint_url, local_checkpoint or checkpoint from an azureml run id.
|
||||
This is called at the start of training.
|
||||
"""
|
||||
|
||||
if self.container.src_checkpoint:
|
||||
self.trained_weights_path = self.get_local_checkpoints_path_or_download()
|
||||
self.trained_weights_path = self.container.src_checkpoint.get_path(self.container.checkpoint_folder)
|
||||
self.container.trained_weights_path = self.trained_weights_path
|
||||
|
||||
def additional_training_done(self) -> None:
|
||||
"""
|
||||
|
@ -85,53 +76,3 @@ class CheckpointHandler:
|
|||
logging.info(f"Using pre-trained weights from {self.trained_weights_path}")
|
||||
return self.trained_weights_path
|
||||
raise ValueError("Unable to determine which checkpoint should be used for testing.")
|
||||
|
||||
@staticmethod
|
||||
def download_weights(url: str, download_folder: Path) -> Path:
|
||||
"""
|
||||
Download a checkpoint from checkpoint_url to the modelweights directory. The file name is determined from
|
||||
from the file name in the URL. If that can't be determined, use a random file name.
|
||||
|
||||
:param url: The URL from which the weights should be downloaded.
|
||||
:param download_folder: The target folder for the download.
|
||||
:return: A path to the downloaded file.
|
||||
"""
|
||||
# assign the same filename as in the download url if possible, so that we can check for duplicates
|
||||
# If that fails, map to a random uuid
|
||||
file_name = os.path.basename(urlparse(url).path) or str(uuid.uuid4().hex)
|
||||
checkpoint_path = download_folder / file_name
|
||||
# only download if hasn't already been downloaded
|
||||
if checkpoint_path.is_file():
|
||||
logging.info(f"File already exists, skipping download: {checkpoint_path}")
|
||||
else:
|
||||
logging.info(f"Downloading weights from URL {url}")
|
||||
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
with open(checkpoint_path, "wb") as file:
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
file.write(chunk)
|
||||
return checkpoint_path
|
||||
|
||||
def get_local_checkpoints_path_or_download(self) -> Path:
|
||||
"""
|
||||
Get the path to the local weights to use or download them.
|
||||
"""
|
||||
|
||||
if self.container.src_checkpoint_is_local_file:
|
||||
checkpoint_path = Path(self.container.src_checkpoint)
|
||||
elif self.container.src_checkpoint_is_url:
|
||||
download_folder = self.container.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME
|
||||
download_folder.mkdir(exist_ok=True, parents=True)
|
||||
checkpoint_path = self.download_weights(url=self.container.src_checkpoint, download_folder=download_folder)
|
||||
elif self.container.src_checkpoint_is_aml_run_id:
|
||||
downloader = CheckpointDownloader(
|
||||
run_id=self.container.src_checkpoint, download_dir=self.container.outputs_folder
|
||||
)
|
||||
checkpoint_path = downloader.local_checkpoint_path
|
||||
else:
|
||||
raise ValueError("Unable to determine how to get the checkpoint path.")
|
||||
|
||||
if checkpoint_path is None or not checkpoint_path.is_file():
|
||||
raise FileNotFoundError(f"Could not find the weights file at {checkpoint_path}")
|
||||
return checkpoint_path
|
||||
|
|
|
@ -2,27 +2,29 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import re
|
||||
import os
|
||||
import uuid
|
||||
import torch
|
||||
import logging
|
||||
import tempfile
|
||||
import requests
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from urllib.parse import urlparse
|
||||
from azureml.core import Run, Workspace
|
||||
from health_azure import download_checkpoints_from_run_id, get_workspace
|
||||
|
||||
from health_azure.utils import (RUN_CONTEXT, download_files_from_run_id, get_run_file_names, is_running_in_azure_ml)
|
||||
from health_ml.utils.common_utils import (AUTOSAVE_CHECKPOINT_CANDIDATES, DEFAULT_AML_CHECKPOINT_DIR)
|
||||
from health_ml.utils.common_utils import (AUTOSAVE_CHECKPOINT_CANDIDATES, DEFAULT_AML_CHECKPOINT_DIR, CHECKPOINT_SUFFIX)
|
||||
from health_ml.utils.type_annotations import PathOrString
|
||||
|
||||
CHECKPOINT_SUFFIX = ".ckpt"
|
||||
# This is a constant that must match a filename defined in pytorch_lightning.ModelCheckpoint, but we don't want
|
||||
# to import that here.
|
||||
LAST_CHECKPOINT_FILE_NAME = "last"
|
||||
LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX = LAST_CHECKPOINT_FILE_NAME + CHECKPOINT_SUFFIX
|
||||
LAST_CHECKPOINT_FILE_NAME = f"last{CHECKPOINT_SUFFIX}"
|
||||
LEGACY_RECOVERY_CHECKPOINT_FILE_NAME = "recovery"
|
||||
MODEL_INFERENCE_JSON_FILE_NAME = "model_inference_config.json"
|
||||
MODEL_WEIGHTS_DIR_NAME = "trained_models"
|
||||
MODEL_WEIGHTS_DIR_NAME = "pretrained_models"
|
||||
|
||||
|
||||
def get_best_checkpoint_path(path: Path) -> Path:
|
||||
|
@ -31,7 +33,7 @@ def get_best_checkpoint_path(path: Path) -> Path:
|
|||
|
||||
:param path to checkpoint folder
|
||||
"""
|
||||
return path / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
return path / LAST_CHECKPOINT_FILE_NAME
|
||||
|
||||
|
||||
def download_folder_from_run_to_temp_folder(folder: str,
|
||||
|
@ -120,7 +122,7 @@ def find_recovery_checkpoint(path: Path) -> Optional[Path]:
|
|||
logging.warning(f"Found these legacy checkpoint files: {legacy_recovery_checkpoints}")
|
||||
raise ValueError("The legacy recovery checkpoint setup is no longer supported. As a workaround, you can take "
|
||||
f"one of the legacy checkpoints and upload as '{AUTOSAVE_CHECKPOINT_CANDIDATES[0]}'")
|
||||
candidates = [*AUTOSAVE_CHECKPOINT_CANDIDATES, LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX]
|
||||
candidates = [*AUTOSAVE_CHECKPOINT_CANDIDATES, LAST_CHECKPOINT_FILE_NAME]
|
||||
highest_epoch: Optional[int] = None
|
||||
file_with_highest_epoch: Optional[Path] = None
|
||||
for f in candidates:
|
||||
|
@ -147,10 +149,10 @@ def cleanup_checkpoints(ckpt_folder: Path) -> None:
|
|||
if len(files_in_checkpoint_folder) == 0:
|
||||
return
|
||||
logging.info(f"Files in checkpoint folder: {' '.join(files_in_checkpoint_folder)}")
|
||||
last_ckpt = ckpt_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
last_ckpt = ckpt_folder / LAST_CHECKPOINT_FILE_NAME
|
||||
all_files = f"Existing files: {' '.join(p.name for p in ckpt_folder.glob('*'))}"
|
||||
if not last_ckpt.is_file():
|
||||
raise FileNotFoundError(f"Checkpoint file {LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX} not found. {all_files}")
|
||||
raise FileNotFoundError(f"Checkpoint file {LAST_CHECKPOINT_FILE_NAME} not found. {all_files}")
|
||||
# Training is finished now. To save storage, remove the autosave checkpoint which is now obsolete.
|
||||
# Lightning does not overwrite checkpoints in-place. Rather, it writes "autosave.ckpt",
|
||||
# then "autosave-1.ckpt" and deletes "autosave.ckpt", then "autosave.ckpt" and deletes "autosave-v1.ckpt"
|
||||
|
@ -183,7 +185,6 @@ class CheckpointDownloader:
|
|||
self.remote_checkpoint_dir = (
|
||||
remote_checkpoint_dir or self.extract_remote_checkpoint_dir_from_checkpoint_filename()
|
||||
)
|
||||
self.download_checkpoint_if_necessary()
|
||||
|
||||
def extract_checkpoint_filename_from_run_id(self) -> str:
|
||||
"""
|
||||
|
@ -192,7 +193,7 @@ class CheckpointDownloader:
|
|||
"""
|
||||
run_id_split = self.run_id.split(":")
|
||||
self.run_id = run_id_split[0]
|
||||
return run_id_split[-1] if len(run_id_split) > 1 else LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
return run_id_split[-1] if len(run_id_split) > 1 else LAST_CHECKPOINT_FILE_NAME
|
||||
|
||||
def extract_remote_checkpoint_dir_from_checkpoint_filename(self) -> Path:
|
||||
"""
|
||||
|
@ -232,3 +233,111 @@ class CheckpointDownloader:
|
|||
self.run_id, str(self.remote_checkpoint_path), self.local_checkpoint_dir, aml_workspace=workspace
|
||||
)
|
||||
assert self.local_checkpoint_path.exists(), f"Couln't download checkpoint from run {self.run_id}."
|
||||
|
||||
|
||||
class CheckpointParser:
|
||||
"""Wrapper class for parsing checkpoint arguments. A checkpoint can be specified in one of the following ways:
|
||||
1. A local checkpoint file path
|
||||
2. A remote checkpoint file path
|
||||
3. A run ID from which to download the checkpoint file
|
||||
"""
|
||||
AML_RUN_ID_FORMAT = (f"<AzureML_run_id>:<optional/custom/path/to/checkpoints/><filename{CHECKPOINT_SUFFIX}>"
|
||||
f"If no custom path is provided (e.g., <AzureML_run_id>:<filename{CHECKPOINT_SUFFIX}>)"
|
||||
"the checkpoint will be downloaded from the default checkpoint folder "
|
||||
f"(e.g., '{DEFAULT_AML_CHECKPOINT_DIR}') If no filename is provided, "
|
||||
"(e.g., `src_checkpoint=<AzureML_run_id>`) the latest checkpoint "
|
||||
f"({LAST_CHECKPOINT_FILE_NAME}) will be downloaded.")
|
||||
INFO_MESSAGE = ("Please provide a valid checkpoint path, URL or AzureML run ID. For custom checkpoint paths "
|
||||
f"within an azureml run, provide a checkpoint in the format {AML_RUN_ID_FORMAT}.")
|
||||
DOC = ("We currently support three types of checkpoints: "
|
||||
" a. A local checkpoint folder that contains a checkpoint file."
|
||||
" b. A URL to a remote checkpoint to be downloaded."
|
||||
" c. A previous azureml run id where the checkpoint is supposed to be "
|
||||
" saved ('outputs/checkpoints/' folder by default.)"
|
||||
f"For the latter case 'c' : src_checkpoint should be in the format of {AML_RUN_ID_FORMAT}")
|
||||
|
||||
def __init__(self, checkpoint: str = "") -> None:
|
||||
self.checkpoint = checkpoint
|
||||
self.validate()
|
||||
|
||||
@property
|
||||
def is_url(self) -> bool:
|
||||
try:
|
||||
result = urlparse(self.checkpoint)
|
||||
return all([result.scheme, result.netloc])
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_local_file(self) -> bool:
|
||||
return Path(self.checkpoint).is_file()
|
||||
|
||||
@property
|
||||
def is_aml_run_id(self) -> bool:
|
||||
match = re.match(r"[_\w-]*$", self.checkpoint.split(":")[0])
|
||||
return match is not None and not self.is_url and not self.is_local_file
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
if self.checkpoint:
|
||||
return self.is_local_file or self.is_url or self.is_aml_run_id
|
||||
return True
|
||||
|
||||
def validate(self) -> None:
|
||||
if not self.is_valid:
|
||||
raise ValueError(f"Invalid checkpoint '{self.checkpoint}'. {self.INFO_MESSAGE}")
|
||||
|
||||
@staticmethod
|
||||
def download_from_url(url: str, download_folder: Path) -> Path:
|
||||
"""
|
||||
Download a checkpoint from checkpoint_url to the download folder. The file name is determined from
|
||||
from the file name in the URL. If that can't be determined, use a random file name.
|
||||
|
||||
:param url: The URL from which to download.
|
||||
:param download_folder: The target folder for the download.
|
||||
:return: A path to the downloaded file.
|
||||
"""
|
||||
# assign the same filename as in the download url if possible, so that we can check for duplicates
|
||||
# If that fails, map to a random uuid
|
||||
file_name = os.path.basename(urlparse(url).path) or str(uuid.uuid4().hex)
|
||||
checkpoint_path = download_folder / file_name
|
||||
# only download if hasn't already been downloaded
|
||||
if checkpoint_path.is_file():
|
||||
logging.info(f"File already exists, skipping download: {checkpoint_path}")
|
||||
else:
|
||||
logging.info(f"Downloading from URL {url}")
|
||||
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
with open(checkpoint_path, "wb") as file:
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
file.write(chunk)
|
||||
return checkpoint_path
|
||||
|
||||
def get_path(self, download_dir: Path) -> Path:
|
||||
"""Returns the path to the checkpoint file. If the checkpoint is a URL, it will be downloaded to the checkpoints
|
||||
folder. If the checkpoint is an AzureML run ID, it will be downloaded from the run to the checkpoints folder.
|
||||
If the checkpoint is a local file, it will be returned as is.
|
||||
|
||||
:param download_dir: The checkpoints folder to which the checkpoint should be downloaded if it is a URL or
|
||||
AzureML run ID.
|
||||
:raises ValueError: If the checkpoint is not a local file, URL or AzureML run ID.
|
||||
:raises FileNotFoundError: If the checkpoint is a URL or AzureML run ID and the download fails.
|
||||
:return: The path to the checkpoint file.
|
||||
"""
|
||||
if self.is_local_file:
|
||||
checkpoint_path = Path(self.checkpoint)
|
||||
elif self.is_url:
|
||||
download_folder = download_dir / MODEL_WEIGHTS_DIR_NAME
|
||||
download_folder.mkdir(exist_ok=True, parents=True)
|
||||
checkpoint_path = self.download_from_url(url=self.checkpoint, download_folder=download_folder)
|
||||
elif self.is_aml_run_id:
|
||||
downloader = CheckpointDownloader(run_id=self.checkpoint, download_dir=download_dir)
|
||||
downloader.download_checkpoint_if_necessary()
|
||||
checkpoint_path = downloader.local_checkpoint_path
|
||||
else:
|
||||
raise ValueError("Unable to determine how to get the checkpoint path.")
|
||||
|
||||
if checkpoint_path is None or not checkpoint_path.is_file():
|
||||
raise FileNotFoundError(f"Could not find the file at {checkpoint_path}")
|
||||
return checkpoint_path
|
||||
|
|
|
@ -12,10 +12,10 @@ class HEDJitter(object):
|
|||
"""
|
||||
Randomly perturbe the HED color space value an RGB image.
|
||||
|
||||
First, it disentangled the hematoxylin and eosin color channels by color deconvolution method using a fixed matrix,
|
||||
taken from Ruifrok and Johnston (2001): "Quantification of histochemical staining by color deconvolution."
|
||||
First, it disentangled the hematoxylin and eosin color channels by color deconvolution method using a fixed matrix.
|
||||
Second, it perturbed the hematoxylin, eosin stains independently.
|
||||
Third, it transformed the resulting stains into regular RGB color space.
|
||||
PyTorch version of: https://github.com/gatsby2016/Augmentation-PyTorch-Transforms/blob/master/myTransforms.py
|
||||
|
||||
Usage example:
|
||||
>>> transform = HEDJitter(0.05)
|
||||
|
@ -37,45 +37,45 @@ class HEDJitter(object):
|
|||
self.hed_from_rgb = torch.tensor([[1.87798274, -1.00767869, -0.55611582],
|
||||
[-0.06590806, 1.13473037, -0.1355218],
|
||||
[-0.60190736, -0.48041419, 1.57358807]])
|
||||
self.log_adjust = torch.log(torch.tensor(1E-6))
|
||||
|
||||
@staticmethod
|
||||
def adjust_hed(img: torch.Tensor,
|
||||
theta: float,
|
||||
stain_from_rgb_mat: torch.Tensor,
|
||||
rgb_from_stain_mat: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
def adjust_hed(self, img: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies HED jitter to image.
|
||||
|
||||
:param img: Input image.
|
||||
:param theta: Strength of the jitter. HED_light: theta=0.05; HED_strong: theta=0.2.
|
||||
:param stain_from_rgb_mat: Transformation matrix from HED to RGB.
|
||||
:param rgb_from_stain_mat: Transformation matrix from RGB to HED.
|
||||
"""
|
||||
alpha = torch.FloatTensor(1, 3).uniform_(1 - theta, 1 + theta)
|
||||
beta = torch.FloatTensor(1, 3).uniform_(-theta, theta)
|
||||
|
||||
# Only perturb the H (=0) and E (=1) channels
|
||||
alpha[0][-1] = 1.
|
||||
beta[0][-1] = 0.
|
||||
alpha = torch.FloatTensor(img.shape[0], 1, 1, 3).uniform_(1 - self.theta, 1 + self.theta)
|
||||
beta = torch.FloatTensor(img.shape[0], 1, 1, 3).uniform_(-self.theta, self.theta)
|
||||
|
||||
# Separate stains
|
||||
img = img.permute([0, 2, 3, 1])
|
||||
img = img + 2 # for consistency with skimage
|
||||
stains = -torch.log10(img) @ stain_from_rgb_mat
|
||||
stains = alpha * stains + beta # perturbations in HED color space
|
||||
img = torch.maximum(img, 1E-6 * torch.ones(img.shape))
|
||||
stains = (torch.log(img) / self.log_adjust) @ self.hed_from_rgb
|
||||
stains = torch.maximum(stains, torch.zeros(stains.shape))
|
||||
|
||||
# perturbations in HED color space
|
||||
stains = alpha * stains + beta
|
||||
|
||||
# Combine stains
|
||||
img = 10 ** (-stains @ rgb_from_stain_mat) - 2
|
||||
img = -(stains * (-self.log_adjust)) @ self.rgb_from_hed
|
||||
img = torch.exp(img)
|
||||
img = torch.clip(img, 0, 1)
|
||||
img = img.permute(0, 3, 1, 2)
|
||||
|
||||
# Normalize
|
||||
imin = torch.amin(img, dim=[1, 2, 3], keepdim=True)
|
||||
imax = torch.amax(img, dim=[1, 2, 3], keepdim=True)
|
||||
img = (img - imin) / (imax - imin)
|
||||
|
||||
return img
|
||||
|
||||
def __call__(self, img: torch.Tensor) -> torch.Tensor:
|
||||
if img.shape[1] != 3:
|
||||
raise ValueError("HED jitter can only be applied to images with 3 channels (RGB).")
|
||||
return self.adjust_hed(img, self.theta, self.hed_from_rgb, self.rgb_from_hed)
|
||||
|
||||
return self.adjust_hed(img)
|
||||
|
||||
|
||||
class StainNormalization(object):
|
||||
|
@ -131,7 +131,20 @@ class StainNormalization(object):
|
|||
return nimg
|
||||
|
||||
def __call__(self, img: torch.Tensor) -> torch.Tensor:
|
||||
return self.stain_normalize(img, self.reference_mean, self.reference_std)
|
||||
original_shape = img.shape
|
||||
if len(original_shape) == 3:
|
||||
img = img.unsqueeze(0) # add batch dimension if missing
|
||||
# if the input is a bag of images, stain normalization needs to run on each image separately
|
||||
if img.shape[0] > 1:
|
||||
for i in range(img.shape[0]):
|
||||
img_tile = img[i]
|
||||
img[i] = self.stain_normalize(img_tile.unsqueeze(0), self.reference_mean, self.reference_std)
|
||||
return img
|
||||
else:
|
||||
img = self.stain_normalize(img, self.reference_mean, self.reference_std)
|
||||
if len(original_shape) == 3:
|
||||
return img.squeeze(0)
|
||||
return img
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
|
|
|
@ -5,7 +5,7 @@ from torch import nn, rand, sum, allclose, ones_like
|
|||
|
||||
from health_ml.networks.layers.attention_layers import (AttentionLayer, GatedAttentionLayer,
|
||||
MeanPoolingLayer, TransformerPooling,
|
||||
MaxPoolingLayer)
|
||||
MaxPoolingLayer, TransformerPoolingBenchmark)
|
||||
|
||||
|
||||
def _test_attention_layer(attentionlayer: nn.Module, dim_in: int, dim_att: int,
|
||||
|
@ -19,11 +19,7 @@ def _test_attention_layer(attentionlayer: nn.Module, dim_in: int, dim_att: int,
|
|||
row_sums = sum(attn_weights, dim=1, keepdim=True)
|
||||
assert allclose(row_sums, ones_like(row_sums))
|
||||
|
||||
if isinstance(attentionlayer, TransformerPooling):
|
||||
pass
|
||||
elif isinstance(attentionlayer, MaxPoolingLayer):
|
||||
pass
|
||||
else:
|
||||
if not isinstance(attentionlayer, (MaxPoolingLayer, TransformerPooling, TransformerPoolingBenchmark)):
|
||||
pooled_features = attn_weights @ features.flatten(start_dim=1)
|
||||
assert allclose(pooled_features, output_features)
|
||||
|
||||
|
@ -59,8 +55,27 @@ def test_max_pooling(dim_in: int, batch_size: int,) -> None:
|
|||
@pytest.mark.parametrize("num_heads", [1, 2])
|
||||
@pytest.mark.parametrize("dim_in", [4, 8]) # dim_in % num_heads must be 0
|
||||
@pytest.mark.parametrize("batch_size", [1, 7])
|
||||
def test_transformer_pooling(num_layers: int, num_heads: int, dim_in: int, batch_size: int) -> None:
|
||||
def test_transformer_pooling(num_layers: int, num_heads: int, dim_in: int,
|
||||
batch_size: int) -> None:
|
||||
transformer_dropout = 0.5
|
||||
transformer_pooling = TransformerPooling(num_layers=num_layers,
|
||||
num_heads=num_heads,
|
||||
dim_representation=dim_in).eval()
|
||||
dim_representation=dim_in,
|
||||
transformer_dropout=transformer_dropout).eval()
|
||||
_test_attention_layer(transformer_pooling, dim_in=dim_in, dim_att=1, batch_size=batch_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_layers", [1, 4])
|
||||
@pytest.mark.parametrize("num_heads", [1, 2])
|
||||
@pytest.mark.parametrize("dim_in", [4, 8]) # dim_in % num_heads must be 0
|
||||
@pytest.mark.parametrize("batch_size", [1, 7])
|
||||
@pytest.mark.parametrize("dim_hid", [1, 4])
|
||||
def test_transformer_pooling_benchmark(num_layers: int, num_heads: int, dim_in: int,
|
||||
batch_size: int, dim_hid: int) -> None:
|
||||
transformer_dropout = 0.5
|
||||
transformer_pooling_benchmark = TransformerPoolingBenchmark(num_layers=num_layers,
|
||||
num_heads=num_heads,
|
||||
dim_representation=dim_in,
|
||||
hidden_dim=dim_hid,
|
||||
transformer_dropout=transformer_dropout).eval()
|
||||
_test_attention_layer(transformer_pooling_benchmark, dim_in=dim_in, dim_att=1, batch_size=batch_size)
|
||||
|
|
|
@ -8,52 +8,73 @@ from unittest import mock
|
|||
import pytest
|
||||
|
||||
from health_ml.configs.hello_world import HelloWorld
|
||||
from health_ml.deep_learning_config import WorkflowParams
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.utils.checkpoint_handler import CheckpointHandler
|
||||
from health_ml.utils.checkpoint_utils import (
|
||||
LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
|
||||
LAST_CHECKPOINT_FILE_NAME,
|
||||
MODEL_WEIGHTS_DIR_NAME,
|
||||
CheckpointDownloader,
|
||||
)
|
||||
CheckpointParser,)
|
||||
from health_ml.utils.checkpoint_handler import CheckpointHandler
|
||||
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
|
||||
from testhiml.utils.fixed_paths_for_tests import full_test_data_path, mock_run_id
|
||||
from testhiml.utils_testhiml import DEFAULT_WORKSPACE
|
||||
|
||||
|
||||
def test_checkpoint_downloader_run_id() -> None:
|
||||
with mock.patch("health_ml.utils.checkpoint_utils.CheckpointDownloader.download_checkpoint_if_necessary"):
|
||||
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id")
|
||||
assert checkpoint_downloader.run_id == "dummy_run_id"
|
||||
assert checkpoint_downloader.checkpoint_filename == LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
assert checkpoint_downloader.remote_checkpoint_dir == Path(DEFAULT_AML_CHECKPOINT_DIR)
|
||||
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id")
|
||||
assert checkpoint_downloader.run_id == "dummy_run_id"
|
||||
assert checkpoint_downloader.checkpoint_filename == LAST_CHECKPOINT_FILE_NAME
|
||||
assert checkpoint_downloader.remote_checkpoint_dir == Path(DEFAULT_AML_CHECKPOINT_DIR)
|
||||
|
||||
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id:best.ckpt")
|
||||
assert checkpoint_downloader.run_id == "dummy_run_id"
|
||||
assert checkpoint_downloader.checkpoint_filename == "best.ckpt"
|
||||
assert checkpoint_downloader.remote_checkpoint_dir == Path(DEFAULT_AML_CHECKPOINT_DIR)
|
||||
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id:best.ckpt")
|
||||
assert checkpoint_downloader.run_id == "dummy_run_id"
|
||||
assert checkpoint_downloader.checkpoint_filename == "best.ckpt"
|
||||
assert checkpoint_downloader.remote_checkpoint_dir == Path(DEFAULT_AML_CHECKPOINT_DIR)
|
||||
|
||||
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id:custom/path/best.ckpt")
|
||||
assert checkpoint_downloader.run_id == "dummy_run_id"
|
||||
assert checkpoint_downloader.checkpoint_filename == "best.ckpt"
|
||||
assert checkpoint_downloader.remote_checkpoint_dir == Path("custom/path")
|
||||
checkpoint_downloader = CheckpointDownloader(run_id="dummy_run_id:custom/path/best.ckpt")
|
||||
assert checkpoint_downloader.run_id == "dummy_run_id"
|
||||
assert checkpoint_downloader.checkpoint_filename == "best.ckpt"
|
||||
assert checkpoint_downloader.remote_checkpoint_dir == Path("custom/path")
|
||||
|
||||
|
||||
def _test_invalid_checkpoint(checkpoint: str) -> None:
|
||||
with pytest.raises(ValueError, match=r"Invalid checkpoint "):
|
||||
CheckpointParser(checkpoint=checkpoint)
|
||||
WorkflowParams(local_datasets=Path("foo"), src_checkpoint=checkpoint).validate()
|
||||
|
||||
|
||||
def test_validate_checkpoint_parser() -> None:
|
||||
|
||||
_test_invalid_checkpoint(checkpoint="dummy/local/path/model.ckpt")
|
||||
_test_invalid_checkpoint(checkpoint="INV@lid%RUN*id")
|
||||
_test_invalid_checkpoint(checkpoint="http/dummy_url-com")
|
||||
|
||||
# The following should be okay
|
||||
checkpoint = str(full_test_data_path(suffix="hello_world_checkpoint.ckpt"))
|
||||
CheckpointParser(checkpoint=checkpoint)
|
||||
WorkflowParams(local_datasets=Path("foo"), src_checkpoint=CheckpointParser(checkpoint)).validate()
|
||||
checkpoint = mock_run_id(id=0)
|
||||
CheckpointParser(checkpoint=checkpoint)
|
||||
WorkflowParams(local_datasets=Path("foo"), src_checkpoint=CheckpointParser(checkpoint)).validate()
|
||||
|
||||
|
||||
def get_checkpoint_handler(tmp_path: Path, src_checkpoint: str) -> Tuple[LightningContainer, CheckpointHandler]:
|
||||
container = LightningContainer()
|
||||
container.set_output_to(tmp_path)
|
||||
container.checkpoint_folder.mkdir(parents=True)
|
||||
container.src_checkpoint = src_checkpoint
|
||||
container.src_checkpoint = CheckpointParser(src_checkpoint)
|
||||
return container, CheckpointHandler(container=container, project_root=tmp_path)
|
||||
|
||||
|
||||
def test_load_model_chcekpoints_from_url(tmp_path: Path) -> None:
|
||||
def test_load_model_checkpoints_from_url(tmp_path: Path) -> None:
|
||||
WEIGHTS_URL = (
|
||||
"https://pl-bolts-weights.s3.us-east-2.amazonaws.com/" "simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
|
||||
)
|
||||
|
||||
container, checkpoint_handler = get_checkpoint_handler(tmp_path, WEIGHTS_URL)
|
||||
download_folder = container.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME
|
||||
assert container.src_checkpoint_is_url
|
||||
assert container.src_checkpoint.is_url
|
||||
checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||
assert checkpoint_handler.trained_weights_path
|
||||
assert checkpoint_handler.trained_weights_path.exists()
|
||||
|
@ -64,7 +85,7 @@ def test_load_model_checkpoints_from_local_file(tmp_path: Path) -> None:
|
|||
local_checkpoint_path = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
|
||||
|
||||
container, checkpoint_handler = get_checkpoint_handler(tmp_path, str(local_checkpoint_path))
|
||||
assert container.src_checkpoint_is_local_file
|
||||
assert container.src_checkpoint.is_local_file
|
||||
checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||
assert checkpoint_handler.trained_weights_path
|
||||
assert checkpoint_handler.trained_weights_path.exists()
|
||||
|
@ -82,10 +103,10 @@ def test_load_model_checkpoints_from_aml_run_id(src_chekpoint_filename: str, tmp
|
|||
src_checkpoint_filename = (
|
||||
src_chekpoint_filename.split("/")[-1]
|
||||
if src_chekpoint_filename
|
||||
else LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
else LAST_CHECKPOINT_FILE_NAME
|
||||
)
|
||||
expected_weights_path = container.outputs_folder / run_id / checkpoint_path / src_checkpoint_filename
|
||||
assert container.src_checkpoint_is_aml_run_id
|
||||
expected_weights_path = container.checkpoint_folder / run_id / checkpoint_path / src_checkpoint_filename
|
||||
assert container.src_checkpoint.is_aml_run_id
|
||||
checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||
assert checkpoint_handler.trained_weights_path
|
||||
assert checkpoint_handler.trained_weights_path.exists()
|
||||
|
@ -99,7 +120,7 @@ def test_custom_checkpoint_for_test(tmp_path: Path) -> None:
|
|||
container = HelloWorld()
|
||||
container.set_output_to(tmp_path)
|
||||
container.checkpoint_folder.mkdir(parents=True)
|
||||
last_checkpoint = container.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
last_checkpoint = container.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME
|
||||
last_checkpoint.touch()
|
||||
checkpoint_handler = CheckpointHandler(container=container, project_root=tmp_path)
|
||||
checkpoint_handler.additional_training_done()
|
||||
|
|
|
@ -16,6 +16,7 @@ dummy_img = torch.Tensor(
|
|||
[0.8717, 0.9098]],
|
||||
[[0.1592, 0.7216],
|
||||
[0.8305, 0.1127]]]])
|
||||
dummy_bag = torch.stack([dummy_img.squeeze(0), dummy_img.squeeze(0)])
|
||||
|
||||
|
||||
def _test_data_augmentation(data_augmentation: Callable[[Tensor], Tensor],
|
||||
|
@ -59,21 +60,38 @@ def test_stain_normalization() -> None:
|
|||
[0.8706, 0.4863]],
|
||||
[[0.8235, 0.5294],
|
||||
[0.8275, 0.7725]]]])
|
||||
expected_output_bag = torch.stack([expected_output_img.squeeze(0), expected_output_img.squeeze(0)])
|
||||
|
||||
_test_data_augmentation(data_augmentation, dummy_img, expected_output_img, stochastic=False)
|
||||
_test_data_augmentation(data_augmentation, dummy_bag, expected_output_bag, stochastic=False)
|
||||
|
||||
# Test tiling on the fly (i.e. when the input image does not have a batch dimension)
|
||||
_test_data_augmentation(data_augmentation, dummy_img.squeeze(0), expected_output_img.squeeze(0), stochastic=False)
|
||||
|
||||
|
||||
def test_hed_jitter() -> None:
|
||||
data_augmentation = HEDJitter(0.05)
|
||||
expected_output_img = torch.Tensor(
|
||||
[[[[0.6241, 0.1635],
|
||||
[0.9993, 1.0000]],
|
||||
[[1.0000, 1.0000],
|
||||
[1.0000, 1.0000]],
|
||||
[[0.2232, 0.8028],
|
||||
[0.9117, 0.1742]]]])
|
||||
expected_output_img1 = torch.Tensor(
|
||||
[[[[0.9639, 0.4130],
|
||||
[0.9134, 1.0000]],
|
||||
[[0.3125, 0.0000],
|
||||
[0.4474, 0.1820]],
|
||||
[[0.9195, 0.5265],
|
||||
[0.9118, 0.8291]]]])
|
||||
expected_output_img2 = torch.Tensor(
|
||||
[[[[0.8411, 0.2361],
|
||||
[0.7857, 0.8766]],
|
||||
[[0.7075, 0.0000],
|
||||
[1.0000, 0.4138]],
|
||||
[[0.9694, 0.4674],
|
||||
[0.9577, 0.8476]]]])
|
||||
expected_output_bag = torch.vstack([expected_output_img1,
|
||||
expected_output_img2])
|
||||
|
||||
_test_data_augmentation(data_augmentation, dummy_img, expected_output_img, stochastic=True)
|
||||
_test_data_augmentation(data_augmentation,
|
||||
dummy_bag,
|
||||
expected_output_bag,
|
||||
stochastic=True)
|
||||
|
||||
|
||||
def test_gaussian_blur() -> None:
|
||||
|
|
|
@ -13,44 +13,40 @@ from pathlib import Path
|
|||
|
||||
from health_ml.deep_learning_config import DatasetParams, WorkflowParams, OutputParams, OptimizerParams, \
|
||||
ExperimentFolderHandler, TrainerParams
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
from testhiml.utils.fixed_paths_for_tests import full_test_data_path, mock_run_id
|
||||
|
||||
|
||||
def _test_invalid_pre_checkpoint_workflow_params(src_checkpoint: str) -> None:
|
||||
error_message = "Invalid src_checkpoint:"
|
||||
with pytest.raises(ValueError) as ex:
|
||||
WorkflowParams(local_datasets=Path("foo"), src_checkpoint=src_checkpoint).validate()
|
||||
assert error_message in ex.value.args[0]
|
||||
|
||||
|
||||
def test_validate_workflow_params_src_checkpoint() -> None:
|
||||
|
||||
_test_invalid_pre_checkpoint_workflow_params(src_checkpoint="dummy/local/path/model.ckpt")
|
||||
_test_invalid_pre_checkpoint_workflow_params(src_checkpoint="INV@lid%RUN*id")
|
||||
_test_invalid_pre_checkpoint_workflow_params(src_checkpoint="http/dummy_url-com")
|
||||
|
||||
# The following should be okay
|
||||
full_file_path = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
|
||||
WorkflowParams(local_dataset=Path("foo"), src_checkpoint=str(full_file_path)).validate()
|
||||
run_id = mock_run_id(id=0)
|
||||
WorkflowParams(local_dataset=Path("foo"), src_checkpoint=run_id).validate()
|
||||
|
||||
|
||||
def test_validate_workflow_params_for_inference_only() -> None:
|
||||
error_message = "Cannot run inference without a src_checkpoint."
|
||||
with pytest.raises(ValueError) as ex:
|
||||
with pytest.raises(ValueError, match=r"Cannot run inference without a src_checkpoint."):
|
||||
WorkflowParams(local_datasets=Path("foo"), run_inference_only=True).validate()
|
||||
assert error_message in ex.value.args[0]
|
||||
|
||||
full_file_path = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
|
||||
run_id = mock_run_id(id=0)
|
||||
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True, src_checkpoint=run_id).validate()
|
||||
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True,
|
||||
src_checkpoint=f"{run_id}:best_val_loss.ckpt").validate()
|
||||
src_checkpoint=CheckpointParser(run_id)).validate()
|
||||
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True,
|
||||
src_checkpoint=f"{run_id}:custom/path/model.ckpt").validate()
|
||||
src_checkpoint=CheckpointParser(f"{run_id}:best_val_loss.ckpt")).validate()
|
||||
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True,
|
||||
src_checkpoint=str(full_file_path)).validate()
|
||||
src_checkpoint=CheckpointParser(f"{run_id}:custom/path/model.ckpt")).validate()
|
||||
WorkflowParams(local_dataset=Path("foo"), run_inference_only=True,
|
||||
src_checkpoint=CheckpointParser(str(full_file_path))).validate()
|
||||
|
||||
|
||||
def test_validate_workflow_params_for_resume_training() -> None:
|
||||
with pytest.raises(ValueError, match=r"Cannot resume training without a src_checkpoint."):
|
||||
WorkflowParams(local_datasets=Path("foo"), resume_training=True).validate()
|
||||
|
||||
full_file_path = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
|
||||
run_id = mock_run_id(id=0)
|
||||
WorkflowParams(local_dataset=Path("foo"), resume_training=True,
|
||||
src_checkpoint=CheckpointParser(run_id)).validate()
|
||||
WorkflowParams(local_dataset=Path("foo"), resume_training=True,
|
||||
src_checkpoint=CheckpointParser(f"{run_id}:best_val_loss.ckpt")).validate()
|
||||
WorkflowParams(local_dataset=Path("foo"), resume_training=True,
|
||||
src_checkpoint=CheckpointParser(f"{run_id}:custom/path/model.ckpt")).validate()
|
||||
WorkflowParams(local_dataset=Path("foo"), resume_training=True,
|
||||
src_checkpoint=CheckpointParser(str(full_file_path))).validate()
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
|
|
|
@ -41,6 +41,7 @@ def test_create_lightning_trainer() -> None:
|
|||
assert trainer.default_root_dir == str(container.outputs_folder)
|
||||
assert trainer.limit_train_batches == 1.0
|
||||
assert trainer._detect_anomaly == container.detect_anomaly
|
||||
assert trainer.accumulate_grad_batches == 1
|
||||
|
||||
assert isinstance(trainer.callbacks[0], TQDMProgressBar)
|
||||
assert isinstance(trainer.callbacks[1], ModelSummary)
|
||||
|
@ -192,3 +193,11 @@ def test_create_lightning_trainer_limit_batches() -> None:
|
|||
assert trainer3.num_training_batches == int(limit_train_batches_float * original_num_train_batches)
|
||||
assert trainer3.num_val_batches[0] == int(limit_val_batches_float * original_num_val_batches)
|
||||
assert trainer3.num_test_batches[0] == int(limit_test_batches_float * original_num_test_batches)
|
||||
|
||||
|
||||
def test_flag_grad_accum() -> None:
|
||||
num_batches = 4
|
||||
container = LightningContainer()
|
||||
container.pl_accumulate_grad_batches = num_batches
|
||||
trainer, _ = create_lightning_trainer(container)
|
||||
assert trainer.accumulate_grad_batches == num_batches
|
||||
|
|
|
@ -8,6 +8,11 @@ import pytest
|
|||
from pathlib import Path
|
||||
from typing import Generator
|
||||
from unittest.mock import DEFAULT, MagicMock, Mock, patch
|
||||
from _pytest.logging import LogCaptureFixture
|
||||
from pytorch_lightning import LightningModule
|
||||
|
||||
import mlflow
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
import mlflow
|
||||
from pytorch_lightning import Trainer
|
||||
|
@ -16,6 +21,7 @@ from health_ml.configs.hello_world import HelloWorld # type: ignore
|
|||
from health_ml.experiment_config import ExperimentConfig
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.run_ml import MLRunner, get_mlflow_run_id_from_previous_loggers
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
from health_ml.utils.common_utils import is_gpu_available
|
||||
from health_ml.utils.lightning_loggers import HimlMLFlowLogger, StoringLogger
|
||||
from health_azure.utils import is_global_rank_zero
|
||||
|
@ -62,7 +68,7 @@ def ml_runner_with_run_id() -> Generator:
|
|||
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||
container = HelloWorld()
|
||||
container.save_checkpoint = True
|
||||
container.src_checkpoint = mock_run_id(id=0)
|
||||
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
|
||||
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:
|
||||
mock_get_workspace.return_value = DEFAULT_WORKSPACE.workspace
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
|
@ -164,26 +170,79 @@ def test_run_validation(run_extra_val_epoch: bool) -> None:
|
|||
container = HelloWorld()
|
||||
container.create_lightning_module_and_store()
|
||||
container.run_extra_val_epoch = run_extra_val_epoch
|
||||
container.model.run_extra_val_epoch = run_extra_val_epoch # type: ignore
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
|
||||
with patch.object(container, "get_data_module"):
|
||||
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
|
||||
runner.setup()
|
||||
mock_trainer = MagicMock()
|
||||
mock_storing_logger = MagicMock()
|
||||
mock_create_trainer.return_value = mock_trainer, mock_storing_logger
|
||||
runner.init_training()
|
||||
with patch.object(container, "on_run_extra_validation_epoch") as mock_on_run_extra_validation_epoch:
|
||||
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
|
||||
runner.setup()
|
||||
mock_trainer = MagicMock()
|
||||
mock_storing_logger = MagicMock()
|
||||
mock_create_trainer.return_value = mock_trainer, mock_storing_logger
|
||||
runner.init_training()
|
||||
|
||||
assert runner.trainer == mock_trainer
|
||||
assert runner.storing_logger == mock_storing_logger
|
||||
assert runner.trainer == mock_trainer
|
||||
assert runner.storing_logger == mock_storing_logger
|
||||
|
||||
mock_trainer.validate = Mock()
|
||||
mock_trainer.validate = Mock()
|
||||
|
||||
if run_extra_val_epoch:
|
||||
runner.run_validation()
|
||||
if run_extra_val_epoch:
|
||||
with patch.object(runner, "validate_model_weights") as mock_validate_model_weights:
|
||||
runner.run_validation()
|
||||
mock_validate_model_weights.assert_called_once()
|
||||
|
||||
assert mock_trainer.validate.called == run_extra_val_epoch
|
||||
assert mock_on_run_extra_validation_epoch.called == run_extra_val_epoch
|
||||
assert hasattr(container.model, "on_run_extra_validation_epoch")
|
||||
assert mock_trainer.validate.called == run_extra_val_epoch
|
||||
|
||||
|
||||
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
|
||||
def test_model_extra_val_epoch(run_extra_val_epoch: bool) -> None:
|
||||
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||
with patch(
|
||||
"health_ml.configs.hello_world.HelloRegression.on_run_extra_validation_epoch"
|
||||
) as mock_on_run_extra_validation_epoch:
|
||||
container = HelloWorld()
|
||||
container.run_extra_val_epoch = run_extra_val_epoch
|
||||
container.create_lightning_module_and_store()
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
with patch.object(container, "get_data_module"):
|
||||
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
|
||||
runner.setup()
|
||||
mock_trainer = MagicMock()
|
||||
mock_create_trainer.return_value = mock_trainer, MagicMock()
|
||||
runner.init_training()
|
||||
mock_trainer.validate = Mock()
|
||||
|
||||
if run_extra_val_epoch:
|
||||
with patch.object(runner, "validate_model_weights") as mock_validate_model_weights:
|
||||
runner.run_validation()
|
||||
mock_validate_model_weights.assert_called_once()
|
||||
assert mock_on_run_extra_validation_epoch.called == run_extra_val_epoch
|
||||
assert mock_trainer.validate.called == run_extra_val_epoch
|
||||
|
||||
|
||||
def test_model_extra_val_epoch_missing_hook(caplog: LogCaptureFixture) -> None:
|
||||
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||
|
||||
def _create_model(self) -> LightningModule: # type: ignore
|
||||
return LightningModule()
|
||||
|
||||
with patch("health_ml.configs.hello_world.HelloWorld.create_model", _create_model):
|
||||
container = HelloWorld()
|
||||
container.create_lightning_module_and_store()
|
||||
container.run_extra_val_epoch = True
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
with patch.object(container, "get_data_module"):
|
||||
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
|
||||
runner.setup()
|
||||
mock_create_trainer.return_value = MagicMock(), MagicMock()
|
||||
runner.init_training()
|
||||
with patch.object(runner, "validate_model_weights") as mock_validate_model_weights:
|
||||
runner.run_validation()
|
||||
mock_validate_model_weights.assert_called_once()
|
||||
latest_message = caplog.records[-1].getMessage()
|
||||
assert "Hook `on_run_extra_validation_epoch` is not implemented by lightning module." in latest_message
|
||||
|
||||
|
||||
def test_run_inference(ml_runner_with_container: MLRunner, tmp_path: Path) -> None:
|
||||
|
@ -218,7 +277,9 @@ def test_run_inference(ml_runner_with_container: MLRunner, tmp_path: Path) -> No
|
|||
|
||||
actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train()
|
||||
assert actual_train_ckpt_path is None
|
||||
ml_runner_with_container.run()
|
||||
with patch.object(ml_runner_with_container, "validate_model_weights") as mock_validate_model_weights:
|
||||
ml_runner_with_container.run()
|
||||
mock_validate_model_weights.assert_called_once()
|
||||
actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train()
|
||||
assert actual_train_ckpt_path == expected_ckpt_path
|
||||
|
||||
|
@ -268,7 +329,10 @@ def test_run_inference_only(ml_runner_with_run_id: MLRunner) -> None:
|
|||
assert ml_runner_with_run_id.checkpoint_handler.trained_weights_path
|
||||
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
|
||||
with patch.multiple(
|
||||
ml_runner_with_run_id, run_training=DEFAULT, run_validation=DEFAULT
|
||||
ml_runner_with_run_id,
|
||||
run_training=DEFAULT,
|
||||
run_validation=DEFAULT,
|
||||
validate_model_weights=DEFAULT
|
||||
) as mocks:
|
||||
mock_trainer = MagicMock()
|
||||
mock_create_trainer.return_value = mock_trainer, MagicMock()
|
||||
|
@ -278,6 +342,7 @@ def test_run_inference_only(ml_runner_with_run_id: MLRunner) -> None:
|
|||
assert recovery_checkpoint == ml_runner_with_run_id.checkpoint_handler.trained_weights_path
|
||||
mocks["run_training"].assert_not_called()
|
||||
mocks["run_validation"].assert_not_called()
|
||||
mocks["validate_model_weights"].assert_not_called()
|
||||
mock_trainer.test.assert_called_once()
|
||||
|
||||
|
||||
|
@ -297,7 +362,8 @@ def test_model_weights_when_resume_training() -> None:
|
|||
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||
container = HelloWorld()
|
||||
container.max_num_gpus = 0
|
||||
container.src_checkpoint = mock_run_id(id=0)
|
||||
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
|
||||
container.resume_training = True
|
||||
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:
|
||||
mock_get_workspace.return_value = DEFAULT_WORKSPACE.workspace
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
|
@ -315,7 +381,7 @@ def test_runner_end_to_end() -> None:
|
|||
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||
container = HelloWorld()
|
||||
container.max_num_gpus = 0
|
||||
container.src_checkpoint = mock_run_id(id=0)
|
||||
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
|
||||
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:
|
||||
mock_get_workspace.return_value = DEFAULT_WORKSPACE.workspace
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
|
|
|
@ -6,7 +6,7 @@ from contextlib import contextmanager
|
|||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Generator, List, Optional
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
@ -17,6 +17,7 @@ from health_azure import AzureRunInfo, DatasetConfig
|
|||
from health_azure.paths import ENVIRONMENT_YAML_FILE_NAME
|
||||
from health_ml.configs.hello_world import HelloWorld # type: ignore
|
||||
from health_ml.deep_learning_config import WorkflowParams
|
||||
from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, DebugDDPOptions
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.runner import Runner
|
||||
from health_ml.utils.common_utils import change_working_directory
|
||||
|
@ -83,6 +84,35 @@ def test_parse_and_load_model(mock_runner: Runner, model_name: Optional[str], cl
|
|||
assert mock_runner.lightning_container.model_name == model_name
|
||||
|
||||
|
||||
@pytest.mark.parametrize("debug_ddp", ["OFF", "INFO", "DETAIL"])
|
||||
def test_ddp_debug_flag(debug_ddp: DebugDDPOptions, mock_runner: Runner) -> None:
|
||||
model_name = "HelloWorld"
|
||||
arguments = ["", f"--debug_ddp={debug_ddp}", f"--model={model_name}"]
|
||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
|
||||
with patch("health_ml.runner.get_workspace"):
|
||||
with patch("health_ml.runner.Runner.run_in_situ"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
mock_runner.run()
|
||||
mock_submit_to_azure_if_needed.assert_called_once()
|
||||
assert mock_submit_to_azure_if_needed.call_args[1]["environment_variables"][DEBUG_DDP_ENV_VAR] == debug_ddp
|
||||
|
||||
|
||||
def test_additional_aml_run_tags(mock_runner: Runner) -> None:
|
||||
model_name = "HelloWorld"
|
||||
arguments = ["", f"--model={model_name}", "--cluster=foo"]
|
||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
|
||||
with patch("health_ml.runner.check_conda_environment"):
|
||||
with patch("health_ml.runner.get_workspace"):
|
||||
with patch("health_ml.runner.get_ml_client"):
|
||||
with patch("health_ml.runner.Runner.run_in_situ"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
mock_runner.run()
|
||||
mock_submit_to_azure_if_needed.assert_called_once()
|
||||
assert "commandline_args" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
|
||||
assert "tag" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
|
||||
assert "max_epochs" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
|
||||
|
||||
|
||||
def test_run(mock_runner: Runner) -> None:
|
||||
model_name = "HelloWorld"
|
||||
arguments = ["", f"--model={model_name}"]
|
||||
|
@ -106,7 +136,8 @@ def test_submit_to_azureml_if_needed(mock_get_workspace: MagicMock,
|
|||
mock_runner: Runner
|
||||
) -> None:
|
||||
def _mock_dont_submit_to_aml(input_datasets: List[DatasetConfig],
|
||||
submit_to_azureml: bool, strictly_aml_v1: bool, # type: ignore
|
||||
submit_to_azureml: bool, strictly_aml_v1: bool, # type: ignore
|
||||
environment_variables: Dict[str, Any], # type: ignore
|
||||
) -> AzureRunInfo:
|
||||
datasets_input = [d.target_folder for d in input_datasets] if input_datasets else []
|
||||
return AzureRunInfo(input_datasets=datasets_input,
|
||||
|
|