Main changes:
- Change prints to logging.info to prevent all ranks from logging
- Run smoke tests with --striclty_aml_v1 flag to be able to wait for
completion
- Set progress=False of LoadTilesBtachd to avoid unnecessary verbosity
in logs
- Fix regression after merge: kill_ddp_processes right after
run_training because we run extra val epoch on a single device
- Remove a test that does nothing
- Use binary/multiclass metrics
- Ignore PL warnings
This commit is contained in:
Kenza Bouzid 2022-11-16 14:59:05 +00:00 коммит произвёл GitHub
Родитель c4e25b8a79
Коммит 321de75cb3
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
9 изменённых файлов: 50 добавлений и 56 удалений

Просмотреть файл

@ -8,7 +8,7 @@ import time
from contextlib import contextmanager
from typing import Generator, Optional, Union
from health_azure.utils import check_is_any_of
from health_azure.utils import check_is_any_of, is_global_rank_zero
logging_stdout_handler: Optional[logging.StreamHandler] = None
logging_to_file_handler: Optional[logging.StreamHandler] = None
@ -29,14 +29,16 @@ def logging_to_stdout(log_level: Union[int, str] = logging.INFO) -> None:
# logging lines.
global logging_stdout_handler
if not logging_stdout_handler:
print("Setting up logging to stdout.")
if is_global_rank_zero():
print("Setting up logging to stdout.")
# At startup, logging has one handler set, that writes to stderr, with a log level of 0 (logging.NOTSET)
if len(logger.handlers) == 1:
logger.removeHandler(logger.handlers[0])
logging_stdout_handler = logging.StreamHandler(stream=sys.stdout)
_add_formatter(logging_stdout_handler)
logger.addHandler(logging_stdout_handler)
print(f"Setting logging level to {log_level}")
if is_global_rank_zero():
print(f"Setting logging level to {log_level}")
logging_stdout_handler.setLevel(log_level)
logger.setLevel(log_level)

Просмотреть файл

@ -160,11 +160,11 @@ define DEFAULT_SMOKE_TEST_ARGS
endef
define AML_ONE_DEVICE_ARGS
--cluster=testing-nc6 --wait_for_completion --num_nodes=1 --max_num_gpus=1
--cluster=testing-nc6 --wait_for_completion --num_nodes=1 --max_num_gpus=1 --strictly_aml_v1=True
endef
define AML_MULTIPLE_DEVICE_ARGS
--cluster=dedicated-nc24s-v2 --wait_for_completion
--cluster=dedicated-nc24s-v2 --wait_for_completion --strictly_aml_v1=True
endef
define DEEPSMILEDEFAULT_SMOKE_TEST_ARGS

Просмотреть файл

@ -4,8 +4,9 @@
# ------------------------------------------------------------------------------------------
import os
import logging
import param
import logging
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union
from pathlib import Path
@ -92,6 +93,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams
logging.info(
"Replacing sampler with `DistributedSampler` is disabled. Make sure to set your own DDP sampler"
)
self.ignore_pl_warnings()
def validate(self) -> None:
super().validate()
@ -222,7 +224,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams
# to avoid division by zero error when computing `workers_per_gpu`
workers_per_gpu = num_cpus // (num_devices or 1)
workers_per_gpu = min(self.max_num_workers, workers_per_gpu)
print(f"Using {workers_per_gpu} data loader worker processes per GPU")
logging.info(f"Using {workers_per_gpu} data loader worker processes per GPU")
dataloader_kwargs = dict(num_workers=workers_per_gpu, pin_memory=True)
return dataloader_kwargs
@ -242,6 +244,14 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams
def get_slides_dataset(self) -> Optional[SlidesDataset]:
return None
def ignore_pl_warnings(self) -> None:
# Pytorch Lightning prints a warning if the batch size is not consistent across all batches. The way PL infers
# the batch size is not compatible with our data loaders. It searches for the first item in the batch that is a
# tensor and uses its size[0] as the batch size. However, in our case, the batch is a list of tensors, so it
# thinks that the batch size is the bag_size which can be different for each WSI in the batch. This is why we
# ignore this warning to avoid noisy logs.
warnings.filterwarnings("ignore", ".*Trying to infer the `batch_size` from an ambiguous collection.*")
class BaseMILTiles(BaseMIL):
"""BaseMILTiles is an abstract subclass of BaseMIL for running MIL experiments on tiles datasets. It is responsible
@ -278,11 +288,11 @@ class BaseMILTiles(BaseMIL):
if self.is_caching:
encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.outputs_folder)
transform = Compose([
LoadTilesBatchd(image_key, progress=True),
LoadTilesBatchd(image_key, progress=False),
EncodeTilesBatchd(image_key, encoder, chunk_size=self.encoding_chunk_size) # type: ignore
])
else:
transform = LoadTilesBatchd(image_key, progress=True) # type: ignore
transform = LoadTilesBatchd(image_key, progress=False) # type: ignore
# in case the transformations for training contain augmentations, val and test transform will be different
return {ModelKey.TRAIN: transform, ModelKey.VAL: transform, ModelKey.TEST: transform}

Просмотреть файл

@ -8,9 +8,10 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pathlib import Path
from pytorch_lightning import LightningModule
from torch import Tensor, argmax, mode, nn, optim, round
from torchmetrics import (AUROC, F1Score, Accuracy, ConfusionMatrix, Precision,
Recall, CohenKappa, AveragePrecision, Specificity)
from torchmetrics.classification import (MulticlassAUROC, MulticlassAccuracy, MulticlassConfusionMatrix,
MulticlassCohenKappa, MulticlassAveragePrecision, BinaryConfusionMatrix,
BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score, BinaryCohenKappa,
BinaryAUROC, BinarySpecificity, BinaryAveragePrecision)
from health_ml.utils import log_on_epoch
from health_ml.deep_learning_config import OptimizerParams
from health_cpath.models.encoders import IdentityEncoder
@ -175,30 +176,30 @@ class BaseDeepMILModule(LightningModule):
def get_metrics(self) -> nn.ModuleDict:
if self.n_classes > 1:
return nn.ModuleDict({
MetricsKey.ACC: Accuracy(num_classes=self.n_classes),
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
MetricsKey.AVERAGE_PRECISION: AveragePrecision(num_classes=self.n_classes),
MetricsKey.ACC: MulticlassAccuracy(num_classes=self.n_classes),
MetricsKey.AUROC: MulticlassAUROC(num_classes=self.n_classes),
MetricsKey.AVERAGE_PRECISION: MulticlassAveragePrecision(num_classes=self.n_classes),
# Quadratic Weighted Kappa (QWK) used in PANDA challenge
# is calculated using Cohen's Kappa with quadratic weights
# https://www.kaggle.com/code/reighns/understanding-the-quadratic-weighted-kappa/
MetricsKey.COHENKAPPA: CohenKappa(num_classes=self.n_classes, weights='quadratic'),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes),
MetricsKey.COHENKAPPA: MulticlassCohenKappa(num_classes=self.n_classes, weights='quadratic'),
MetricsKey.CONF_MATRIX: MulticlassConfusionMatrix(num_classes=self.n_classes),
# Metrics below are computed for multi-class case only
MetricsKey.ACC_MACRO: Accuracy(num_classes=self.n_classes, average='macro'),
MetricsKey.ACC_WEIGHTED: Accuracy(num_classes=self.n_classes, average='weighted')})
MetricsKey.ACC_MACRO: MulticlassAccuracy(num_classes=self.n_classes, average='macro'),
MetricsKey.ACC_WEIGHTED: MulticlassAccuracy(num_classes=self.n_classes, average='weighted')})
else:
return nn.ModuleDict({
MetricsKey.ACC: Accuracy(),
MetricsKey.AUROC: AUROC(num_classes=None),
MetricsKey.ACC: BinaryAccuracy(),
MetricsKey.AUROC: BinaryAUROC(),
# Average precision is a measure of area under the PR curve
MetricsKey.AVERAGE_PRECISION: AveragePrecision(),
MetricsKey.COHENKAPPA: CohenKappa(num_classes=2, weights='quadratic'),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=2),
MetricsKey.AVERAGE_PRECISION: BinaryAveragePrecision(),
MetricsKey.COHENKAPPA: BinaryCohenKappa(weights='quadratic'),
MetricsKey.CONF_MATRIX: BinaryConfusionMatrix(),
# Metrics below are computed for binary case only
MetricsKey.F1: F1Score(),
MetricsKey.PRECISION: Precision(),
MetricsKey.RECALL: Recall(),
MetricsKey.SPECIFICITY: Specificity()})
MetricsKey.F1: BinaryF1Score(),
MetricsKey.PRECISION: BinaryPrecision(),
MetricsKey.RECALL: BinaryRecall(),
MetricsKey.SPECIFICITY: BinarySpecificity()})
def get_extra_prefix(self) -> str:
"""Get extra prefix for the metrics name to avoir overriding best validation metrics."""

Просмотреть файл

@ -15,8 +15,6 @@ from azureml.core.run import _OfflineRun
HIML_ROOT = Path(__file__).parent.parent.parent.parent.parent.absolute()
health_ml_root = HIML_ROOT / "hi-ml" / "src"
health_azure_root = HIML_ROOT / "hi-ml-azure" / "src"
print(f"Inserting into sys path: {health_ml_root}")
print(f"Inserting into sys path: {health_azure_root}")
sys.path.insert(0, str(health_ml_root))
sys.path.insert(0, str(health_azure_root))

Просмотреть файл

@ -132,9 +132,9 @@ class ExperimentFolderHandler(Parameterized):
outputs_folder = snapshot_dir / DEFAULT_AML_UPLOAD_DIR
logs_folder = snapshot_dir / DEFAULT_LOGS_DIR_NAME
print(f"Run outputs folder: {outputs_folder}")
print(f"Logs folder: {logs_folder}")
print(f"Run root directory: {run_folder}")
logging.info(f"Run outputs folder: {outputs_folder}")
logging.info(f"Logs folder: {logs_folder}")
logging.info(f"Run root directory: {run_folder}")
return ExperimentFolderHandler(
outputs_folder=outputs_folder,
logs_folder=logs_folder,

Просмотреть файл

@ -338,10 +338,6 @@ class MLRunner:
if self.container.has_custom_test_step():
# Run Lightning's built-in test procedure if the `test_step` method has been overridden
logging.info("Running inference via the LightningModule.test_step method")
# We run inference on a single device because distributed strategies such as DDP use DistributedSampler
# internally, which replicates some samples to make sure all devices have some batch size in case of
# uneven inputs.
self.container.max_num_gpus = 1
checkpoint_path = (
self.checkpoint_handler.get_checkpoint_to_test() if self.container.run_inference_only else None
@ -408,6 +404,9 @@ class MLRunner:
with logging_section("Model training"):
self.run_training()
# Kill all processes besides rank 0
self.after_ddp_cleanup(old_environ)
# load model checkpoint for custom inference or additional validation step
if self.container.has_custom_test_step() or self.container.run_extra_val_epoch:
self.load_model_checkpoint()
@ -416,10 +415,6 @@ class MLRunner:
if self.container.run_extra_val_epoch:
with logging_section("Model Validation to save plots on validation set"):
self.run_validation()
# Kill all processes besides rank 0
self.after_ddp_cleanup(old_environ)
# Run inference on a single device
with logging_section("Model inference"):
self.run_inference()

Просмотреть файл

@ -32,7 +32,7 @@ from health_azure.amulet import prepare_amulet_job, is_amulet_job # noqa: E402
from health_azure.utils import (get_workspace, get_ml_client, is_local_rank_zero, # noqa: E402
is_running_in_azure_ml, set_environment_variables_for_multi_node,
create_argparser, parse_arguments, ParserResult, apply_overrides,
filter_v2_input_output_args)
filter_v2_input_output_args, is_global_rank_zero)
from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, ExperimentConfig # noqa: E402
from health_ml.lightning_container import LightningContainer # noqa: E402
@ -354,7 +354,8 @@ def run(project_root: Path) -> Tuple[LightningContainer, AzureRunInfo]:
:return: If submitting to AzureML, returns the model configuration that was used for training,
including commandline overrides applied (if any). For details on the arguments, see the constructor of Runner.
"""
print(f"project root: {project_root}")
if is_global_rank_zero():
print(f"project root: {project_root}")
runner = Runner(project_root)
return runner.run()

Просмотреть файл

@ -374,19 +374,6 @@ def test_model_weights_when_resume_training() -> None:
assert recovery_checkpoint == runner.checkpoint_handler.trained_weights_path
def test_runner_end_to_end() -> None:
experiment_config = ExperimentConfig(model="HelloWorld")
container = HelloWorld()
container.max_num_gpus = 0
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:
mock_get_workspace.return_value = DEFAULT_WORKSPACE.workspace
runner = MLRunner(experiment_config=experiment_config, container=container)
runner.setup()
runner.init_training()
runner.run_training()
@pytest.mark.parametrize("log_from_vm", [True, False])
def test_log_on_vm(log_from_vm: bool) -> None:
"""Test if the AzureML logger is called when the experiment is run outside AzureML."""