Add Covid configs (#456)
This PR adds configs to train Covid detection models from Chest-Xray data. Co-authored-by: Shruthi42 <13177030+Shruthi42@users.noreply.github.com>
This commit is contained in:
Родитель
8bae42eb92
Коммит
55120d7a6b
|
@ -84,6 +84,7 @@ console for easier diagnostics.
|
|||
Additionally, the `TrainHelloWorldAndHelloContainer` job in the PR build has been split into two jobs, `TrainHelloWorld` and
|
||||
`TrainHelloContainer`. A pytest marker `after_training_hello_container` has been added to run tests after training is
|
||||
finished in the `TrainHelloContainer` job.
|
||||
- ([#456](https://github.com/microsoft/InnerEye-DeepLearning/pull/456)) Adding configs to train Covid detection models.
|
||||
|
||||
### Changed
|
||||
|
||||
|
|
|
@ -4,11 +4,12 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Optional
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
from torch.utils.data import Subset
|
||||
from torchvision.datasets import VisionDataset
|
||||
|
||||
from InnerEye.Common.type_annotations import PathOrString
|
||||
|
@ -175,3 +176,39 @@ class CheXpert(InnerEyeCXRDatasetWithReturnIndex):
|
|||
self.dataset_dataframe.Path = self.dataset_dataframe.Path.apply(lambda x: x[strip_n:])
|
||||
self.indices = np.arange(len(self.dataset_dataframe))
|
||||
self.filenames = [self.root / p for p in self.dataset_dataframe.Path.values]
|
||||
|
||||
|
||||
class CovidDataset(InnerEyeCXRDatasetWithReturnIndex):
|
||||
"""
|
||||
Dataset class to load CovidDataset dataset as datamodule for monitoring SSL training quality directly on
|
||||
CovidDataset data.
|
||||
We use CVX03 against CVX12 as proxy task.
|
||||
"""
|
||||
|
||||
def _prepare_dataset(self) -> None:
|
||||
self.dataset_dataframe = pd.read_csv(self.root / "dataset.csv")
|
||||
mapping = {0: 0, 3: 0, 1: 1, 2: 1}
|
||||
# For monitoring purpose with use binary classification CV03vsCV12
|
||||
self.dataset_dataframe["final_label"] = self.dataset_dataframe.final_label.apply(lambda x: mapping[x])
|
||||
self.indices = np.arange(len(self.dataset_dataframe))
|
||||
self.subject_ids = self.dataset_dataframe.subject.values
|
||||
self.filenames = [self.root / file for file in self.dataset_dataframe.filepath.values]
|
||||
self.targets = self.dataset_dataframe.final_label.values.astype(np.int64).reshape(-1)
|
||||
|
||||
@property
|
||||
def num_classes(self) -> int:
|
||||
return 2
|
||||
|
||||
def _split_dataset(self, val_split: float, seed: int) -> Tuple[Subset, Subset]:
|
||||
"""
|
||||
Implements val - train split.
|
||||
:param val_split: proportion to use for validation
|
||||
:param seed: random seed for splitting
|
||||
:return: dataset_train, dataset_val
|
||||
"""
|
||||
shuffled_subject_ids = np.random.RandomState(seed).permutation(np.unique(self.subject_ids))
|
||||
n_val = int(len(shuffled_subject_ids) * val_split)
|
||||
val_subjects, train_subjects = shuffled_subject_ids[:n_val], shuffled_subject_ids[n_val:]
|
||||
train_ids, val_ids = np.where(np.isin(self.subject_ids, train_subjects))[0], \
|
||||
np.where(np.isin(self.subject_ids, val_subjects))[0]
|
||||
return Subset(self, train_ids), Subset(self, val_ids)
|
||||
|
|
|
@ -13,7 +13,7 @@ from pytorch_lightning import LightningModule
|
|||
from yacs.config import CfgNode
|
||||
|
||||
from InnerEye.ML.SSL.datamodules_and_datasets.cifar_datasets import InnerEyeCIFAR10, InnerEyeCIFAR100
|
||||
from InnerEye.ML.SSL.datamodules_and_datasets.cxr_datasets import CheXpert, NIHCXR, RSNAKaggleCXR
|
||||
from InnerEye.ML.SSL.datamodules_and_datasets.cxr_datasets import CheXpert, CovidDataset, NIHCXR, RSNAKaggleCXR
|
||||
from InnerEye.ML.SSL.datamodules_and_datasets.datamodules import CombinedDataModule, InnerEyeVisionDataModule
|
||||
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import InnerEyeCIFARLinearHeadTransform, \
|
||||
InnerEyeCIFARTrainTransform, \
|
||||
|
@ -42,11 +42,12 @@ class EncoderName(Enum):
|
|||
|
||||
|
||||
class SSLDatasetName(Enum):
|
||||
RSNAKaggleCXR = "RSNAKaggleCXR"
|
||||
NIHCXR = "NIHCXR"
|
||||
CIFAR10 = "CIFAR10"
|
||||
CIFAR100 = "CIFAR100"
|
||||
RSNAKaggleCXR = "RSNAKaggleCXR"
|
||||
NIHCXR = "NIHCXR"
|
||||
CheXpert = "CheXpert"
|
||||
Covid = "CovidDataset"
|
||||
|
||||
|
||||
InnerEyeDataModuleTypes = Union[InnerEyeVisionDataModule, CombinedDataModule]
|
||||
|
@ -62,11 +63,12 @@ class SSLContainer(LightningContainer):
|
|||
Note that this container is also used as the base class for SSLImageClassifier (finetuning container) as they share
|
||||
setup and datamodule methods.
|
||||
"""
|
||||
_SSLDataClassMappings = {SSLDatasetName.RSNAKaggleCXR.value: RSNAKaggleCXR,
|
||||
SSLDatasetName.NIHCXR.value: NIHCXR,
|
||||
SSLDatasetName.CIFAR10.value: InnerEyeCIFAR10,
|
||||
_SSLDataClassMappings = {SSLDatasetName.CIFAR10.value: InnerEyeCIFAR10,
|
||||
SSLDatasetName.CIFAR100.value: InnerEyeCIFAR100,
|
||||
SSLDatasetName.CheXpert.value: CheXpert}
|
||||
SSLDatasetName.RSNAKaggleCXR.value: RSNAKaggleCXR,
|
||||
SSLDatasetName.NIHCXR.value: NIHCXR,
|
||||
SSLDatasetName.CheXpert.value: CheXpert,
|
||||
SSLDatasetName.Covid.value: CovidDataset}
|
||||
|
||||
ssl_augmentation_config = param.ClassSelector(class_=Path, allow_None=True,
|
||||
doc="The path to the yaml config defining the parameters of the "
|
||||
|
@ -87,11 +89,13 @@ class SSLContainer(LightningContainer):
|
|||
"Used for debugging and tests.")
|
||||
linear_head_augmentation_config = param.ClassSelector(class_=Path,
|
||||
doc="The path to the yaml config for the linear head "
|
||||
"augmentations")
|
||||
"augmentations")
|
||||
linear_head_dataset_name = param.ClassSelector(class_=SSLDatasetName,
|
||||
doc="Name of the dataset to use for the linear head training")
|
||||
linear_head_batch_size = param.Integer(default=256, doc="Batch size for linear head tuning")
|
||||
learning_rate_linear_head_during_ssl_training = param.Number(default=1e-4, doc="Learning rate for linear head training during SSL training.")
|
||||
learning_rate_linear_head_during_ssl_training = param.Number(default=1e-4,
|
||||
doc="Learning rate for linear head training during "
|
||||
"SSL training.")
|
||||
|
||||
def setup(self) -> None:
|
||||
from InnerEye.ML.SSL.lightning_containers.ssl_image_classifier import SSLClassifierContainer
|
||||
|
@ -173,7 +177,8 @@ class SSLContainer(LightningContainer):
|
|||
return self.data_module
|
||||
encoder_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=True)
|
||||
linear_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=False)
|
||||
return CombinedDataModule(encoder_data_module, linear_data_module, self.use_balanced_binary_loss_for_linear_head)
|
||||
return CombinedDataModule(encoder_data_module, linear_data_module,
|
||||
self.use_balanced_binary_loss_for_linear_head)
|
||||
|
||||
def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisionDataModule:
|
||||
"""
|
||||
|
@ -220,7 +225,10 @@ class SSLContainer(LightningContainer):
|
|||
applied on. If False, return only one transformation.
|
||||
:return: training transformation pipeline and validation transformation pipeline.
|
||||
"""
|
||||
if dataset_name in [SSLDatasetName.RSNAKaggleCXR.value, SSLDatasetName.NIHCXR.value, SSLDatasetName.CheXpert.value]:
|
||||
if dataset_name in [SSLDatasetName.RSNAKaggleCXR.value,
|
||||
SSLDatasetName.NIHCXR.value,
|
||||
SSLDatasetName.CheXpert.value,
|
||||
SSLDatasetName.Covid.value]:
|
||||
assert augmentation_config is not None
|
||||
train_transforms, val_transforms = get_cxr_ssl_transforms(augmentation_config,
|
||||
return_two_views_per_sample=is_ssl_encoder_module,
|
||||
|
|
|
@ -0,0 +1,247 @@
|
|||
import codecs
|
||||
import logging
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
import PIL
|
||||
import pandas as pd
|
||||
import param
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pytorch_lightning import LightningModule
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from InnerEye.Common.common_util import ModelProcessing, get_best_epoch_results_path
|
||||
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import create_chest_xray_transform
|
||||
from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName
|
||||
|
||||
from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier
|
||||
from InnerEye.ML.SSL.utils import create_ssl_encoder, create_ssl_image_classifier, load_ssl_augmentation_config
|
||||
from InnerEye.ML.common import ModelExecutionMode
|
||||
|
||||
from InnerEye.ML.configs.ssl.CXR_SSL_configs import path_linear_head_augmentation_cxr
|
||||
from InnerEye.ML.deep_learning_config import LRSchedulerType, MultiprocessingStartMethod, \
|
||||
OptimizerType
|
||||
|
||||
from InnerEye.ML.model_config_base import ModelTransformsPerExecutionMode
|
||||
from InnerEye.ML.model_testing import MODEL_OUTPUT_CSV
|
||||
|
||||
from InnerEye.ML.models.architectures.classification.image_encoder_with_mlp import ImagingFeatureType
|
||||
from InnerEye.ML.reports.notebook_report import generate_notebook, get_ipynb_report_name, str_or_empty
|
||||
|
||||
from InnerEye.ML.scalar_config import ScalarLoss, ScalarModelBase
|
||||
from InnerEye.ML.utils.augmentation import ScalarItemAugmentation
|
||||
from InnerEye.ML.utils.run_recovery import RunRecovery
|
||||
from InnerEye.ML.utils.split_dataset import DatasetSplits
|
||||
|
||||
from InnerEye.ML.configs.ssl.CovidContainers import COVID_DATASET_ID
|
||||
from InnerEye.Common import fixed_paths as fixed_paths_innereye
|
||||
|
||||
|
||||
class CovidHierarchicalModel(ScalarModelBase):
|
||||
"""
|
||||
Model to train a CovidDataset model from scratch or finetune from SSL-pretrained model.
|
||||
|
||||
For AML you need to provide the run_id of your SSL training job as a command line argument
|
||||
--pretraining_run_recovery_id=id_of_your_ssl_model, this will download the checkpoints of the run to your
|
||||
machine and load the corresponding pretrained model.
|
||||
|
||||
To recover from a particular checkpoint from your SSL run e.g. "recovery_epoch=499.ckpt" please use hte
|
||||
--name_of_checkpoint argument.
|
||||
"""
|
||||
use_pretrained_model = param.Boolean(default=False, doc="If True, start training from a model pretrained with SSL."
|
||||
"If False, start training a DenseNet model from scratch"
|
||||
"(random initialization).")
|
||||
freeze_encoder = param.Boolean(default=False, doc="Whether to freeze the pretrained encoder or not.")
|
||||
name_of_checkpoint = param.String(default=None, doc="Filename of checkpoint to use for recovery")
|
||||
test_set_ids_csv = param.String(default=None,
|
||||
doc="Name of the csv file in the dataset folder with the test set ids. The dataset"
|
||||
"is expected to have a 'series' and a 'subject' column. The subject column"
|
||||
"is assumed to contain unique ids.")
|
||||
|
||||
def __init__(self, covid_dataset_id: str = COVID_DATASET_ID, **kwargs: Any):
|
||||
learning_rate = 1e-5 if self.use_pretrained_model else 1e-4
|
||||
super().__init__(target_names=['CVX03vs12', 'CVX0vs3', 'CVX1vs2'],
|
||||
loss_type=ScalarLoss.CustomClassification,
|
||||
class_names=['CVX0', 'CVX1', 'CVX2', 'CVX3'],
|
||||
max_num_gpus=1,
|
||||
azure_dataset_id=covid_dataset_id,
|
||||
subject_column="series",
|
||||
image_file_column="filepath",
|
||||
label_value_column="final_label",
|
||||
non_image_feature_channels=[],
|
||||
numerical_columns=[],
|
||||
use_mixed_precision=False,
|
||||
num_dataload_workers=12,
|
||||
multiprocessing_start_method=MultiprocessingStartMethod.fork,
|
||||
train_batch_size=64,
|
||||
optimizer_type=OptimizerType.Adam,
|
||||
num_epochs=50,
|
||||
l_rate_scheduler=LRSchedulerType.Step,
|
||||
l_rate_step_gamma=1.0,
|
||||
l_rate=learning_rate,
|
||||
l_rate_multi_step_milestones=None,
|
||||
**kwargs)
|
||||
self.num_classes = 3
|
||||
if not self.use_pretrained_model and self.freeze_encoder:
|
||||
raise ValueError("No encoder to freeze when training from scratch. You requested training from scratch and"
|
||||
"encoder freezing.")
|
||||
|
||||
def should_generate_multilabel_report(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
|
||||
if self.test_set_ids_csv:
|
||||
test_df = pd.read_csv(self.local_dataset / self.test_set_ids_csv)
|
||||
in_test_set = dataset_df.series.isin(test_df.series)
|
||||
train_ids = dataset_df.series[~in_test_set].values
|
||||
test_ids = dataset_df.series[in_test_set].values
|
||||
num_val_samples = 400
|
||||
val_ids = train_ids[:num_val_samples]
|
||||
train_ids = train_ids[num_val_samples:]
|
||||
return DatasetSplits.from_subject_ids(dataset_df, train_ids=train_ids, val_ids=val_ids, test_ids=test_ids,
|
||||
subject_column="series", group_column="subject")
|
||||
else:
|
||||
return DatasetSplits.from_proportions(dataset_df,
|
||||
proportion_train=0.8,
|
||||
proportion_val=0.1,
|
||||
proportion_test=0.1,
|
||||
subject_column="series",
|
||||
group_column="subject",
|
||||
shuffle=True)
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def get_image_sample_transforms(self) -> ModelTransformsPerExecutionMode:
|
||||
config = load_ssl_augmentation_config(path_linear_head_augmentation_cxr)
|
||||
train_transforms = ScalarItemAugmentation(
|
||||
Compose(
|
||||
[DicomPreparation(), create_chest_xray_transform(config, apply_augmentations=True)]))
|
||||
val_transforms = ScalarItemAugmentation(
|
||||
Compose(
|
||||
[DicomPreparation(), create_chest_xray_transform(config, apply_augmentations=False)]))
|
||||
|
||||
return ModelTransformsPerExecutionMode(train=train_transforms,
|
||||
val=val_transforms,
|
||||
test=val_transforms)
|
||||
|
||||
def create_model(self) -> LightningModule:
|
||||
"""
|
||||
This method must create the actual Lightning model that will be trained.
|
||||
"""
|
||||
if self.use_pretrained_model:
|
||||
path_to_checkpoint = self._get_ssl_checkpoint_path()
|
||||
|
||||
model = create_ssl_image_classifier(
|
||||
num_classes=self.num_classes,
|
||||
pl_checkpoint_path=str(path_to_checkpoint),
|
||||
freeze_encoder=self.freeze_encoder)
|
||||
|
||||
else:
|
||||
encoder = create_ssl_encoder(encoder_name=EncoderName.densenet121.value)
|
||||
model = SSLClassifier(num_classes=self.num_classes,
|
||||
encoder=encoder,
|
||||
freeze_encoder=self.freeze_encoder,
|
||||
class_weights=None)
|
||||
# Next args are just here because we are using this model within an InnerEyeContainer
|
||||
model.imaging_feature_type = ImagingFeatureType.Image # type: ignore
|
||||
model.num_non_image_features = 0 # type: ignore
|
||||
model.encode_channels_jointly = True # type: ignore
|
||||
return model
|
||||
|
||||
def _get_ssl_checkpoint_path(self) -> Path:
|
||||
# Get the SSL weights from the AML run provided via "pretraining_run_recovery_id" command line argument.
|
||||
# Accessible via extra_downloaded_run_id field of the config.
|
||||
assert self.extra_downloaded_run_id is not None
|
||||
assert isinstance(self.extra_downloaded_run_id, RunRecovery)
|
||||
ssl_path = self.checkpoint_folder / "ssl_checkpoint.ckpt"
|
||||
|
||||
if not ssl_path.exists(): # for test (when it is already present) we don't need to redo this.
|
||||
if self.name_of_checkpoint is not None:
|
||||
logging.info(f"Using checkpoint: {self.name_of_checkpoint} as starting point.")
|
||||
path_to_checkpoint = self.extra_downloaded_run_id.checkpoints_roots[0] / self.name_of_checkpoint
|
||||
else:
|
||||
path_to_checkpoint = self.extra_downloaded_run_id.get_best_checkpoint_paths()[0]
|
||||
if not path_to_checkpoint.exists():
|
||||
logging.info("No best checkpoint found for this model. Getting the latest recovery "
|
||||
"checkpoint instead.")
|
||||
path_to_checkpoint = self.extra_downloaded_run_id.get_recovery_checkpoint_paths()[0]
|
||||
assert path_to_checkpoint.exists()
|
||||
path_to_checkpoint.rename(ssl_path)
|
||||
return ssl_path
|
||||
|
||||
def pre_process_dataset_dataframe(self) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_posthoc_label_transform() -> Callable:
|
||||
import torch
|
||||
|
||||
def multiclass_to_hierarchical_labels(classes: torch.Tensor) -> torch.Tensor:
|
||||
classes = classes.clone()
|
||||
cvx03vs12 = classes[..., 1] + classes[..., 2]
|
||||
cvx0vs3 = classes[..., 3]
|
||||
cvx1vs2 = classes[..., 2]
|
||||
cvx0vs3[cvx03vs12 == 1] = float('nan') # CVX0vs3 only gets gradient for CVX03
|
||||
cvx1vs2[cvx03vs12 == 0] = float('nan') # CVX1vs2 only gets gradient for CVX12
|
||||
return torch.stack([cvx03vs12, cvx0vs3, cvx1vs2], -1)
|
||||
|
||||
return multiclass_to_hierarchical_labels
|
||||
|
||||
@staticmethod
|
||||
def get_loss_function() -> Callable:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
def nan_bce_with_logits(output: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute BCE with logits, ignoring NaN values"""
|
||||
valid = labels.isfinite()
|
||||
losses = F.binary_cross_entropy_with_logits(output[valid], labels[valid], reduction='none')
|
||||
return losses.sum() / labels.shape[0]
|
||||
|
||||
return nan_bce_with_logits
|
||||
|
||||
def generate_custom_report(self, report_dir: Path, model_proc: ModelProcessing) -> Path:
|
||||
"""
|
||||
Generate a custom report for the CovidDataset Hierarchical model. At the moment, this report will read the
|
||||
file model_output.csv generated for the training, validation or test sets and compute a 4 class accuracy
|
||||
and confusion matrix based on this.
|
||||
:param report_dir: Directory report is to be written to
|
||||
:param model_proc: Whether this is a single or ensemble model (model_output.csv will be located in different
|
||||
paths for single vs ensemble runs.)
|
||||
"""
|
||||
|
||||
def get_output_csv_path(mode: ModelExecutionMode) -> Path:
|
||||
p = get_best_epoch_results_path(mode=mode, model_proc=model_proc)
|
||||
return self.outputs_folder / p / MODEL_OUTPUT_CSV
|
||||
|
||||
train_metrics = get_output_csv_path(ModelExecutionMode.TRAIN)
|
||||
val_metrics = get_output_csv_path(ModelExecutionMode.VAL)
|
||||
test_metrics = get_output_csv_path(ModelExecutionMode.TEST)
|
||||
|
||||
notebook_params = \
|
||||
{
|
||||
'innereye_path': str(fixed_paths_innereye.repository_root_directory()),
|
||||
'train_metrics_csv': str_or_empty(train_metrics),
|
||||
'val_metrics_csv': str_or_empty(val_metrics),
|
||||
'test_metrics_csv': str_or_empty(test_metrics),
|
||||
"config": codecs.encode(pickle.dumps(self), "base64").decode(),
|
||||
"is_crossval_report": False
|
||||
}
|
||||
template = Path(__file__).absolute().parent.parent / "reports" / "CovidHierarchicalModelReport.ipynb"
|
||||
return generate_notebook(template,
|
||||
notebook_params=notebook_params,
|
||||
result_notebook=report_dir / get_ipynb_report_name(
|
||||
f"{self.model_category.value}_hierarchical"))
|
||||
|
||||
|
||||
class DicomPreparation:
|
||||
def __call__(self, item: torch.Tensor) -> PIL.Image:
|
||||
# Item will be of dimension [C, Z, X, Y]
|
||||
images = item.numpy()
|
||||
assert images.shape[0] == 1 and images.shape[1] == 1
|
||||
images = images.reshape(images.shape[2:])
|
||||
normalized_image = (images - images.min()) * 255. / (images.max() - images.min())
|
||||
image = Image.fromarray(normalized_image).convert("L")
|
||||
return image
|
|
@ -0,0 +1,161 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%javascript\n",
|
||||
"IPython.OutputArea.prototype._should_scroll = function(lines) {\n",
|
||||
" return false;\n",
|
||||
"}\n",
|
||||
"// Stops auto-scrolling so entire output is visible: see https://stackoverflow.com/a/41646403"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2",
|
||||
"metadata": {
|
||||
"tags": [
|
||||
"parameters"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Default parameter values. They will be overwritten by papermill notebook parameters.\n",
|
||||
"# This cell must carry the tag \"parameters\" in its metadata.\n",
|
||||
"from pathlib import Path\n",
|
||||
"import pickle\n",
|
||||
"import codecs\n",
|
||||
"\n",
|
||||
"innereye_path = Path.cwd().parent.parent.parent.parent\n",
|
||||
"train_metrics_csv = \"\"\n",
|
||||
"val_metrics_csv = \"\"\n",
|
||||
"test_metrics_csv = \"\"\n",
|
||||
"config = \"\"\n",
|
||||
"is_crossval_report = False"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"print(f\"Adding to path: {innereye_path}\")\n",
|
||||
"if str(innereye_path) not in sys.path:\n",
|
||||
" sys.path.append(str(innereye_path))\n",
|
||||
"\n",
|
||||
"%matplotlib inline\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"config = pickle.loads(codecs.decode(config.encode(), \"base64\"))\n",
|
||||
"\n",
|
||||
"from InnerEye.ML.common import ModelExecutionMode\n",
|
||||
"from InnerEye.ML.reports.notebook_report import print_header\n",
|
||||
"from InnerEye.ML.configs.reports.covid_hierarchical_model_report import print_metrics_from_csv\n",
|
||||
"\n",
|
||||
"import warnings\n",
|
||||
"warnings.filterwarnings(\"ignore\")\n",
|
||||
"plt.rcParams['figure.figsize'] = (20, 10)\n",
|
||||
"\n",
|
||||
"#convert params to Path\n",
|
||||
"train_metrics_csv = Path(train_metrics_csv)\n",
|
||||
"val_metrics_csv = Path(val_metrics_csv)\n",
|
||||
"test_metrics_csv = Path(test_metrics_csv)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Metrics\n",
|
||||
"## Train Set"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if train_metrics_csv.is_file():\n",
|
||||
" print_metrics_from_csv(csv_to_set_optimal_threshold=train_metrics_csv,\n",
|
||||
" csv_to_compute_metrics=train_metrics_csv,\n",
|
||||
" config=config, is_crossval_report=is_crossval_report)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Validation Set"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if val_metrics_csv.is_file():\n",
|
||||
" print_metrics_from_csv(csv_to_set_optimal_threshold=val_metrics_csv,\n",
|
||||
" csv_to_compute_metrics=val_metrics_csv,\n",
|
||||
" config=config, is_crossval_report=is_crossval_report)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Test Set"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
|
||||
" print_metrics_from_csv(csv_to_set_optimal_threshold=val_metrics_csv,\n",
|
||||
" csv_to_compute_metrics=test_metrics_csv,\n",
|
||||
" config=config, is_crossval_report=is_crossval_report)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"celltoolbar": "Tags",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
|
@ -0,0 +1,104 @@
|
|||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from pathlib import Path
|
||||
from sklearn.metrics import accuracy_score, confusion_matrix
|
||||
from typing import Dict
|
||||
|
||||
from InnerEye.Common.metrics_constants import LoggingColumns
|
||||
from InnerEye.ML.reports.classification_report import get_labels_and_predictions_from_dataframe, LabelsAndPredictions
|
||||
from InnerEye.ML.reports.notebook_report import print_table
|
||||
from InnerEye.ML.scalar_config import ScalarModelBase
|
||||
|
||||
TARGET_NAMES = ['CVX03vs12', 'CVX0vs3', 'CVX1vs2']
|
||||
MULTICLASS_HUE_NAME = "Multiclass"
|
||||
|
||||
|
||||
def get_label_from_label_dict(label_dict: Dict[str, float]) -> int:
|
||||
"""
|
||||
Converts strings CVX03vs12, CVX1vs2, CVX0vs3 to the corresponding class as int.
|
||||
"""
|
||||
if label_dict['CVX03vs12'] == 0:
|
||||
assert np.isnan(label_dict['CVX1vs2'])
|
||||
if label_dict['CVX0vs3'] == 0:
|
||||
label = 0
|
||||
elif label_dict['CVX0vs3'] == 1:
|
||||
label = 3
|
||||
else:
|
||||
raise ValueError("CVX0vs3 should be 0 or 1.")
|
||||
elif label_dict['CVX03vs12'] == 1:
|
||||
assert np.isnan(label_dict['CVX0vs3'])
|
||||
if label_dict['CVX1vs2'] == 0:
|
||||
label = 1
|
||||
elif label_dict['CVX1vs2'] == 1:
|
||||
label = 2
|
||||
else:
|
||||
raise ValueError("CVX1vs2 should be 0 or 1.")
|
||||
else:
|
||||
raise ValueError("CVX03vs12 should be 0 or 1.")
|
||||
return label
|
||||
|
||||
|
||||
def get_model_prediction_by_probabilities(output_dict: Dict[str, float]) -> int:
|
||||
"""
|
||||
Based on the values for CVX03vs12, CVX0vs3 and CVX1vs2 predicted by the model, predict the CVX scores as followed:
|
||||
score(CVX0) = [1 - score(CVX03vs12)][1 - score(CVX0vs3)]
|
||||
score(CVX1) = score(CVX03vs12)[1 - score(CVX1vs2)]
|
||||
score(CVX2) = score(CVX03vs12)score(CVX1vs2)
|
||||
score(CVX3) = [1 - score(CVX03vs12)]score(CVX0vs3)
|
||||
"""
|
||||
cvx0 = (1 - output_dict['CVX03vs12']) * (1 - output_dict['CVX0vs3'])
|
||||
cvx3 = (1 - output_dict['CVX03vs12']) * output_dict['CVX0vs3']
|
||||
cvx1 = output_dict['CVX03vs12'] * (1 - output_dict['CVX1vs2'])
|
||||
cvx2 = output_dict['CVX03vs12'] * output_dict['CVX1vs2']
|
||||
return np.argmax([cvx0, cvx1, cvx2, cvx3])
|
||||
|
||||
|
||||
def get_dataframe_with_covid_labels(metrics_df: pd.DataFrame) -> pd.DataFrame:
|
||||
def get_CVX_labels(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Given a dataframe (with only one subject) with the model outputs for CVX03vs12, CVX0vs3 and CVX1vs2,
|
||||
returns a corresponding dataframe with scores for CVX0, CVX1, CVX2 and CVX3 for this subject. See
|
||||
`get_model_prediction_by_probabilities` for details on mapping the model output to CVX labels.
|
||||
"""
|
||||
df_by_hue = df[df[LoggingColumns.Hue.value].isin(TARGET_NAMES)].set_index(LoggingColumns.Hue.value)
|
||||
model_output = get_model_prediction_by_probabilities(df_by_hue[LoggingColumns.ModelOutput.value].to_dict())
|
||||
label = get_label_from_label_dict(df_by_hue[LoggingColumns.Label.value].to_dict())
|
||||
|
||||
return pd.DataFrame.from_dict({LoggingColumns.Patient.value: [df.iloc[0][LoggingColumns.Patient.value]],
|
||||
LoggingColumns.ModelOutput.value: [model_output],
|
||||
LoggingColumns.Label.value: [label]})
|
||||
|
||||
df = metrics_df.copy()
|
||||
# Group by subject, and for each subject, convert the CVX03vs12, CVX0vs3 and CVX1vs2 predictions to CVX labels.
|
||||
df = df.groupby(LoggingColumns.Patient.value, as_index=False).apply(get_CVX_labels).reset_index(drop=True)
|
||||
df[LoggingColumns.Hue.value] = [MULTICLASS_HUE_NAME] * len(df)
|
||||
return df
|
||||
|
||||
|
||||
def get_labels_and_predictions_covid_labels(csv: Path) -> LabelsAndPredictions:
|
||||
metrics_df = pd.read_csv(csv)
|
||||
df = get_dataframe_with_covid_labels(metrics_df=metrics_df)
|
||||
return get_labels_and_predictions_from_dataframe(df)
|
||||
|
||||
|
||||
def print_metrics_from_csv(csv_to_set_optimal_threshold: Path,
|
||||
csv_to_compute_metrics: Path,
|
||||
config: ScalarModelBase,
|
||||
is_crossval_report: bool) -> None:
|
||||
assert config.target_names == TARGET_NAMES
|
||||
|
||||
predictions_to_compute_metrics = get_labels_and_predictions_covid_labels(
|
||||
csv=csv_to_compute_metrics)
|
||||
|
||||
acc = accuracy_score(predictions_to_compute_metrics.labels, predictions_to_compute_metrics.model_outputs)
|
||||
rows = [[f"{acc:.4f}"]]
|
||||
print_table(rows, header=["Multiclass Accuracy"])
|
||||
|
||||
conf_matrix = confusion_matrix(predictions_to_compute_metrics.labels, predictions_to_compute_metrics.model_outputs)
|
||||
rows = []
|
||||
header = ["", "CVX0 predicted", "CVX1 predicted", "CVX2 predicted", "CVX3 predicted"]
|
||||
for i in range(conf_matrix.shape[0]):
|
||||
line = [f"CVX{i} GT"] + list(conf_matrix[i])
|
||||
rows.append(line)
|
||||
print_table(rows, header=header)
|
|
@ -0,0 +1,36 @@
|
|||
from typing import Any
|
||||
|
||||
from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName, SSLContainer, SSLDatasetName
|
||||
from InnerEye.ML.SSL.utils import SSLTrainingType
|
||||
from InnerEye.ML.configs.ssl.CXR_SSL_configs import NIH_AZURE_DATASET_ID, path_encoder_augmentation_cxr, \
|
||||
path_linear_head_augmentation_cxr
|
||||
|
||||
COVID_DATASET_ID = "id-of-your-dataset"
|
||||
|
||||
|
||||
class NIH_COVID_BYOL(SSLContainer):
|
||||
"""
|
||||
Class to train a SSL model on NIH dataset and monitor embeddings quality on a Covid Dataset.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
covid_dataset_id: str = COVID_DATASET_ID,
|
||||
**kwargs: Any):
|
||||
super().__init__(ssl_training_dataset_name=SSLDatasetName.NIHCXR,
|
||||
linear_head_dataset_name=SSLDatasetName.Covid,
|
||||
random_seed=1,
|
||||
recovery_checkpoint_save_interval=50,
|
||||
recovery_checkpoints_save_last_k=3,
|
||||
num_epochs=500,
|
||||
ssl_training_batch_size=1200, # This runs with 16 gpus (4 nodes)
|
||||
num_workers=12,
|
||||
ssl_encoder=EncoderName.densenet121,
|
||||
ssl_training_type=SSLTrainingType.BYOL,
|
||||
use_balanced_binary_loss_for_linear_head=True,
|
||||
ssl_augmentation_config=path_encoder_augmentation_cxr,
|
||||
extra_azure_dataset_ids=[covid_dataset_id],
|
||||
azure_dataset_id=NIH_AZURE_DATASET_ID,
|
||||
linear_head_augmentation_config=path_linear_head_augmentation_cxr,
|
||||
online_evaluator_lr=1e-5,
|
||||
linear_head_batch_size=64,
|
||||
**kwargs)
|
|
@ -0,0 +1,22 @@
|
|||
import pandas as pd
|
||||
from math import nan
|
||||
|
||||
from InnerEye.Common.metrics_constants import LoggingColumns
|
||||
from InnerEye.ML.configs.reports.covid_hierarchical_model_report import MULTICLASS_HUE_NAME, \
|
||||
get_dataframe_with_covid_labels
|
||||
|
||||
|
||||
def test_get_dataframe_with_covid_labels() -> None:
|
||||
|
||||
df = pd.DataFrame.from_dict({LoggingColumns.Patient.value: [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4],
|
||||
LoggingColumns.Hue.value: ['CVX03vs12', 'CVX0vs3', 'CVX1vs2'] * 4,
|
||||
LoggingColumns.Label.value: [0, 0, nan, 0, 1, nan, 1, nan, 0, 1, nan, 1],
|
||||
LoggingColumns.ModelOutput.value: [0.1, 0.1, 0.5, 0.1, 0.9, 0.5, 0.9, 0.9, 0.9, 0.1, 0.2, 0.1]})
|
||||
expected_df = pd.DataFrame.from_dict({LoggingColumns.Patient.value: [1, 2, 3, 4],
|
||||
LoggingColumns.ModelOutput.value: [0, 3, 2, 0],
|
||||
LoggingColumns.Label.value: [0, 3, 1, 2],
|
||||
LoggingColumns.Hue.value: [MULTICLASS_HUE_NAME] * 4
|
||||
})
|
||||
|
||||
multiclass_df = get_dataframe_with_covid_labels(df)
|
||||
assert expected_df.equals(multiclass_df)
|
Загрузка…
Ссылка в новой задаче