зеркало из https://github.com/microsoft/hi-ml.git
BUG: Fix hanging smoke tests (#666)
Main changes: - Change prints to logging.info to prevent all ranks from logging - Run smoke tests with --striclty_aml_v1 flag to be able to wait for completion - Set progress=False of LoadTilesBtachd to avoid unnecessary verbosity in logs - Fix regression after merge: kill_ddp_processes right after run_training because we run extra val epoch on a single device - Remove a test that does nothing - Use binary/multiclass metrics - Ignore PL warnings
This commit is contained in:
Родитель
c4e25b8a79
Коммит
321de75cb3
|
@ -8,7 +8,7 @@ import time
|
|||
from contextlib import contextmanager
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
from health_azure.utils import check_is_any_of
|
||||
from health_azure.utils import check_is_any_of, is_global_rank_zero
|
||||
|
||||
logging_stdout_handler: Optional[logging.StreamHandler] = None
|
||||
logging_to_file_handler: Optional[logging.StreamHandler] = None
|
||||
|
@ -29,14 +29,16 @@ def logging_to_stdout(log_level: Union[int, str] = logging.INFO) -> None:
|
|||
# logging lines.
|
||||
global logging_stdout_handler
|
||||
if not logging_stdout_handler:
|
||||
print("Setting up logging to stdout.")
|
||||
if is_global_rank_zero():
|
||||
print("Setting up logging to stdout.")
|
||||
# At startup, logging has one handler set, that writes to stderr, with a log level of 0 (logging.NOTSET)
|
||||
if len(logger.handlers) == 1:
|
||||
logger.removeHandler(logger.handlers[0])
|
||||
logging_stdout_handler = logging.StreamHandler(stream=sys.stdout)
|
||||
_add_formatter(logging_stdout_handler)
|
||||
logger.addHandler(logging_stdout_handler)
|
||||
print(f"Setting logging level to {log_level}")
|
||||
if is_global_rank_zero():
|
||||
print(f"Setting logging level to {log_level}")
|
||||
logging_stdout_handler.setLevel(log_level)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
|
|
|
@ -160,11 +160,11 @@ define DEFAULT_SMOKE_TEST_ARGS
|
|||
endef
|
||||
|
||||
define AML_ONE_DEVICE_ARGS
|
||||
--cluster=testing-nc6 --wait_for_completion --num_nodes=1 --max_num_gpus=1
|
||||
--cluster=testing-nc6 --wait_for_completion --num_nodes=1 --max_num_gpus=1 --strictly_aml_v1=True
|
||||
endef
|
||||
|
||||
define AML_MULTIPLE_DEVICE_ARGS
|
||||
--cluster=dedicated-nc24s-v2 --wait_for_completion
|
||||
--cluster=dedicated-nc24s-v2 --wait_for_completion --strictly_aml_v1=True
|
||||
endef
|
||||
|
||||
define DEEPSMILEDEFAULT_SMOKE_TEST_ARGS
|
||||
|
|
|
@ -4,8 +4,9 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
import logging
|
||||
import param
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union
|
||||
|
||||
from pathlib import Path
|
||||
|
@ -92,6 +93,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams
|
|||
logging.info(
|
||||
"Replacing sampler with `DistributedSampler` is disabled. Make sure to set your own DDP sampler"
|
||||
)
|
||||
self.ignore_pl_warnings()
|
||||
|
||||
def validate(self) -> None:
|
||||
super().validate()
|
||||
|
@ -222,7 +224,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams
|
|||
# to avoid division by zero error when computing `workers_per_gpu`
|
||||
workers_per_gpu = num_cpus // (num_devices or 1)
|
||||
workers_per_gpu = min(self.max_num_workers, workers_per_gpu)
|
||||
print(f"Using {workers_per_gpu} data loader worker processes per GPU")
|
||||
logging.info(f"Using {workers_per_gpu} data loader worker processes per GPU")
|
||||
dataloader_kwargs = dict(num_workers=workers_per_gpu, pin_memory=True)
|
||||
return dataloader_kwargs
|
||||
|
||||
|
@ -242,6 +244,14 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams
|
|||
def get_slides_dataset(self) -> Optional[SlidesDataset]:
|
||||
return None
|
||||
|
||||
def ignore_pl_warnings(self) -> None:
|
||||
# Pytorch Lightning prints a warning if the batch size is not consistent across all batches. The way PL infers
|
||||
# the batch size is not compatible with our data loaders. It searches for the first item in the batch that is a
|
||||
# tensor and uses its size[0] as the batch size. However, in our case, the batch is a list of tensors, so it
|
||||
# thinks that the batch size is the bag_size which can be different for each WSI in the batch. This is why we
|
||||
# ignore this warning to avoid noisy logs.
|
||||
warnings.filterwarnings("ignore", ".*Trying to infer the `batch_size` from an ambiguous collection.*")
|
||||
|
||||
|
||||
class BaseMILTiles(BaseMIL):
|
||||
"""BaseMILTiles is an abstract subclass of BaseMIL for running MIL experiments on tiles datasets. It is responsible
|
||||
|
@ -278,11 +288,11 @@ class BaseMILTiles(BaseMIL):
|
|||
if self.is_caching:
|
||||
encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.outputs_folder)
|
||||
transform = Compose([
|
||||
LoadTilesBatchd(image_key, progress=True),
|
||||
LoadTilesBatchd(image_key, progress=False),
|
||||
EncodeTilesBatchd(image_key, encoder, chunk_size=self.encoding_chunk_size) # type: ignore
|
||||
])
|
||||
else:
|
||||
transform = LoadTilesBatchd(image_key, progress=True) # type: ignore
|
||||
transform = LoadTilesBatchd(image_key, progress=False) # type: ignore
|
||||
# in case the transformations for training contain augmentations, val and test transform will be different
|
||||
return {ModelKey.TRAIN: transform, ModelKey.VAL: transform, ModelKey.TEST: transform}
|
||||
|
||||
|
|
|
@ -8,9 +8,10 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
|||
from pathlib import Path
|
||||
from pytorch_lightning import LightningModule
|
||||
from torch import Tensor, argmax, mode, nn, optim, round
|
||||
from torchmetrics import (AUROC, F1Score, Accuracy, ConfusionMatrix, Precision,
|
||||
Recall, CohenKappa, AveragePrecision, Specificity)
|
||||
|
||||
from torchmetrics.classification import (MulticlassAUROC, MulticlassAccuracy, MulticlassConfusionMatrix,
|
||||
MulticlassCohenKappa, MulticlassAveragePrecision, BinaryConfusionMatrix,
|
||||
BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score, BinaryCohenKappa,
|
||||
BinaryAUROC, BinarySpecificity, BinaryAveragePrecision)
|
||||
from health_ml.utils import log_on_epoch
|
||||
from health_ml.deep_learning_config import OptimizerParams
|
||||
from health_cpath.models.encoders import IdentityEncoder
|
||||
|
@ -175,30 +176,30 @@ class BaseDeepMILModule(LightningModule):
|
|||
def get_metrics(self) -> nn.ModuleDict:
|
||||
if self.n_classes > 1:
|
||||
return nn.ModuleDict({
|
||||
MetricsKey.ACC: Accuracy(num_classes=self.n_classes),
|
||||
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
|
||||
MetricsKey.AVERAGE_PRECISION: AveragePrecision(num_classes=self.n_classes),
|
||||
MetricsKey.ACC: MulticlassAccuracy(num_classes=self.n_classes),
|
||||
MetricsKey.AUROC: MulticlassAUROC(num_classes=self.n_classes),
|
||||
MetricsKey.AVERAGE_PRECISION: MulticlassAveragePrecision(num_classes=self.n_classes),
|
||||
# Quadratic Weighted Kappa (QWK) used in PANDA challenge
|
||||
# is calculated using Cohen's Kappa with quadratic weights
|
||||
# https://www.kaggle.com/code/reighns/understanding-the-quadratic-weighted-kappa/
|
||||
MetricsKey.COHENKAPPA: CohenKappa(num_classes=self.n_classes, weights='quadratic'),
|
||||
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes),
|
||||
MetricsKey.COHENKAPPA: MulticlassCohenKappa(num_classes=self.n_classes, weights='quadratic'),
|
||||
MetricsKey.CONF_MATRIX: MulticlassConfusionMatrix(num_classes=self.n_classes),
|
||||
# Metrics below are computed for multi-class case only
|
||||
MetricsKey.ACC_MACRO: Accuracy(num_classes=self.n_classes, average='macro'),
|
||||
MetricsKey.ACC_WEIGHTED: Accuracy(num_classes=self.n_classes, average='weighted')})
|
||||
MetricsKey.ACC_MACRO: MulticlassAccuracy(num_classes=self.n_classes, average='macro'),
|
||||
MetricsKey.ACC_WEIGHTED: MulticlassAccuracy(num_classes=self.n_classes, average='weighted')})
|
||||
else:
|
||||
return nn.ModuleDict({
|
||||
MetricsKey.ACC: Accuracy(),
|
||||
MetricsKey.AUROC: AUROC(num_classes=None),
|
||||
MetricsKey.ACC: BinaryAccuracy(),
|
||||
MetricsKey.AUROC: BinaryAUROC(),
|
||||
# Average precision is a measure of area under the PR curve
|
||||
MetricsKey.AVERAGE_PRECISION: AveragePrecision(),
|
||||
MetricsKey.COHENKAPPA: CohenKappa(num_classes=2, weights='quadratic'),
|
||||
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=2),
|
||||
MetricsKey.AVERAGE_PRECISION: BinaryAveragePrecision(),
|
||||
MetricsKey.COHENKAPPA: BinaryCohenKappa(weights='quadratic'),
|
||||
MetricsKey.CONF_MATRIX: BinaryConfusionMatrix(),
|
||||
# Metrics below are computed for binary case only
|
||||
MetricsKey.F1: F1Score(),
|
||||
MetricsKey.PRECISION: Precision(),
|
||||
MetricsKey.RECALL: Recall(),
|
||||
MetricsKey.SPECIFICITY: Specificity()})
|
||||
MetricsKey.F1: BinaryF1Score(),
|
||||
MetricsKey.PRECISION: BinaryPrecision(),
|
||||
MetricsKey.RECALL: BinaryRecall(),
|
||||
MetricsKey.SPECIFICITY: BinarySpecificity()})
|
||||
|
||||
def get_extra_prefix(self) -> str:
|
||||
"""Get extra prefix for the metrics name to avoir overriding best validation metrics."""
|
||||
|
|
|
@ -15,8 +15,6 @@ from azureml.core.run import _OfflineRun
|
|||
HIML_ROOT = Path(__file__).parent.parent.parent.parent.parent.absolute()
|
||||
health_ml_root = HIML_ROOT / "hi-ml" / "src"
|
||||
health_azure_root = HIML_ROOT / "hi-ml-azure" / "src"
|
||||
print(f"Inserting into sys path: {health_ml_root}")
|
||||
print(f"Inserting into sys path: {health_azure_root}")
|
||||
sys.path.insert(0, str(health_ml_root))
|
||||
sys.path.insert(0, str(health_azure_root))
|
||||
|
||||
|
|
|
@ -132,9 +132,9 @@ class ExperimentFolderHandler(Parameterized):
|
|||
outputs_folder = snapshot_dir / DEFAULT_AML_UPLOAD_DIR
|
||||
logs_folder = snapshot_dir / DEFAULT_LOGS_DIR_NAME
|
||||
|
||||
print(f"Run outputs folder: {outputs_folder}")
|
||||
print(f"Logs folder: {logs_folder}")
|
||||
print(f"Run root directory: {run_folder}")
|
||||
logging.info(f"Run outputs folder: {outputs_folder}")
|
||||
logging.info(f"Logs folder: {logs_folder}")
|
||||
logging.info(f"Run root directory: {run_folder}")
|
||||
return ExperimentFolderHandler(
|
||||
outputs_folder=outputs_folder,
|
||||
logs_folder=logs_folder,
|
||||
|
|
|
@ -338,10 +338,6 @@ class MLRunner:
|
|||
if self.container.has_custom_test_step():
|
||||
# Run Lightning's built-in test procedure if the `test_step` method has been overridden
|
||||
logging.info("Running inference via the LightningModule.test_step method")
|
||||
# We run inference on a single device because distributed strategies such as DDP use DistributedSampler
|
||||
# internally, which replicates some samples to make sure all devices have some batch size in case of
|
||||
# uneven inputs.
|
||||
self.container.max_num_gpus = 1
|
||||
|
||||
checkpoint_path = (
|
||||
self.checkpoint_handler.get_checkpoint_to_test() if self.container.run_inference_only else None
|
||||
|
@ -408,6 +404,9 @@ class MLRunner:
|
|||
with logging_section("Model training"):
|
||||
self.run_training()
|
||||
|
||||
# Kill all processes besides rank 0
|
||||
self.after_ddp_cleanup(old_environ)
|
||||
|
||||
# load model checkpoint for custom inference or additional validation step
|
||||
if self.container.has_custom_test_step() or self.container.run_extra_val_epoch:
|
||||
self.load_model_checkpoint()
|
||||
|
@ -416,10 +415,6 @@ class MLRunner:
|
|||
if self.container.run_extra_val_epoch:
|
||||
with logging_section("Model Validation to save plots on validation set"):
|
||||
self.run_validation()
|
||||
|
||||
# Kill all processes besides rank 0
|
||||
self.after_ddp_cleanup(old_environ)
|
||||
|
||||
# Run inference on a single device
|
||||
with logging_section("Model inference"):
|
||||
self.run_inference()
|
||||
|
|
|
@ -32,7 +32,7 @@ from health_azure.amulet import prepare_amulet_job, is_amulet_job # noqa: E402
|
|||
from health_azure.utils import (get_workspace, get_ml_client, is_local_rank_zero, # noqa: E402
|
||||
is_running_in_azure_ml, set_environment_variables_for_multi_node,
|
||||
create_argparser, parse_arguments, ParserResult, apply_overrides,
|
||||
filter_v2_input_output_args)
|
||||
filter_v2_input_output_args, is_global_rank_zero)
|
||||
|
||||
from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, ExperimentConfig # noqa: E402
|
||||
from health_ml.lightning_container import LightningContainer # noqa: E402
|
||||
|
@ -354,7 +354,8 @@ def run(project_root: Path) -> Tuple[LightningContainer, AzureRunInfo]:
|
|||
:return: If submitting to AzureML, returns the model configuration that was used for training,
|
||||
including commandline overrides applied (if any). For details on the arguments, see the constructor of Runner.
|
||||
"""
|
||||
print(f"project root: {project_root}")
|
||||
if is_global_rank_zero():
|
||||
print(f"project root: {project_root}")
|
||||
runner = Runner(project_root)
|
||||
return runner.run()
|
||||
|
||||
|
|
|
@ -374,19 +374,6 @@ def test_model_weights_when_resume_training() -> None:
|
|||
assert recovery_checkpoint == runner.checkpoint_handler.trained_weights_path
|
||||
|
||||
|
||||
def test_runner_end_to_end() -> None:
|
||||
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||
container = HelloWorld()
|
||||
container.max_num_gpus = 0
|
||||
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
|
||||
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:
|
||||
mock_get_workspace.return_value = DEFAULT_WORKSPACE.workspace
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
runner.setup()
|
||||
runner.init_training()
|
||||
runner.run_training()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("log_from_vm", [True, False])
|
||||
def test_log_on_vm(log_from_vm: bool) -> None:
|
||||
"""Test if the AzureML logger is called when the experiment is run outside AzureML."""
|
||||
|
|
Загрузка…
Ссылка в новой задаче