ENH: Small Fixes for Transfer Main merge

This commit is contained in:
Kenza Bouzid 2022-11-11 13:46:58 +00:00
Родитель 97724f3686
Коммит b8a1ae0bc0
15 изменённых файлов: 55 добавлений и 36 удалений

4
.github/workflows/codeql-analysis.yml поставляемый
Просмотреть файл

@ -13,10 +13,10 @@ name: "CodeQL"
on:
push:
branches: [ main, PR, transfer_main ]
branches: [ main, PR ]
pull_request:
# The branches below must be a subset of the branches above
branches: [ main, transfer_main ]
branches: [ main ]
schedule:
- cron: '20 1 * * 6' # https://crontab.guru/#20_1_*_*_6 this is: At 01:20 on Saturday.

4
.github/workflows/cpath-pr.yml поставляемый
Просмотреть файл

@ -1,11 +1,11 @@
name: Pathology PR
on:
push:
branches: [ main, transfer_main ]
branches: [ main ]
tags:
- '*'
pull_request:
branches: [ main, transfer_main ]
branches: [ main ]
paths:
- "hi-ml-cpath/**"
- ".github/workflows/cpath-pr.yml"

4
.github/workflows/credscan.yml поставляемый
Просмотреть файл

@ -2,9 +2,9 @@ name: CredScan
on:
push:
branches: [ main, transfer_main ]
branches: [ main ]
pull_request:
branches: [ main, transfer_main ]
branches: [ main ]
jobs:
CredScan:

4
.github/workflows/hi-ml-pr.yml поставляемый
Просмотреть файл

@ -1,11 +1,11 @@
name: HI-ML HI-ML-Azure PR
on:
push:
branches: [ main, transfer_main ]
branches: [ main ]
tags:
- '*'
pull_request:
branches: [ main, transfer_main ]
branches: [ main ]
paths:
- "hi-ml-azure/**"
- "hi-ml/**"

Просмотреть файл

@ -237,7 +237,7 @@ himl-runner --model=Mycontainer --run_inference_only --src_checkpoint=MyContaine
## Resume training from a given checkpoint
Analogously, one can resume training by setting `--src_checkpoint` and `--resume_training`to train a model longer.
Analogously, one can resume training by setting `--src_checkpoint` and `--resume_training` to train a model longer.
The pytorch lightning trainer will initialize the lightning module from the given checkpoint corresponding to the best
validation loss epoch as set in the following comandline.

Просмотреть файл

@ -103,7 +103,6 @@ DEFAULT_ENVIRONMENT_VARIABLES = {
"RSLEX_DIRECT_VOLUME_MOUNT": "true",
"RSLEX_DIRECT_VOLUME_MOUNT_MAX_CACHE_SIZE": "1",
"DATASET_MOUNT_CACHE_SIZE": "1",
"AZUREML_COMPUTE_USE_COMMON_RUNTIME": "false",
}

Просмотреть файл

@ -31,7 +31,7 @@ pip_from_conda:
# Lock the current Conda environment secondary dependencies versions
lock_env:
./create_and_lock_environment.sh
../create_and_lock_environment.sh
# clean build artifacts
clean:
@ -111,7 +111,7 @@ define CRCKSIMCLR_ARGS
endef
define REGRESSION_TEST_ARGS
--cluster dedicated-nc24s-v2 --regression_test_csv_tolerance=0.5
--cluster dedicated-nc24s-v2 --regression_test_csv_tolerance=0.5 --strictly_aml_v1=True
endef
define PANDA_REGRESSION_METRICS

Просмотреть файл

@ -133,9 +133,10 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams
options.add(PlotOption.TOP_BOTTOM_TILES)
return options
# overwrite this method if you want to produce validation plots at each epoch. By default, at the end of the
# training an extra validation epoch is run where val_plot_options = test_plot_options
def get_val_plot_options(self) -> Set[PlotOption]:
""" Override this method if you want to produce validation plots at each epoch. By default, at the end of the
training an extra validation epoch is run where val_plot_options = test_plot_options
"""
return set()
def get_outputs_handler(self) -> DeepMILOutputsHandler:

Просмотреть файл

@ -137,7 +137,11 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
class TilesDataModule(HistoDataModule[TilesDataset]):
"""Base class to load the tiles of a dataset as train, val, test sets"""
"""Base class to load the tiles of a dataset as train, val, test sets. Note that tiles are always shuffled by
default. This means that we sample a random subset of tiles from each bag at each epoch. This is different from
slides shuffling that is switched on during training time only. This is done to avoid overfitting to the order of
the tiles in each bag.
"""
def __init__(
self,

Просмотреть файл

@ -148,6 +148,7 @@ class LossAnalysisCallback(Callback):
return self.outputs_folder / f"{stage}/loss_anomalies"
def create_outputs_folders(self) -> None:
"""Creates the output folders if they don't exist."""
folders = [
self.get_cache_folder,
self.get_scatter_folder,
@ -161,12 +162,14 @@ class LossAnalysisCallback(Callback):
os.makedirs(folder(stage), exist_ok=True)
def get_empty_loss_cache(self) -> LossCacheDictType:
"""Returns an empty loss cache dictionary for keys: slide_id, loss, entropy and tile_ids if save_tile_ids."""
keys = [ResultsKey.SLIDE_ID, ResultsKey.LOSS, ResultsKey.ENTROPY]
if self.save_tile_ids:
keys.append(ResultsKey.TILE_ID)
return {key: [] for key in keys}
def _format_epoch(self, epoch: int) -> str:
"""Formats the epoch number to a string with 3 digits."""
return str(epoch).zfill(len(str(self.max_epochs)))
def get_loss_cache_file(self, epoch: int, stage: ModelKey) -> Path:
@ -201,6 +204,7 @@ class LossAnalysisCallback(Callback):
return pd.read_csv(self.get_loss_cache_file(epoch, stage), index_col=idx_col, usecols=columns)
def should_cache_loss_values(self, current_epoch: int) -> bool:
"""Returns True if the current epoch is a multiple of the epochs_interval."""
if current_epoch >= self.max_epochs:
return False # Don't cache loss values for the extra validation epoch
current_epoch = current_epoch + 1
@ -244,7 +248,15 @@ class LossAnalysisCallback(Callback):
high: Optional[bool] = None,
num_values: Optional[int] = None
) -> List[np.ndarray]:
"""Selects the values corresponding to keys from a dataframe for the given epoch and stage"""
"""Selects the values corresponding to keys from a dataframe for the given epoch and stage.
:param keys: The keys to select.
:param epoch: The epoch to select.
:param stage: The model's stage e.g. train, val, test.
:param high: If True, selects the highest values, if False, selects the lowest values, if None, selects all
values.
:param num_values: The number of values to select.
"""
loss_cache = self.read_loss_cache(epoch, stage)
return_values = []
for key in keys:
@ -319,7 +331,11 @@ class LossAnalysisCallback(Callback):
return np.array(slides).T, np.array(slides_entropy).T
def save_slide_ids(self, slide_ids: List[str], path: Path) -> None:
"""Dumps the slides ids in a txt file."""
"""Dumps the slides ids in a txt file.
:param slide_ids: The slides ids to save.
:param path: The path to save the slides ids to.
"""
if slide_ids:
with open(path, "w") as f:
for slide_id in slide_ids:
@ -481,7 +497,7 @@ class LossAnalysisCallback(Callback):
self.loss_cache[stage] = self.get_empty_loss_cache() # reset loss cache
def handle_loss_exceptions(self, stage: ModelKey, exception: Exception) -> None:
"""Handles the loss exceptions."""
"""Handles the loss exceptions. If log_exceptions is True, logs the exception as warnings, else raises it."""
if self.log_exceptions:
# If something goes wrong, we don't want to crash the training. We just log the error and carry on
# validation.

Просмотреть файл

@ -57,6 +57,11 @@ def validate_class_names(class_names: Optional[Sequence[str]], n_classes: int) -
def validate_slide_datasets_for_plot_options(
plot_options: Collection[PlotOption], slides_dataset: Optional[SlidesDataset]
) -> None:
"""Validate that the specified plot options are compatible with the specified slides dataset.
:param plot_options: Plot options to validate.
:param slides_dataset: Slides dataset to validate against.
"""
def _validate_slide_plot_option(plot_option: PlotOption) -> None:
if plot_option in plot_options and not slides_dataset:

Просмотреть файл

@ -68,7 +68,7 @@ def save_pr_curve(results: ResultsType, figures_dir: Path, stage: str = '') -> N
format_pr_or_roc_axes(plot_type='pr', ax=ax)
save_figure(fig=fig, figpath=figures_dir / f"pr_curve_{stage}.png")
else:
raise Warning("The PR curve plot implementation works only for binary cases, this plot will be skipped.")
logging.warning("The PR curve plot implementation works only for binary cases, this plot will be skipped.")
def save_confusion_matrix(results: ResultsType, class_names: Sequence[str], figures_dir: Path, stage: str = '') -> None:
@ -184,7 +184,7 @@ class DeepMILPlotsHandler:
class_names: Optional[Sequence[str]] = None,
wsi_has_mask: bool = True,
) -> None:
"""_summary_
"""Class that handles the plotting of DeepMIL results.
:param plot_options: A set of plot options to produce the desired plot outputs.
:param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original,

Просмотреть файл

@ -158,19 +158,14 @@ def test_slides_datamodule_pl_replace_sampler_ddp(mock_panda_slides_root_dir: Pa
run_distributed(_test_datamodule_pl_ddp_sampler_false, [slides_datamodule], world_size=2)
@pytest.mark.skip(reason="Test fails with Broken Pipe Error. To be fixed.")
@pytest.mark.skipif(not torch.distributed.is_available(), reason="PyTorch distributed unavailable")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
@pytest.mark.gpu
def test_tiles_datamodule_pl_replace_sampler_ddp_true(mock_panda_tiles_root_dir: Path) -> None:
def test_tiles_datamodule_pl_replace_sampler_ddp(mock_panda_tiles_root_dir: Path) -> None:
tiles_datamodule = PandaTilesDataModule(root_path=mock_panda_tiles_root_dir, seed=42, pl_replace_sampler_ddp=True)
run_distributed(_test_datamodule_pl_ddp_sampler_true, [tiles_datamodule], world_size=2)
@pytest.mark.skipif(not torch.distributed.is_available(), reason="PyTorch distributed unavailable")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
@pytest.mark.gpu
def test_tiles_datamodule_pl_replace_sampler_ddp_false(mock_panda_tiles_root_dir: Path) -> None:
tiles_datamodule = PandaTilesDataModule(root_path=mock_panda_tiles_root_dir, seed=42, pl_replace_sampler_ddp=False)
tiles_datamodule.pl_replace_sampler_ddp = False
run_distributed(_test_datamodule_pl_ddp_sampler_false, [tiles_datamodule], world_size=2)

Просмотреть файл

@ -144,7 +144,7 @@ def test_save_conf_matrix_integration(tmp_path: Path) -> None:
assert actual_conf_matrix.shape == expected_conf_matrix_shape
def test_pr_curve_integration(tmp_path: Path) -> None:
def test_pr_curve_integration(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None:
results = {
ResultsKey.TRUE_LABEL: [0, 1, 0, 1, 0, 1],
ResultsKey.PROB: [0.1, 0.8, 0.6, 0.3, 0.5, 0.4]
@ -156,9 +156,11 @@ def test_pr_curve_integration(tmp_path: Path) -> None:
assert file.exists()
os.remove(file)
# check warning is raised and plot is not produced if NOT a binary case
# check warning is logged and plot is not produced if NOT a binary case
results[ResultsKey.TRUE_LABEL] = [0, 1, 0, 2, 0, 1]
with pytest.raises(Warning) as w:
save_pr_curve(results, tmp_path, stage='foo') # type: ignore
assert "The PR curve plot implementation works only for binary cases, this plot will be skipped." in str(w)
save_pr_curve(results, tmp_path, stage='foo') # type: ignore
warning_message = "The PR curve plot implementation works only for binary cases, this plot will be skipped."
assert warning_message in caplog.records[-1].getMessage()
assert not file.exists()

Просмотреть файл

@ -14,9 +14,6 @@ from pytorch_lightning import LightningModule
import mlflow
from pytorch_lightning import Trainer
import mlflow
from pytorch_lightning import Trainer
from health_ml.configs.hello_world import HelloWorld # type: ignore
from health_ml.experiment_config import ExperimentConfig
from health_ml.lightning_container import LightningContainer