This PR adds configs to train Covid detection models from Chest-Xray data. 

Co-authored-by: Shruthi42 <13177030+Shruthi42@users.noreply.github.com>
This commit is contained in:
melanibe 2021-05-20 17:27:25 +01:00 коммит произвёл GitHub
Родитель 8bae42eb92
Коммит 55120d7a6b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 628 добавлений и 12 удалений

Просмотреть файл

@ -84,6 +84,7 @@ console for easier diagnostics.
Additionally, the `TrainHelloWorldAndHelloContainer` job in the PR build has been split into two jobs, `TrainHelloWorld` and
`TrainHelloContainer`. A pytest marker `after_training_hello_container` has been added to run tests after training is
finished in the `TrainHelloContainer` job.
- ([#456](https://github.com/microsoft/InnerEye-DeepLearning/pull/456)) Adding configs to train Covid detection models.
### Changed

Просмотреть файл

@ -4,11 +4,12 @@
# ------------------------------------------------------------------------------------------
import logging
from pathlib import Path
from typing import Any, Callable, List, Optional
from typing import Any, Callable, List, Optional, Tuple
import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Subset
from torchvision.datasets import VisionDataset
from InnerEye.Common.type_annotations import PathOrString
@ -175,3 +176,39 @@ class CheXpert(InnerEyeCXRDatasetWithReturnIndex):
self.dataset_dataframe.Path = self.dataset_dataframe.Path.apply(lambda x: x[strip_n:])
self.indices = np.arange(len(self.dataset_dataframe))
self.filenames = [self.root / p for p in self.dataset_dataframe.Path.values]
class CovidDataset(InnerEyeCXRDatasetWithReturnIndex):
"""
Dataset class to load CovidDataset dataset as datamodule for monitoring SSL training quality directly on
CovidDataset data.
We use CVX03 against CVX12 as proxy task.
"""
def _prepare_dataset(self) -> None:
self.dataset_dataframe = pd.read_csv(self.root / "dataset.csv")
mapping = {0: 0, 3: 0, 1: 1, 2: 1}
# For monitoring purpose with use binary classification CV03vsCV12
self.dataset_dataframe["final_label"] = self.dataset_dataframe.final_label.apply(lambda x: mapping[x])
self.indices = np.arange(len(self.dataset_dataframe))
self.subject_ids = self.dataset_dataframe.subject.values
self.filenames = [self.root / file for file in self.dataset_dataframe.filepath.values]
self.targets = self.dataset_dataframe.final_label.values.astype(np.int64).reshape(-1)
@property
def num_classes(self) -> int:
return 2
def _split_dataset(self, val_split: float, seed: int) -> Tuple[Subset, Subset]:
"""
Implements val - train split.
:param val_split: proportion to use for validation
:param seed: random seed for splitting
:return: dataset_train, dataset_val
"""
shuffled_subject_ids = np.random.RandomState(seed).permutation(np.unique(self.subject_ids))
n_val = int(len(shuffled_subject_ids) * val_split)
val_subjects, train_subjects = shuffled_subject_ids[:n_val], shuffled_subject_ids[n_val:]
train_ids, val_ids = np.where(np.isin(self.subject_ids, train_subjects))[0], \
np.where(np.isin(self.subject_ids, val_subjects))[0]
return Subset(self, train_ids), Subset(self, val_ids)

Просмотреть файл

@ -13,7 +13,7 @@ from pytorch_lightning import LightningModule
from yacs.config import CfgNode
from InnerEye.ML.SSL.datamodules_and_datasets.cifar_datasets import InnerEyeCIFAR10, InnerEyeCIFAR100
from InnerEye.ML.SSL.datamodules_and_datasets.cxr_datasets import CheXpert, NIHCXR, RSNAKaggleCXR
from InnerEye.ML.SSL.datamodules_and_datasets.cxr_datasets import CheXpert, CovidDataset, NIHCXR, RSNAKaggleCXR
from InnerEye.ML.SSL.datamodules_and_datasets.datamodules import CombinedDataModule, InnerEyeVisionDataModule
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import InnerEyeCIFARLinearHeadTransform, \
InnerEyeCIFARTrainTransform, \
@ -42,11 +42,12 @@ class EncoderName(Enum):
class SSLDatasetName(Enum):
RSNAKaggleCXR = "RSNAKaggleCXR"
NIHCXR = "NIHCXR"
CIFAR10 = "CIFAR10"
CIFAR100 = "CIFAR100"
RSNAKaggleCXR = "RSNAKaggleCXR"
NIHCXR = "NIHCXR"
CheXpert = "CheXpert"
Covid = "CovidDataset"
InnerEyeDataModuleTypes = Union[InnerEyeVisionDataModule, CombinedDataModule]
@ -62,11 +63,12 @@ class SSLContainer(LightningContainer):
Note that this container is also used as the base class for SSLImageClassifier (finetuning container) as they share
setup and datamodule methods.
"""
_SSLDataClassMappings = {SSLDatasetName.RSNAKaggleCXR.value: RSNAKaggleCXR,
SSLDatasetName.NIHCXR.value: NIHCXR,
SSLDatasetName.CIFAR10.value: InnerEyeCIFAR10,
_SSLDataClassMappings = {SSLDatasetName.CIFAR10.value: InnerEyeCIFAR10,
SSLDatasetName.CIFAR100.value: InnerEyeCIFAR100,
SSLDatasetName.CheXpert.value: CheXpert}
SSLDatasetName.RSNAKaggleCXR.value: RSNAKaggleCXR,
SSLDatasetName.NIHCXR.value: NIHCXR,
SSLDatasetName.CheXpert.value: CheXpert,
SSLDatasetName.Covid.value: CovidDataset}
ssl_augmentation_config = param.ClassSelector(class_=Path, allow_None=True,
doc="The path to the yaml config defining the parameters of the "
@ -91,7 +93,9 @@ class SSLContainer(LightningContainer):
linear_head_dataset_name = param.ClassSelector(class_=SSLDatasetName,
doc="Name of the dataset to use for the linear head training")
linear_head_batch_size = param.Integer(default=256, doc="Batch size for linear head tuning")
learning_rate_linear_head_during_ssl_training = param.Number(default=1e-4, doc="Learning rate for linear head training during SSL training.")
learning_rate_linear_head_during_ssl_training = param.Number(default=1e-4,
doc="Learning rate for linear head training during "
"SSL training.")
def setup(self) -> None:
from InnerEye.ML.SSL.lightning_containers.ssl_image_classifier import SSLClassifierContainer
@ -173,7 +177,8 @@ class SSLContainer(LightningContainer):
return self.data_module
encoder_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=True)
linear_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=False)
return CombinedDataModule(encoder_data_module, linear_data_module, self.use_balanced_binary_loss_for_linear_head)
return CombinedDataModule(encoder_data_module, linear_data_module,
self.use_balanced_binary_loss_for_linear_head)
def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisionDataModule:
"""
@ -220,7 +225,10 @@ class SSLContainer(LightningContainer):
applied on. If False, return only one transformation.
:return: training transformation pipeline and validation transformation pipeline.
"""
if dataset_name in [SSLDatasetName.RSNAKaggleCXR.value, SSLDatasetName.NIHCXR.value, SSLDatasetName.CheXpert.value]:
if dataset_name in [SSLDatasetName.RSNAKaggleCXR.value,
SSLDatasetName.NIHCXR.value,
SSLDatasetName.CheXpert.value,
SSLDatasetName.Covid.value]:
assert augmentation_config is not None
train_transforms, val_transforms = get_cxr_ssl_transforms(augmentation_config,
return_two_views_per_sample=is_ssl_encoder_module,

Просмотреть файл

@ -0,0 +1,247 @@
import codecs
import logging
import pickle
from pathlib import Path
from typing import Any, Callable
import PIL
import pandas as pd
import param
import torch
from PIL import Image
from pytorch_lightning import LightningModule
from torchvision.transforms import Compose
from InnerEye.Common.common_util import ModelProcessing, get_best_epoch_results_path
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import create_chest_xray_transform
from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName
from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier
from InnerEye.ML.SSL.utils import create_ssl_encoder, create_ssl_image_classifier, load_ssl_augmentation_config
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.configs.ssl.CXR_SSL_configs import path_linear_head_augmentation_cxr
from InnerEye.ML.deep_learning_config import LRSchedulerType, MultiprocessingStartMethod, \
OptimizerType
from InnerEye.ML.model_config_base import ModelTransformsPerExecutionMode
from InnerEye.ML.model_testing import MODEL_OUTPUT_CSV
from InnerEye.ML.models.architectures.classification.image_encoder_with_mlp import ImagingFeatureType
from InnerEye.ML.reports.notebook_report import generate_notebook, get_ipynb_report_name, str_or_empty
from InnerEye.ML.scalar_config import ScalarLoss, ScalarModelBase
from InnerEye.ML.utils.augmentation import ScalarItemAugmentation
from InnerEye.ML.utils.run_recovery import RunRecovery
from InnerEye.ML.utils.split_dataset import DatasetSplits
from InnerEye.ML.configs.ssl.CovidContainers import COVID_DATASET_ID
from InnerEye.Common import fixed_paths as fixed_paths_innereye
class CovidHierarchicalModel(ScalarModelBase):
"""
Model to train a CovidDataset model from scratch or finetune from SSL-pretrained model.
For AML you need to provide the run_id of your SSL training job as a command line argument
--pretraining_run_recovery_id=id_of_your_ssl_model, this will download the checkpoints of the run to your
machine and load the corresponding pretrained model.
To recover from a particular checkpoint from your SSL run e.g. "recovery_epoch=499.ckpt" please use hte
--name_of_checkpoint argument.
"""
use_pretrained_model = param.Boolean(default=False, doc="If True, start training from a model pretrained with SSL."
"If False, start training a DenseNet model from scratch"
"(random initialization).")
freeze_encoder = param.Boolean(default=False, doc="Whether to freeze the pretrained encoder or not.")
name_of_checkpoint = param.String(default=None, doc="Filename of checkpoint to use for recovery")
test_set_ids_csv = param.String(default=None,
doc="Name of the csv file in the dataset folder with the test set ids. The dataset"
"is expected to have a 'series' and a 'subject' column. The subject column"
"is assumed to contain unique ids.")
def __init__(self, covid_dataset_id: str = COVID_DATASET_ID, **kwargs: Any):
learning_rate = 1e-5 if self.use_pretrained_model else 1e-4
super().__init__(target_names=['CVX03vs12', 'CVX0vs3', 'CVX1vs2'],
loss_type=ScalarLoss.CustomClassification,
class_names=['CVX0', 'CVX1', 'CVX2', 'CVX3'],
max_num_gpus=1,
azure_dataset_id=covid_dataset_id,
subject_column="series",
image_file_column="filepath",
label_value_column="final_label",
non_image_feature_channels=[],
numerical_columns=[],
use_mixed_precision=False,
num_dataload_workers=12,
multiprocessing_start_method=MultiprocessingStartMethod.fork,
train_batch_size=64,
optimizer_type=OptimizerType.Adam,
num_epochs=50,
l_rate_scheduler=LRSchedulerType.Step,
l_rate_step_gamma=1.0,
l_rate=learning_rate,
l_rate_multi_step_milestones=None,
**kwargs)
self.num_classes = 3
if not self.use_pretrained_model and self.freeze_encoder:
raise ValueError("No encoder to freeze when training from scratch. You requested training from scratch and"
"encoder freezing.")
def should_generate_multilabel_report(self) -> bool:
return False
def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
if self.test_set_ids_csv:
test_df = pd.read_csv(self.local_dataset / self.test_set_ids_csv)
in_test_set = dataset_df.series.isin(test_df.series)
train_ids = dataset_df.series[~in_test_set].values
test_ids = dataset_df.series[in_test_set].values
num_val_samples = 400
val_ids = train_ids[:num_val_samples]
train_ids = train_ids[num_val_samples:]
return DatasetSplits.from_subject_ids(dataset_df, train_ids=train_ids, val_ids=val_ids, test_ids=test_ids,
subject_column="series", group_column="subject")
else:
return DatasetSplits.from_proportions(dataset_df,
proportion_train=0.8,
proportion_val=0.1,
proportion_test=0.1,
subject_column="series",
group_column="subject",
shuffle=True)
# noinspection PyTypeChecker
def get_image_sample_transforms(self) -> ModelTransformsPerExecutionMode:
config = load_ssl_augmentation_config(path_linear_head_augmentation_cxr)
train_transforms = ScalarItemAugmentation(
Compose(
[DicomPreparation(), create_chest_xray_transform(config, apply_augmentations=True)]))
val_transforms = ScalarItemAugmentation(
Compose(
[DicomPreparation(), create_chest_xray_transform(config, apply_augmentations=False)]))
return ModelTransformsPerExecutionMode(train=train_transforms,
val=val_transforms,
test=val_transforms)
def create_model(self) -> LightningModule:
"""
This method must create the actual Lightning model that will be trained.
"""
if self.use_pretrained_model:
path_to_checkpoint = self._get_ssl_checkpoint_path()
model = create_ssl_image_classifier(
num_classes=self.num_classes,
pl_checkpoint_path=str(path_to_checkpoint),
freeze_encoder=self.freeze_encoder)
else:
encoder = create_ssl_encoder(encoder_name=EncoderName.densenet121.value)
model = SSLClassifier(num_classes=self.num_classes,
encoder=encoder,
freeze_encoder=self.freeze_encoder,
class_weights=None)
# Next args are just here because we are using this model within an InnerEyeContainer
model.imaging_feature_type = ImagingFeatureType.Image # type: ignore
model.num_non_image_features = 0 # type: ignore
model.encode_channels_jointly = True # type: ignore
return model
def _get_ssl_checkpoint_path(self) -> Path:
# Get the SSL weights from the AML run provided via "pretraining_run_recovery_id" command line argument.
# Accessible via extra_downloaded_run_id field of the config.
assert self.extra_downloaded_run_id is not None
assert isinstance(self.extra_downloaded_run_id, RunRecovery)
ssl_path = self.checkpoint_folder / "ssl_checkpoint.ckpt"
if not ssl_path.exists(): # for test (when it is already present) we don't need to redo this.
if self.name_of_checkpoint is not None:
logging.info(f"Using checkpoint: {self.name_of_checkpoint} as starting point.")
path_to_checkpoint = self.extra_downloaded_run_id.checkpoints_roots[0] / self.name_of_checkpoint
else:
path_to_checkpoint = self.extra_downloaded_run_id.get_best_checkpoint_paths()[0]
if not path_to_checkpoint.exists():
logging.info("No best checkpoint found for this model. Getting the latest recovery "
"checkpoint instead.")
path_to_checkpoint = self.extra_downloaded_run_id.get_recovery_checkpoint_paths()[0]
assert path_to_checkpoint.exists()
path_to_checkpoint.rename(ssl_path)
return ssl_path
def pre_process_dataset_dataframe(self) -> None:
pass
@staticmethod
def get_posthoc_label_transform() -> Callable:
import torch
def multiclass_to_hierarchical_labels(classes: torch.Tensor) -> torch.Tensor:
classes = classes.clone()
cvx03vs12 = classes[..., 1] + classes[..., 2]
cvx0vs3 = classes[..., 3]
cvx1vs2 = classes[..., 2]
cvx0vs3[cvx03vs12 == 1] = float('nan') # CVX0vs3 only gets gradient for CVX03
cvx1vs2[cvx03vs12 == 0] = float('nan') # CVX1vs2 only gets gradient for CVX12
return torch.stack([cvx03vs12, cvx0vs3, cvx1vs2], -1)
return multiclass_to_hierarchical_labels
@staticmethod
def get_loss_function() -> Callable:
import torch
import torch.nn.functional as F
def nan_bce_with_logits(output: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""Compute BCE with logits, ignoring NaN values"""
valid = labels.isfinite()
losses = F.binary_cross_entropy_with_logits(output[valid], labels[valid], reduction='none')
return losses.sum() / labels.shape[0]
return nan_bce_with_logits
def generate_custom_report(self, report_dir: Path, model_proc: ModelProcessing) -> Path:
"""
Generate a custom report for the CovidDataset Hierarchical model. At the moment, this report will read the
file model_output.csv generated for the training, validation or test sets and compute a 4 class accuracy
and confusion matrix based on this.
:param report_dir: Directory report is to be written to
:param model_proc: Whether this is a single or ensemble model (model_output.csv will be located in different
paths for single vs ensemble runs.)
"""
def get_output_csv_path(mode: ModelExecutionMode) -> Path:
p = get_best_epoch_results_path(mode=mode, model_proc=model_proc)
return self.outputs_folder / p / MODEL_OUTPUT_CSV
train_metrics = get_output_csv_path(ModelExecutionMode.TRAIN)
val_metrics = get_output_csv_path(ModelExecutionMode.VAL)
test_metrics = get_output_csv_path(ModelExecutionMode.TEST)
notebook_params = \
{
'innereye_path': str(fixed_paths_innereye.repository_root_directory()),
'train_metrics_csv': str_or_empty(train_metrics),
'val_metrics_csv': str_or_empty(val_metrics),
'test_metrics_csv': str_or_empty(test_metrics),
"config": codecs.encode(pickle.dumps(self), "base64").decode(),
"is_crossval_report": False
}
template = Path(__file__).absolute().parent.parent / "reports" / "CovidHierarchicalModelReport.ipynb"
return generate_notebook(template,
notebook_params=notebook_params,
result_notebook=report_dir / get_ipynb_report_name(
f"{self.model_category.value}_hierarchical"))
class DicomPreparation:
def __call__(self, item: torch.Tensor) -> PIL.Image:
# Item will be of dimension [C, Z, X, Y]
images = item.numpy()
assert images.shape[0] == 1 and images.shape[1] == 1
images = images.reshape(images.shape[2:])
normalized_image = (images - images.min()) * 255. / (images.max() - images.min())
image = Image.fromarray(normalized_image).convert("L")
return image

Просмотреть файл

@ -0,0 +1,161 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
"%%javascript\n",
"IPython.OutputArea.prototype._should_scroll = function(lines) {\n",
" return false;\n",
"}\n",
"// Stops auto-scrolling so entire output is visible: see https://stackoverflow.com/a/41646403"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2",
"metadata": {
"tags": [
"parameters"
]
},
"outputs": [],
"source": [
"# Default parameter values. They will be overwritten by papermill notebook parameters.\n",
"# This cell must carry the tag \"parameters\" in its metadata.\n",
"from pathlib import Path\n",
"import pickle\n",
"import codecs\n",
"\n",
"innereye_path = Path.cwd().parent.parent.parent.parent\n",
"train_metrics_csv = \"\"\n",
"val_metrics_csv = \"\"\n",
"test_metrics_csv = \"\"\n",
"config = \"\"\n",
"is_crossval_report = False"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"print(f\"Adding to path: {innereye_path}\")\n",
"if str(innereye_path) not in sys.path:\n",
" sys.path.append(str(innereye_path))\n",
"\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"\n",
"config = pickle.loads(codecs.decode(config.encode(), \"base64\"))\n",
"\n",
"from InnerEye.ML.common import ModelExecutionMode\n",
"from InnerEye.ML.reports.notebook_report import print_header\n",
"from InnerEye.ML.configs.reports.covid_hierarchical_model_report import print_metrics_from_csv\n",
"\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"plt.rcParams['figure.figsize'] = (20, 10)\n",
"\n",
"#convert params to Path\n",
"train_metrics_csv = Path(train_metrics_csv)\n",
"val_metrics_csv = Path(val_metrics_csv)\n",
"test_metrics_csv = Path(test_metrics_csv)"
]
},
{
"cell_type": "markdown",
"id": "4",
"metadata": {},
"source": [
"# Metrics\n",
"## Train Set"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"metadata": {},
"outputs": [],
"source": [
"if train_metrics_csv.is_file():\n",
" print_metrics_from_csv(csv_to_set_optimal_threshold=train_metrics_csv,\n",
" csv_to_compute_metrics=train_metrics_csv,\n",
" config=config, is_crossval_report=is_crossval_report)"
]
},
{
"cell_type": "markdown",
"id": "6",
"metadata": {},
"source": [
"## Validation Set"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"metadata": {},
"outputs": [],
"source": [
"if val_metrics_csv.is_file():\n",
" print_metrics_from_csv(csv_to_set_optimal_threshold=val_metrics_csv,\n",
" csv_to_compute_metrics=val_metrics_csv,\n",
" config=config, is_crossval_report=is_crossval_report)"
]
},
{
"cell_type": "markdown",
"id": "8",
"metadata": {},
"source": [
"## Test Set"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9",
"metadata": {},
"outputs": [],
"source": [
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
" print_metrics_from_csv(csv_to_set_optimal_threshold=val_metrics_csv,\n",
" csv_to_compute_metrics=test_metrics_csv,\n",
" config=config, is_crossval_report=is_crossval_report)"
]
}
],
"metadata": {
"celltoolbar": "Tags",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

Просмотреть файл

@ -0,0 +1,104 @@
import pandas as pd
import numpy as np
from pathlib import Path
from sklearn.metrics import accuracy_score, confusion_matrix
from typing import Dict
from InnerEye.Common.metrics_constants import LoggingColumns
from InnerEye.ML.reports.classification_report import get_labels_and_predictions_from_dataframe, LabelsAndPredictions
from InnerEye.ML.reports.notebook_report import print_table
from InnerEye.ML.scalar_config import ScalarModelBase
TARGET_NAMES = ['CVX03vs12', 'CVX0vs3', 'CVX1vs2']
MULTICLASS_HUE_NAME = "Multiclass"
def get_label_from_label_dict(label_dict: Dict[str, float]) -> int:
"""
Converts strings CVX03vs12, CVX1vs2, CVX0vs3 to the corresponding class as int.
"""
if label_dict['CVX03vs12'] == 0:
assert np.isnan(label_dict['CVX1vs2'])
if label_dict['CVX0vs3'] == 0:
label = 0
elif label_dict['CVX0vs3'] == 1:
label = 3
else:
raise ValueError("CVX0vs3 should be 0 or 1.")
elif label_dict['CVX03vs12'] == 1:
assert np.isnan(label_dict['CVX0vs3'])
if label_dict['CVX1vs2'] == 0:
label = 1
elif label_dict['CVX1vs2'] == 1:
label = 2
else:
raise ValueError("CVX1vs2 should be 0 or 1.")
else:
raise ValueError("CVX03vs12 should be 0 or 1.")
return label
def get_model_prediction_by_probabilities(output_dict: Dict[str, float]) -> int:
"""
Based on the values for CVX03vs12, CVX0vs3 and CVX1vs2 predicted by the model, predict the CVX scores as followed:
score(CVX0) = [1 - score(CVX03vs12)][1 - score(CVX0vs3)]
score(CVX1) = score(CVX03vs12)[1 - score(CVX1vs2)]
score(CVX2) = score(CVX03vs12)score(CVX1vs2)
score(CVX3) = [1 - score(CVX03vs12)]score(CVX0vs3)
"""
cvx0 = (1 - output_dict['CVX03vs12']) * (1 - output_dict['CVX0vs3'])
cvx3 = (1 - output_dict['CVX03vs12']) * output_dict['CVX0vs3']
cvx1 = output_dict['CVX03vs12'] * (1 - output_dict['CVX1vs2'])
cvx2 = output_dict['CVX03vs12'] * output_dict['CVX1vs2']
return np.argmax([cvx0, cvx1, cvx2, cvx3])
def get_dataframe_with_covid_labels(metrics_df: pd.DataFrame) -> pd.DataFrame:
def get_CVX_labels(df: pd.DataFrame) -> pd.DataFrame:
"""
Given a dataframe (with only one subject) with the model outputs for CVX03vs12, CVX0vs3 and CVX1vs2,
returns a corresponding dataframe with scores for CVX0, CVX1, CVX2 and CVX3 for this subject. See
`get_model_prediction_by_probabilities` for details on mapping the model output to CVX labels.
"""
df_by_hue = df[df[LoggingColumns.Hue.value].isin(TARGET_NAMES)].set_index(LoggingColumns.Hue.value)
model_output = get_model_prediction_by_probabilities(df_by_hue[LoggingColumns.ModelOutput.value].to_dict())
label = get_label_from_label_dict(df_by_hue[LoggingColumns.Label.value].to_dict())
return pd.DataFrame.from_dict({LoggingColumns.Patient.value: [df.iloc[0][LoggingColumns.Patient.value]],
LoggingColumns.ModelOutput.value: [model_output],
LoggingColumns.Label.value: [label]})
df = metrics_df.copy()
# Group by subject, and for each subject, convert the CVX03vs12, CVX0vs3 and CVX1vs2 predictions to CVX labels.
df = df.groupby(LoggingColumns.Patient.value, as_index=False).apply(get_CVX_labels).reset_index(drop=True)
df[LoggingColumns.Hue.value] = [MULTICLASS_HUE_NAME] * len(df)
return df
def get_labels_and_predictions_covid_labels(csv: Path) -> LabelsAndPredictions:
metrics_df = pd.read_csv(csv)
df = get_dataframe_with_covid_labels(metrics_df=metrics_df)
return get_labels_and_predictions_from_dataframe(df)
def print_metrics_from_csv(csv_to_set_optimal_threshold: Path,
csv_to_compute_metrics: Path,
config: ScalarModelBase,
is_crossval_report: bool) -> None:
assert config.target_names == TARGET_NAMES
predictions_to_compute_metrics = get_labels_and_predictions_covid_labels(
csv=csv_to_compute_metrics)
acc = accuracy_score(predictions_to_compute_metrics.labels, predictions_to_compute_metrics.model_outputs)
rows = [[f"{acc:.4f}"]]
print_table(rows, header=["Multiclass Accuracy"])
conf_matrix = confusion_matrix(predictions_to_compute_metrics.labels, predictions_to_compute_metrics.model_outputs)
rows = []
header = ["", "CVX0 predicted", "CVX1 predicted", "CVX2 predicted", "CVX3 predicted"]
for i in range(conf_matrix.shape[0]):
line = [f"CVX{i} GT"] + list(conf_matrix[i])
rows.append(line)
print_table(rows, header=header)

Просмотреть файл

@ -0,0 +1,36 @@
from typing import Any
from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName, SSLContainer, SSLDatasetName
from InnerEye.ML.SSL.utils import SSLTrainingType
from InnerEye.ML.configs.ssl.CXR_SSL_configs import NIH_AZURE_DATASET_ID, path_encoder_augmentation_cxr, \
path_linear_head_augmentation_cxr
COVID_DATASET_ID = "id-of-your-dataset"
class NIH_COVID_BYOL(SSLContainer):
"""
Class to train a SSL model on NIH dataset and monitor embeddings quality on a Covid Dataset.
"""
def __init__(self,
covid_dataset_id: str = COVID_DATASET_ID,
**kwargs: Any):
super().__init__(ssl_training_dataset_name=SSLDatasetName.NIHCXR,
linear_head_dataset_name=SSLDatasetName.Covid,
random_seed=1,
recovery_checkpoint_save_interval=50,
recovery_checkpoints_save_last_k=3,
num_epochs=500,
ssl_training_batch_size=1200, # This runs with 16 gpus (4 nodes)
num_workers=12,
ssl_encoder=EncoderName.densenet121,
ssl_training_type=SSLTrainingType.BYOL,
use_balanced_binary_loss_for_linear_head=True,
ssl_augmentation_config=path_encoder_augmentation_cxr,
extra_azure_dataset_ids=[covid_dataset_id],
azure_dataset_id=NIH_AZURE_DATASET_ID,
linear_head_augmentation_config=path_linear_head_augmentation_cxr,
online_evaluator_lr=1e-5,
linear_head_batch_size=64,
**kwargs)

Просмотреть файл

@ -0,0 +1,22 @@
import pandas as pd
from math import nan
from InnerEye.Common.metrics_constants import LoggingColumns
from InnerEye.ML.configs.reports.covid_hierarchical_model_report import MULTICLASS_HUE_NAME, \
get_dataframe_with_covid_labels
def test_get_dataframe_with_covid_labels() -> None:
df = pd.DataFrame.from_dict({LoggingColumns.Patient.value: [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4],
LoggingColumns.Hue.value: ['CVX03vs12', 'CVX0vs3', 'CVX1vs2'] * 4,
LoggingColumns.Label.value: [0, 0, nan, 0, 1, nan, 1, nan, 0, 1, nan, 1],
LoggingColumns.ModelOutput.value: [0.1, 0.1, 0.5, 0.1, 0.9, 0.5, 0.9, 0.9, 0.9, 0.1, 0.2, 0.1]})
expected_df = pd.DataFrame.from_dict({LoggingColumns.Patient.value: [1, 2, 3, 4],
LoggingColumns.ModelOutput.value: [0, 3, 2, 0],
LoggingColumns.Label.value: [0, 3, 1, 2],
LoggingColumns.Hue.value: [MULTICLASS_HUE_NAME] * 4
})
multiclass_df = get_dataframe_with_covid_labels(df)
assert expected_df.equals(multiclass_df)