Adding Active Label Cleaning code (#559)

* initial commit

* updating the build

* flake8

* update main page

* add the links

* try to fix the env

* update build, gitignore and remove duplicate license

* update gitignore again

* Adding to changelog

* conda activate

* update again

* wrong instruction

* add data quality

* rephrase

* first pass on Readme.md

* switch from our to the, and clarify the cxr datasets

* move content to a separate markdown file

* move additional content to config readme file

* finish updating dataquality readme

* Rename

* pr ocmment

* todos

* changed default dir for cifar10 dataset

Co-authored-by: Ozan Oktay <ozan.oktay@microsoft.com>
This commit is contained in:
melanibe 2021-09-21 10:22:03 +01:00 коммит произвёл GitHub
Родитель 521c004357
Коммит 94553a5c0b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
142 изменённых файлов: 10303 добавлений и 2 удалений

9
.gitignore поставляемый
Просмотреть файл

@ -158,4 +158,11 @@ Tests/ML/test_outputs
# This file contains the run recovery ID of the most recent job
most_recent_run.txt
# The default folder that contains downloaded Tensorboard files
tensorboard_runs
tensorboard_runs
# InnerEye-DataQuality
InnerEye-DataQuality/cifar-10-python.tar.gz
InnerEye-DataQuality/name_stats_scoring.png
InnerEye-DataQuality/cifar-10-batches-py
InnerEye-DataQuality/logs
InnerEye-DataQuality/data

Просмотреть файл

@ -22,6 +22,7 @@ jobs that run in AzureML.
ensemble) using the parameter `model_id`.
- ([#554](https://github.com/microsoft/InnerEye-DeepLearning/pull/554)) Added a parameter `pretraining_dataset_id` to
`NIH_COVID_BYOL` to specify the name of the SSL training dataset.
- ([#559](https://github.com/microsoft/InnerEye-DeepLearning/pull/559)) Adding the accompanying code for the ["Active label cleaning: Improving dataset quality under resource constraints"](https://arxiv.org/abs/2109.00574) paper. The code can be found in the [InnerEye-DataQuality](InnerEye-DataQuality/README.md) subfolder. It provides tools for training noise robust models, running label cleaning simulation and loading our label cleaning benchmark datasets.
### Changed
- ([#531](https://github.com/microsoft/InnerEye-DeepLearning/pull/531)) Updated PL to 1.3.8, torchmetrics and pl-bolts and changed relevant metrics and SSL code API.

Просмотреть файл

@ -0,0 +1,5 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

Просмотреть файл

@ -0,0 +1,115 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
import time
from dataclasses import dataclass
from typing import Any, Optional
import numpy as np
import scipy
from scipy.special import softmax
from sklearn.neighbors import kneighbors_graph
@dataclass(init=True, frozen=True)
class GraphParameters:
"""Class for setting graph connectivity and diffusion parameters."""
n_neighbors: int
diffusion_alpha: float
cg_solver_max_iter: int
diffusion_batch_size: Optional[int]
distance_kernel: str # {'euclidean' or 'cosine'}
def _get_affinity_matrix(embeddings: np.ndarray,
n_neighbors: int,
distance_kernel: str = 'cosine') -> scipy.sparse.csr.csr_matrix:
"""
:param embeddings: Input sample embeddings (n_samples x n_embedding_dim)
:param n_neighbors: Number of neighbors in the KNN Graph
:param distance_kernel: Distance kernel to compute sample similarities {'euclidean' or 'cosine'}
"""
# Build a k-NN graph using the embeddings
if distance_kernel == 'euclidean':
sigma = embeddings.shape[1]
knn_dist_graph = kneighbors_graph(embeddings, n_neighbors, mode='distance', metric='euclidean', n_jobs=-1)
knn_dist_graph.data = np.exp(-1.0 * np.asarray(knn_dist_graph.data) ** 2 / (2.0 * sigma ** 2))
elif distance_kernel == 'cosine':
knn_dist_graph = kneighbors_graph(embeddings, n_neighbors, mode='distance', metric='cosine', n_jobs=-1)
knn_dist_graph.data = 1.0 - np.asarray(knn_dist_graph.data) / 2.0
else:
raise ValueError(f"Unknown sample distance kernel {distance_kernel}")
return knn_dist_graph
def build_connectivity_graph(normalised: bool = True, **affinity_kwargs: Any) -> np.ndarray:
"""
Builds connectivity graph and returns adjacency and degree matrix
:param normalised: If set to True, graph adjacency is normalised with the norm of degree matrix
:param affinity_kwargs: Arguments required to construct an affinity matrix
(weights representing similarity between points)
"""
# Build a symmetric adjacency matrix
A = _get_affinity_matrix(**affinity_kwargs)
W = 0.5 * (A + A.T)
if normalised:
# Normalize the similarity graph
W = W - scipy.sparse.diags(W.diagonal())
D = W.sum(axis=1)
D[D == 0] = 1
D_sqrt_inv = np.array(1. / np.sqrt(D))
D_sqrt_inv = scipy.sparse.diags(D_sqrt_inv.reshape(-1))
L_norm = D_sqrt_inv * W * D_sqrt_inv
return L_norm
else:
num_samples = W.shape[0]
D = W.sum(axis=1)
D = np.diag(np.asarray(D).reshape(num_samples, ))
L = D - W
return L
def label_diffusion(inv_laplacian: np.ndarray,
labels: np.ndarray,
query_batch_ids: np.ndarray,
class_priors: Optional[np.ndarray] = None,
diffusion_normalizing_factor: float = 0.01) -> np.ndarray:
"""
:param laplacian_inv: inverse laplacian of the graph
:param labels:
:param query_batch_ids: the ids of the "labeled" samples
:param class_priors: prior distribution of each class [n_classes,]
:param diffusion_normalizing_factor: factor to normalize the diffused labels
"""
diffusion_start = time.time()
# Input number of nodes and classes
n_samples = labels.shape[0]
n_classes = labels.shape[1]
# Find the labelled set of nodes in the graph
all_idx = np.array(range(n_samples))
labeled_idx = np.setdiff1d(all_idx, query_batch_ids.flatten())
assert (np.all(np.isin(query_batch_ids, all_idx)))
assert (np.allclose(np.sum(labels, axis=1), np.ones(shape=(labels.shape[0])), rtol=1e-3))
# Initialize the y vector for each class (eq 5 from the paper, normalized with the class size)
# and apply label propagation
y = np.zeros((n_samples, n_classes))
y[labeled_idx] = labels[labeled_idx] / np.sum(labels[labeled_idx], axis=0, keepdims=True)
if class_priors is not None:
y = y * class_priors
Z = np.matmul(inv_laplacian[query_batch_ids, :], y)
# Normalise the diffused logits
output = softmax(Z / diffusion_normalizing_factor, axis=1)
# output = Z / Z.sum(axis=1)
logging.debug(f"Graph diffusion time: {0: .2f} secs".format(time.time() - diffusion_start))
return output

Просмотреть файл

@ -0,0 +1,60 @@
## Config arguments for model training:
All possible config arguments are defined in [model_config.py](InnerEyeDataQuality/configs/models/model_config.py). Here you will find a summary of the most important config arguments:
* If you want to train a model with co_teaching, you will need to set `train.use_co_teaching: True` in your config.
* If you want to finetune from a pretrained SSL checkpoint:
* You will need to set `train.use_self_supervision: True` to tell the code to load a pretrained checkpoint.
* You will need update the `train.self_supervision.checkpoints: [PATH_TO_SSL]` field with the checkpoints to use for initialization of your model. Note that if you want to train a co-teaching model in combination with SSL pretrained initialization your list of checkpoint needs to be of length 2.
* You can also choose whether to freeze the encoder or not during finetuning with `train.self_supervision.freeze_encoder` field.
* If you want to train a model using ELR, you can set `train.use_elr: True`
### CIFAR10H
We provide configurations to run experiments on CIFAR10H with resp. 15% and 30% noise rate.
* Configs for 15% noise rate experiments can be found in [configs/models/cifar10h_noise_15](InnerEyeDataQuality/configs/models/cifar10h_noise_15). In detail this folder contains configs for
* vanilla resnet training: [InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet.yaml)
* co-teaching resnet training: [InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet_co_teaching.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet_co_teaching.yaml)
* SSL + linear head training: [InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet_self_supervision_v3.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet_self_supervision_v3.yaml)
* Configs for 30% noise rate experiments can be found in [configs/models/cifar10h_noise_30](InnerEyeDataQuality/configs/models/cifar10h_noise_30). In detail this folder contains configs for:
* vanilla resnet training: [InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet.yaml)
* co-teaching resnet training: [InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_co_teaching.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_co_teaching.yaml)
* SSL + linear head training: [InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_self_supervision_v3.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_self_supervision_v3.yaml)
* ELR training: [InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_elr.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_elr.yaml)
* Examples of configs for models used in the model selection benchmark experiment can be found in the [configs/models/benchmark_3_idn](InnerEyeDataQuality/configs/models/benchmark_3_idn)
### Noisy Chest-Xray
*Note:* To run any model on this dataset, you will need to first make sure you have the dataset ready onto your machine (see Benchmark datasets > Noisy Chest-Xray > Pre-requisite section).
* With provide configurations corresponding to the experiments on the NoisyChestXray dataset with 13% noise rate in the [configs/models/cxr](InnerEyeDataQuality/configs/models/cxr) folder:
* Vanilla resnet training: [InnerEyeDataQuality/configs/models/cxr/resnet.yaml](InnerEyeDataQuality/configs/models/cxr/resnet.yaml)
* Co-teaching resnet training: [InnerEyeDataQuality/configs/models/cxr/resnet_coteaching.yaml](InnerEyeDataQuality/configs/models/cxr/resnet_coteaching.yaml)
* Co-teaching from a pretrained SSL checkpoint: [InnerEyeDataQuality/configs/models/cxr/resnet_ssl_coteaching.yaml]([InnerEyeDataQuality/configs/models/cxr/resnet_ssl_coteaching.yaml])
<br/><br/>
## Config arguments for label cleaning simulation:
#### More details about the selector config
Here is an example of a selector config, with details about each field:
* `selector:`
* `type`: Which selector to use. There are several options available:
* `PosteriorBasedSelectorJoint`: Using the ranking function proposed in the manuscript CE(posteriors, labels) - H(posteriors)
* `PosteriorBasedSelector`: Using CE(posteriors, labels) as the ranking function
* `GraphBasedSelector`: Graph based selection of the next samples based on the embeddings of each sample.
* `BaldSelector`: Selection of the next sample with the BALD objective.
* `model_name`: The name that will be used for the legend of the simulation plot
* `model_config_path`: Path to the config file used to train the selection model.
* `use_active_relabelling`: Whether to turn on the active component of the active learning framework. If set to True, the model will be retrained regularly during the selection process.
* `output_directory`: Optional field where can specify the output directory to store the results in.
#### Off-the-shelf simulation configs
* We provide the configs for various selectors for the CIFAR10H experiments in the [configs/selection/cifar10h_noise_15](InnerEyeDataQuality/configs/selection/cifar10h_noise_15) and in the [configs/selection/cifar10h_noise_30](InnerEyeDataQuality/configs/selection/cifar10h_noise_30) folders.
* And likewise for the NoisyChestXray dataset, you can find a set of selector configs in the [configs/selection/cxr](InnerEyeDataQuality/configs/selection/cxr) folder.
<br/><br/>
## Configs for self-supervised model training:
CIFAR10H: To pretrain embeddings with contrastive learning on CIFAR10H you can use the
[cifar10h_byol.yaml](InnerEyeDataQuality/deep_learning/self_supervised/configs/cifar10h_byol.yaml) or the [cifar10h_simclr.yaml](InnerEyeDataQuality/deep_learning/self_supervised/configs/cifar10h_simclr.yaml) config files.
Chest X-ray: Provided that you have downloaded the dataset as explained in the Benchmark Datasets > Other Chest Xray Datasets > NIH Datasets > Pre-requisites section, you can easily launch a unsupervised pretraining run on the full NIH dataset using the [nih_byol.yaml](InnerEyeDataQuality/deep_learning/self_supervised/configs/nih_byol.yaml) or the [nih_simclr.yaml](InnerEyeDataQuality/deep_learning/self_supervised/configs/nih_simclr.yaml)
configs. Don't forget to update the `dataset_dir` field of your config to reflect your local path.

Просмотреть файл

@ -0,0 +1,48 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Dict, List, Optional, Union
import yacs.config
class ConfigNode(yacs.config.CfgNode):
def __init__(self, init_dict: Optional[Dict] = None, key_list: Optional[List] = None, new_allowed: bool = False):
super().__init__(init_dict, key_list, new_allowed)
def __str__(self) -> str:
def _indent(s_: str, num_spaces: int) -> Union[str, List[str]]:
s = s_.split('\n')
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * ' ') + line for line in s]
s = '\n'.join(s) # type: ignore
s = first + '\n' + s # type: ignore
return s
r = ''
s = []
for k, v in self.items():
separator = '\n' if isinstance(v, ConfigNode) else ' '
if isinstance(v, str) and not v:
v = '\'\''
attr_str = f'{str(k)}:{separator}{str(v)}'
attr_str = _indent(attr_str, 2) # type: ignore
s.append(attr_str)
r += '\n'.join(s)
return r
def as_dict(self) -> Dict:
def convert_to_dict(node: ConfigNode) -> Dict:
if not isinstance(node, ConfigNode):
return node
else:
dic = dict()
for k, v in node.items():
dic[k] = convert_to_dict(v)
return dic
return convert_to_dict(self)

Просмотреть файл

@ -0,0 +1,205 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from .config_node import ConfigNode
config = ConfigNode()
config.pretty_name = ""
config.device = 'cuda'
# cuDNN
config.cudnn = ConfigNode()
config.cudnn.benchmark = True
config.cudnn.deterministic = False
config.dataset = ConfigNode()
config.dataset.name = 'CIFAR10'
config.dataset.dataset_dir = ''
config.dataset.image_size = 32
config.dataset.n_channels = 3
config.dataset.n_classes = 10
config.dataset.num_samples = None
config.dataset.noise_temperature = 1.0
config.dataset.noise_offset = 0.0
config.dataset.noise_rate = None
config.dataset.csv_to_ignore = None
config.dataset.cxr_consolidation_noise_rate = 0.10
config.model = ConfigNode()
# options: 'cifar', 'imagenet'
# Use 'cifar' for small input images
config.model.type = 'cifar'
config.model.name = 'resnet_preact'
config.model.init_mode = 'kaiming_fan_out'
config.model.use_dropout = False
config.model.resnet = ConfigNode()
config.model.resnet.depth = 110 # for cifar type model
config.model.resnet.n_blocks = [2, 2, 2, 2] # for imagenet type model
config.model.resnet.block_type = 'basic'
config.model.resnet.initial_channels = 16
config.model.resnet.apply_l2_norm = False # if set to True, last activations are l2-normalised.
config.model.densenet = ConfigNode()
config.model.densenet.depth = 100 # for cifar type model
config.model.densenet.n_blocks = [6, 12, 24, 16] # for imagenet type model
config.model.densenet.block_type = 'bottleneck'
config.model.densenet.growth_rate = 12
config.model.densenet.drop_rate = 0.0
config.model.densenet.compression_rate = 0.5
config.train = ConfigNode()
config.train.root_dir = ''
config.train.checkpoint = ''
config.train.resume_epoch = 0
config.train.restore_scheduler = True
config.train.batch_size = 128
# optimizer (options: sgd, adam)
config.train.optimizer = 'sgd'
config.train.base_lr = 0.1
config.train.momentum = 0.9
config.train.nesterov = True
config.train.weight_decay = 1e-4
config.train.start_epoch = 0
config.train.seed = 0
config.train.pretrained = False
config.train.no_weight_decay_on_bn = False
config.train.use_balanced_sampler = False
config.train.output_dir = 'experiments/exp00'
config.train.log_period = 100
config.train.checkpoint_period = 10
config.train.use_elr = False
# co_teaching defaults
config.train.use_co_teaching = False
config.train.use_teacher_model = False
config.train.co_teaching_consistency_loss = False
config.train.co_teaching_forget_rate = 0.2
config.train.co_teaching_num_gradual = 10
config.train.co_teaching_use_graph = False
config.train.co_teaching_num_warmup = 25
# self-supervision defaults
config.train.use_self_supervision = False
config.train.self_supervision = ConfigNode()
config.train.self_supervision.checkpoints = ['', '']
config.train.self_supervision.freeze_encoder = True
config.train.tanh_regularisation = 0.0
# optimizer
config.optim = ConfigNode()
# Adam
config.optim.adam = ConfigNode()
config.optim.adam.betas = (0.9, 0.999)
# scheduler
config.scheduler = ConfigNode()
config.scheduler.epochs = 160
# warm up (options: none, linear, exponential)
config.scheduler.warmup = ConfigNode()
config.scheduler.warmup.type = 'none'
config.scheduler.warmup.epochs = 0
config.scheduler.warmup.start_factor = 1e-3
config.scheduler.warmup.exponent = 4
# main scheduler (options: constant, linear, multistep, cosine)
config.scheduler.type = 'multistep'
config.scheduler.milestones = [80, 120]
config.scheduler.lr_decay = 0.1
config.scheduler.lr_min_factor = 0.001
# tensorboard
config.tensorboard = ConfigNode()
config.tensorboard.save_events = True
# train data loader
config.train.dataloader = ConfigNode()
config.train.dataloader.num_workers = 2
config.train.dataloader.drop_last = True
config.train.dataloader.pin_memory = False
config.train.dataloader.non_blocking = False
# validation data loader
config.validation = ConfigNode()
config.validation.batch_size = 256
config.validation.dataloader = ConfigNode()
config.validation.dataloader.num_workers = 2
config.validation.dataloader.drop_last = False
config.validation.dataloader.pin_memory = False
config.validation.dataloader.non_blocking = False
config.augmentation = ConfigNode()
config.augmentation.use_random_crop = True
config.augmentation.use_random_horizontal_flip = True
config.augmentation.use_random_affine = False
config.augmentation.use_label_smoothing = False
config.augmentation.use_random_color = False
config.augmentation.add_gaussian_noise = False
config.augmentation.use_gamma_transform = False
config.augmentation.use_random_erasing = False
config.augmentation.use_elastic_transform = False
config.augmentation.elastic_transform = ConfigNode()
config.augmentation.elastic_transform.sigma = 4
config.augmentation.elastic_transform.alpha = 35
config.augmentation.elastic_transform.p_apply = 0.5
config.augmentation.random_crop = ConfigNode()
config.augmentation.random_crop.scale = (0.9, 1.0)
config.augmentation.gaussian_noise = ConfigNode()
config.augmentation.gaussian_noise.std = 0.01
config.augmentation.gaussian_noise.p_apply = 0.5
config.augmentation.random_horizontal_flip = ConfigNode()
config.augmentation.random_horizontal_flip.prob = 0.5
config.augmentation.random_affine = ConfigNode()
config.augmentation.random_affine.max_angle = 0
config.augmentation.random_affine.max_horizontal_shift = 0.0
config.augmentation.random_affine.max_vertical_shift = 0.0
config.augmentation.random_affine.max_shear = 5
config.augmentation.random_color = ConfigNode()
config.augmentation.random_color.brightness = 0.0
config.augmentation.random_color.contrast = 0.1
config.augmentation.random_color.saturation = 0.1
config.augmentation.gamma = ConfigNode()
config.augmentation.gamma.scale = (0.5, 1.5)
config.augmentation.label_smoothing = ConfigNode()
config.augmentation.label_smoothing.epsilon = 0.1
config.augmentation.random_erasing = ConfigNode()
config.augmentation.random_erasing.scale = (0.01, 0.1)
config.augmentation.random_erasing.ratio = (0.3, 3.3)
config.preprocess = ConfigNode()
config.preprocess.use_resize = False
config.preprocess.use_center_crop = False
config.preprocess.center_crop_size = None
config.preprocess.histogram_normalization = ConfigNode()
config.preprocess.histogram_normalization.disk_size = 30
config.preprocess.resize = 32
# test config
config.test = ConfigNode()
config.test.checkpoint = None
config.test.output_dir = None
config.test.batch_size = 256
# test data loader
config.test.dataloader = ConfigNode()
config.test.dataloader.num_workers = 2
config.test.dataloader.pin_memory = False
def get_default_model_config() -> ConfigNode:
return config.clone()

Просмотреть файл

@ -0,0 +1,75 @@
device: cuda
cudnn:
benchmark: True
deterministic: True
dataset:
name: CIFAR10IDN
num_samples: 10000
noise_rate: 0.4
model:
type: cifar
name: resnet
init_mode: kaiming_fan_out
resnet:
depth: 50
initial_channels: 16
block_type: basic
apply_l2_norm: True
train:
resume_epoch: 0
seed: 1
batch_size: 256
optimizer: sgd
base_lr: 0.05
momentum: 0.9
nesterov: True
weight_decay: 1e-4
output_dir: experiments/benchmark_3_idn/co_teaching_v5_res50
log_period: 100
checkpoint_period: 100
use_co_teaching: True
co_teaching_forget_rate: 0.38
co_teaching_num_gradual: 10
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
validation:
batch_size: 512
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
test:
batch_size: 512
dataloader:
num_workers: 2
pin_memory: False
scheduler:
epochs: 120
type: multistep
milestones: [70, 100]
lr_decay: 0.1
augmentation:
use_random_horizontal_flip: True
use_label_smoothing: False
use_random_affine: False
use_random_color: True
random_horizontal_flip:
prob: 0.5
random_affine:
max_angle: 15
max_horizontal_shift: 0.05
max_vertical_shift: 0.05
max_shear: 5
random_color:
brightness: 0.5
contrast: 0.5
saturation: 0.5

Просмотреть файл

@ -0,0 +1,7 @@
selector:
type: ['PosteriorBasedSelector']
model_name: 'Self-Supervision'
model_config_path: 'InnerEyeDataQuality/configs/models/benchmark_3_idn/b3_ssl_cleaning_v2.yaml'
output_directory: 'b3_ssl_cleaning_v2'
tensorboard:
save_events: False

Просмотреть файл

@ -0,0 +1,8 @@
selector:
type: ['PosteriorBasedSelector']
model_name: 'Self-Supervision-Active'
model_config_path: 'InnerEyeDataQuality/configs/models/benchmark_3_idn/b3_ssl_cleaning_v2.yaml'
output_directory: 'b3_ssl_cleaning_v2_active'
use_active_relabelling: True
tensorboard:
save_events: False

Просмотреть файл

@ -0,0 +1,73 @@
device: cuda
cudnn:
benchmark: True
deterministic: True
dataset:
name: CIFAR10IDN
num_samples: 10000
noise_rate: 0.4
model:
type: cifar
name: resnet
init_mode: kaiming_fan_out
resnet:
depth: 110
initial_channels: 16
block_type: basic
apply_l2_norm: True
train:
resume_epoch: 0
seed: 1
batch_size: 256
optimizer: sgd
base_lr: 0.005
momentum: 0.9
nesterov: True
weight_decay: 1e-3
output_dir: experiments/benchmark_3_idn/ssl_cleaning_v2
log_period: 100
checkpoint_period: 100
tanh_regularisation: 0.30
use_self_supervision: True
self_supervision:
checkpoints: ["cifar10h/self_supervised/lightning_logs/BYOL_seed_1/checkpoints/last.ckpt"]
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
validation:
batch_size: 512
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
test:
batch_size: 128
dataloader:
num_workers: 2
pin_memory: False
scheduler:
epochs: 120
type: multistep
milestones: [90]
lr_decay: 0.1
augmentation:
use_random_horizontal_flip: True
use_label_smoothing: True
use_random_affine: False
use_random_color: True
random_crop:
scale: (0.9, 1.0)
random_horizontal_flip:
prob: 0.5
random_color:
brightness: 0.5
contrast: 0.5
saturation: 0.5

Просмотреть файл

@ -0,0 +1,73 @@
device: cuda
cudnn:
benchmark: True
deterministic: True
dataset:
name: CIFAR10IDN
num_samples: 10000
noise_rate: 0.4
model:
type: cifar
name: resnet
init_mode: kaiming_fan_out
resnet:
depth: 110
initial_channels: 16
block_type: basic
apply_l2_norm: True
train:
resume_epoch: 0
seed: 1
batch_size: 256
optimizer: sgd
base_lr: 0.005
momentum: 0.9
nesterov: True
weight_decay: 1e-4
output_dir: experiments/benchmark_3_idn/ssl_v1
log_period: 100
checkpoint_period: 100
tanh_regularisation: 0.00
use_self_supervision: True
self_supervision:
checkpoints: ["cifar10h/self_supervised/lightning_logs/BYOL_seed_1/checkpoints/last.ckpt"]
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
validation:
batch_size: 512
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
test:
batch_size: 128
dataloader:
num_workers: 2
pin_memory: False
scheduler:
epochs: 120
type: multistep
milestones: [90]
lr_decay: 0.1
augmentation:
use_random_horizontal_flip: True
use_label_smoothing: False
use_random_affine: False
use_random_color: True
random_crop:
scale: (0.9, 1.0)
random_horizontal_flip:
prob: 0.5
random_color:
brightness: 0.5
contrast: 0.5
saturation: 0.5

Просмотреть файл

@ -0,0 +1,72 @@
device: cuda
cudnn:
benchmark: True
deterministic: True
dataset:
name: CIFAR10IDN
num_samples: 10000
noise_rate: 0.4
model:
type: cifar
name: resnet
init_mode: kaiming_fan_out
resnet:
depth: 50
initial_channels: 16
block_type: basic
apply_l2_norm: True
train:
resume_epoch: 0
seed: 1
batch_size: 256
optimizer: sgd
base_lr: 0.05
momentum: 0.9
nesterov: True
weight_decay: 1e-4
output_dir: experiments/benchmark_3_idn/vanilla_model_res50
log_period: 100
checkpoint_period: 100
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
validation:
batch_size: 512
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
test:
batch_size: 512
dataloader:
num_workers: 2
pin_memory: False
scheduler:
epochs: 120
type: multistep
milestones: [70, 100]
lr_decay: 0.1
augmentation:
use_random_horizontal_flip: True
use_label_smoothing: False
use_random_affine: False
use_random_color: True
random_horizontal_flip:
prob: 0.5
random_affine:
max_angle: 15
max_horizontal_shift: 0.05
max_vertical_shift: 0.05
max_shear: 5
random_color:
brightness: 0.5
contrast: 0.5
saturation: 0.5

Просмотреть файл

@ -0,0 +1,72 @@
device: cuda
cudnn:
benchmark: True
deterministic: True
dataset:
name: CIFAR10H
noise_temperature: 2.0
num_samples: 5000
model:
type: cifar
name: resnet
init_mode: kaiming_fan_out
resnet:
depth: 110
initial_channels: 16
block_type: basic
apply_l2_norm: True
train:
resume_epoch: 0
seed: 1
batch_size: 256
optimizer: sgd
base_lr: 0.1
momentum: 0.9
nesterov: True
weight_decay: 1e-4
output_dir: cifar10h_noise_15/vanilla_model
log_period: 100
checkpoint_period: 100
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
validation:
batch_size: 512
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
test:
batch_size: 512
dataloader:
num_workers: 2
pin_memory: False
scheduler:
epochs: 160
type: multistep
milestones: [80, 120]
lr_decay: 0.1
augmentation:
use_random_horizontal_flip: True
use_label_smoothing: False
use_random_affine: True
use_random_color: True
random_horizontal_flip:
prob: 0.5
random_affine:
max_angle: 15
max_horizontal_shift: 0.05
max_vertical_shift: 0.05
max_shear: 5
random_color:
brightness: 0.5
contrast: 0.5
saturation: 0.5

Просмотреть файл

@ -0,0 +1,75 @@
device: cuda
cudnn:
benchmark: True
deterministic: True
dataset:
name: CIFAR10H
noise_temperature: 2.0
num_samples: 5000
model:
type: cifar
name: resnet
init_mode: kaiming_fan_out
resnet:
depth: 110
initial_channels: 16
block_type: basic
apply_l2_norm: True
train:
resume_epoch: 0
seed: 1
batch_size: 256
optimizer: sgd
base_lr: 0.1
momentum: 0.9
nesterov: True
weight_decay: 1e-4
output_dir: cifar10h_noise_15/co_teaching
log_period: 100
checkpoint_period: 100
use_co_teaching: True
co_teaching_forget_rate: 0.15
co_teaching_num_gradual: 10
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
validation:
batch_size: 2048
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
test:
batch_size: 128
dataloader:
num_workers: 2
pin_memory: False
scheduler:
epochs: 160
type: multistep
milestones: [80, 120]
lr_decay: 0.1
augmentation:
use_random_horizontal_flip: True
use_label_smoothing: False
use_random_affine: True
use_random_color: True
random_horizontal_flip:
prob: 0.5
random_affine:
max_angle: 15
max_horizontal_shift: 0.05
max_vertical_shift: 0.05
max_shear: 5
random_color:
brightness: 0.5
contrast: 0.5
saturation: 0.5

Просмотреть файл

@ -0,0 +1,64 @@
device: cuda
cudnn:
benchmark: True
deterministic: True
dataset:
name: CIFAR10H
noise_temperature: 2.0
num_samples: 5000
train:
resume_epoch: 0
seed: 2
batch_size: 256
optimizer: sgd
base_lr: 0.01
momentum: 0.9
nesterov: True
weight_decay: 1e-3
output_dir: experiments/cifar10h_noise_15/vanilla_self_supervision_v3
log_period: 100
checkpoint_period: 100
tanh_regularisation: 0.30
use_self_supervision: True
self_supervision:
checkpoints: ["cifar10h/self_supervised/lightning_logs/BYOL_seed_1/checkpoints/last.ckpt"]
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
validation:
batch_size: 512
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
test:
batch_size: 128
dataloader:
num_workers: 2
pin_memory: False
scheduler:
epochs: 120
type: multistep
milestones: [90]
lr_decay: 0.1
augmentation:
use_random_horizontal_flip: True
use_label_smoothing: True
use_random_affine: False
use_random_color: True
random_crop:
scale: (0.9, 1.0)
random_horizontal_flip:
prob: 0.5
random_color:
brightness: 0.5
contrast: 0.5
saturation: 0.5

Просмотреть файл

@ -0,0 +1,72 @@
device: cuda
cudnn:
benchmark: True
deterministic: True
dataset:
name: CIFAR10H
noise_temperature: 10.0
num_samples: 5000
model:
type: cifar
name: resnet
init_mode: kaiming_fan_out
resnet:
depth: 110
initial_channels: 16
block_type: basic
apply_l2_norm: True
train:
resume_epoch: 0
seed: 1
batch_size: 256
optimizer: sgd
base_lr: 0.1
momentum: 0.9
nesterov: True
weight_decay: 1e-4
output_dir: experiments/cifar10h_noise_30/vanilla_model
log_period: 100
checkpoint_period: 100
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
validation:
batch_size: 512
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
test:
batch_size: 512
dataloader:
num_workers: 2
pin_memory: False
scheduler:
epochs: 160
type: multistep
milestones: [80, 120]
lr_decay: 0.1
augmentation:
use_random_horizontal_flip: True
use_label_smoothing: False
use_random_affine: True
use_random_color: True
random_horizontal_flip:
prob: 0.5
random_affine:
max_angle: 15
max_horizontal_shift: 0.05
max_vertical_shift: 0.05
max_shear: 5
random_color:
brightness: 0.5
contrast: 0.5
saturation: 0.5

Просмотреть файл

@ -0,0 +1,75 @@
device: cuda
cudnn:
benchmark: True
deterministic: True
dataset:
name: CIFAR10H
noise_temperature: 10.0
num_samples: 5000
model:
type: cifar
name: resnet
init_mode: kaiming_fan_out
resnet:
depth: 110
initial_channels: 16
block_type: basic
apply_l2_norm: True
train:
resume_epoch: 0
seed: 1
batch_size: 256
optimizer: sgd
base_lr: 0.1
momentum: 0.9
nesterov: True
weight_decay: 1e-4
output_dir: experiments/cifar10h_noise_30/co_teaching
log_period: 100
checkpoint_period: 100
use_co_teaching: True
co_teaching_forget_rate: 0.28
co_teaching_num_gradual: 10
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
validation:
batch_size: 2048
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
test:
batch_size: 128
dataloader:
num_workers: 2
pin_memory: False
scheduler:
epochs: 160
type: multistep
milestones: [80, 120]
lr_decay: 0.1
augmentation:
use_random_horizontal_flip: True
use_label_smoothing: False
use_random_affine: True
use_random_color: True
random_horizontal_flip:
prob: 0.5
random_affine:
max_angle: 15
max_horizontal_shift: 0.05
max_vertical_shift: 0.05
max_shear: 5
random_color:
brightness: 0.5
contrast: 0.5
saturation: 0.5

Просмотреть файл

@ -0,0 +1,67 @@
device: cuda:2
cudnn:
benchmark: True
deterministic: True
dataset:
name: CIFAR10H
noise_temperature: 10.0
num_samples: 5000
model:
type: cifar
name: resnet
init_mode: kaiming_fan_out
use_dropout: False
resnet:
depth: 110
initial_channels: 16
block_type: basic
apply_l2_norm: True
train:
use_elr: True
resume_epoch: 0
seed: 0
batch_size: 256
optimizer: sgd
base_lr: 0.1
momentum: 0.9
nesterov: True
weight_decay: 1e-4
output_dir: experiments/cifar10h_noise_30/elr
log_period: 100
checkpoint_period: 100
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
validation:
batch_size: 512
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
test:
batch_size: 512
dataloader:
num_workers: 2
pin_memory: False
scheduler:
epochs: 160
type: multistep
milestones: [80, 120]
lr_decay: 0.1
augmentation:
use_random_horizontal_flip: True
use_label_smoothing: False
use_random_affine: True
use_random_color: True
random_horizontal_flip:
prob: 0.5
random_affine:
max_angle: 15
max_horizontal_shift: 0.05
max_vertical_shift: 0.05
max_shear: 5
random_color:
brightness: 0.5
contrast: 0.5
saturation: 0.5

Просмотреть файл

@ -0,0 +1,64 @@
device: cuda
cudnn:
benchmark: True
deterministic: True
dataset:
name: CIFAR10H
noise_temperature: 10.0
num_samples: 5000
train:
resume_epoch: 0
seed: 2
batch_size: 256
optimizer: sgd
base_lr: 0.01
momentum: 0.9
nesterov: True
weight_decay: 1e-3
output_dir: experiments/cifar10h_noise_30/vanilla_self_supervision_v3
log_period: 100
checkpoint_period: 100
tanh_regularisation: 0.30
use_self_supervision: True
self_supervision:
checkpoints: ["logs/cifar10h/self_supervised/byol_seed_1/checkpoints/last.ckpt"]
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
validation:
batch_size: 512
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
test:
batch_size: 128
dataloader:
num_workers: 2
pin_memory: False
scheduler:
epochs: 120
type: multistep
milestones: [90]
lr_decay: 0.1
augmentation:
use_random_horizontal_flip: True
use_label_smoothing: True
use_random_affine: False
use_random_color: True
random_crop:
scale: (0.9, 1.0)
random_horizontal_flip:
prob: 0.5
random_color:
brightness: 0.5
contrast: 0.5
saturation: 0.5

Просмотреть файл

@ -0,0 +1,66 @@
device: cuda
cudnn:
benchmark: True
deterministic: True
dataset:
name: CIFAR10H
noise_temperature: 10.0
num_samples: 5000
model:
type: cifar
name: resnet
init_mode: kaiming_fan_out
use_dropout: True
resnet:
depth: 110
initial_channels: 16
block_type: basic
apply_l2_norm: True
train:
resume_epoch: 0
seed: 1
batch_size: 256
optimizer: sgd
base_lr: 0.1
momentum: 0.9
nesterov: True
weight_decay: 1e-4
output_dir: experiments/cifar10h_noise_30/vanilla_with_dropout
log_period: 100
checkpoint_period: 100
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
validation:
batch_size: 512
dataloader:
num_workers: 2
drop_last: False
pin_memory: False
test:
batch_size: 512
dataloader:
num_workers: 2
pin_memory: False
scheduler:
epochs: 160
type: multistep
milestones: [80, 120]
lr_decay: 0.1
augmentation:
use_random_horizontal_flip: True
use_label_smoothing: False
use_random_affine: True
use_random_color: True
random_horizontal_flip:
prob: 0.5
random_affine:
max_angle: 15
max_horizontal_shift: 0.05
max_vertical_shift: 0.05
max_shear: 5
random_color:
brightness: 0.5
contrast: 0.5
saturation: 0.5

Просмотреть файл

@ -0,0 +1,62 @@
device: cuda:0
cudnn:
benchmark: True
deterministic: True
dataset:
name: NoisyChestXray
dataset_dir: /datadrive/rsna_train
n_channels: 3
n_classes: 2
cxr_consolidation_noise_rate: 0.10
model:
type: cxr
name: resnet50
train:
use_balanced_sampler: True
pretrained: False
seed: 5
batch_size: 32
optimizer: adam
base_lr: 1e-5
output_dir: cxr/nih10/vanilla
log_period: 5
checkpoint_period: 20
use_co_teaching: False
use_teacher_model: False
dataloader:
num_workers: 6
drop_last: False
pin_memory: False
validation:
batch_size: 32
dataloader:
num_workers: 6
drop_last: False
pin_memory: False
preprocess:
center_crop_size: 224
resize: 256
scheduler:
epochs: 150
type: constant
augmentation:
use_random_horizontal_flip: True
use_random_affine: True
use_random_color: True
use_random_crop: True
use_gamma_transform: False
use_random_erasing: False
add_gaussian_noise: False
random_horizontal_flip:
prob: 0.5
random_affine:
max_angle: 30
max_horizontal_shift: 0.00
max_vertical_shift: 0.00
max_shear: 15
random_color:
brightness: 0.2
contrast: 0.2
saturation: 0.0
random_crop:
scale: (0.8, 1.0)

Просмотреть файл

@ -0,0 +1,67 @@
device: cuda:3
cudnn:
benchmark: True
deterministic: True
dataset:
name: NoisyChestXray
dataset_dir: /datadrive/rsna_train
n_channels: 3
n_classes: 2
cxr_consolidation_noise_rate: 0.10
model:
type: cxr
name: resnet50
train:
use_balanced_sampler: True
pretrained: False
seed: 5
batch_size: 32
optimizer: adam
base_lr: 1e-5
output_dir: cxr/nih10/co_teaching
log_period: 5
checkpoint_period: 20
use_co_teaching: True
use_teacher_model: False
co_teaching_consistency_loss: False
co_teaching_use_graph: False
co_teaching_forget_rate: 0.15
co_teaching_num_gradual: 0
co_teaching_num_warmup: 20
dataloader:
num_workers: 6
drop_last: False
pin_memory: False
validation:
batch_size: 16
dataloader:
num_workers: 6
drop_last: False
pin_memory: False
preprocess:
center_crop_size: 224
resize: 256
scheduler:
epochs: 150
type: constant
augmentation:
use_random_horizontal_flip: True
use_random_affine: True
use_random_color: True
use_random_crop: True
use_gamma_transform: False
use_random_erasing: False
add_gaussian_noise: False
random_horizontal_flip:
prob: 0.5
random_affine:
max_angle: 30
max_horizontal_shift: 0.00
max_vertical_shift: 0.00
max_shear: 15
random_color:
brightness: 0.2
contrast: 0.2
saturation: 0.0
random_crop:
scale: (0.8, 1.0)

Просмотреть файл

@ -0,0 +1,72 @@
device: cuda:3
cudnn:
benchmark: True
deterministic: True
dataset:
name: NoisyChestXray
dataset_dir: /datadrive/rsna_train
n_channels: 3
n_classes: 2
cxr_consolidation_noise_rate: 0.10
model:
type: cxr
name: resnet50
train:
use_balanced_sampler: True
pretrained: False
seed: 5
batch_size: 32
optimizer: adam
base_lr: 1e-6
output_dir: cxr/nih10/co_ssl
log_period: 5
checkpoint_period: 20
use_co_teaching: True
co_teaching_consistency_loss: False
co_teaching_use_graph: False
co_teaching_forget_rate: 0.15
co_teaching_num_gradual: 0
co_teaching_num_warmup: 10
use_teacher_model: False
use_self_supervision: True
self_supervision:
freeze_encoder: False
checkpoints: ["/datadrive/InnerEye-DataQuality/logs/nih/ssup/byol_seed_3/checkpoints/epoch=999.ckpt",
"/datadrive/InnerEye-DataQuality/logs/nih/ssup/byol_seed_3/checkpoints/epoch=999.ckpt"]
dataloader:
num_workers: 6
drop_last: False
pin_memory: False
validation:
batch_size: 32
dataloader:
num_workers: 6
drop_last: False
pin_memory: False
scheduler:
epochs: 100
type: constant
preprocess:
center_crop_size: 224
resize: 256
augmentation:
use_random_horizontal_flip: True
use_random_affine: True
use_random_color: True
use_random_crop: True
use_gamma_transform: False
use_random_erasing: False
add_gaussian_noise: False
random_horizontal_flip:
prob: 0.5
random_affine:
max_angle: 30
max_horizontal_shift: 0.00
max_vertical_shift: 0.00
max_shear: 15
random_color:
brightness: 0.2
contrast: 0.2
saturation: 0.0
random_crop:
scale: (0.8, 1.0)

Просмотреть файл

@ -0,0 +1,5 @@
selector:
type: ['PosteriorBasedSelectorJoint']
model_name: 'Coteaching'
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet_co_teaching.yaml'

Просмотреть файл

@ -0,0 +1,5 @@
selector:
type: ['PosteriorBasedSelectorJoint']
model_name: 'Vanilla'
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet.yaml'

Просмотреть файл

@ -0,0 +1,5 @@
selector:
type: ['PosteriorBasedSelectorJoint']
model_name: 'Self-Supervision-Joint'
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet_self_supervision_v3.yaml'

Просмотреть файл

@ -0,0 +1,6 @@
selector:
type: ['PosteriorBasedSelectorJoint']
model_name: 'Coteaching'
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_co_teaching.yaml'
output_directory: 'outputs/cifar10h_30/coteaching'

Просмотреть файл

@ -0,0 +1,6 @@
selector:
type: ['PosteriorBasedSelectorJoint']
model_name: 'Coteaching-Active'
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_co_teaching.yaml'
use_active_relabelling: True

Просмотреть файл

@ -0,0 +1,6 @@
selector:
type: ['PosteriorBasedSelectorJoint']
model_name: 'Vanilla'
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet.yaml'
output_directory: 'outputs/cifar10h_30/vanilla'

Просмотреть файл

@ -0,0 +1,6 @@
selector:
type: ['PosteriorBasedSelectorJoint']
model_name: 'Vanilla-Active'
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet.yaml'
use_active_relabelling: True

Просмотреть файл

@ -0,0 +1,6 @@
selector:
type: ['BaldSelector']
model_name: 'BALD-MC-Active'
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_with_dropout.yaml'
use_active_relabelling: True

Просмотреть файл

@ -0,0 +1,6 @@
selector:
type: ['PosteriorBasedSelectorJoint']
model_name: 'Self-Supervision'
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_self_supervision_v3.yaml'
output_directory: 'outputs/cifar10h_30/ssup_v3'

Просмотреть файл

@ -0,0 +1,6 @@
selector:
type: ['PosteriorBasedSelectorJoint']
model_name: 'Self-Supervision-Active'
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_self_supervision_v3.yaml'
use_active_relabelling: True

Просмотреть файл

@ -0,0 +1,4 @@
selector:
type: ['PosteriorBasedSelector']
model_name: 'Coteaching'
model_config_path: 'InnerEyeDataQuality/configs/models/cxr/resnet_coteaching.yaml'

Просмотреть файл

@ -0,0 +1,5 @@
selector:
type: ['PosteriorBasedSelector']
model_name: 'SSL+Coteaching'
model_config_path: 'InnerEyeDataQuality/configs/models/cxr/resnet_ssl_coteaching.yaml'
use_active_relabelling: False

Просмотреть файл

@ -0,0 +1,5 @@
selector:
type: ['PosteriorBasedSelector']
model_name: 'SSL+Coteaching+Active'
model_config_path: 'InnerEyeDataQuality/configs/models/cxr/resnet_ssl_coteaching.yaml'
use_active_relabelling: True

Просмотреть файл

@ -0,0 +1,4 @@
selector:
type: ['PosteriorBasedSelector']
model_name: 'Vanilla'
model_config_path: 'InnerEyeDataQuality/configs/models/cxr/resnet.yaml'

Просмотреть файл

@ -0,0 +1,5 @@
selector:
type: ['PosteriorBasedSelector']
model_name: 'Vanilla (heavy aug) Active'
model_config_path: 'InnerEyeDataQuality/configs/models/cxr/resnet.yaml'
use_active_relabelling: True

Просмотреть файл

@ -0,0 +1,30 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from InnerEyeDataQuality.configs.config_node import ConfigNode
config = ConfigNode()
# data selector
config.selector = ConfigNode()
config.selector.type = None
config.selector.model_name = None
config.selector.model_config_path = None
config.selector.use_active_relabelling = False
# Other selector parameters (unused)
config.selector.training_dynamics_data_path = None
config.selector.burnout_period = 0
config.selector.number_samples_to_relabel = 10
# output files
config.selector.output_directory = None
# tensorboard
config.tensorboard = ConfigNode()
config.tensorboard.save_events = False
def get_default_selector_config() -> ConfigNode:
return config.clone()

Просмотреть файл

@ -0,0 +1,140 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from typing import Callable, Optional, Tuple
import numpy as np
import torchvision
import PIL.Image
from default_paths import FIGURE_DIR
from InnerEyeDataQuality.datasets.label_distribution import LabelDistribution
from InnerEyeDataQuality.datasets.label_noise_model import get_cifar10h_confusion_matrix, \
get_cifar10_asym_noise_model, get_cifar10_sym_noise_model
from InnerEyeDataQuality.utils.plot import plot_confusion_matrix
from InnerEyeDataQuality.datasets.cifar10_utils import get_cifar10_label_names
class CIFAR10AsymNoise(torchvision.datasets.CIFAR10):
"""
Dataset class for the CIFAR10 dataset where target labels are sampled from a confusion matrix.
"""
def __init__(self,
root: str,
train: bool,
transform: Optional[Callable] = None,
download: bool = True,
use_fixed_labels: bool = True,
seed: int = 1234,
) -> None:
"""
:param root: The directory in which the CIFAR10 images will be stored.
:param train: If True, creates dataset from training set, otherwise creates from test set.
:param transform: Transform to apply to the images.
:param download: Whether to download the dataset if it is not already in the local disk.
:param use_fixed_labels: If true labels are sampled only once and are kept fixed. If false labels are sampled at
each get_item() function call from label distribution.
:param seed: The random seed that defines which samples are train/test and which labels are sampled.
"""
super().__init__(root, train=train, transform=transform, target_transform=None, download=download)
self.seed = seed
self.targets = np.array(self.targets, dtype=np.int64) # type: ignore
self.num_classes = np.unique(self.targets, return_counts=False).size
self.num_samples = len(self.data)
self.label_counts = np.eye(self.num_classes, dtype=np.int64)[self.targets]
self.np_random_state = np.random.RandomState(seed)
self.use_fixed_labels = use_fixed_labels
self.clean_targets = np.argmax(self.label_counts, axis=1)
logging.info(f"Preparing dataset: CIFAR10-Asym-Noise (N={self.num_samples})")
# Create label distribution for simulation of label adjudication
self.label_distribution = LabelDistribution(seed, self.label_counts, temperature=1.0)
# Add asymmetric noise on the labels
self.noise_model = self.create_noise_transition_model(self.targets, self.num_classes, "cifar10_sym")
# Sample fixed labels from the distribution
if use_fixed_labels:
self.targets = self.sample_labels_from_model()
# Identify label noise cases
noise_rate = np.mean(self.clean_targets != self.targets) * 100.0
# Check the class distribution after sampling
class_distribution = self.get_class_frequencies(targets=self.targets, num_classes=self.num_classes)
# Log dataset details
logging.info(f"Class distribution (%) (true labels): {class_distribution * 100.0}")
logging.info(f"Label noise rate: {noise_rate}")
@staticmethod
def create_noise_transition_model(labels: np.ndarray, num_classes: int, noise_model: str) -> np.ndarray:
logging.info(f"Using {noise_model} label noise model")
if noise_model == "cifar10h":
transition_matrix = get_cifar10h_confusion_matrix(temperature=2.0)
elif noise_model == "cifar10_asym":
transition_matrix = get_cifar10_asym_noise_model(eta=0.4)
elif noise_model == "cifar10_sym":
transition_matrix = get_cifar10_sym_noise_model(eta=0.4)
else:
raise ValueError("Unknown noise transition model")
assert(np.all(np.sum(transition_matrix, axis=1) - 1.00 < 1e-6)) # Checks = it sums up to one.
# Visualise the noise model
plot_confusion_matrix(list(), list(), get_cifar10_label_names(), cm=transition_matrix, save_path=FIGURE_DIR)
# Compute the expected noise rate
assert labels.ndim == 1
exp_noise_rate = 1.0 - np.sum(
np.diag(transition_matrix) * CIFAR10AsymNoise.get_class_frequencies(labels, num_classes))
logging.info(f"Expected noise rate (transition model): {exp_noise_rate}")
return transition_matrix
def sample_labels_from_model(self, sample_index: Optional[int] = None) -> np.ndarray:
# Sample based on the transition matrix and original labels (labels, transition mat, seed)
if sample_index is not None:
cur_label = self.targets[sample_index]
label = self.np_random_state.choice(self.num_classes, 1, p=self.noise_model[cur_label, :])[0]
return label
noisy_targets = np.zeros_like(self.targets)
for ii in range(self.num_samples):
cur_label = self.targets[ii]
noisy_targets[ii] = self.np_random_state.choice(self.num_classes, 1, p=self.noise_model[cur_label, :])[0]
return noisy_targets
@staticmethod
def get_class_frequencies(targets: np.ndarray, num_classes: int) -> np.ndarray:
"""
Returns normalised frequency of each semantic class
"""
assert targets.ndim == 1
class_ids, class_counts = np.unique(targets, return_counts=True)
class_distribution = np.zeros(num_classes, dtype=np.float)
for ii in range(num_classes):
if np.any(class_ids == ii):
class_distribution[ii] = class_counts[class_ids == ii]/targets.size
return class_distribution
def __getitem__(self, index: int) -> Tuple[PIL.Image.Image, int]:
"""
:param index: The index of the sample to be fetched
:return: The image and label tensors
"""
img = PIL.Image.fromarray(self.data[index])
if self.transform is not None:
img = self.transform(img)
if self.use_fixed_labels:
target = self.targets[index]
else:
target = self.sample_labels_from_model(sample_index=index)
return img, int(target)

Просмотреть файл

@ -0,0 +1,153 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from typing import Callable, Optional, Tuple, List, Generator
import numpy as np
import PIL.Image
import torch
import torchvision
from InnerEyeDataQuality.datasets.label_distribution import LabelDistribution
from InnerEyeDataQuality.datasets.tools import get_instance_noise_model
from InnerEyeDataQuality.utils.generic import convert_labels_to_one_hot
from pl_bolts.models.self_supervised.resnets import resnet50_bn
from InnerEyeDataQuality.datasets.cifar10h import CIFAR10H
def chunks(lst: List, n: int) -> Generator:
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]
class CIFAR10IDN(torchvision.datasets.CIFAR10):
"""
Dataset class for the CIFAR10 dataset where target labels are sampled from a confusion matrix.
"""
def __init__(self,
root: str,
train: bool,
noise_rate: float,
transform: Optional[Callable] = None,
download: bool = True,
use_fixed_labels: bool = True,
seed: int = 1
) -> None:
"""
:param root: The directory in which the CIFAR10 images will be stored.
:param train: If True, creates dataset from training set, otherwise creates from test set.
:param transform: Transform to apply to the images.
:param download: Whether to download the dataset if it is not already in the local disk.
:param noise_rate: Expected noise rate in the sampled labels.
:param use_fixed_labels: If true labels are sampled only once and are kept fixed. If false labels are sampled at
each get_item() function call from label distribution.
:param seed: The random seed that defines which samples are train/test and which labels are sampled.
"""
super().__init__(root, train=train, transform=transform, target_transform=None, download=download)
self.seed = seed
self.targets = np.array(self.targets, dtype=np.int64) # type: ignore
self.num_classes = np.unique(self.targets, return_counts=False).size
self.num_samples = len(self.data)
self.clean_targets = np.copy(self.targets)
self.np_random_state = np.random.RandomState(seed)
self.use_fixed_labels = use_fixed_labels
self.indices = np.array(range(self.num_samples))
logging.info(f"Preparing dataset: CIFAR10-IDN (N={self.num_samples})")
# Set seed for torch operations
initial_state = torch.get_rng_state()
torch.manual_seed(self.seed)
torch.cuda.manual_seed_all(self.seed)
# Collect image embeddings
embeddings = self.get_cnn_image_embeddings(self.data)
targets = torch.from_numpy(self.targets)
# Create label distribution for simulation of label adjudication
label_counts = convert_labels_to_one_hot(self.clean_targets, n_classes=self.num_classes) if train else \
CIFAR10H.download_cifar10h_labels(self.root)
self.label_distribution = LabelDistribution(seed, label_counts, temperature=1.0)
# Add asymmetric noise on the labels
self.noise_models = get_instance_noise_model(n=noise_rate,
dataset=zip(embeddings, targets),
labels=targets,
num_classes=self.num_classes,
feature_size=embeddings.shape[1],
norm_std=0.01,
seed=self.seed)
if self.use_fixed_labels:
# Sample target labels
self.targets = self.sample_labels_from_model()
# Check the class distribution after sampling
class_distribution = self.get_class_frequencies(targets=self.targets, num_classes=self.num_classes)
noise_rate = np.mean(self.clean_targets != self.targets) * 100.0
# Log dataset details
logging.info(f"Class distribution (%) (true labels): {class_distribution * 100.0}")
logging.info(f"Label noise rate: {noise_rate}")
else:
self.targets = None
# Restore initial state
torch.set_rng_state(initial_state)
def sample_labels_from_model(self, sample_index: Optional[int] = None) -> np.ndarray:
"""
Samples class labels for each data point based on true label and instance dependent noise model
"""
classes = [i for i in range(self.num_classes)]
if sample_index is not None:
_t = self.np_random_state.choice(classes, p=self.noise_models[sample_index])
else:
_t = [self.np_random_state.choice(classes, p=self.noise_models[i]) for i in range(self.num_samples)]
return np.array(_t)
@staticmethod
def get_class_frequencies(targets: np.ndarray, num_classes: int) -> np.ndarray:
"""
Returns normalised frequency of each semantic class
"""
assert targets.ndim == 1
class_ids, class_counts = np.unique(targets, return_counts=True)
class_distribution = np.zeros(num_classes, dtype=np.float)
for ii in range(num_classes):
if np.any(class_ids == ii):
class_distribution[ii] = class_counts[class_ids == ii]/targets.size
return class_distribution
@staticmethod
def get_cnn_image_embeddings(data: np.ndarray) -> torch.Tensor:
"""
Extracts image embeddings using a pre-trained model
"""
num_samples = data.shape[0]
embeddings = list()
data = torch.from_numpy(data).float().cuda()
encoder = resnet50_bn(return_all_feature_maps=False, pretrained=True).cuda().eval()
with torch.no_grad():
for i in chunks(list(range(num_samples)), n=100):
input = data[i].permute(0, 3, 1, 2)
embeddings.append(encoder(input)[-1].cpu())
return torch.cat(embeddings, dim=0).view(num_samples, -1)
def __getitem__(self, index: int) -> Tuple[PIL.Image.Image, int]:
"""
:param index: The index of the sample to be fetched
:return: The image and label tensors
"""
img = PIL.Image.fromarray(self.data[index])
if self.transform is not None:
img = self.transform(img)
if self.use_fixed_labels:
target = self.targets[index]
else:
target = self.sample_labels_from_model(sample_index=index)
return img, int(target)

Просмотреть файл

@ -0,0 +1,165 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import io
import logging
import os
import pickle
import tarfile
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
import requests
logging.getLogger('matplotlib').setLevel(logging.WARNING)
def split_dataset(labels: np.ndarray,
reference_split: float,
shuffle: bool = True,
seed: int = 1234) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
:param labels: Complete label array that is split into two subsets, reference and noisy test set.
:param reference_split: Ratio of total samples that will be put in the reference set.
:param shuffle: If set to true, rows of the label matrix are shuffled prior to split.
:param seed: Random seed used in sample shuffle
"""
# Create two testing sets from CIFAR10H Test Samples (10k - 10 classes)
num_samples = labels.shape[0]
num_samples_set1 = int(num_samples * reference_split)
perm = np.random.RandomState(seed=seed).permutation(num_samples) if shuffle else np.array(range(num_samples))
d_set1 = labels[perm[:num_samples_set1], :]
d_set2 = labels[perm[num_samples_set1:], :]
return d_set1, d_set2, perm
def get_cifar10h_labels(reference_split: float = 0.5,
shuffle: bool = True,
immutable: bool = True) -> Tuple[np.ndarray, np.ndarray]:
"""
:param reference_split: The sample ratio between gold standard test set and full test set
:param shuffle: Shuffle samples prior to split.
:param immutable: If set to True, the returned arrays are only read only.
"""
cifar10h_counts = download_cifar10h_data() # Num_samples x n_classes
d_split1_standard, d_split2_complete, permutation = split_dataset(cifar10h_counts, reference_split, shuffle)
d_split1_permutation = permutation[:d_split1_standard.shape[0]]
if immutable:
d_split1_standard.setflags(write=False)
d_split2_complete.setflags(write=False)
return d_split1_standard, d_split1_permutation
def download_cifar10h_data() -> np.ndarray:
"""
Pulls cifar10h label data stream and returns it in numpy array.
"""
url = 'https://raw.githubusercontent.com/jcpeterson/cifar-10h/master/data/cifar10h-counts.npy'
response = requests.get(url)
response.raise_for_status()
if response.status_code == requests.codes.ok:
cifar10h_data = np.load(io.BytesIO(response.content))
else:
raise ValueError('CIFAR10H content was not found.')
return cifar10h_data
def download_cifar10_data() -> Path:
"""
Download CIFAR10 dataset and returns path to the test set
"""
import wget
local_path = Path.cwd() / 'InnerEyeDataQuality' / 'downloaded_data'
local_path_to_test_batch = local_path / 'cifar-10-batches-py/test_batch'
if not local_path_to_test_batch.exists():
local_path.mkdir(parents=True, exist_ok=True)
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
path_to_tar = local_path / 'cifar10.tar.gz'
wget.download(url, str(path_to_tar))
tf = tarfile.open(str(path_to_tar))
tf.extractall(local_path)
os.remove(path_to_tar)
return local_path_to_test_batch
def get_cifar10_label_names(file: Optional[Path] = None) -> List[str]:
"""
TBD
"""
if file:
dict = load_cifar10_file(file)
label_names = [_s.decode("utf-8") for _s in dict[b"label_names"]]
else:
label_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
return label_names
def load_cifar10_file(file: Path) -> Dict:
"""
TBD
"""
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
def plot_cifar10_images(sample_ids: List[int], save_directory: Path) -> None:
"""
Displays a set of CIFAR-10 Images based on the input sample ids. In the title of the figure,
label distribution is displayed as well to understand the sample difficulty.
"""
path_cifar10_test_batch = download_cifar10_data()
test_batch = load_cifar10_file(path_cifar10_test_batch)
plot_images(test_batch[b'data'], sample_ids, save_directory=save_directory)
def plot_images(images: np.ndarray,
selected_sample_ids: List[int],
cifar10h_labels: Optional[np.ndarray] = None,
label_names: Optional[List[str]] = None,
save_directory: Optional[Path] = None) -> None:
"""
Displays a set of CIFAR-10 Images based on the input sample ids. In the title of the figure,
label distribution is displayed as well to understand the sample difficulty.
"""
f, ax = plt.subplots(figsize=(2, 2))
for sample_id in selected_sample_ids:
img = np.reshape(images[sample_id, :], (3, 32, 32)).transpose(1, 2, 0)
ax.imshow(img)
if (cifar10h_labels is not None) and (label_names is not None):
num_k_classes = 3
label_distribution = cifar10h_labels[sample_id, :]
k_min_val = np.sort(label_distribution)[-num_k_classes]
available_classes = np.where(label_distribution >= k_min_val)[0]
class_counts = label_distribution[available_classes]
class_names = [label_names[_c] for _c in available_classes]
ax_title = ''.join([a + '_' + str(b) + ' ' for a, b in zip(class_names, class_counts)])
ax.set_title(ax_title)
else:
ax_title = f'CIFAR10H - Sample ID {sample_id}'
ax.set_title(ax_title)
if save_directory:
save_directory.mkdir(parents=True, exist_ok=True)
f.savefig(save_directory / f"{sample_id}.png", bbox_inches='tight')
ax.clear()
plt.close(f)
else:
plt.show()
f, ax = plt.subplots(figsize=(2, 2))

Просмотреть файл

@ -0,0 +1,184 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import io
import logging
from pathlib import Path
from typing import Callable, List, Optional, Tuple
import PIL.Image
import numpy as np
import requests
import torchvision
from InnerEyeDataQuality.datasets.cifar10_utils import get_cifar10_label_names
from InnerEyeDataQuality.datasets.label_distribution import LabelDistribution
from InnerEyeDataQuality.evaluation.metrics import compute_label_entropy
from InnerEyeDataQuality.selection.simulation_statistics import SimulationStats, get_ambiguous_sample_ids
from InnerEyeDataQuality.utils.generic import convert_labels_to_one_hot
TOTAL_CIFAR10H_DATASET_SIZE = 10000
class CIFAR10H(torchvision.datasets.CIFAR10):
"""
Dataset class for the CIFAR10H dataset. The CIFAR10H dataset is the CIFAR10 test set but all the samples have
been labelled my multiple annotators
"""
def __init__(self,
root: str,
transform: Optional[Callable] = None,
num_samples: Optional[int] = None,
preset_indices: Optional[np.ndarray] = None,
seed: int = 1234,
noise_temperature: float = 1.0,
noise_offset: float = 0.0) -> None:
"""
:param root: The directory in which the CIFAR10 images will be stored
:param transform: Transform to apply to the images
:param num_samples: The number of samples to use out of a maximum of TOTAL_CIFAR10H_DATASET_SIZE
:param seed: The random seed that defines which samples are train/test and which labels are sampled
:param preset_indices: Image indices that will be used to create a subset of CIFAR10H dataset.
If not specified and num_samples < 10000 then random sub-selection is performed.
:param shuffle: Whether to shuffle the data before splitting into training and test sets.
:param noise_temperature: A temperature a value that is used to temperature scale the label distribution.
:param noise_offset: Offset parameter to control the noise rate in sampling initial labels.
"""
super().__init__(root, train=False, transform=transform, target_transform=None, download=True)
num_samples = TOTAL_CIFAR10H_DATASET_SIZE if num_samples is None else num_samples
self.seed = seed
cifar10h_labels = self.download_cifar10h_labels(self.root)
self.num_classes = cifar10h_labels.shape[1]
assert cifar10h_labels.shape[0] == TOTAL_CIFAR10H_DATASET_SIZE
assert self.num_classes == 10
assert 0 < num_samples <= TOTAL_CIFAR10H_DATASET_SIZE
self.num_samples = num_samples
# Create a set indices that
self.indices = self.get_dataset_indices(num_samples, cifar10h_labels, keep_hard_samples=True, seed=seed) \
if preset_indices is None else preset_indices
self.verify_data_indices()
self.label_counts = cifar10h_labels[self.indices]
self.label_counts.flags.writeable = False
self.true_label_entropy = compute_label_entropy(label_counts=self.label_counts)
self.label_distribution = LabelDistribution(seed, self.label_counts, noise_temperature, noise_offset)
self.targets = self.label_distribution.sample_initial_labels_for_all()
# Check the class distribution
_, class_counts = np.unique(self.targets, return_counts=True)
class_distribution = np.array([_c/self.num_samples for _c in class_counts])
logging.info(f"Preparing dataset: CIFAR10H (N={self.num_samples})")
logging.info(f"Class distribution (%) (true labels): {class_distribution * 100.0}")
self.clean_targets = np.argmax(self.label_counts, axis=1)
# Identify true ambiguous and clear label noise cases
self._identify_sample_types()
def _identify_sample_types(self) -> None:
"""
Stores and logs clear label noise and ambiguous case types.
"""
label_stats = SimulationStats(name="cifar10h", true_label_counts=self.label_counts,
initial_labels=convert_labels_to_one_hot(self.targets, self.num_classes))
self.ambiguous_mislabelled_cases = label_stats.mislabelled_ambiguous_sample_ids[0]
self.clear_mislabeled_cases = label_stats.mislabelled_not_ambiguous_sample_ids[0]
self.ambiguity_metric_args = {"ambiguous_mislabelled_ids": self.ambiguous_mislabelled_cases,
"clear_mislabelled_ids": self.clear_mislabeled_cases,
"true_label_entropy": self.true_label_entropy}
# Log dataset details
logging.info(f"Ambiguous mislabeled cases: {100 * len(self.ambiguous_mislabelled_cases) / self.num_samples}%")
logging.info(f"Clear mislabeled cases: {100 * len(self.clear_mislabeled_cases) / self.num_samples}%\n")
@classmethod
def download_cifar10h_labels(self, root: str = ".") -> np.ndarray:
"""
Pulls cifar10h label data stream and returns it in numpy array.
"""
try:
cifar10h_labels = np.load(Path(root) / "cifar10h-counts.npy")
except FileNotFoundError:
url = 'https://raw.githubusercontent.com/jcpeterson/cifar-10h/master/data/cifar10h-counts.npy'
response = requests.get(url)
response.raise_for_status()
if response.status_code == requests.codes.ok:
cifar10h_labels = np.load(io.BytesIO(response.content))
else:
raise ValueError('Failed to download CIFAR10H labels!')
return cifar10h_labels
def __getitem__(self, index: int) -> Tuple[PIL.Image.Image, int]:
"""
:param index: The index of the sample to be fetched
:return: The image and label tensors
"""
img = PIL.Image.fromarray(self.data[self.indices[index]])
target = self.targets[index]
if self.transform is not None:
img = self.transform(img)
return img, int(target)
def __len__(self) -> int:
"""
:return: The size of the dataset
"""
return len(self.indices)
def get_label_names(self) -> List[str]:
return get_cifar10_label_names()
def verify_data_indices(self) -> None:
assert isinstance(self.indices, np.ndarray)
assert self.indices.size == self.num_samples
assert np.all(self.indices < TOTAL_CIFAR10H_DATASET_SIZE)
assert np.all(0 <= self.indices)
_, c = np.unique(self.indices, return_counts=True)
assert np.all(c == 1)
def get_dataset_indices(self,
num_samples: int,
true_label_counts: np.ndarray,
keep_hard_samples: bool,
seed: int = 1234) -> np.ndarray:
"""
Function to choose a subset of the CIFAR10H dataset. Returns selected subset of sample
indices in a shuffled order.
:param num_samples: Number of samples in the selected subset.
If the full dataset size is specified then indices are just shuffled and returned.
:param true_label_counts: True label counts of CIFAR10H images (num_samples x num_classes)
:param keep_hard_samples: If set to True, all hard examples are kept in the selected subset of points.
:param seed: Random seed used in shuffling data indices.
"""
random_state = np.random.RandomState(seed=seed)
assert num_samples <= TOTAL_CIFAR10H_DATASET_SIZE
if (not keep_hard_samples) or (num_samples == TOTAL_CIFAR10H_DATASET_SIZE):
indices = random_state.permutation(true_label_counts.shape[0])
return indices[:num_samples]
# Identify difficult samples and keep them in the dataset
hard_sample_indices = get_ambiguous_sample_ids(true_label_counts)
if hard_sample_indices.shape[0] > num_samples:
logging.info(f"Total number of hard samples: {hard_sample_indices.shape[0]} and requested: {num_samples}")
hard_sample_indices = hard_sample_indices[:num_samples]
num_hard_samples = hard_sample_indices.shape[0]
# Sample the remaining indices randomly and aggregate
remaining_indices = np.setdiff1d(range(TOTAL_CIFAR10H_DATASET_SIZE), hard_sample_indices)
easy_sample_indices = random_state.choice(remaining_indices, num_samples - num_hard_samples, replace=False)
indices = np.concatenate([hard_sample_indices, easy_sample_indices], axis=0)
random_state.shuffle(indices)
# Assert that there are no repeated indices
_, _counts = np.unique(indices, return_counts=True)
assert not np.any(_counts > 1)
return indices

Просмотреть файл

@ -0,0 +1,116 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union
import PIL
import numpy as np
import pandas as pd
import pydicom as dicom
from PIL import Image
from torch.utils.data import Dataset
KAGGLE_TOTAL_SIZE = 26684
class KaggleCXR(Dataset):
def __init__(self,
data_directory: str,
use_training_split: bool,
train_fraction: float = 0.8,
seed: int = 1234,
shuffle: bool = True,
transform: Optional[Callable] = None,
num_samples: int = None,
return_index: bool = True) -> None:
"""
Class for the full Kaggle RSNA Pneumonia Detection Dataset.
:param data_directory: the directory containing all training images from the Challenge (stage 1) as well as the
dataset.csv containing the kaggle and the original labels.
:param use_training_split: whether to return the training or the validation split of the dataset.
:param train_fraction: the proportion of samples to use for training
:param seed: random seed to use for dataset creation
:param shuffle: whether to shuffle the dataset prior to spliting between validation and training
:param transform: a preprocessing function that takes a PIL image as input and returns a tensor
:param num_samples: number of the samples to return (has to been smaller than the dataset split)
"""
self.data_directory = Path(data_directory)
if not self.data_directory.exists():
logging.error(
f"The data directory {self.data_directory} does not exist. Make sure to download to Kaggle data "
f"first.The kaggle dataset can "
"be acceded via the Kaggle CLI kaggle competitions download -c rsna-pneumonia-detection-challenge or "
"on the main page of the challenge "
"https://www.kaggle.com/c/rsna-pneumonia-detection-challenge/data?select=stage_2_train_images")
self.train = use_training_split
self.train_fraction = train_fraction
self.seed = seed
self.random_state = np.random.RandomState(seed)
full_dataset = pd.read_csv(self.data_directory / "dataset.csv")
self.dataset_dataframe = full_dataset
self.transforms = transform
targets = self.dataset_dataframe.label.values.astype(np.int64)
subjects_ids = self.dataset_dataframe.subject.values
self.num_classes = 2
self.num_datapoints = len(self.dataset_dataframe)
all_indices = np.arange(len(self.dataset_dataframe))
# ------------- Split the data into training and validation sets ------------- #
num_samples_set1 = int(self.num_datapoints * self.train_fraction)
sampled_indices = self.random_state.permutation(all_indices) \
if shuffle else all_indices
train_indices = sampled_indices[:num_samples_set1]
val_indices = sampled_indices[num_samples_set1:]
self.indices = train_indices if use_training_split else val_indices
# ------------- Select subset of current split ------------- #
if num_samples is not None:
assert 0 < num_samples <= len(self.indices)
self.indices = self.indices[:num_samples]
self.subject_ids = subjects_ids[self.indices]
self.targets = targets[self.indices].reshape(-1)
dataset_type = "TRAIN" if use_training_split else "VAL"
logging.info(f"Proportion of positive labels - {dataset_type}: {np.mean(self.targets)}")
logging.info(f"Number samples - {dataset_type}: {self.targets.shape[0]}")
self.return_index = return_index
self.weight = np.mean(self.targets)
logging.info(f"Weight negative {self.weight:.2f} - weight positive {(1 - self.weight):.2f}")
def __getitem__(self, index: int) -> Union[Tuple[int, PIL.Image.Image, int], Tuple[PIL.Image.Image, int]]:
"""
:param index: The index of the sample to be fetched
:return: The image and label tensors
"""
subject_id = self.subject_ids[index]
filename = self.data_directory / f"{subject_id}.dcm"
target = self.targets[index]
scan_image = dicom.dcmread(filename).pixel_array
scan_image = Image.fromarray(scan_image)
if self.transforms is not None:
scan_image = self.transforms(scan_image)
if self.return_index:
return index, scan_image, int(target)
return scan_image, int(target)
def __len__(self) -> int:
"""
:return: The size of the dataset
"""
return len(self.indices)
def get_label_names(self) -> List[str]:
return ["Normal", "Opacity"]

Просмотреть файл

@ -0,0 +1,90 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import numpy as np
class LabelDistribution(object):
"""
LabelDistribution class handles sampling from a label distribution with reproducible behavior given a seed
"""
def __init__(self,
seed: int,
label_counts: np.ndarray,
temperature: float = 1.0,
offset: float = 0.0) -> None:
"""
:param seed: The random seed used to ensure reproducible behaviour
:param label_counts: An array of shape (num_samples, num_classes) where each entry represents the number of
labels available for each sample and class
:param temperature: A temperature a value that will be used to temperature scale the distribution, default is
1.0 which is equivalent to no scaling; temperature must be greater than 0.0, values between 0 and 1 will result
in a sharper distribution and values greater than 1 in a more uniform distribution over classes.
:param offset: Offset parameter to control the noise rate in sampling initial labels.
All classes are assigned a uniform fixed offset amount.
"""
assert label_counts.dtype == np.int64
assert label_counts.ndim == 2
self.num_classes = label_counts.shape[1]
self.num_samples = label_counts.shape[0]
self.seed = seed
self.random_state = np.random.RandomState(seed)
self.label_counts = label_counts
self.temperature = temperature
# make the distribution
self.distribution = label_counts / np.sum(label_counts, axis=1, keepdims=True)
assert np.isfinite(self.distribution).all()
assert self.temperature > 0
# scale distribution based on temperature and offset
_d = np.power(self.distribution, 1. / temperature)
_d = _d / np.sum(_d, axis=1, keepdims=True)
self.distribution_temp_scaled = self.add_noise_to_distribution(offset, _d, 'asym') if offset > 0.0 else _d
# check if there are multiple labels per data point
self.is_multi_label_per_sample = np.all(np.sum(label_counts, axis=1) > 1.0)
def sample_initial_labels_for_all(self) -> np.ndarray:
"""
Sample one label for each sample in the dataset according to its label distribution
:return: None
"""
if not self.is_multi_label_per_sample:
RuntimeWarning("Sampling labels from one-hot encoded distribution - Multi labels are not available")
sampling_fn = lambda p: self.random_state.choice(self.num_classes, 1, p=p)
return np.squeeze(np.apply_along_axis(sampling_fn, arr=self.distribution_temp_scaled, axis=1))
def sample(self, sample_idx: int) -> int:
"""
Sample one label for a given sample index
:param sample_idx: The sample index for which the label will be sampled
:return: None
"""
return self.random_state.choice(self.num_classes, 1, p=self.distribution[sample_idx])[0]
def add_noise_to_distribution(self, offset: float, distribution: np.ndarray, noise_model_type: str) -> np.ndarray:
from InnerEyeDataQuality.datasets.label_noise_model import get_cifar10_asym_noise_model
from InnerEyeDataQuality.datasets.label_noise_model import get_cifar10_sym_noise_model
# Create noise model
if noise_model_type == 'sym':
noise_model = get_cifar10_sym_noise_model(eta=1.0)
elif noise_model_type == 'asym':
noise_model = get_cifar10_asym_noise_model(eta=1.0)
else:
raise ValueError("Unknown noise model type")
np.fill_diagonal(noise_model, 0.0)
noise_model *= offset
# Add this noise on top of every sample in the dataset
for _ii in range(self.num_samples):
true_label = np.argmax(self.label_counts[_ii])
distribution[_ii] += noise_model[true_label]
# Normalise the distribution
return distribution / np.sum(distribution, axis=1, keepdims=True)

Просмотреть файл

@ -0,0 +1,86 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import numpy as np
from default_paths import CIFAR10_ROOT_DIR
from InnerEyeDataQuality.datasets.cifar10h import CIFAR10H
from InnerEyeDataQuality.selection.simulation_statistics import get_ambiguous_sample_ids
from sklearn.metrics import confusion_matrix
def get_cifar10h_confusion_matrix(temperature: float = 1.0, only_difficult_cases: bool = False) -> np.ndarray:
"""
Generates a class confusion matrix based on the label distribution in CIFAR10H.
"""
cifar10h_labels = CIFAR10H.download_cifar10h_labels(str(CIFAR10_ROOT_DIR))
if only_difficult_cases:
ambiguous_sample_ids = get_ambiguous_sample_ids(cifar10h_labels)
cifar10h_labels = cifar10h_labels[ambiguous_sample_ids, :]
# Temperature scale the original distribution
if temperature > 1.0:
orig_distribution = cifar10h_labels / np.sum(cifar10h_labels, axis=1, keepdims=True)
_d = np.power(orig_distribution, 1. / temperature)
scaled_distribution = _d / np.sum(_d, axis=1, keepdims=True)
sample_counts = (scaled_distribution * np.sum(cifar10h_labels, axis=1, keepdims=True)).astype(np.int64)
else:
sample_counts = cifar10h_labels
y_pred, y_true = list(), list()
for image_index in range(sample_counts.shape[0]):
image_label_counts = sample_counts[image_index]
for _iter, _el in enumerate(image_label_counts.tolist()):
y_pred.extend([_iter] * _el)
y_true.extend([np.argmax(image_label_counts)] * np.sum(image_label_counts))
cm = confusion_matrix(y_true, y_pred, normalize="true")
return cm
def get_cifar10_asym_noise_model(eta: float = 0.3) -> np.ndarray:
"""
CLASS-DEPENDENT ASYMMETRIC LABEL NOISE
https://proceedings.neurips.cc/paper/2018/file/f2925f97bc13ad2852a7a551802feea0-Paper.pdf
TRUCK -> AUTOMOBILE, BIRD -> AIRPLANE, DEER -> HORSE, CAT -> DOG, and DOG -> CAT
:param eta: The likelihood of true label switching from one of the specified classes to nearest class.
In other words, likelihood of introducing a class-dependent label noise
"""
# Generate a noise transition matrix.
assert (0.0 <= eta) and (eta <= 1.0)
eps = 1e-12
num_classes = 10
conf_mat = np.eye(N=num_classes)
indices = [[2, 0], [9, 1], [5, 3], [3, 5], [4, 7]]
for ind in indices:
conf_mat[ind[0], ind[1]] = eta / (1.0 - eta + eps)
return conf_mat / np.sum(conf_mat, axis=1, keepdims=True)
def get_cifar10_sym_noise_model(eta: float = 0.3) -> np.ndarray:
"""
Symmetric LABEL NOISE
:param eta: The likelihood of true label switching from true class to rest of the classes.
"""
# Generate a noise transition matrix.
assert (0.0 <= eta) and (eta <= 1.0)
assert isinstance(eta, float)
num_classes = 10
conf_mat = np.eye(N=num_classes)
for ind in range(num_classes):
conf_mat[ind, ind] -= eta
other_classes = np.setdiff1d(range(num_classes), ind)
for o_c in other_classes:
conf_mat[ind, o_c] += eta / other_classes.size
assert np.all(np.abs(np.sum(conf_mat, axis=1) - 1.0) < 1e-9)
return conf_mat

Просмотреть файл

@ -0,0 +1,106 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from pathlib import Path
from typing import Callable, List, Optional, Tuple, Dict, Union
import PIL
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
NIH_TOTAL_SIZE = 112120
class NIHCXR(Dataset):
def __init__(self,
data_directory: str,
use_training_split: bool,
seed: int = 1234,
shuffle: bool = True,
transform: Optional[Callable] = None,
num_samples: int = None,
return_index: bool = True) -> None:
"""
Class for the full NIH ChestXray Dataset (112k images)
:param data_directory: the directory containing all training images from the dataset as well as the
Data_Entry_2017.csv file containing the dataset labels.
:param use_training_split: whether to return the training or the test split of the dataset.
:param seed: random seed to use for dataset creation
:param shuffle: whether to shuffle the dataset prior to spliting between validation and training
:param transform: a preprocessing function that takes a PIL image as input and returns a tensor
:param num_samples: number of the samples to return (has to been smaller than the dataset split)
"""
self.data_directory = Path(data_directory)
if not self.data_directory.exists():
logging.error(
f"The data directory {self.data_directory} does not exist. Make sure to download the NIH data "
f"first.The dataset can on the main page"
"https://www.kaggle.com/nih-chest-xrays/data. Make sure all images are placed directly under the "
"data_directory folder. Make sure you downloaded the Data_Entry_2017.csv file to this directory as"
"well.")
self.train = use_training_split
self.seed = seed
self.random_state = np.random.RandomState(seed)
self.dataset_dataframe = pd.read_csv(self.data_directory / "Data_Entry_2017.csv")
self.dataset_dataframe["pneumonia_like"] = self.dataset_dataframe["Finding Labels"].apply(
lambda x: x.split("|")).apply(lambda x: "pneumonia" in x.lower()
or "infiltration" in x.lower()
or "consolidation" in x.lower())
self.transforms = transform
orig_labels = self.dataset_dataframe.pneumonia_like.values.astype(np.int64)
subjects_ids = self.dataset_dataframe["Image Index"].values
is_train_ids = self.dataset_dataframe["train"].values
self.num_classes = 2
self.indices = np.where(is_train_ids)[0] if use_training_split else np.where(~is_train_ids)[0]
self.indices = self.random_state.permutation(self.indices) \
if shuffle else self.indices
# ------------- Select subset of current split ------------- #
if num_samples is not None:
assert 0 < num_samples <= len(self.indices)
self.indices = self.indices[:num_samples]
self.subject_ids = subjects_ids[self.indices]
self.orig_labels = orig_labels[self.indices].reshape(-1)
self.targets = self.orig_labels
# Identify case ids for ambiguous and clear label noise cases
self.ambiguity_metric_args: Dict = dict()
dataset_type = "TRAIN" if use_training_split else "VAL"
logging.info(f"Proportion of positive labels - {dataset_type}: {np.mean(self.targets)}")
logging.info(f"Number samples - {dataset_type}: {self.targets.shape[0]}")
self.return_index = return_index
def __getitem__(self, index: int) -> Union[Tuple[int, PIL.Image.Image, int], Tuple[PIL.Image.Image, int]]:
"""
:param index: The index of the sample to be fetched
:return: The image and label tensors
"""
subject_id = self.subject_ids[index]
filename = self.data_directory / f"{subject_id}"
target = self.targets[index]
scan_image = Image.open(filename).convert("L")
if self.transforms is not None:
scan_image = self.transforms(scan_image)
if self.return_index:
return index, scan_image, int(target)
return scan_image, int(target)
def __len__(self) -> int:
"""
:return: The size of the dataset
"""
return len(self.indices)
def get_label_names(self) -> List[str]:
return ["NotPneunomiaLike", "PneunomiaLike"]

Просмотреть файл

@ -0,0 +1,112 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import json
from pathlib import Path
import pandas as pd
import numpy as np
def create_nih_dataframe(mapping_file_path: Path) -> pd.DataFrame:
"""
This function loads the json file mapping NIH ids to Kaggle images.
Loads the original NIH label (multiple labels for each image).
Then it creates the grouping by NIH categories (pneumonia, pneumonia like,
other disease, no finding).
:param mapping_file_path: path to the json mapping from NIH to Kaggle dataset (on the
RSNA webpage)
:return: dataframe with original NIH labels for each patient in the Kaggle dataset.
"""
with open(mapping_file_path) as f:
list_subjects = json.load(f)
orig_dataset = pd.DataFrame(columns=["subject", "orig_label"])
orig_dataset["subject"] = [l["subset_img_id"] for l in list_subjects] # noqa: E741
orig_labels = [str(l['orig_labels']).lower() for l in list_subjects] # noqa: E741
orig_dataset["nih_pneumonia"] = ["pneumonia" in l for l in orig_labels] # noqa: E741
orig_dataset["nih_pneumonia_like"] = [(("infiltration" in l or "consolidation" in l) and ("pneumonia" not in l)) for
l in orig_labels] # noqa: E741
orig_dataset["no_finding"] = ["no finding" in str(l).lower() for l in orig_labels] # noqa: E741
orig_dataset["orig_label"] = orig_labels
orig_dataset["orig_label"].apply(lambda x: sorted(x))
orig_dataset[
"nih_other_disease"] = ~orig_dataset.nih_pneumonia_like & ~orig_dataset.nih_pneumonia & ~orig_dataset.no_finding
orig_dataset[
"nih_category"] = 1 * orig_dataset.nih_pneumonia + 2 * orig_dataset.no_finding + 3 * \
orig_dataset.nih_other_disease
orig_dataset["StudyInstanceUID"] = [l["StudyInstanceUID"] for l in list_subjects] # noqa: E741
orig_dataset.nih_category = orig_dataset.nih_category.apply(lambda x: ["Consolidation/Infiltration", "Pneumonia",
"No finding", "Other disease"][x])
return orig_dataset
def process_detailed_probs_dataset(detailed_probs_path: Path) -> pd.DataFrame:
"""
This function loads the csv file with the detailed information for each bounding boxes as annotated
by the readers during the adjudication for the preparation of the challenge. It maps low, medium and high
probabilities label to a numerical scale from 1 to 3. Computes the minimum, maximum, average confidence for each
patient for which at least one bounding box was present.
:param detailed_probs_path: path to detailed_probs csv file released in Kaggle challenge.
:return: dataframe with metrics for confidence in bounding boxes by patient.
"""
conversion_map = {"Lung Opacity (Low Prob)": 1, "Lung Opacity (Med Prob)": 2, "Lung Opacity (High Prob)": 3}
detailed_probs_dataset = pd.read_csv(detailed_probs_path)
detailed_probs_dataset["ClassProb"] = detailed_probs_dataset["labelName"].apply(lambda x: conversion_map[x])
process_details = detailed_probs_dataset.groupby("StudyInstanceUID")["ClassProb"].agg(
[np.mean, np.min, np.max, np.count_nonzero, list])
process_details.rename(columns={"mean": "avg_conf_score", "amin": "min_conf_score", "amax": "max_conf_score"},
inplace=True)
return process_details
def create_mapping_dataset_nih(mapping_file_path: Path,
kaggle_dataset_path: Path,
detailed_class_info_path: Path,
detailed_probs_path: Path) -> pd.DataFrame:
"""
Creates the final chest x-ray dataset combining labels from NIH, kaggle and detailed information about kaggle
labels from the detailed_class_info and detailed_probs csv file released during the challenge.
:param mapping_file_path:
:param kaggle_dataset_path:
:param detailed_class_info_path:
:param detailed_probs_path:
:return: detailed dataset
"""
orig_dataset = create_nih_dataframe(mapping_file_path)
kaggle_dataset = pd.read_csv(kaggle_dataset_path)
difficulty = pd.read_csv(detailed_class_info_path).drop_duplicates()
detailed_probs_dataset = process_detailed_probs_dataset(detailed_probs_path)
# Merge NIH info with Kaggle dataset
merged = pd.merge(orig_dataset, kaggle_dataset)
merged.rename(columns={"label": "label_kaggle"}, inplace=True)
# Define binary label from original NIH label, for consolidation/infiltration
# mapping is not clear, use kaggle label.
merged.loc[merged.nih_pneumonia, "binary_nih_initial_label"] = True
merged.loc[merged.no_finding | merged.nih_other_disease, "binary_nih_initial_label"] = False
merged.loc[merged.nih_pneumonia_like, "binary_nih_initial_label"] = merged.loc[
merged.nih_pneumonia_like, "label_kaggle"]
# Add subclass information from Kaggle challenge to define ambiguous cases
merged = pd.merge(merged, difficulty, left_on="subject", right_on="patientId")
merged["not_normal"] = (merged["class"] == "No Lung Opacity / Not Normal")
# Add difficulty information from Kaggle based on presence of bounding boxes
merged = pd.merge(merged, detailed_probs_dataset, on="StudyInstanceUID", how="left")
merged.drop(columns=["patientId"], inplace=True)
merged.fillna(-1, inplace=True)
# Ambiguous if there was only low probability boxes and the adjudicated label is true
merged.loc[merged.label_kaggle, ["ambiguous"]] = (merged.loc[merged.label_kaggle].min_conf_score == 1) & (
merged.loc[merged.label_kaggle].max_conf_score == 1)
# Ambiguous if there was some bounding boxes but adjudicated label is false
merged.loc[~merged.label_kaggle, ["ambiguous"]] = merged.loc[~merged.label_kaggle].min_conf_score > -1
return merged
if __name__ == "__main__":
from default_paths import INNEREYE_DQ_DIR
current_dir = Path(__file__).parent
mapping_file = Path(__file__).parent / "pneumonia-challenge-dataset-mappings_2018.json"
kaggle_dataset_path = current_dir / "stage_2_train_labels.csv"
detailed_class_info = current_dir / "stage_2_detailed_class_info.csv"
detailed_probs = current_dir / "RSNA_pneumonia_all_probs.csv"
dataset = create_mapping_dataset_nih(mapping_file, kaggle_dataset_path, detailed_class_info, detailed_probs)
dataset.to_csv(INNEREYE_DQ_DIR / "datasets" / "noisy_chestxray_dataset.csv", index=False)

Просмотреть файл

@ -0,0 +1,183 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from pathlib import Path
from typing import Callable, List, Optional, Tuple
import PIL
import numpy as np
import pandas as pd
import pydicom as dicom
from PIL import Image
from torch.utils.data import Dataset
from InnerEyeDataQuality.datasets.label_distribution import LabelDistribution
from InnerEyeDataQuality.selection.simulation_statistics import SimulationStats
from InnerEyeDataQuality.utils.generic import convert_labels_to_one_hot
from InnerEyeDataQuality.evaluation.metrics import compute_label_entropy
class NoisyKaggleSubsetCXR(Dataset):
def __init__(self, data_directory: str,
use_training_split: bool,
consolidation_noise_rate: float,
train_fraction: float = 0.5,
seed: int = 1234,
shuffle: bool = True,
transform: Optional[Callable] = None,
num_samples: Optional[int] = None,
use_noisy_fixed_labels: bool = True) -> None:
"""
Class for the noisy Kaggle RSNA Pneumonia Detection Dataset. This dataset uses the kaggle dataset with noisy
labels
as the original labels from RSNA and the clean labels are the Kaggle labels.
:param data_directory: the directory containing all training images from the Challenge (stage 1) as well as the
dataset.csv containing the kaggle and the original labels.
:param use_training_split: whether to return the training or the validation split of the dataset.
:param train_fraction: the proportion of samples to use for training
:param seed: random seed to use for dataset creation
:param shuffle: whether to shuffle the dataset prior to spliting between validation and training
:param transform: a preprocessing function that takes a PIL image as input and returns a tensor
:param num_samples: number of the samples to return (has to been smaller than the dataset split)
:param use_noisy_fixed_labels: if True use the original labels as the initial labels else use the clean labels.
:param consolidation_noise_rate: proportion of noisy samples among consolidation/infiltration NIH category.
"""
dataset_type = "TRAIN" if use_training_split else "VAL"
self.data_directory = Path(data_directory)
if not self.data_directory.exists():
raise RuntimeError(
f"The data directory {self.data_directory} does not exist. Make sure to download to Kaggle data "
f"first.The kaggle dataset can "
"be acceded via the Kaggle CLI kaggle competitions download -c rsna-pneumonia-detection-challenge or "
"on the main page of the challenge "
"https://www.kaggle.com/c/rsna-pneumonia-detection-challenge/data?select=stage_2_train_images")
path_to_noisy_csv = Path(__file__).parent / "noisy_chestxray_dataset.csv"
if not path_to_noisy_csv.exists():
raise RuntimeError(f"The noisy dataset csv can not be found in {path_to_noisy_csv}, make sure to run "
"create_noisy_chestxray_dataset.py first. See readme for more detailed instructions on the pre-requisite"
" for running the noisy Chest Xray benchmark.")
self.train = use_training_split
self.train_fraction = train_fraction
self.seed = seed
self.random_state = np.random.RandomState(seed)
self.dataset_dataframe = pd.read_csv(str(Path(__file__).parent / "noisy_chestxray_dataset.csv"))
self.transforms = transform
self.dataset_dataframe["initial_label"] = self.dataset_dataframe.binary_nih_initial_label
# Random uniform noise among consolidation
pneumonia_like_subj = self.dataset_dataframe.loc[self.dataset_dataframe.nih_pneumonia_like, "subject"].values
selected = self.random_state.choice(pneumonia_like_subj,
replace=False,
size=int(len(pneumonia_like_subj) * consolidation_noise_rate))
self.dataset_dataframe.loc[self.dataset_dataframe.subject.isin(selected), "initial_label"] = \
~self.dataset_dataframe.loc[self.dataset_dataframe.subject.isin(selected), "initial_label"]
self.dataset_dataframe["is_noisy"] = self.dataset_dataframe.label_kaggle != self.dataset_dataframe.initial_label
initial_labels = self.dataset_dataframe.initial_label.values.astype(np.int64).reshape(-1, 1)
kaggle_labels = self.dataset_dataframe.label_kaggle.values.astype(np.int64).reshape(-1, 1)
subjects_ids = self.dataset_dataframe.subject.values
is_ambiguous = self.dataset_dataframe.ambiguous.values
orig_label = self.dataset_dataframe.orig_label.values
nih_category = self.dataset_dataframe.nih_category.values
# Convert clean labels to one-hot to populate label counts
# i.e for easy cases assume the true distribution is 100% ground truth
kaggle_label_counts = convert_labels_to_one_hot(kaggle_labels, n_classes=2)
# For ambiguous cases: [0, 1] -> [1, 2] and [1, 0] -> [2, 1]
kaggle_label_counts[is_ambiguous, :] = kaggle_label_counts[is_ambiguous, :] * 2 + 1
_, self.num_classes = kaggle_label_counts.shape
assert self.num_classes == 2
# ------------- Split the data into training and validation sets ------------- #
self.num_datapoints = len(self.dataset_dataframe)
all_indices = np.arange(self.num_datapoints)
num_samples_set1 = int(self.num_datapoints * self.train_fraction)
all_indices = self.random_state.permutation(all_indices) \
if shuffle else all_indices
train_indices = all_indices[:num_samples_set1]
val_indices = all_indices[num_samples_set1:]
self.indices = train_indices if use_training_split else val_indices
# ------------- Select subset of current split ------------- #
# If n_samples is set to restrict dataset i.e. for data_curation
num_samples = self.num_datapoints if num_samples is None else num_samples
if num_samples < self.num_datapoints:
assert 0 < num_samples <= len(self.indices)
self.indices = self.indices[:num_samples]
# ------------ Finalize dataset --------------- #
self.subject_ids = subjects_ids[self.indices]
# Label distribution is constructed from the true labels
self.label_counts = kaggle_label_counts[self.indices]
self.label_distribution = LabelDistribution(seed, self.label_counts)
self.initial_labels = initial_labels[self.indices].reshape(-1)
self.kaggle_labels = kaggle_labels[self.indices].reshape(-1)
self.targets = self.initial_labels if use_noisy_fixed_labels else self.kaggle_labels
self.orig_labels = orig_label[self.indices]
self.is_ambiguous = is_ambiguous[self.indices]
self.nih_category = nih_category[self.indices]
# Identify case ids for ambiguous and clear label noise cases
label_stats = SimulationStats(name="NoisyChestXray", true_label_counts=self.label_counts,
initial_labels=convert_labels_to_one_hot(self.targets, self.num_classes))
self.clear_mislabeled_cases = label_stats.mislabelled_not_ambiguous_sample_ids[0]
self.ambiguous_mislabelled_cases = label_stats.mislabelled_ambiguous_sample_ids[0]
self.true_label_entropy = compute_label_entropy(label_counts=self.label_counts)
self.ambiguity_metric_args = {"ambiguous_mislabelled_ids": self.ambiguous_mislabelled_cases,
"clear_mislabelled_ids": self.clear_mislabeled_cases,
"true_label_entropy": self.true_label_entropy}
self.num_samples = self.targets.shape[0]
logging.info(self.num_samples)
logging.info(len(self.targets))
logging.info(len(self.indices))
logging.info(f"Proportion of positive clean labels - {dataset_type}: {np.mean(self.kaggle_labels)}")
logging.info(f"Proportion of positive noisy labels - {dataset_type}: {np.mean(self.targets)}")
logging.info(
f"Total noise rate on the {dataset_type} dataset: {np.mean(self.kaggle_labels != self.targets)} \n")
selected_df = self.dataset_dataframe.loc[self.dataset_dataframe.subject.isin(self.subject_ids)]
noisy_df = selected_df.loc[selected_df.is_noisy]
noisy_df["nih_noise"] = ~noisy_df.nih_pneumonia_like
logging.info(f"\n{pd.crosstab(noisy_df.nih_noise, noisy_df.ambiguous).to_string()}")
# self.weight = np.mean(self.kaggle_labels)
# logging.info(f"Weight negative {self.weight:.2f} - weight positive {(1 - self.weight):.2f}")
self.png_files = (self.data_directory / f"{self.subject_ids[0]}.png").exists()
def __getitem__(self, index: int) -> Tuple[int, PIL.Image.Image, int]:
"""
:param index: The index of the sample to be fetched
:return: The image and label tensors
"""
subject_id = self.subject_ids[index]
target = self.targets[index]
if self.png_files:
filename = self.data_directory / f"{subject_id}.png"
scan_image = Image.open(filename)
else:
filename = self.data_directory / f"{subject_id}.dcm"
scan_image = dicom.dcmread(filename).pixel_array
scan_image = Image.fromarray(scan_image)
if self.transforms is not None:
scan_image = self.transforms(scan_image)
if scan_image.shape == 2:
scan_image = scan_image.unsqueeze(dim=0)
return index, scan_image, int(target)
def __len__(self) -> int:
"""
:return: The size of the dataset
"""
return len(self.subject_ids)
def get_label_names(self) -> List[str]:
return ["Normal", "Opacity"]

Просмотреть файл

@ -0,0 +1,45 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any, List, Union
import numpy as np
import torch
from math import inf
from scipy import stats
import torch.nn.functional as F
def get_instance_noise_model(n: float, dataset: Any, labels: Union[List, torch.Tensor], num_classes: int,
feature_size: int, norm_std: float, seed: int) -> np.ndarray:
"""
:param n: noise_rate
:param dataset: cifar10 # not train_loader
:param labels: labels (targets)
:param num_classes: class number
:param feature_size: the size of input images (e.g. 28*28)
:param norm_std: default 0.1
:param seed: random_seed
"""
if isinstance(labels, list):
labels = torch.FloatTensor(labels)
P = []
random_state = np.random.RandomState(seed)
flip_distribution = stats.truncnorm((0 - n) / norm_std, (1 - n) / norm_std, loc=n, scale=norm_std)
flip_rate = flip_distribution.rvs(labels.shape[0], random_state=seed)
W = random_state.randn(num_classes, feature_size, num_classes)
W = torch.FloatTensor(W)
for i, (x, y) in enumerate(dataset):
# (1 x M) * (M x 10) = (1 x 10)
A = x.view(1, -1).mm(W[y]).squeeze(0)
A[y] = -inf
A = flip_rate[i] * F.softmax(A, dim=0)
A[y] += 1 - flip_rate[i]
P.append(A)
P = torch.stack(P, 0).numpy()
return P

Просмотреть файл

@ -0,0 +1,5 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

Просмотреть файл

@ -0,0 +1,26 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import torch
import torchvision
from torch import nn
from InnerEyeDataQuality.configs.config_node import ConfigNode
class Network(nn.Module):
"""
DenseNet121 as implement in torchvision
"""
def __init__(self, config: ConfigNode) -> None:
super().__init__()
self.densenet121 = torchvision.models.densenet121(pretrained=config.train.pretrained, progress=False)
num_ftrs = self.densenet121.classifier.in_features
self.densenet121.classifier = nn.Linear(num_ftrs, config.dataset.n_classes)
self.projection = self.densenet121.classifier
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.densenet121(x)
return x

Просмотреть файл

@ -0,0 +1,26 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import torch
import torchvision
from torch import nn
from InnerEyeDataQuality.configs.config_node import ConfigNode
class Network(nn.Module):
"""
Resnet50 as implement in torchvision
"""
def __init__(self, config: ConfigNode) -> None:
super().__init__()
self.resnet = torchvision.models.resnet50(pretrained=config.train.pretrained, progress=False)
num_ftrs = self.resnet.fc.in_features
self.resnet.fc = nn.Linear(num_ftrs, config.dataset.n_classes)
self.projection = self.resnet.fc
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.resnet(x)
return x

Просмотреть файл

@ -0,0 +1,117 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from typing import Dict
import numpy as np
import torch
class EMA:
"""Exponential moving average of model parameters.
Args:
model (torch.nn.Module): Model with parameters whose EMA will be kept.
decay (float): Decay rate for exponential moving average.
step_max (int): Maximum required number of steps to reach specified decay rate from zero.
In the initial epochs, decay rate is kept small to keep the teacher model up-to-date.
"""
def __init__(self, model: torch.nn.Module, decay: float = 0.99, step_max: int = 150) -> None:
self.decay_max = decay
self.step_max = step_max
self.step_count = int(0)
self.shadow = {}
self.original: Dict[str, torch.Tensor] = {}
self.model = model # reference to the student model
# Register model parameters
for name, param in self.model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
logging.info("Creating a teacher model (EMA)")
def update(self) -> None:
"""
Receives a new set of parameter values and merges them with the previously stored ones.
"""
self.step_count += int(1)
decay = self._get_decay_rate(self.step_count)
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
new_average = (1.0 - decay) * param.data + decay * self.shadow[name]
self.shadow[name] = new_average.clone()
def _get_decay_rate(self, step: int) -> float:
"""
Return decay rate for current stored parameters
"""
if step <= self.step_max:
ratio = step / self.step_max
return 0.5 * self.decay_max * (1 - np.cos(ratio * np.pi))
else:
return self.decay_max
@torch.no_grad()
def inference(self, inputs: torch.Tensor) -> torch.Tensor:
"""
Runs an inference using the template of student model
"""
is_train = self.model.training
self._assign()
self.model.eval()
self._switch_hooks(False)
outputs = self.model.forward(inputs).detach()
self._switch_hooks(True)
self._restore()
self.model.train() if is_train else self.model.eval()
return outputs
def _switch_hooks(self, bool_value: bool) -> None:
for layer in self.model.children():
if hasattr(layer, "use_hook"):
layer.use_hook = bool_value # type: ignore
def _assign(self) -> None:
"""Assign exponential moving average of parameter values to the
respective parameters.
Args:
model (torch.nn.Module): Model to assign parameter values.
"""
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
self.original[name] = param.data.clone()
param.data = self.shadow[name]
def _restore(self) -> None:
"""Restore original parameters to a model. That is, put back
the values that were in each parameter at the last call to `assign`.
Args:
model (torch.nn.Module): Model to assign parameter values.
"""
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
param.data = self.original[name]
def save_model(self, save_path: str) -> None:
"""
Saves EMA model parameters to a checkpoint file.
"""
self._assign()
state = {'ema_model': self.model.state_dict(),
'ema_params': self.shadow,
'step_count': self.step_count}
torch.save(state, save_path)
self._restore()
def restore_from_checkpoint(self, path: str) -> None:
state = torch.load(path)
self.shadow = state['ema_params']
self.step_count = state['step_count']

Просмотреть файл

@ -0,0 +1,22 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Callable
import torch
import torch.nn as nn
class Lambda(nn.Module):
"""
Lambda torch nn module that can be used to modularise nn.functional methods.
"""
def __init__(self, fn: Callable) -> None:
super(Lambda, self).__init__()
self.lambda_func = fn
def forward(self, input: torch.Tensor) -> torch.Tensor:
return self.lambda_func(input)

Просмотреть файл

@ -0,0 +1,84 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import List, Optional, Union
import numpy as np
import torch
from InnerEyeDataQuality.deep_learning.self_supervised.ssl_classifier_module import PretrainedClassifier, SSLClassifier
class CollectCNNEmbeddings:
"""
This class takes care of registering a forward hook to get the embeddings for a given model.
"""
def __init__(self, use_only_in_train: bool, store_input_tensor: bool) -> None:
"""
:param use_only_in_train: If set to True, hooks are registered only for model forward passes in training mode.
:param store_input_tensor:
"""
self.inputs: list = list()
self.layer: Optional[torch.nn.Module] = None
self.use_only_in_train = use_only_in_train
self.store_input_tensor = store_input_tensor
def __call__(self,
module: torch.nn.Module,
module_in: List[torch.Tensor],
module_out: torch.Tensor) -> None:
# If model is in validation state and only training time collection is allowed, then exit.
if self.use_only_in_train and not module.training:
return
if module.use_hook:
_tensor = module_in[0] if self.store_input_tensor else module_out
self.inputs.append(_tensor.detach().cpu())
def return_embeddings(self, return_numpy: int = True) -> Union[np.ndarray, torch.Tensor, None]:
if len(self.inputs) == 0:
return None
embeddings = torch.cat(self.inputs, dim=0)
return embeddings.cpu().numpy() if return_numpy else embeddings
def reset(self) -> None:
self.inputs = list()
def register_embeddings_collector(models: List[torch.nn.Module],
use_only_in_train: bool = False) -> List[CollectCNNEmbeddings]:
"""
Takes a list of models and register a foward hook for each model
:param models: Torch module
:param use_only_in_train: If set to True, hooks are registered only for model forward passes in training mode.
"""
assert(isinstance(use_only_in_train, bool))
all_model_cnn_embeddings = []
for model in models:
store_input_tensor = False if isinstance(model, SSLClassifier) else True
cnn_embeddings = CollectCNNEmbeddings(use_only_in_train, store_input_tensor)
if hasattr(model, "projection"):
cnn_embeddings.layer = model.projection # type: ignore
elif hasattr(model, "resnet"):
cnn_embeddings.layer = model.resnet.fc # type: ignore
elif isinstance(model, PretrainedClassifier):
cnn_embeddings.layer = model.classifier_head
else:
cnn_embeddings.layer = model.fc # type: ignore
cnn_embeddings.layer.use_hook = True # type: ignore
cnn_embeddings.layer.register_forward_hook(cnn_embeddings) # type: ignore
all_model_cnn_embeddings.append(cnn_embeddings)
return all_model_cnn_embeddings
def get_all_embeddings(embeddings_collectors: List[CollectCNNEmbeddings]) -> List[np.ndarray]:
"""
Returns all embeddings from a list of embeddings collectors and resets the list
"""
output = list()
for cnn_embeddings in embeddings_collectors:
output.append(cnn_embeddings.return_embeddings(return_numpy=True))
cnn_embeddings.reset()
return output

Просмотреть файл

@ -0,0 +1,83 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any, Callable, List, Tuple
import numpy as np
import torchvision
from torchvision.transforms import ToTensor
from InnerEyeDataQuality.configs.config_node import ConfigNode
from InnerEyeDataQuality.deep_learning.transforms import AddGaussianNoise, CenterCrop, ElasticTransform, \
ExpandChannels, RandomAffine, RandomColorJitter, RandomErasing, RandomGamma, \
RandomHorizontalFlip, RandomResizeCrop, Resize
def _get_dataset_stats(
config: ConfigNode) -> Tuple[np.ndarray, np.ndarray]:
name = config.dataset.name
if name == 'CIFAR10':
mean = np.array([0.4914, 0.4822, 0.4465])
std = np.array([0.2470, 0.2435, 0.2616])
else:
raise ValueError()
return mean, std
def create_transform(config: ConfigNode, is_train: bool) -> Callable:
if config.dataset.name in ["NoisyChestXray", "Kaggle"]:
return create_chest_xray_transform(config, is_train)
elif config.dataset.name in ["CIFAR10", "CIFAR10H", "CIFAR10IDN", "CIFAR10H_TRAIN_VAL", "CIFAR10SYM"]:
return create_cifar_transform(config, is_train)
else:
raise ValueError
def create_cifar_transform(config: ConfigNode,
is_train: bool) -> Callable:
transforms: List[Any] = list()
if is_train:
if config.augmentation.use_random_affine:
transforms.append(RandomAffine(config))
if config.augmentation.use_random_crop:
transforms.append(RandomResizeCrop(config))
if config.augmentation.use_random_horizontal_flip:
transforms.append(RandomHorizontalFlip(config))
if config.augmentation.use_random_color:
transforms.append(RandomColorJitter(config))
transforms += [ToTensor()]
return torchvision.transforms.Compose(transforms)
def create_chest_xray_transform(config: ConfigNode,
is_train: bool) -> Callable:
"""
Defines the image transformations pipeline for Chest-Xray datasets.
"""
transforms: List[Any] = []
if is_train:
if config.augmentation.use_random_affine:
transforms.append(RandomAffine(config))
if config.augmentation.use_random_crop:
transforms.append(RandomResizeCrop(config))
else:
transforms.append(Resize(config))
if config.augmentation.use_random_horizontal_flip:
transforms.append(RandomHorizontalFlip(config))
if config.augmentation.use_gamma_transform:
transforms.append(RandomGamma(config))
if config.augmentation.use_random_color:
transforms.append(RandomColorJitter(config))
if config.augmentation.use_elastic_transform:
transforms.append(ElasticTransform(config))
transforms += [CenterCrop(config), ToTensor()]
if config.augmentation.use_random_erasing:
transforms.append(RandomErasing(config))
if config.augmentation.add_gaussian_noise:
transforms.append(AddGaussianNoise(config))
else:
transforms += [Resize(config), CenterCrop(config), ToTensor()]
transforms.append(ExpandChannels())
return torchvision.transforms.Compose(transforms)

Просмотреть файл

@ -0,0 +1,61 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any
import numpy as np
import torch
import yacs.config
from torch.utils.data import DataLoader
from torchvision.datasets.vision import VisionDataset
def get_number_of_samples_per_epoch(dataloader: torch.utils.data.DataLoader) -> int:
"""
Returns the expected number of samples for a single epoch
"""
total_num_samples = len(dataloader.dataset) # type: ignore
batch_size = dataloader.batch_size
drop_last = dataloader.drop_last
num_samples = int(total_num_samples / batch_size) * batch_size if drop_last else total_num_samples # type:ignore
return num_samples
def get_train_dataloader(train_dataset: VisionDataset,
config: yacs.config.CfgNode,
seed: int,
**kwargs: Any) -> DataLoader:
if config.train.use_balanced_sampler:
counts = np.bincount(train_dataset.targets)
class_weights = counts.sum() / counts
sample_weights = class_weights[train_dataset.targets]
sample = torch.utils.data.WeightedRandomSampler(weights=sample_weights, num_samples=len(train_dataset))
kwargs.pop("shuffle", None)
kwargs.update({"sampler": sample})
return torch.utils.data.DataLoader(
train_dataset,
batch_size=config.train.batch_size,
num_workers=config.train.dataloader.num_workers,
pin_memory=config.train.dataloader.pin_memory,
worker_init_fn=WorkerInitFunc(seed),
**kwargs)
class WorkerInitFunc:
def __init__(self, seed: int) -> None:
self.seed = seed
def __call__(self, worker_id: int) -> None:
return np.random.seed(self.seed + worker_id)
def get_val_dataloader(val_dataset: VisionDataset, config: yacs.config.CfgNode, seed: int) -> DataLoader:
return torch.utils.data.DataLoader(
val_dataset,
batch_size=config.validation.batch_size,
shuffle=False,
num_workers=config.validation.dataloader.num_workers,
pin_memory=config.validation.dataloader.pin_memory,
worker_init_fn=WorkerInitFunc(seed))

Просмотреть файл

@ -0,0 +1,111 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from typing import Tuple
import numpy as np
import scipy
import torch
from sklearn.neighbors import kneighbors_graph
from InnerEyeDataQuality.algorithms.graph import GraphParameters, build_connectivity_graph, label_diffusion
from InnerEyeDataQuality.utils.generic import convert_labels_to_one_hot, find_set_difference_torch
class GraphClassifier:
"""
Graph based classifier. Builds a graph and runs label diffusion to classify new points.
"""
def __init__(self,
num_samples: int,
num_classes: int,
labels: np.ndarray,
device: torch.device) -> None:
self.graph = None
self.device = device
self.num_samples = num_samples
self.num_classes = num_classes
self.graph_params = GraphParameters(n_neighbors=12,
diffusion_alpha=0.95,
cg_solver_max_iter=10,
diffusion_batch_size=None,
distance_kernel="cosine")
# Convert to one-hot label distribution
self.labels = np.array(labels) if isinstance(labels, list) else labels
self.one_hot_labels = convert_labels_to_one_hot(self.labels, n_classes=num_classes)
assert np.all(self.one_hot_labels.sum(axis=1) == 1.0)
assert self.labels.shape[0] == num_samples
def build_graph(self, embeddings: np.ndarray) -> None:
logging.info("Building a new connectivity graph")
assert embeddings.shape[0] == self.num_samples
# Build a connectivity graph and k-nearest neighbours.
n_neighbors = self.graph_params.n_neighbors
self.knn = kneighbors_graph(embeddings, n_neighbors, metric=self.graph_params.distance_kernel, n_jobs=-1)
self.graph = build_connectivity_graph(normalised=True,
embeddings=embeddings,
n_neighbors=n_neighbors,
distance_kernel=self.graph_params.distance_kernel)
laplacian = scipy.sparse.eye(self.num_samples) - self.graph_params.diffusion_alpha * self.graph # type: ignore
self.laplacian_inv = scipy.sparse.linalg.inv(laplacian.tocsc()).todense()
def fit(self, query_batch_ids: np.ndarray) -> np.ndarray:
"""
Run label diffusion and identify potential labels for query samples
"""
diffused_labels = label_diffusion(inv_laplacian=self.laplacian_inv,
labels=self.one_hot_labels,
query_batch_ids=query_batch_ids,
diffusion_normalizing_factor=0.01)
assert np.all(diffused_labels.shape == (query_batch_ids.size, self.num_classes))
assert not np.isnan(diffused_labels).any()
return diffused_labels
def filter_cases(self,
local_ind_keep: torch.Tensor,
local_ind_exclude: torch.Tensor,
global_ind: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Filter list of cases to drop based on diffused labels. If labels and diffused labels agree, do not
exclude the sample.
:param local_ind_keep: original list of samples to keep
:param local_ind_exclude: original list of samples to drop
:param global_ind: list of global indices
:return: updated list of indices to keep and drop
"""
# If input indices are empty return an empty tensor
if torch.numel(local_ind_exclude) == 0:
return local_ind_keep, local_ind_exclude
# Check input variable consistency
num_samples = torch.numel(global_ind)
assert num_samples == torch.numel(local_ind_exclude) + torch.numel(local_ind_keep)
global_ind = global_ind.to(local_ind_keep.device)
all_local_ind = torch.tensor(range(num_samples), device=local_ind_keep.device, dtype=local_ind_keep.dtype)
# Run graph diffusion to filter out incorrectly picked indices.
global_ind_exclude = global_ind[local_ind_exclude].cpu().numpy()
diffused_probs = self.fit(global_ind_exclude)
graph_pred = np.argmax(diffused_probs, axis=1)
initial_labels = self.labels[global_ind_exclude]
# Update the local indices for exclude
local_ind_exclude_updated = local_ind_exclude[graph_pred != initial_labels]
local_ind_keep_updated = find_set_difference_torch(all_local_ind, local_ind_exclude_updated)
return local_ind_keep_updated, local_ind_exclude_updated
def compute_mingling_index(self, indices: np.ndarray) -> None:
"""
Computes mingling index of each graph node based on label distribution in local graphs.
"""
mingling = np.zeros(indices.shape, dtype=np.float)
for loop_id, _ind in enumerate(indices):
disagreement = self.labels[self.knn[_ind].indices] != self.labels[_ind]
mingling[loop_id] = np.sum(disagreement) / float(disagreement.size)

Просмотреть файл

@ -0,0 +1,75 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Union
import yacs
import torch
from torch import Tensor as T
import torch.nn.functional as F
def tanh_loss(logits: T) -> T:
"""
Penalises sparse logits with large values
"Bootstrap Your Own Latent A New Approach to Self-Supervised Learning" Supplementary C.1
"""
alpha = 10
tclip = alpha * torch.tanh(logits / alpha)
return torch.mean(torch.pow(tclip, 2.0))
# Computes consistency loss between unlabelled or excluded points
def consistency_loss(logits_source: T, logits_target: T) -> Union[T, float]:
"""
Class probability consistency loss based on total variation.
"""
if (logits_source.numel() > 0) & (logits_source.shape == logits_target.shape):
_prob_source = torch.softmax(logits_source, dim=-1)
_prob_target = torch.softmax(logits_target, dim=-1)
loss = torch.mean(torch.norm(_prob_source - _prob_target.detach(), p=2, dim=-1))
return loss
else:
return 0.0
def early_regularisation_loss(student_logits: torch.Tensor, ema_logits: torch.Tensor) -> torch.Tensor:
"""
Implements early regularisation loss term proposed in:
https://arxiv.org/abs/2007.00151
"""
posteriors_teacher = torch.softmax(ema_logits, dim=-1).detach()
posteriors_student = torch.softmax(student_logits, dim=-1)
inner_prod = torch.sum(posteriors_teacher * posteriors_student, dim=-1)
early_regularisation = torch.mean(torch.log(1 - inner_prod))
return early_regularisation
def onehot_encoding(label: torch.Tensor, n_classes: int) -> torch.Tensor:
return torch.zeros(label.size(0), n_classes).to(label.device).scatter_(1, label.view(-1, 1), 1)
class CrossEntropyLoss:
"""
Cross entropy loss - implements label smoothing
"""
def __init__(self, config: yacs.config.CfgNode):
self.n_classes = config.dataset.n_classes
self.use_label_smoothing = config.augmentation.use_label_smoothing
self.epsilon = config.augmentation.label_smoothing.epsilon
def __call__(self, predictions: torch.Tensor, targets: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
if self.use_label_smoothing:
device = predictions.device
onehot = onehot_encoding(targets, self.n_classes).type_as(predictions).to(device)
targets = onehot * (1 - self.epsilon) + torch.ones_like(onehot).to(device) * self.epsilon / self.n_classes
logp = F.log_softmax(predictions, dim=1)
loss_per_sample = torch.sum(-logp * targets, dim=1)
if reduction == 'none':
return loss_per_sample
elif reduction == 'mean':
return loss_per_sample.mean()
else:
raise NotImplementedError
else:
return F.cross_entropy(predictions, targets, reduction=reduction)

Просмотреть файл

Просмотреть файл

@ -0,0 +1,175 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from dataclasses import dataclass
from typing import Any, Optional
import matplotlib.pyplot as plt
import numpy as np
from InnerEyeDataQuality.datasets.cifar10_utils import get_cifar10_label_names
from InnerEyeDataQuality.deep_learning.metrics.sample_metrics import SampleMetrics
from InnerEyeDataQuality.deep_learning.metrics.plots_tensorboard import (get_scatter_plot, plot_disagreement_per_sample,
plot_excluded_cases_coteaching)
from InnerEyeDataQuality.deep_learning.transforms import ToNumpy
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import CIFAR10
@dataclass()
class JointMetrics():
"""
Stores metrics for co-teaching models.
"""
num_samples: int
num_epochs: int
dataset: Optional[Any] = None
ambiguous_mislabelled_ids: np.ndarray = None
clear_mislabelled_ids: np.ndarray = None
true_label_entropy: np.ndarray = None
plot_dropped_images: bool = False
def reset(self) -> None:
self.kl_divergence_symmetric = np.full([self.num_samples], np.nan)
self.active = False
def __post_init__(self) -> None:
self.reset()
self.prediction_disagreement = np.zeros([self.num_samples, self.num_epochs], dtype=np.bool)
self._initialise_dataset_properties()
self.case_drop_histogram = np.zeros([self.num_samples, self.num_epochs + 1], dtype=np.bool)
if not isinstance(self.clear_mislabelled_ids, np.ndarray) or not isinstance(self.ambiguous_mislabelled_ids,
np.ndarray):
return
self.true_mislabelled_ids = np.concatenate([self.ambiguous_mislabelled_ids, self.clear_mislabelled_ids], axis=0)
self.case_drop_histogram[self.clear_mislabelled_ids, -1] = True
self.case_drop_histogram[self.ambiguous_mislabelled_ids, -1] = True
def _initialise_dataset_properties(self) -> None:
if self.dataset is not None:
self.label_names = get_cifar10_label_names() if isinstance(self.dataset, CIFAR10) \
else self.dataset.get_label_names()
self.dataset.transform = ToNumpy() # type: ignore
def log_results(self, writer: SummaryWriter, epoch: int, sample_metrics: SampleMetrics) -> None:
if (not self.active) or (self.ambiguous_mislabelled_ids is None) or (self.clear_mislabelled_ids is None):
return
# KL Divergence between the two posteriors
writer.add_scalars(main_tag='symmetric-kl-divergence', tag_scalar_dict={
'all': np.nanmean(self.kl_divergence_symmetric),
'ambiguous': np.nanmean(self.kl_divergence_symmetric[self.ambiguous_mislabelled_ids]),
'clear_noise': np.nanmean(self.kl_divergence_symmetric[self.clear_mislabelled_ids])},
global_step=epoch)
# Disagreement rate between the models
writer.add_scalars(main_tag='disagreement_rate', tag_scalar_dict={
'all': np.nanmean(self.prediction_disagreement[:, epoch]),
'ambiguous': np.nanmean(self.prediction_disagreement[self.ambiguous_mislabelled_ids, epoch]),
'clear_noise': np.nanmean(self.prediction_disagreement[self.clear_mislabelled_ids, epoch])},
global_step=epoch)
# Add histogram for the loss values
self.log_loss_values(writer, sample_metrics.loss_per_sample[:, epoch], epoch)
# Add disagreement metrics
fig = get_scatter_plot(self.true_label_entropy, self.kl_divergence_symmetric,
x_label="Label entropy", y_label="Symmetric-KL", y_lim=[0.0, 2.0])
writer.add_figure('Sym-KL vs Label Entropy', figure=fig, global_step=epoch, close=True)
fig = plot_disagreement_per_sample(self.prediction_disagreement, self.true_label_entropy)
writer.add_figure('Disagreement of prediction', figure=fig, global_step=epoch, close=True)
# Excluded cases diagnostics
self.log_dropped_cases_metrics(writer=writer, epoch=epoch)
# Every 10 epochs, display the dropped cases in the co-teaching algorithm
if epoch % 10 and self.plot_dropped_images:
self.log_dropped_images(writer=writer, predictions=sample_metrics.predictions, epoch=epoch)
# Close all figures
plt.close('all')
def log_dropped_cases_metrics(self, writer: SummaryWriter, epoch: int) -> None:
"""
Creates all diagnostics for dropped cases analysis.
"""
entropy_sorted_indices = np.argsort(self.true_label_entropy)
drop_cur_epoch_mask = self.case_drop_histogram[:, epoch]
drop_cur_epoch_ids = np.where(drop_cur_epoch_mask)[0]
is_sample_dropped = np.any(drop_cur_epoch_mask)
title = None
if is_sample_dropped:
n_dropped = float(drop_cur_epoch_ids.size)
average_label_entropy_dropped_cases = np.mean(self.true_label_entropy[drop_cur_epoch_mask])
n_detected_mislabelled = np.intersect1d(drop_cur_epoch_ids, self.true_mislabelled_ids).size
n_clean_dropped = int(n_dropped - n_detected_mislabelled)
n_detected_mislabelled_ambiguous = np.intersect1d(drop_cur_epoch_ids, self.ambiguous_mislabelled_ids).size
n_detected_mislabelled_clear = np.intersect1d(drop_cur_epoch_ids, self.clear_mislabelled_ids).size
perc_detected_mislabelled = n_detected_mislabelled / n_dropped * 100
perc_detected_clear_mislabelled = n_detected_mislabelled_clear / n_dropped * 100
perc_detected_ambiguous_mislabelled = n_detected_mislabelled_ambiguous / n_dropped * 100
title = f"Dropped Cases: Avg label entropy {average_label_entropy_dropped_cases:.3f}\n " \
f"Dropped cases: {n_detected_mislabelled} mislabelled ({perc_detected_mislabelled:.1f}%) - " \
f"{n_clean_dropped} clean ({(100 - perc_detected_mislabelled):.1f}%)\n" \
f"Num ambiguous mislabelled among detected cases: {n_detected_mislabelled_ambiguous}" \
f" ({perc_detected_ambiguous_mislabelled:.1f}%)\n" \
f"Num clear mislabelled among detected cases: {n_detected_mislabelled_clear}" \
f" ({perc_detected_clear_mislabelled:.1f}%)"
writer.add_scalars(main_tag='Number of dropped cases', tag_scalar_dict={
'clean_cases': n_clean_dropped,
'all_mislabelled_cases': n_detected_mislabelled,
'mislabelled_clear_cases': n_detected_mislabelled_clear,
'mislabelled_ambiguous_cases': n_detected_mislabelled_ambiguous}, global_step=epoch)
writer.add_scalar(tag="Percentage of mislabelled among dropped cases",
scalar_value=perc_detected_mislabelled, global_step=epoch)
fig = plot_excluded_cases_coteaching(case_drop_mask=self.case_drop_histogram,
entropy_sorted_indices=entropy_sorted_indices, title=title,
num_epochs=self.num_epochs, num_samples=self.num_samples)
writer.add_figure('Histogram of excluded cases', figure=fig, global_step=epoch, close=True)
def log_loss_values(self, writer: SummaryWriter, loss_values: np.ndarray, epoch: int) -> None:
"""
Logs histogram of loss values of one of the co-teaching models.
"""
writer.add_histogram('loss/all', loss_values, epoch)
writer.add_histogram('loss/ambiguous_noise', loss_values[self.ambiguous_mislabelled_ids], epoch)
writer.add_histogram('loss/clear_noise', loss_values[self.clear_mislabelled_ids], epoch)
def log_dropped_images(self, writer: SummaryWriter, predictions: np.ndarray, epoch: int) -> None:
"""
Logs images dropped during co-teaching training
"""
dropped_cases = np.where(self.case_drop_histogram[:, epoch])[0]
if dropped_cases.size > 0 and self.dataset is not None:
dropped_cases = dropped_cases[np.argsort(self.true_label_entropy[dropped_cases])]
fig = self.plot_batch_images_and_labels(predictions, list_indices=dropped_cases[:64])
writer.add_figure("Dropped images with lowest entropy", figure=fig, global_step=epoch, close=True)
fig = self.plot_batch_images_and_labels(predictions, list_indices=dropped_cases[-64:])
writer.add_figure("Dropped images with highest entropy", figure=fig, global_step=epoch, close=True)
kept_cases = np.where(~self.case_drop_histogram[:, epoch])[0]
kept_cases = kept_cases[np.argsort(self.true_label_entropy[kept_cases])]
fig = self.plot_batch_images_and_labels(predictions, kept_cases[-64:])
writer.add_figure("Kept images with highest entropy", figure=fig, global_step=epoch, close=True)
def plot_batch_images_and_labels(self, predictions: np.ndarray, list_indices: np.ndarray) -> plt.Figure:
"""
Plots of batch of images along with their labels and predictions. Noise cases are colored in red, clean cases
in green. Images are assumed to be numpy images (use ToNumpy() transform).
"""
assert self.dataset is not None
fig, ax = plt.subplots(8, 8, figsize=(8, 10))
ax = ax.ravel()
for i, index in enumerate(list_indices):
predicted = int(predictions[index])
color = "red" if index in self.true_mislabelled_ids else "green"
_, img, training_label = self.dataset.__getitem__(index)
ax[i].imshow(img)
ax[i].set_axis_off()
ax[i].set_title(f"Label: {self.label_names[training_label]}\nPred: {self.label_names[predicted]}",
color=color, fontsize="x-small")
return fig

Просмотреть файл

@ -0,0 +1,87 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Optional, List, Tuple, Union
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.use('Agg') # No display
def get_scatter_plot(x_data: np.ndarray, y_data: np.ndarray, scale: np.ndarray = None,
title: str = '', x_label: str = '', y_label: str = '',
y_lim: Optional[List[float]] = None) -> plt.Figure:
fig, ax = plt.subplots()
ax.scatter(x_data, y_data, alpha=0.3, s=scale)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_title(title)
ax.grid()
if y_lim:
ax.set_ylim(y_lim)
return fig
def get_histogram_plot(data: Union[List[np.ndarray], np.ndarray], num_bins: int, title: str = '',
x_label: str = '', x_lim: Tuple[float, float] = None) -> plt.Figure:
"""
Creates a histogram plot for a given set of numpy arrays specified in `data` object.
Return the generated figure object.
"""
fig, ax = plt.subplots()
cm = plt.cm.get_cmap('Pastel1')
if isinstance(data, list):
for _d_id, _d in enumerate(data):
ax.hist(_d, density=True, bins=num_bins, color=cm(_d_id), alpha=0.6)
else:
ax.hist(data, density=True, bins=num_bins)
ax.set_ylabel('Sample density')
ax.set_xlabel(x_label)
ax.set_title(title)
ax.grid()
if x_lim:
ax.set_xlim(x_lim)
return fig
def plot_excluded_cases_coteaching(case_drop_mask: np.ndarray,
entropy_sorted_indices: np.ndarray,
num_epochs: int,
num_samples: int,
title: Optional[str] = None) -> plt.Figure:
"""
Plots the excluded cases in co-teaching training - training epochs vs sample_ids
Samples are sorted based on their true ambiguity score.
"""
fig, ax = plt.subplots(figsize=(15, 10))
ax.imshow(case_drop_mask[entropy_sorted_indices, :].astype(np.uint8) * 255, cmap="gray",
extent=[0, num_epochs, 0, num_samples], vmin=0, vmax=10, aspect='auto')
ax.set_xlabel("Number of epochs")
ax.set_ylabel("Training sample ids (ordered by true entropy)")
if title:
ax.set_title(title)
return fig
def plot_disagreement_per_sample(prediction_disagreement: np.ndarray, true_label_entropy: np.ndarray) -> plt.Figure:
"""
Plots predicted class disagreement between two models - training epochs vs sample_ids
Samples are sorted based on their true ambiguity score.
"""
entropy_sorted_indices = np.argsort(true_label_entropy)
fig, ax = plt.subplots()
num_epochs = prediction_disagreement.shape[1]
num_samples = prediction_disagreement.shape[0]
ax.imshow(prediction_disagreement[entropy_sorted_indices, :].astype(np.uint8) * 2, cmap="gray",
extent=[0, num_epochs, 0, num_samples], vmin=0, vmax=1, aspect='auto')
ax.set_xlabel("Number of epochs")
ax.set_ylabel("Training sample ids (ordered by true entropy)")
return fig

Просмотреть файл

@ -0,0 +1,170 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from dataclasses import dataclass
from typing import Optional, List
import numpy as np
import torch
from sklearn.metrics import auc, f1_score, precision_recall_curve, roc_auc_score
from torch.utils.tensorboard import SummaryWriter
@dataclass
class SampleMetrics():
"""
Stores data required for training monitoring of individual model and
post-training analysis in sample selection
"""
name: str
num_epochs: int
num_samples: int
num_classes: int
clear_labels: np.ndarray = None
ambiguous_mislabelled_ids: np.ndarray = None
clear_mislabelled_ids: np.ndarray = None
embeddings_size: Optional[int] = None
true_label_entropy: Optional[np.ndarray] = None
def __post_init__(self) -> None:
self.reset()
self.loss_per_sample = np.full([self.num_samples, self.num_epochs], np.nan)
self.logits_per_sample = np.full([self.num_samples, self.num_classes, self.num_epochs], np.nan)
if self.clear_mislabelled_ids is not None:
full_set = set(range(self.num_samples))
if self.ambiguous_mislabelled_ids is not None:
self.all_mislabelled_ids = np.concatenate(
[self.ambiguous_mislabelled_ids, self.clear_mislabelled_ids], 0)
else:
self.all_mislabelled_ids = self.clear_mislabelled_ids
self.all_clean_label_ids = np.array(list(full_set.difference(self.all_mislabelled_ids)))
def reset(self) -> None:
self.loss_optimised: List[float] = list()
self.predictions = np.full([self.num_samples], np.nan)
self.labels = np.full([self.num_samples], np.nan)
self.probabilities = np.full([self.num_samples, self.num_classes], np.nan)
self.correct_predictions_teacher = np.full(self.num_samples, np.nan)
self.correct_predictions = np.full([self.num_samples], np.nan)
self.embeddings_per_sample = np.full([self.num_samples, self.embeddings_size],
np.nan) if self.embeddings_size is not None else None
def get_average_accuracy(self) -> np.ndarray:
return np.nanmean(self.correct_predictions)
def get_average_loss(self) -> np.ndarray:
return np.mean(self.loss_optimised)
def log_results(self, epoch: int, name: str, writer: Optional[SummaryWriter] = None) -> None:
mean_loss = self.get_average_loss()
accuracy = self.get_average_accuracy()
accuracy_on_clean = np.nanmean(self.predictions == self.clear_labels)
logging.info(f'{self.name} \t accuracy: {accuracy:.2f} \t loss: {mean_loss:.2e}')
if writer is None:
return
# Store accuracy and loss metrics
writer.add_scalar(tag='loss', scalar_value=mean_loss, global_step=epoch)
writer.add_scalar(tag='accuracy/acc_on_sampled_labels', scalar_value=accuracy, global_step=epoch)
writer.add_scalar(tag='accuracy/acc_on_clean_labels', scalar_value=accuracy_on_clean, global_step=epoch)
# Binary classification case
if self.num_classes == 2:
logits = self.logits_per_sample[:, :, epoch]
labels = self.labels
predictions = np.argmax(logits, axis=-1)
available_indices = np.where(~np.isnan(self.labels))[0]
roc_auc = roc_auc_score(labels[available_indices], logits[available_indices, 1].reshape(-1))
f1 = f1_score(labels[available_indices], predictions[available_indices])
precision, recall, _ = precision_recall_curve(labels[available_indices], logits[available_indices, 1].reshape(-1))
pr_auc = auc(recall, precision)
writer.add_scalar(tag='roc_auc', scalar_value=roc_auc, global_step=epoch)
writer.add_scalar(tag='f1_score', scalar_value=f1, global_step=epoch)
writer.add_scalar(tag='pr_auc', scalar_value=pr_auc, global_step=epoch)
logging.info(f'{name} \t roc_auc: {roc_auc: .2f} \t pr_auc: {pr_auc: .2f} \t f1_score: {f1: .2f}')
if self.clear_mislabelled_ids is not None:
get_sub_f1 = lambda ind: f1_score(labels[ind], predictions[ind]) if len(ind) > 0 else 0
clean_available = np.intersect1d(self.all_clean_label_ids, available_indices)
mislabelled_available = np.intersect1d(self.all_mislabelled_ids, available_indices)
scalar_dict = {
'clean_cases': get_sub_f1(clean_available),
'all_mislabelled_cases': get_sub_f1(mislabelled_available)}
writer.add_scalars(main_tag='f1_breakdown', tag_scalar_dict=scalar_dict, global_step=epoch)
get_sub_auc = lambda ind: roc_auc_score(labels[ind], predictions[ind]) if len(ind) > 0 else 0
scalar_dict = {
'clean_cases': get_sub_auc(clean_available),
'all_mislabelled_cases': get_sub_auc(mislabelled_available)}
writer.add_scalars(main_tag='auc_breakdown', tag_scalar_dict=scalar_dict, global_step=epoch)
# Add histogram for the loss values
self.log_loss_values(writer, self.loss_per_sample[:, epoch], epoch)
# Breakdown of the accuracy on different sample types
if self.clear_mislabelled_ids is not None:
get_sub_acc = lambda ind: np.nanmean(self.correct_predictions[ind])
scalar_dict = {
'clean_cases': get_sub_acc(self.all_clean_label_ids),
'all_mislabelled_cases': get_sub_acc(self.all_mislabelled_ids),
'mislabelled_clear_cases': get_sub_acc(self.clear_mislabelled_ids)}
if self.ambiguous_mislabelled_ids is not None:
scalar_dict.update({'mislabelled_ambiguous_cases': get_sub_acc(self.ambiguous_mislabelled_ids)})
writer.add_scalars(main_tag='accuracy_breakdown', tag_scalar_dict=scalar_dict, global_step=epoch)
# Log mean teacher's accuracy
if not np.isnan(self.correct_predictions_teacher).any():
writer.add_scalar("teacher_accuracy", np.nanmean(self.correct_predictions_teacher), epoch)
def get_margin(self, epoch: int) -> np.ndarray:
"""
Get the margin for each sample defined as logits(y) - max_{y != t}[logits(t)]
"""
margin = np.full(self.num_samples, np.nan)
logits = self.logits_per_sample[:, :, epoch]
for i in range(self.num_samples):
label = int(self.labels[i])
assigned_logit = logits[i, label]
order = np.argsort(logits[i])
order = order[order != label]
other_max_logits = logits[i, order[-1]]
margin[i] = assigned_logit - other_max_logits
return margin
def log_loss_values(self, writer: SummaryWriter, loss_values: np.ndarray, epoch: int) -> None:
"""
Logs histogram of loss values of one of the co-teaching models.
"""
writer.add_histogram('loss/all', loss_values, epoch)
def append_batch(
self,
epoch: int,
logits: torch.Tensor,
labels: torch.Tensor,
loss: float,
indices: list,
per_sample_loss: torch.Tensor,
embeddings: Optional[np.ndarray] = None,
teacher_logits: Optional[torch.Tensor] = None) -> None:
"""
Append stats collected from batch of samples to metrics
"""
if teacher_logits is not None:
self.correct_predictions_teacher[indices] = torch.eq(torch.argmax(teacher_logits, dim=-1),
labels).type(torch.float32).cpu()
self.correct_predictions[indices] = torch.eq(torch.argmax(logits, dim=-1), labels).type(torch.float32).cpu()
self.predictions[indices] = torch.argmax(logits, dim=-1).cpu()
self.loss_per_sample[indices, epoch] = per_sample_loss.cpu()
self.logits_per_sample[indices, :, epoch] = logits.cpu()
self.probabilities[indices, :] = torch.softmax(logits, dim=-1).cpu()
self.labels[indices] = labels.cpu()
self.loss_optimised.append(loss)
if embeddings is not None:
# We don't know the size of the features in advance.
if self.embeddings_size is None:
self.embeddings_size = embeddings.shape[1]
self.embeddings_per_sample = np.full([self.num_samples, self.embeddings_size], np.nan)
assert self.embeddings_per_sample is not None
self.embeddings_per_sample[indices, :] = embeddings

Просмотреть файл

@ -0,0 +1,92 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import os
from pathlib import Path
from typing import Any, Optional
import numpy as np
import torch
from InnerEyeDataQuality.deep_learning.metrics.joint_metrics import JointMetrics
from InnerEyeDataQuality.deep_learning.metrics.sample_metrics import SampleMetrics
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
class MetricTracker(object):
"""
"""
def __init__(self,
output_dir: str,
num_epochs: int,
num_samples_total: int,
num_samples_per_epoch: int,
num_classes: int,
save_tf_events: bool,
dataset: Optional[Dataset] = None,
name: str = "default_metric",
**sample_info_kwargs: Any):
"""
Class to track model training metrics.
If a co-teaching model is trained, joint model metrics are stored such as disagreement rate and kl divergence.
Similarly, it stores loss and logits values on a per sample basis for each epoch for post-training analysis.
This stored data can be utilised in data selection simulation.
"""
Path(output_dir).mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir
self.name = name
self.num_classes = num_classes
self.num_samples_total = num_samples_total
self.num_samples_per_epoch = num_samples_per_epoch
clean_targets = dataset.clean_targets if hasattr(dataset, "clean_targets") else None # type: ignore
self.joint_model_metrics = JointMetrics(num_samples_total, num_epochs, dataset, **sample_info_kwargs)
self.sample_metrics = SampleMetrics(name, num_epochs, num_samples_total, num_classes,
clear_labels=clean_targets,
embeddings_size=None, **sample_info_kwargs)
self.writer = SummaryWriter(log_dir=output_dir) if save_tf_events else None
def reset(self) -> None:
self.sample_metrics.reset()
self.joint_model_metrics.reset()
def log_epoch_and_reset(self, epoch: int) -> None:
# assert np.count_nonzero(~np.isnan(self.sample_metrics.loss_per_sample[:, epoch])) == self.num_samples_per_epoch
self.sample_metrics.log_results(epoch=epoch, name=self.name, writer=self.writer)
if self.writer:
self.joint_model_metrics.log_results(self.writer, epoch, self.sample_metrics)
# Reset epoch metrics
self.reset()
def append_batch_aggregate(self, epoch: int, logits_x: torch.Tensor, logits_y: torch.Tensor,
dropped_cases: torch.Tensor, indices: torch.Tensor) -> None:
"""
Stores the disagreement stats for co-teaching models
"""
post_x = torch.softmax(logits_x, dim=-1)
post_y = torch.softmax(logits_y, dim=-1)
sym_kl_per_sample = torch.sum(post_x * torch.log(post_x / post_y) + post_y * torch.log(post_y / post_x), dim=-1)
pred_x = torch.argmax(logits_x, dim=-1)
pred_y = torch.argmax(logits_y, dim=-1)
class_pred_disagreement = pred_x != pred_y
self.joint_model_metrics.kl_divergence_symmetric[indices] = sym_kl_per_sample.cpu().numpy()
self.joint_model_metrics.prediction_disagreement[indices, epoch] = class_pred_disagreement.cpu().numpy()
self.joint_model_metrics.case_drop_histogram[dropped_cases.cpu().numpy(), epoch] = True
self.joint_model_metrics.active = True
def save_loss(self) -> None:
output_path = os.path.join(self.output_dir, f'{self.name}_training_stats.npz')
if hasattr(self.joint_model_metrics, "case_drop_histogram"):
np.savez(output_path,
loss_per_sample=self.sample_metrics.loss_per_sample,
logits_per_sample=self.sample_metrics.logits_per_sample,
dropped_cases=self.joint_model_metrics.case_drop_histogram[:, :-1])
else:
np.savez(output_path,
loss_per_sample=self.sample_metrics.loss_per_sample,
logits_per_sample=self.sample_metrics.logits_per_sample)

Просмотреть файл

@ -0,0 +1,69 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any, List, Tuple
import numpy as np
from InnerEyeDataQuality.configs.config_node import ConfigNode
from InnerEyeDataQuality.deep_learning.collect_embeddings import get_all_embeddings, register_embeddings_collector
from InnerEyeDataQuality.deep_learning.utils import get_run_config
from InnerEyeDataQuality.deep_learning.trainers.vanilla_trainer import VanillaTrainer
from InnerEyeDataQuality.deep_learning.trainers.model_trainer_base import ModelTrainer
from InnerEyeDataQuality.deep_learning.trainers.co_teaching_trainer import CoTeachingTrainer
from InnerEyeDataQuality.deep_learning.dataloader import get_val_dataloader
from InnerEyeDataQuality.utils.custom_types import SelectorTypes as ST
NUM_MULTIPLE_RUNS = 10
def inference_model(dataloader: Any, model_trainer: ModelTrainer, use_mc_sampling: bool = False) -> Tuple[List, List]:
"""
Performs an inference pass on a single model
:param config:
:return:
"""
# Inference on given dataloader
all_model_cnn_embeddings = register_embeddings_collector(model_trainer.models, use_only_in_train=False)
trackers = model_trainer.run_inference(dataloader, use_mc_sampling)
embs = get_all_embeddings(all_model_cnn_embeddings)
probs = [metric_tracker.sample_metrics.probabilities for metric_tracker in trackers]
return embs, probs
def inference_ensemble(dataset: Any, config: ConfigNode) -> Tuple[np.ndarray, np.ndarray, np.ndarray, ModelTrainer]:
"""
Returns:
embeddings: 2D numpy array - containing sample embeddings obtained from a CNN.
[num_samples, embedding_size]
posteriors: 2D numpy array - containing class posteriors obtained from a CNN.
[num_samples, num_classes]
trainer: Model trainer object built using config file
"""
# Reload the model from config
model_trainer_class = CoTeachingTrainer if config.train.use_co_teaching else VanillaTrainer
config = get_run_config(config, config.train.seed)
model_trainer = model_trainer_class(config=config)
model_trainer.load_checkpoints(restore_scheduler=False)
# Prepare output data structures
all_embeddings = []
all_posteriors = []
# Run inference on the given dataset
use_mc_sampling = config.model.use_dropout and ST(config.selector.type[0]) == ST.BaldSelector
multi_inference = config.train.use_self_supervision or config.model.use_dropout
num_runs = NUM_MULTIPLE_RUNS if multi_inference else 1
for run_ind in range(num_runs):
dataloader = get_val_dataloader(dataset, config, seed=config.train.seed + run_ind)
embeddings, posteriors = inference_model(dataloader, model_trainer, use_mc_sampling)
all_embeddings.append(embeddings)
all_posteriors.append(posteriors)
# Aggregate results and return
embeddings = np.mean(all_embeddings, axis=(0, 1))
all_posteriors = np.stack(all_posteriors, axis=0)
avg_posteriors = np.mean(all_posteriors, axis=(0, 1))
return embeddings, avg_posteriors, all_posteriors, model_trainer

Просмотреть файл

@ -0,0 +1,51 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
import numpy as np
class ForgetRateScheduler(object):
"""
Forget rate scheduler for the co-teaching model
"""
def __init__(self,
num_epochs: int,
forget_rate: float = 0.0,
num_gradual: int = 10,
num_warmup_epochs: int = 25,
start_epoch: int = 0):
"""
:param num_epochs: The total number of training epochs
:param forget_rate: The base forget rate
:param num_gradual: The number of epochs to gradual increase the forget_rate to its base value
:param start_epoch: Allows manually set the start epoch if training is resumed.
"""
logging.info(f"No samples will be excluded in co-teaching for the first {num_warmup_epochs} epochs.")
if num_gradual <= num_warmup_epochs:
logging.warning(f"Num gradual {num_gradual} <= num warm up epochs. This argument will be ignored.")
assert 0 <= forget_rate < 1.
self.forget_rate_schedule = np.ones(num_epochs) * forget_rate
self.forget_rate_schedule[:num_gradual] = np.linspace(0, forget_rate, num_gradual)
self.forget_rate_schedule[:num_warmup_epochs] = np.zeros(num_warmup_epochs)
self.current_epoch = start_epoch
def step(self) -> None:
"""
Step the current epoch by one
:return:
"""
self.current_epoch += 1
@property
def get_forget_rate(self) -> float:
"""
:return: The current forget rate
"""
current_epoch = min(self.current_epoch, len(self.forget_rate_schedule) - 1)
return float(self.forget_rate_schedule[current_epoch])

Просмотреть файл

@ -0,0 +1,5 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

Просмотреть файл

@ -0,0 +1,60 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Tuple, Any
from InnerEyeDataQuality.deep_learning.architectures.lambda_layer import Lambda
from InnerEyeDataQuality.deep_learning.self_supervised.ssl_classifier_module import get_encoder_output_dim
from InnerEyeDataQuality.deep_learning.self_supervised.utils import create_ssl_encoder
from torch import Tensor as T
from torch import nn
class _MLP(nn.Module):
def __init__(self, input_dim: int, hidden_size: int, output_dim: int) -> None:
super().__init__()
self.output_dim = output_dim
self.input_dim = input_dim
self.model = nn.Sequential(
nn.Linear(input_dim, hidden_size, bias=False),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, output_dim, bias=True))
def forward(self, x: T) -> T:
x = self.model(x)
return x
class SSLEncoder(nn.Module):
def __init__(self, encoder_name: str, dataset_name: str, use_output_pooling: bool = True):
super().__init__()
self.cnn_model = create_ssl_encoder(encoder_name=encoder_name, dataset_name=dataset_name)
self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
self.use_output_pooling = use_output_pooling
def forward(self, x: T) -> T:
x = self.cnn_model(x)
x = x[-1] if isinstance(x, list) else x
x = self.avgpool(x).view(x.size(0), -1) if self.use_output_pooling else x
return x
def get_output_feature_dim(self) -> int:
return get_encoder_output_dim(self)
class SiameseArm(nn.Module):
def __init__(self, *encoder_kwargs: Any) -> None:
super().__init__()
self.encoder = SSLEncoder(*encoder_kwargs) # Encoder
self.projector = _MLP(input_dim=self.encoder.get_output_feature_dim(), hidden_size=2048, output_dim=128)
self.predictor = _MLP(input_dim=self.projector.output_dim, hidden_size=128, output_dim=128)
self.projector_normalised = nn.Sequential(self.projector,
Lambda(lambda x: nn.functional.normalize(x, dim=-1)))
def forward(self, x: T) -> Tuple[T, T, T]:
y = self.encoder(x)
z = self.projector(y)
h = self.predictor(z)
return y, z, h

Просмотреть файл

@ -0,0 +1,142 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from copy import deepcopy
from typing import Any, Dict, Iterator, List, Tuple, Optional
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from InnerEyeDataQuality.deep_learning.self_supervised.byol.byol_models import SiameseArm
from InnerEyeDataQuality.deep_learning.self_supervised.byol.byol_moving_average import BYOLMAWeightUpdate
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from torch import Tensor as T
from torch.optim import Adam
BatchType = Tuple[List, T]
class BYOLInnerEye(pl.LightningModule):
"""
Implementation of `Bootstrap Your Own Latent (BYOL) <https://arxiv.org/pdf/2006.07733.pdf>`
"""
def __init__(self,
num_samples: int,
learning_rate: float,
batch_size: int,
encoder_name: str,
warmup_epochs: int,
dataset_name: Optional[str] = None,
weight_decay: float = 1e-6,
**kwargs: Any) -> None:
"""
Args:
num_samples: Number of samples present in training dataset / dataloader.
learning_rate: Optimizer learning rate.
batch_size: Sample batch size used in gradient updates.
encoder_name: Type of CNN encoder used to extract image embeddings. The options are:
{'resnet18', 'resnet50', 'resnet101'}.
warmup_epochs: Number of epochs for scheduler warm up (linear increase from 0 to base_lr).
dataset_name: Name of training dataset - If set to "CIFAR10" then the encoder is adjusted to image size.
weight_decay: L2-norm weight decay.
"""
super().__init__()
self.save_hyperparameters()
self.min_learning_rate = 1e-4
self.online_network = SiameseArm(encoder_name, dataset_name)
self.target_network = deepcopy(self.online_network)
self.weight_callback = BYOLMAWeightUpdate()
def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None:
# Add callback for user automatically since it's key to BYOL weight update
self.weight_callback.on_before_zero_grad(self.trainer, self)
def forward(self, x: T) -> T: # type: ignore
return self.target_network.encoder(x)
def cosine_loss(self, a: T, b: T) -> T:
a = F.normalize(a, dim=-1)
b = F.normalize(b, dim=-1)
neg_cos_sim = -(a * b).sum(dim=-1).mean()
return neg_cos_sim
def shared_step(self, batch: BatchType, batch_idx: int) -> T:
(img_1, img_2), y = batch
# Image 1 to image 2 loss
_, _, h_img1 = self.online_network(img_1)
_, _, h_img2 = self.online_network(img_2)
with torch.no_grad():
_, z_img1, _ = self.target_network(img_1)
_, z_img2, _ = self.target_network(img_2)
loss = 0.5 * (self.cosine_loss(h_img1, z_img2.detach())
+ self.cosine_loss(h_img2, z_img1.detach()))
return loss
def training_step(self, batch: BatchType, batch_idx: int) -> T: # type: ignore
loss = self.shared_step(batch, batch_idx)
self.log_dict({'byol/train_loss': loss, 'byol/tau': self.weight_callback.current_tau})
return loss
def validation_step(self, batch: BatchType, batch_idx: int) -> T: # type: ignore
loss = self.shared_step(batch, batch_idx)
self.log_dict({'byol/validation_loss': loss})
return loss
def setup(self, *args: Any, **kwargs: Any) -> None:
global_batch_size = self.trainer.world_size * self.hparams.batch_size # type: ignore
self.train_iters_per_epoch = self.hparams.num_samples // global_batch_size # type: ignore
def configure_optimizers(self) -> Any:
# TRICK 1 (Use lars + filter weights)
# exclude certain parameters
parameters = self.exclude_from_wt_decay(self.online_network.named_parameters(),
weight_decay=self.hparams.weight_decay) # type: ignore
optimizer = LARSWrapper(Adam(parameters, lr=self.hparams.learning_rate)) # type: ignore
# Trick 2 (after each step)
self.hparams.warmup_epochs = self.hparams.warmup_epochs * self.train_iters_per_epoch # type: ignore
max_epochs = self.trainer.max_epochs * self.train_iters_per_epoch
linear_warmup_cosine_decay = LinearWarmupCosineAnnealingLR(
optimizer,
warmup_epochs=self.hparams.warmup_epochs, # type: ignore
max_epochs=max_epochs,
warmup_start_lr=0,
eta_min=self.min_learning_rate,
)
scheduler = {'scheduler': linear_warmup_cosine_decay, 'interval': 'step', 'frequency': 1}
return [optimizer], [scheduler]
def exclude_from_wt_decay(self,
named_params: Iterator[Tuple[str, T]],
weight_decay: float,
skip_list: List[str] = ['bias', 'bn']) -> List[Dict[str, Any]]:
"""
Convolution-Linear bias-terms and batch-norm parameters are excluded from l2-norm weight decay regularisation.
https://arxiv.org/pdf/2006.07733.pdf Section 3.3 Optimisation and Section F.5.
"""
params = []
excluded_params = []
for name, param in named_params:
if not param.requires_grad:
continue
elif any(layer_name in name for layer_name in skip_list):
excluded_params.append(param)
else:
params.append(param)
return [
{'params': params, 'weight_decay': weight_decay},
{'params': excluded_params, 'weight_decay': 0.}
]

Просмотреть файл

@ -0,0 +1,58 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import math
import torch
import pytorch_lightning as pl
from pytorch_lightning import Callback
class BYOLMAWeightUpdate(Callback):
"""
Weight updates for BYOL moving average encoder (e.g. teacher). Pl_module is expected to contain three attributes:
- ``pl_module.online_network``
- ``pl_module.target_network``
- ``pl_module.global_step``
Updates the target_network params using an exponential moving average update rule weighted by tau.
Tau parameter is increased from its base value to 1.0 with every training step scheduled with a cosine function.
global_step correspond to the total number of sgd updates expected to happen throughout the BYOL training.
Target network is updated at the end of each SGD update on training batch.
"""
def __init__(self, initial_tau: float = 0.99):
"""
Args:
initial_tau: starting tau. Auto-updates with every training step
"""
super().__init__()
self.initial_tau = initial_tau
self.current_tau = initial_tau
def on_before_zero_grad(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: # type: ignore
# get networks
online_net = pl_module.online_network
target_net = pl_module.target_network
assert(isinstance(online_net, torch.nn.Module))
assert(isinstance(target_net, torch.nn.Module))
# update weights
self.update_weights(online_net, target_net)
# update tau after
self.current_tau = self.update_tau(pl_module, trainer)
def update_tau(self, pl_module: pl.LightningModule, trainer: pl.Trainer) -> float:
max_steps = len(trainer.train_dataloader) * trainer.max_epochs # type: ignore
tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi * pl_module.global_step / max_steps) + 1) / 2
return tau
def update_weights(self, online_net: torch.nn.Module, target_net: torch.nn.Module) -> None:
# apply MA weight update
for current_params, ma_params in zip(online_net.parameters(), target_net.parameters()):
up_weight, old_weight = current_params.data, ma_params.data
ma_params.data = old_weight * self.current_tau + (1 - self.current_tau) * up_weight

Просмотреть файл

@ -0,0 +1,5 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

Просмотреть файл

@ -0,0 +1,16 @@
dataset:
name: CIFAR10H
num_workers: 6
train:
seed: 1
batch_size: 512
base_lr: 1e-3
output_dir: cifar10h/self_supervised
resume_from_last_checkpoint: False
self_supervision:
type: byol
encoder_name: resnet50
scheduler:
epochs: 2500
preprocess:
resize: 32

Просмотреть файл

@ -0,0 +1,16 @@
dataset:
name: CIFAR10H
num_workers: 6
train:
seed: 1
batch_size: 512
base_lr: 1e-3
output_dir: cifar10h/self_supervised
resume_from_last_checkpoint: False
self_supervision:
type: simclr
encoder_name: resnet101
scheduler:
epochs: 2000
preprocess:
resize: 32

Просмотреть файл

@ -0,0 +1,49 @@
dataset:
name: NIH
dataset_dir: datadrive/NIH
train:
seed: 3
batch_size: 2400
base_lr: 1e-3
output_dir: nih/ssup
resume_from_last_checkpoint: False
self_supervision:
type: byol
encoder_name: resnet50
use_balanced_binary_loss_for_linear_head: True
scheduler:
epochs: 1500
preprocess:
center_crop_size: 224
resize: 256
augmentation:
use_random_horizontal_flip: True
use_random_affine: True
use_random_color: True
use_random_crop: True
use_gamma_transform: True
use_random_erasing: True
add_gaussian_noise: True
use_elastic_transform: True
random_horizontal_flip:
prob: 0.5
random_erasing:
scale: (0.15, 0.4)
ratio: (0.33, 3)
random_affine:
max_angle: 180
max_horizontal_shift: 0.00
max_vertical_shift: 0.00
max_shear: 40
elastic_transform:
sigma: 4
alpha: 34
p_apply: 0.4
random_color:
brightness: 0.2
contrast: 0.2
saturation: 0.0
random_crop:
scale: (0.4, 1.0)
gaussian_noise:
std: 0.05

Просмотреть файл

@ -0,0 +1,46 @@
dataset:
name: NIH
dataset_dir: datadrive/NIH
train:
seed: 2
batch_size: 1600
base_lr: 1e-3
output_dir: nih/nohist
resume_from_last_checkpoint: False
self_supervision:
type: simclr
encoder_name: resnet50
scheduler:
epochs: 1500
preprocess:
center_crop_size: 224
resize: 256
augmentation:
use_random_horizontal_flip: True
use_random_affine: True
use_random_color: True
use_random_crop: True
use_gamma_transform: False
use_random_erasing: True
add_gaussian_noise: True
use_elastic_transform: True
random_horizontal_flip:
prob: 0.5
random_erasing:
scale: (0.005, 0.08)
ratio: (0.33, 3)
random_affine:
max_angle: 180
max_horizontal_shift: 0.00
max_vertical_shift: 0.00
max_shear: 40
elastic_transform:
sigma: 4
alpha: 34
p_apply: 0.4
random_color:
brightness: 0.3
contrast: 0.3
saturation: 0.0
random_crop:
scale: (0.5, 1.0)

Просмотреть файл

@ -0,0 +1,85 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from InnerEyeDataQuality.configs.config_node import ConfigNode
config = ConfigNode()
config.dataset = ConfigNode()
config.dataset.name = 'CIFAR10'
config.dataset.dataset_dir = ''
config.dataset.num_workers = None
config.train = ConfigNode()
config.train.seed = None
config.train.batch_size = None
config.train.output_dir = None
config.train.resume_from_last_checkpoint = False
config.train.base_lr = None
config.train.self_supervision = ConfigNode()
config.train.self_supervision.type = None
config.train.self_supervision.encoder_name = None
config.train.self_supervision.use_balanced_binary_loss_for_linear_head = False
config.train.checkpoint_period = 200
config.scheduler = ConfigNode()
config.scheduler.epochs = None
config.augmentation = ConfigNode()
config.augmentation.use_random_crop = False
config.augmentation.use_random_horizontal_flip = False
config.augmentation.use_random_affine = False
config.augmentation.use_label_smoothing = False
config.augmentation.use_random_color = False
config.augmentation.add_gaussian_noise = False
config.augmentation.use_gamma_transform = False
config.augmentation.use_random_erasing = False
config.augmentation.use_elastic_transform = False
config.augmentation.random_crop = ConfigNode()
config.augmentation.random_crop.scale = (0.9, 1.0)
config.augmentation.elastic_transform = ConfigNode()
config.augmentation.elastic_transform.sigma = 4
config.augmentation.elastic_transform.alpha = 35
config.augmentation.elastic_transform.p_apply = 0.5
config.augmentation.gaussian_noise = ConfigNode()
config.augmentation.gaussian_noise.std = 0.01
config.augmentation.gaussian_noise.p_apply = 0.5
config.augmentation.random_horizontal_flip = ConfigNode()
config.augmentation.random_horizontal_flip.prob = 0.5
config.augmentation.random_affine = ConfigNode()
config.augmentation.random_affine.max_angle = 0
config.augmentation.random_affine.max_horizontal_shift = 0.0
config.augmentation.random_affine.max_vertical_shift = 0.0
config.augmentation.random_affine.max_shear = 5
config.augmentation.random_color = ConfigNode()
config.augmentation.random_color.brightness = 0.0
config.augmentation.random_color.contrast = 0.1
config.augmentation.random_color.saturation = 0.1
config.augmentation.gamma = ConfigNode()
config.augmentation.gamma.scale = (0.5, 1.5)
config.augmentation.label_smoothing = ConfigNode()
config.augmentation.label_smoothing.epsilon = 0.1
config.augmentation.random_erasing = ConfigNode()
config.augmentation.random_erasing.scale = (0.01, 0.1)
config.augmentation.random_erasing.ratio = (0.3, 3.3)
config.preprocess = ConfigNode()
config.preprocess.use_resize = False
config.preprocess.use_center_crop = False
config.preprocess.center_crop_size = 224
config.preprocess.histogram_normalization = ConfigNode()
config.preprocess.histogram_normalization.disk_size = 30
config.preprocess.resize = 32
def get_default_model_config() -> ConfigNode:
return config.clone()

Просмотреть файл

@ -0,0 +1,156 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any, Callable, Optional
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from InnerEyeDataQuality.configs.config_node import ConfigNode
from InnerEyeDataQuality.datasets.kaggle_cxr import KAGGLE_TOTAL_SIZE, KaggleCXR
from InnerEyeDataQuality.datasets.nih_cxr import NIHCXR, NIH_TOTAL_SIZE
from InnerEyeDataQuality.deep_learning.create_dataset_transforms import create_chest_xray_transform
from InnerEyeDataQuality.deep_learning.dataloader import WorkerInitFunc
from InnerEyeDataQuality.deep_learning.transforms import DualViewTransformWrapper
import numpy as np
class KaggleDataModule(LightningDataModule):
def __init__(self, config: ConfigNode,
num_devices: int,
num_workers: int,
*args: Any, **kwargs: Any) -> None:
"""
This is the data module to load and prepare the Kaggle Pneumonia detection challenge dataset.
:param config:
:param num_devices: The number of GPUs to use. The total batch size specified in the config will be divided
by the number of GPUs.
:param num_workers: The number of dataloader workers.
"""
super().__init__(*args, **kwargs)
self.config = config
self.num_samples = KAGGLE_TOTAL_SIZE
self.batch_size = config.train.batch_size // num_devices
self.num_workers = num_workers
self.train_transforms = DualViewTransformWrapper(create_chest_xray_transform(self.config, is_train=True))
self.val_transforms = DualViewTransformWrapper(create_chest_xray_transform(self.config, is_train=False))
self.train_dataset = KaggleCXR(self.config.dataset.dataset_dir, use_training_split=True,
transform=self.train_transforms, return_index=False)
self.class_weights: Optional[torch.Tensor] = None
if config.train.self_supervision.use_balanced_binary_loss_for_linear_head:
# Weight = inverse class proportion.
class_weights = len(self.train_dataset.targets) / np.bincount(self.train_dataset.targets)
# Normalized class weights
class_weights /= class_weights.sum()
self.class_weights = torch.tensor(class_weights)
@property
def num_classes(self) -> int:
return 2
def train_dataloader(self) -> DataLoader: # type: ignore
"""
Returns Kaggle training set (80% of total dataset)
"""
return torch.utils.data.DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=False,
worker_init_fn=WorkerInitFunc(self.config.train.seed),
drop_last=True)
def val_dataloader(self) -> DataLoader: # type: ignore
"""
Returns Kaggle validation set (20% of total dataset)
"""
val_dataset = KaggleCXR(self.config.dataset.dataset_dir, use_training_split=False,
transform=self.val_transforms, return_index=False)
loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=False,
worker_init_fn=WorkerInitFunc(self.config.train.seed),
drop_last=True)
return loader
def test_dataloader(self) -> DataLoader: # type: ignore
"""
No Kaggle test split implemented
"""
raise NotImplementedError
def default_transforms(self) -> Callable:
transform = create_chest_xray_transform(self.config, is_train=False)
return transform
class NIHDataModule(LightningDataModule):
def __init__(self, config: ConfigNode, num_devices: int, num_workers: int, *args: Any, **kwargs: Any) -> None:
"""
This is the data module to load and prepare the NIH dataset (112k scans).
"""
super().__init__(*args, **kwargs)
self.config = config
self.batch_size = config.train.batch_size // num_devices
self.num_workers = num_workers
self.train_transforms = DualViewTransformWrapper(create_chest_xray_transform(self.config, is_train=True))
self.val_transforms = DualViewTransformWrapper(create_chest_xray_transform(self.config, is_train=False))
self.num_samples = NIH_TOTAL_SIZE
self.train_dataset = NIHCXR(self.config.dataset.dataset_dir, use_training_split=True,
transform=self.train_transforms, return_index=False)
self.class_weights: Optional[torch.Tensor] = None
if config.train.self_supervision.use_balanced_binary_loss_for_linear_head:
# Weight = inverse class proportion.
class_weights = len(self.train_dataset.targets) / np.bincount(self.train_dataset.targets)
# Normalized class weights
class_weights /= class_weights.sum()
self.class_weights = torch.tensor(class_weights)
@property
def num_classes(self) -> int:
return 2
def train_dataloader(self) -> DataLoader: # type: ignore
"""
Returns NIH training set (80% of total dataset)
"""
return torch.utils.data.DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=False,
worker_init_fn=WorkerInitFunc(self.config.train.seed),
drop_last=True)
def val_dataloader(self) -> DataLoader: # type: ignore
"""
Returns NIH validation set (20% of total dataset)
"""
val_dataset = NIHCXR(self.config.dataset.dataset_dir,
use_training_split=False, transform=self.val_transforms, return_index=False)
loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=False,
worker_init_fn=WorkerInitFunc(self.config.train.seed),
drop_last=True)
return loader
def test_dataloader(self) -> DataLoader: # type: ignore
"""
No NIH test split implemented
"""
raise NotImplementedError
def default_transforms(self) -> Callable:
transform = create_chest_xray_transform(self.config, is_train=False)
return transform

Просмотреть файл

@ -0,0 +1,82 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any, List
import torch
from InnerEyeDataQuality.datasets.cifar10h import TOTAL_CIFAR10H_DATASET_SIZE
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from torch.utils.data import DataLoader, random_split
from torchvision import transforms as transform_lib
class CIFAR10HDataModule(CIFAR10DataModule):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""
"""
super().__init__(*args, **kwargs)
self.num_samples = TOTAL_CIFAR10H_DATASET_SIZE
self.class_weights = None
def train_dataloader(self) -> DataLoader:
"""
CIFAR train set removes a subset to use for validation
"""
transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms
dataset = self.DATASET(self.data_dir, train=False, download=True, transform=transforms, **self.extra_args)
loader = DataLoader(dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True)
assert len(dataset) == TOTAL_CIFAR10H_DATASET_SIZE
return loader
def val_dataloader(self) -> DataLoader:
"""
CIFAR10 val set uses a subset of the training set for validation
"""
transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms
dataset = self.DATASET(self.data_dir, train=True, download=False, transform=transforms, **self.extra_args)
num_samples = len(dataset)
_, dataset_val = random_split(dataset,
[num_samples - self.val_split, self.val_split],
generator=torch.Generator().manual_seed(self.seed))
assert len(dataset_val) == self.val_split
loader = DataLoader(dataset_val,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True)
return loader
def test_dataloader(self) -> DataLoader:
"""
CIFAR10 test set uses the test split
"""
transforms = self.default_transforms() if self.test_transforms is None else self.test_transforms
dataset = self.DATASET(self.data_dir, train=True, download=False, transform=transforms, **self.extra_args)
num_samples = len(dataset)
dataset_test, _ = random_split(dataset,
[num_samples - self.val_split, self.val_split],
generator=torch.Generator().manual_seed(self.seed))
loader = DataLoader(dataset_test,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True)
return loader
def default_transforms(self) -> List[object]:
cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
return cf10_transforms

Просмотреть файл

@ -0,0 +1,37 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import multiprocessing
import torch
import pytorch_lightning as pl
from InnerEyeDataQuality.configs.config_node import ConfigNode
from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform
from .chestxray_datamodule import KaggleDataModule, NIHDataModule
from .cifar10h_datamodule import CIFAR10HDataModule
num_gpus = torch.cuda.device_count()
num_devices = num_gpus if num_gpus > 0 else 1
def create_ssl_data_modules(config: ConfigNode) -> pl.LightningDataModule:
"""
Returns torch lightining data module.
"""
num_workers = config.dataset.num_workers if config.dataset.num_workers else multiprocessing.cpu_count()
if config.dataset.name == "Kaggle":
dm = KaggleDataModule(config, num_devices=num_devices, num_workers=num_workers) # type: ignore
elif config.dataset.name == "NIH":
dm = NIHDataModule(config, num_devices=num_devices, num_workers=num_workers) # type: ignore
elif config.dataset.name == "CIFAR10H":
dm = CIFAR10HDataModule(num_workers=num_workers,
batch_size=config.train.batch_size // num_devices,
seed=1234)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)
else:
raise NotImplementedError(f"No pytorch data module implemented for dataset type: {config.dataset.name}")
return dm

Просмотреть файл

@ -0,0 +1,22 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import torch
from torchvision.models import densenet121
class DenseNet121Encoder(torch.nn.Module):
"""
This module creates a Densenet121 encoder i.e. Densenet121 model without
its classification head.
"""
def __init__(self) -> None:
super().__init__()
self.densenet121 = densenet121()
self.cnn_model = self.densenet121.features
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.cnn_model(x)

Просмотреть файл

@ -0,0 +1,118 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import argparse
import logging
import sys
from pathlib import Path
from typing import Dict, Union
import pytorch_lightning as pl
import torch
from default_paths import EXPERIMENT_DIR
from InnerEyeDataQuality.configs.config_node import ConfigNode
from InnerEyeDataQuality.deep_learning.self_supervised.byol.byol_module import BYOLInnerEye
from InnerEyeDataQuality.deep_learning.self_supervised.datamodules.utils import create_ssl_data_modules
from InnerEyeDataQuality.deep_learning.self_supervised.simclr_module import SimCLRInnerEye
from InnerEyeDataQuality.deep_learning.self_supervised.ssl_classifier_module import (SSLOnlineEvaluatorInnerEye,
get_encoder_output_dim)
from InnerEyeDataQuality.deep_learning.utils import load_ssl_model_config
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
num_gpus = torch.cuda.device_count()
num_devices = num_gpus if num_gpus > 0 else 1
trained_kwargs: Dict[str, Union[str, int, float]] = {"precision": 16} if num_gpus > 0 else {}
if num_gpus > 1: # If multi-gpu training update the parameters for DDP
trained_kwargs.update({"distributed_backend": "ddp", "sync_batchnorm": True})
def get_last_checkpoint_path(default_root_dir: str, model_version: str) -> str:
return str(Path(default_root_dir) / model_version / "checkpoints" / "last.ckpt")
def cli_main(config: ConfigNode, debug: bool = False) -> None:
"""
Runs self-supervised training on imaging data using contrastive loss.
Currently it supports only ``BYOL`` and ``SimCLR``.
:param config: A ssl_model config specifying training configurations (and augmentations).
Beware the augmentations parameters are ignored at the moment when using CIFAR10 as we
always use the default augmentations from PyTorch Lightning.
:param debug: If set to True, only runs training and validation on 1% of the data.
"""
# set seed
seed_everything(config.train.seed)
# self-supervision type
ssl_type = config.train.self_supervision.type
default_root_dir = EXPERIMENT_DIR / config.train.output_dir
model_version = f'{ssl_type}_seed_{config.train.seed}'
checkpoint_dir = str(default_root_dir / model_version / "checkpoints")
# Model checkpointing callback
checkpoint_callback = ModelCheckpoint(period=config.train.checkpoint_period, save_top_k=-1, dirpath=checkpoint_dir)
checkpoint_callback_last = ModelCheckpoint(save_last=True, dirpath=checkpoint_dir)
lr_logger = LearningRateMonitor()
tb_logger = pl.loggers.TensorBoardLogger(save_dir=str(default_root_dir),
version='logs',
name=model_version)
# Create SimCLR data modules and model
dm = create_ssl_data_modules(config)
if ssl_type == "simclr":
model = SimCLRInnerEye(num_samples=dm.num_samples, # type: ignore
batch_size=dm.batch_size, # type: ignore
lr=config.train.base_lr,
dataset_name=config.dataset.name,
encoder_name=config.train.self_supervision.encoder_name)
# Create BYOL model
else:
model = BYOLInnerEye(num_samples=dm.num_samples, # type: ignore
learning_rate=config.train.base_lr,
dataset_name=config.dataset.name,
encoder_name=config.train.self_supervision.encoder_name,
batch_size=dm.batch_size, # type: ignore
warmup_epochs=10)
model.hparams.update({'ssl_type': ssl_type})
# Online fine-tunning using an MLP
online_eval = SSLOnlineEvaluatorInnerEye(class_weights=dm.class_weights, # type: ignore
z_dim=get_encoder_output_dim(model, dm),
num_classes=dm.num_classes, # type: ignore
dataset=config.dataset.name,
drop_p=0.2) # type: ignore
# Load latest checkpoint
resume_from_last_checkpoint = get_last_checkpoint_path(default_root_dir, model_version) if \
config.train.resume_from_last_checkpoint else None
if debug:
overfit_batches = num_devices / (min(len(dm.val_dataloader()), len(dm.train_dataloader())) * 2.0)
trained_kwargs.update({"overfit_batches": overfit_batches})
# Create trainer and run training
trainer = pl.Trainer(gpus=num_gpus,
logger=tb_logger,
default_root_dir=str(default_root_dir),
benchmark=True,
max_epochs=config.scheduler.epochs,
callbacks=[lr_logger, online_eval, checkpoint_callback, checkpoint_callback_last],
resume_from_checkpoint=resume_from_last_checkpoint,
**trained_kwargs)
trainer.fit(model, dm) # type: ignore
if __name__ == "__main__":
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
parser = argparse.ArgumentParser(description='Train a self-supervised model')
parser.add_argument('--config', dest='config', type=str, required=True,
help='Path to config file characterising trained CNN model/s')
args, unknown_args = parser.parse_known_args()
config_path = args.config
config = load_ssl_model_config(config_path)
# Launch the script
cli_main(config)

Просмотреть файл

@ -0,0 +1,38 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import numpy as np
import torch
from pytorch_lightning.metrics import Metric
from pytorch_lightning.metrics.functional import auroc
class AreaUnderRocCurve(Metric):
"""
Computes the area under the receiver operating curve (ROC).
"""
def __init__(self) -> None:
super().__init__(dist_sync_on_step=False)
self.add_state("preds", default=[], dist_reduce_fx=None)
self.add_state("targets", default=[], dist_reduce_fx=None)
self.name = "auc"
def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: # type: ignore
assert preds.dim() == 2 and targets.dim() == 1 and \
preds.shape[1] == 2 and preds.shape[0] == targets.shape[
0], f"Expected 2-dim preds, 1-dim targets, but got: preds = {preds.shape}, targets = {targets.shape}"
self.preds.append(preds) # type: ignore
self.targets.append(targets) # type: ignore
def compute(self) -> torch.Tensor:
"""
Computes a metric from the stored predictions and targets.
"""
preds = torch.cat(self.preds) # type: ignore
targets = torch.cat(self.targets) # type: ignore
if torch.unique(targets).numel() == 1:
return torch.tensor(np.nan)
return auroc(preds[:, 1], targets)

Просмотреть файл

@ -0,0 +1,43 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from InnerEyeDataQuality.deep_learning.self_supervised.byol.byol_models import SSLEncoder
from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR
class _Projection(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int) -> None:
super().__init__()
self.output_dim = output_dim
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.model = nn.Sequential(
nn.Linear(self.input_dim, self.hidden_dim, bias=True),
nn.BatchNorm1d(self.hidden_dim),
nn.ReLU(),
nn.Linear(self.hidden_dim, self.output_dim, bias=False))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.model(x)
return F.normalize(x, dim=1)
class SimCLRInnerEye(SimCLR):
def __init__(self, encoder_name: str, dataset_name: str, **kwargs: Any) -> None:
"""
Args:
encoder_name [str]: Image encoder name (predefined models)
dataset_name [str]: Image dataset name (e.g. cifar10, kaggle, etc.)
"""
super().__init__(**kwargs)
self.save_hyperparameters()
self.encoder = SSLEncoder(encoder_name, dataset_name, use_output_pooling=True)
self.projection = _Projection(input_dim=self.encoder.get_output_feature_dim(), hidden_dim=2048, output_dim=128)

Просмотреть файл

@ -0,0 +1,194 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any, List, Optional, Tuple, Union
import pytorch_lightning as pl
import torch
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator
from torch import Tensor as T
from torch.nn import ModuleList, functional as F
from InnerEyeDataQuality.deep_learning.self_supervised.metrics import AreaUnderRocCurve
from pytorch_lightning.metrics import Accuracy
BatchType = Tuple[List, T]
class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
def __init__(self, class_weights: Optional[torch.Tensor] = None, **kwargs: Any) -> None:
"""
Creates a hook to evaluate a linear model on top of an SSL embedding.
:param class_weights: The class weights to use when computing the cross entropy loss. If set to None,
no weighting will be done.
"""
super().__init__(**kwargs)
self.training_step = int(0)
self.weight_decay = 1e-4
self.learning_rate = 1e-4
self.train_metrics = ModuleList([AreaUnderRocCurve(), Accuracy()]) if self.num_classes == 2 else ModuleList([Accuracy()])
self.val_metrics = ModuleList([AreaUnderRocCurve(), Accuracy()]) if self.num_classes == 2 else ModuleList([Accuracy()])
self.class_weights = class_weights
def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
# Move metrics and class weights to module device
for metric in [*self.train_metrics, *self.val_metrics]:
metric.to(device=pl_module.device) # type: ignore
if self.class_weights is not None:
self.class_weights = self.class_weights.float().to(device=pl_module.device)
pl_module.non_linear_evaluator = SSLEvaluator(n_input=self.z_dim,
n_classes=self.num_classes,
p=self.drop_p,
n_hidden=self.hidden_dim).to(pl_module.device)
assert isinstance(pl_module.non_linear_evaluator, torch.nn.Module)
self.optimizer = torch.optim.Adam(pl_module.non_linear_evaluator.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay)
# Use only one of the transformed images
@staticmethod
def to_device(batch: BatchType, device: Union[str, torch.device]) -> Tuple[T, T]:
(x1, x2), y = batch
x1 = x1.to(device)
y = y.to(device)
return x1, y
def shared_step(self, batch: BatchType, pl_module: pl.LightningModule, is_training: bool) -> T:
x, y = self.to_device(batch, pl_module.device)
with torch.no_grad():
representations = self.get_representations(pl_module, x)
representations = representations.detach()
assert isinstance(pl_module.non_linear_evaluator, torch.nn.Module)
# Run the linear-head with SSL embeddings.
mlp_preds = pl_module.non_linear_evaluator(representations)
mlp_loss = F.cross_entropy(mlp_preds, y, weight=self.class_weights)
with torch.no_grad():
posteriors = F.softmax(mlp_preds, dim=-1)
for metric in (self.train_metrics if is_training else self.val_metrics):
metric(posteriors, y)
return mlp_loss
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): # type: ignore
loss = self.shared_step(batch, pl_module, is_training=False)
# Log classification metrics
pl_module.log('ssl/online_val_loss', loss, on_step=False, on_epoch=True, sync_dist=False)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: # type: ignore
logger = trainer.logger.experiment # type: ignore
loss = self.shared_step(batch, pl_module, is_training=True)
# update finetune weights
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
self.training_step += 1
# log metrics
logger.add_scalar('ssl/online_train_loss', loss, global_step=self.training_step)
class SSLClassifier(torch.nn.Module):
"""
SSL Image classifier that combines pre-trained SSL encoder with a trainable linear-head.
"""
def __init__(self, num_classes: int, encoder: torch.nn.Module, projection: torch.nn.Module):
super().__init__()
self.encoder = encoder
self.projection = projection
self.encoder.eval(), self.projection.eval()
self.classifier_head = SSLEvaluator(n_input=get_encoder_output_dim(self.encoder),
n_hidden=None,
n_classes=num_classes,
p=0.20)
def train(self, mode: bool = True) -> Any:
self.training = mode
self.encoder.train(False)
self.projection.train(False)
self.classifier_head.train(mode)
return self
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert isinstance(self.encoder.avgpool, torch.nn.Module)
with torch.no_grad():
# Generate representations
repr = self.encoder(x)
# Generate image embeddings
self.projection(repr)
# Generate class logits
agg_repr = self.encoder.avgpool(repr) if repr.ndim > 2 else repr
agg_repr = agg_repr.reshape(agg_repr.size(0), -1).detach()
return self.classifier_head(agg_repr)
class PretrainedClassifier(torch.nn.Module):
def __init__(self, num_classes: int, encoder: torch.nn.Module):
super().__init__()
self.encoder = encoder
self.classifier_head = torch.nn.Linear(get_encoder_output_dim(self.encoder), num_classes)
def train(self, mode: bool = True) -> Any:
self.classifier_head.train(mode)
self.encoder.train(mode)
return self
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Generate representations
repr = self.encoder(x)
# Generate class logits
agg_repr = self.encoder.avgpool(repr) if repr.ndim > 2 else repr # type: ignore
agg_repr = agg_repr.reshape(agg_repr.size(0), -1)
return self.classifier_head(agg_repr)
def get_encoder_output_dim(pl_module: Union[pl.LightningModule, torch.nn.Module],
dm: Optional[pl.LightningDataModule] = None) -> int:
"""
Calculates the output dimension of ssl encoder by making a single forward pass.
:param pl_module: pl encoder module
:param dm: pl datamodule
"""
# Target device
device = pl_module.device if isinstance(pl_module, pl.LightningDataModule) else \
next(pl_module.parameters()).device # type: ignore
assert (isinstance(device, torch.device))
# Create a dummy input image
if dm is not None:
batch = iter(dm.train_dataloader()).next() # type: ignore
x, _ = SSLOnlineEvaluatorInnerEye.to_device(batch, device)
else:
x = torch.rand((1, 3, 256, 256)).to(device)
# Extract the number of output feature dimensions
with torch.no_grad():
representations = pl_module(x)
return representations.shape[1]
def WrapSSL(ssl_class: Any, num_classes: int) -> Any:
"""
Wraps a given SSL encoder and adds a non-linear evaluator to it. This is done to load pre-trained SSL checkpoints.
PL requires non_linear_evaluator to be included in pl_module at SSL training time.
:param num_classes: Number of target classes for the linear head.
:param ssl_class: SSL object either BYOL or SimCLR.
"""
class _wrap(ssl_class): # type: ignore
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.non_linear_evaluator = SSLEvaluator(n_input=get_encoder_output_dim(self),
n_classes=num_classes,
n_hidden=None)
return _wrap

Просмотреть файл

@ -0,0 +1,63 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Optional
import logging
import torch
from InnerEyeDataQuality.deep_learning.self_supervised.ssl_classifier_module import PretrainedClassifier, SSLClassifier, \
WrapSSL
from pl_bolts.models.self_supervised.resnets import resnet18, resnet50_bn, resnet101
def create_ssl_encoder(encoder_name: str, dataset_name: Optional[str] = None) -> torch.nn.Module:
"""
"""
if encoder_name == 'resnet18':
encoder = resnet18(return_all_feature_maps=False)
elif encoder_name == 'resnet50':
encoder = resnet50_bn(return_all_feature_maps=False)
elif encoder_name == 'resnet101':
encoder = resnet101(return_all_feature_maps=False)
else:
raise ValueError("Unknown model type")
if dataset_name is not None:
if dataset_name in ["CIFAR10", "CIFAR10H"]:
logging.info("Updating the initial convolution in order not to shrink CIFAR10 images")
encoder.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
return encoder
def create_ssl_image_classifier(num_classes: int, pl_checkpoint_path: str, freeze_encoder: bool = True) -> torch.nn.Module:
from InnerEyeDataQuality.deep_learning.self_supervised.byol.byol_module import BYOLInnerEye
from InnerEyeDataQuality.deep_learning.self_supervised.simclr_module import SimCLRInnerEye
"""
"""
ssl_type = torch.load(pl_checkpoint_path, map_location=lambda storage, loc: storage)["hyper_parameters"]["ssl_type"]
logging.info(f"Creating a {ssl_type} based image classifier")
logging.info(f"Loading pretrained {ssl_type} weights from:\n {pl_checkpoint_path}")
if ssl_type == "byol":
byol_module = WrapSSL(BYOLInnerEye, num_classes).load_from_checkpoint(pl_checkpoint_path, strict=False)
if freeze_encoder:
model = SSLClassifier(num_classes=num_classes, encoder=byol_module.target_network.encoder,
projection=byol_module.target_network.projector_normalised)
else:
model = PretrainedClassifier(num_classes=num_classes, # type: ignore
encoder=byol_module.target_network.encoder)
elif ssl_type == "simclr":
simclr_module = WrapSSL(SimCLRInnerEye, num_classes).load_from_checkpoint(pl_checkpoint_path, strict=False)
if freeze_encoder:
model = SSLClassifier(num_classes=num_classes, encoder=simclr_module.encoder,
projection=simclr_module.projection)
else:
model = PretrainedClassifier(num_classes=num_classes, # type: ignore
encoder=simclr_module.encoder)
else:
raise NotImplementedError(f"Unknown unsupervised model: {ssl_type}")
return model

Просмотреть файл

@ -0,0 +1,51 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import argparse
import logging
import os
from InnerEyeDataQuality.configs.config_node import ConfigNode
from InnerEyeDataQuality.deep_learning.utils import create_logger, get_run_config, load_model_config
from InnerEyeDataQuality.deep_learning.trainers.co_teaching_trainer import CoTeachingTrainer
from InnerEyeDataQuality.deep_learning.trainers.elr_trainer import ELRTrainer
from InnerEyeDataQuality.deep_learning.trainers.vanilla_trainer import VanillaTrainer
from InnerEyeDataQuality.utils.generic import set_seed
def train(config: ConfigNode) -> None:
create_logger(config.train.output_dir)
logging.info('Starting training...')
if config.train.use_co_teaching and config.train.use_elr:
raise ValueError("You asked for co-teaching and ELR at the same time. Please double check your configuration.")
if config.train.use_co_teaching:
model_trainer_class = CoTeachingTrainer
elif config.train.use_elr:
model_trainer_class = ELRTrainer # type: ignore
else:
model_trainer_class = VanillaTrainer # type: ignore
model_trainer_class(config).run_training()
def train_ensemble(config: ConfigNode, num_runs: int) -> None:
for i, _ in enumerate(range(num_runs)):
config_run = get_run_config(config, config.train.seed + i)
set_seed(config_run.train.seed)
os.makedirs(config_run.train.output_dir, exist_ok=True)
train(config_run)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Parser for model training.')
parser.add_argument('--config', type=str, required=True,
help='Path to config file characterising trained CNN model/s')
parser.add_argument('--num_runs', type=int, default=1, help='Number of runs (ensemble)')
args, unknown_args = parser.parse_known_args()
# Load config
config = load_model_config(args.config)
# Launch training
train_ensemble(config, args.num_runs)

Просмотреть файл

@ -0,0 +1,5 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

Просмотреть файл

@ -0,0 +1,203 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any, List, Optional, Tuple
import torch
from torch.utils.data.dataloader import DataLoader
from InnerEyeDataQuality.configs.config_node import ConfigNode
from InnerEyeDataQuality.deep_learning.architectures.ema import EMA
from InnerEyeDataQuality.deep_learning.collect_embeddings import get_all_embeddings
from InnerEyeDataQuality.deep_learning.graph.classifier import GraphClassifier
from InnerEyeDataQuality.deep_learning.loss import CrossEntropyLoss, consistency_loss
from InnerEyeDataQuality.deep_learning.trainers.model_trainer_base import IndexContainer, Loss, ModelTrainer
from InnerEyeDataQuality.deep_learning.scheduler import ForgetRateScheduler
from InnerEyeDataQuality.deep_learning.utils import create_model
from InnerEyeDataQuality.utils.generic import find_union_set_torch, map_to_device
class CoTeachingTrainer(ModelTrainer):
"""
Implements co-teaching training using two models
"""
def __init__(self, config: ConfigNode):
self.use_teacher_model = config.train.use_teacher_model
super().__init__(config)
self.forget_rate_scheduler = ForgetRateScheduler(
config.scheduler.epochs,
forget_rate=config.train.co_teaching_forget_rate,
num_gradual=config.train.co_teaching_num_gradual,
start_epoch=config.train.resume_epoch if config.train.resume_epoch > 0 else 0,
num_warmup_epochs=config.train.co_teaching_num_warmup)
self.joint_metric_tracker = self.train_trackers[0]
self.use_consistency_loss = config.train.co_teaching_consistency_loss
self.use_graph = config.train.co_teaching_use_graph
self.consistency_loss_weight = 0.10
self.num_models = len(self.models)
self.loss_fn = CrossEntropyLoss(config)
self.ema_models = [EMA(self.models[0]), EMA(self.models[1])] if config.train.use_teacher_model else None
self.graph_classifiers = [GraphClassifier(num_samples=len(self.train_loader.dataset), # type: ignore
num_classes=config.dataset.n_classes,
labels=self.train_loader.dataset.targets, # type: ignore
device=config.device)
for _ in range(2)]
# Create two models for co-teaching
def get_models(self, config: ConfigNode) -> List[torch.nn.Module]:
"""
:param config: The job config
:return: A list of two models to be trained
"""
return [create_model(config, model_id=0), create_model(config, model_id=1)]
def update_teacher_models(self) -> None:
if self.ema_models:
for i in range(len(self.models)):
self.ema_models[i].update()
def deploy_teacher_models(self, inputs: torch.Tensor) -> Optional[List[torch.Tensor]]:
if not self.ema_models:
return None
elif isinstance(inputs, list):
return [_m.inference(_i) for _m, _i in zip(self.ema_models, inputs)]
else:
return [_m.inference(inputs) for _m in self.ema_models]
@torch.no_grad()
def _get_samples_for_update(self,
outputs: List[torch.Tensor],
labels: torch.Tensor,
global_indices: torch.Tensor,
teacher_logits: Optional[List[torch.Tensor]]) -> Tuple[IndexContainer, IndexContainer]:
"""
Return a list of indices that should be kept for gradient updates.
"""
def _get_small_loss_sample_ids(logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
num_samples = labels.shape[0]
num_remember = int((1. - self.forget_rate_scheduler.get_forget_rate) * num_samples)
per_sample_loss = self.loss_fn(logits, labels, reduction='none')
ind_sorted = torch.argsort(per_sample_loss)
return ind_sorted[:num_remember], ind_sorted[num_remember:]
judge = lambda i: teacher_logits[i] if self.use_teacher_model and teacher_logits else outputs[i]
ind_1_keep, ind_1_exc = _get_small_loss_sample_ids(judge(0))
ind_0_keep, ind_0_exc = _get_small_loss_sample_ids(judge(1))
# Use graph based classifier
if self.graph_classifiers[0].graph is not None:
ind_0_keep, ind_0_exc = self.graph_classifiers[0].filter_cases(ind_0_keep, ind_0_exc, global_indices)
ind_1_keep, ind_1_exc = self.graph_classifiers[1].filter_cases(ind_1_keep, ind_1_exc, global_indices)
return IndexContainer(ind_0_keep, ind_0_exc), IndexContainer(ind_1_keep, ind_1_exc)
# Co-teaching loss
def compute_loss(self,
outputs: List[torch.Tensor],
labels: torch.Tensor,
indices: Optional[Tuple[IndexContainer, IndexContainer]] = None,
**kwargs: Any) -> List[Loss]:
"""
Implements the co-teaching loss using the outputs of two different models
:param outputs: A list of logits outputed by each model
:param labels: The target labels
:param indices: Sample indices that should be kept and excluded in loss computation (for both models).
:return: A list of Loss object, each element contains the loss that is fed to the optimizer and a
tensor of per sample losses
"""
options = {'ema_logits': None}
options.update(kwargs)
ema_logits: Optional[List[torch.Tensor]] = options['ema_logits']
loss_obj = list()
indices: Tuple[IndexContainer, IndexContainer] = [ # type: ignore
IndexContainer(keep=torch.arange(labels.shape[0]), exclude=torch.tensor([], dtype=torch.long)) for _ in
range(len(outputs))] if indices is None else indices
assert indices is not None # for mypy
assert len(outputs) == len(indices) == 2
for _output, _index in zip(outputs, indices):
# The indices to keep for each model is determined by the loss of the other.
per_sample_loss = self.loss_fn(_output, labels, reduction='none')
if self.weight is not None:
per_sample_loss *= self.weight[labels]
loss_update = torch.mean(per_sample_loss[_index.keep])
loss_obj.append(Loss(per_sample_loss, loss_update))
# Consistency loss between predictions on noisy samples
if self.use_consistency_loss and ema_logits:
joint_excluded = find_union_set_torch(indices[0].exclude, indices[1].exclude)
c_loss0 = consistency_loss(outputs[0][joint_excluded], ema_logits[0][joint_excluded])
c_loss1 = consistency_loss(outputs[1][joint_excluded], ema_logits[1][joint_excluded])
loss_obj[0].loss += self.consistency_loss_weight * c_loss0
loss_obj[1].loss += self.consistency_loss_weight * c_loss1
return loss_obj
def run_epoch(self, dataloader: DataLoader, epoch: int, is_train: bool = False) -> None:
"""
Run a training or validation epoch of the base model trainer but also step the forget rate scheduler
:param dataloader: A dataloader object.
:param epoch: Current epoch id.
:param is_train: Whether this is a training epoch or not.
:param run_inference_on_training_set: If True, record all metrics using the train_trackers
(even if is_train = False)
:return:
"""
for model in self.models:
model.train() if is_train else model.eval()
trackers = self.train_trackers if is_train else self.val_trackers
# Consume input dataloader and update model
for indices, images, labels in dataloader:
images, labels = map_to_device(images, self.device), labels.to(self.device)
outputs = self.forward(images, requires_grad=is_train)
ema_logits = self.deploy_teacher_models(images)
selected_ind = self._get_samples_for_update(outputs, labels, indices, ema_logits) if is_train else None
losses = self.compute_loss(outputs, labels, selected_ind, ema_logits=ema_logits)
assert (len(outputs) == len(losses)) & (len(outputs) == 2)
if is_train:
assert selected_ind is not None
self.step_optimizers(losses)
self.update_teacher_models()
self.joint_metric_tracker.append_batch_aggregate(epoch=epoch,
logits_x=outputs[0].detach(),
logits_y=outputs[1].detach(),
dropped_cases=indices[selected_ind[0].exclude],
indices=indices)
# Collect model embeddings
embeddings = get_all_embeddings(self.all_model_cnn_embeddings)
# Log training and validation stats in metric tracker
for i, (logits, loss) in enumerate(zip(outputs, losses)):
teacher_logits = ema_logits[i].detach() if self.ema_models else None # type: ignore
trackers[i].sample_metrics.append_batch(epoch=epoch,
logits=logits.detach(),
labels=labels.detach(),
loss=loss.loss.item(),
indices=indices.tolist(),
per_sample_loss=loss.per_sample_loss.detach(),
embeddings=embeddings[i],
teacher_logits=teacher_logits)
# Adjust forget rate for co-teaching
if is_train:
self.forget_rate_scheduler.step()
if self.use_graph:
self.graph_classifiers[0].build_graph(embeddings=trackers[1].sample_metrics.embeddings_per_sample)
self.graph_classifiers[1].build_graph(embeddings=trackers[0].sample_metrics.embeddings_per_sample)
def save_checkpoint(self, epoch: int) -> bool:
is_save = super().save_checkpoint(epoch=epoch)
if is_save and self.ema_models:
is_last_epoch = epoch == self.config.scheduler.epochs - 1
suffix = '_last_epoch.pt' if is_last_epoch else f'_epoch_{epoch:d}.pt'
for i, ema_model in enumerate(self.ema_models):
save_path = str(self.checkpoint_dir / f'checkpoint_ema_model_{i:d}') + suffix
ema_model.save_model(save_path)
return is_save

Просмотреть файл

@ -0,0 +1,96 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import List
import torch
from torch import nn
import torch.nn.functional as F
from torch.types import Device
from torch.utils.data import DataLoader
from InnerEyeDataQuality.configs.config_node import ConfigNode
from InnerEyeDataQuality.deep_learning.collect_embeddings import get_all_embeddings
from InnerEyeDataQuality.deep_learning.loss import CrossEntropyLoss
from InnerEyeDataQuality.deep_learning.trainers.model_trainer_base import Loss
from InnerEyeDataQuality.deep_learning.trainers.vanilla_trainer import VanillaTrainer
class ELRTrainer(VanillaTrainer):
def __init__(self, config: ConfigNode):
super().__init__(config)
self.num_classes = config.dataset.n_classes
self.loss_fn = {"TRAIN": ELRLoss(num_examples=self.train_trackers[0].num_samples_total,
num_classes=config.dataset.n_classes,
device=self.device),
"VAL": CrossEntropyLoss(config)}
def compute_loss(self, is_train: bool, outputs: List[torch.Tensor], labels: torch.Tensor,
indices: torch.Tensor = None) -> Loss:
"""
Implements the standard cross-entropy loss using one model
:param outputs: A list of logits outputed by each model
:param labels: The target labels
:return: A list of Loss object, each element contains the loss that is fed to the optimizer and a
tensor of per sample losses
"""
logits = outputs[0]
if is_train:
per_sample_loss = self.loss_fn["TRAIN"](predictions=logits, targets=labels, indices=indices) # type: ignore
else:
per_sample_loss = self.loss_fn["VAL"](predictions=logits, targets=labels, reduction='none') # type: ignore
loss = torch.mean(per_sample_loss)
return Loss(per_sample_loss, loss)
def run_epoch(self, dataloader: DataLoader, epoch: int, is_train: bool = False) -> None:
"""
Run a training or validation epoch of the base model trainer but also step the forget rate scheduler
:param dataloader: A dataloader object
:param epoch: Current epoch id.
:param is_train: Whether this is a training epoch or not
:param run_inference_on_training_set: If True, record all metrics using the train_trackers
(even if is_train = False)
:return:
"""
for model in self.models:
model.train() if is_train else model.eval()
for indices, images, labels in dataloader:
images, labels = images.to(self.device), labels.to(self.device)
outputs = self.forward(images, requires_grad=is_train)
embeddings = get_all_embeddings(self.all_model_cnn_embeddings)[0]
losses = self.compute_loss(is_train, outputs, labels, indices)
if is_train:
self.step_optimizers([losses])
# Log training and validation stats in metric tracker
tracker = self.train_trackers[0] if is_train else self.val_trackers[0]
tracker.sample_metrics.append_batch(epoch, outputs[0].detach(), labels.detach(), losses.loss.item(),
indices.cpu().tolist(), losses.per_sample_loss.detach(), embeddings)
class ELRLoss(nn.Module):
"""
Adapted from https://github.com/shengliu66/ELR.
"""
def __init__(self, num_examples: int, num_classes: int, device: Device, beta: float = 0.9, _lambda: float = 3):
super().__init__()
self.num_classes = num_classes
self.targets = torch.zeros(num_examples, self.num_classes, device=device)
self.beta = beta
self._lambda = _lambda
def forward(self, predictions: torch.Tensor, targets: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
y_pred = F.softmax(predictions, dim=1)
y_pred = torch.clamp(y_pred, 1e-4, 1.0 - 1e-4)
y_pred_ = y_pred.data.detach()
self.targets[indices] = self.beta * self.targets[indices] + (1 - self.beta) * (
(y_pred_) / (y_pred_).sum(dim=1, keepdim=True))
ce_loss = F.cross_entropy(predictions, targets, reduction="none")
elr_reg = (1 - (self.targets[indices] * y_pred).sum(dim=1)).log()
per_sample_loss = ce_loss + self._lambda * elr_reg
return per_sample_loss

Просмотреть файл

@ -0,0 +1,284 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, List, Optional, Tuple, Union
import torch
from torch.utils.data.dataloader import DataLoader
from InnerEyeDataQuality.configs.config_node import ConfigNode
from InnerEyeDataQuality.deep_learning.architectures.ema import EMA
from InnerEyeDataQuality.deep_learning.collect_embeddings import register_embeddings_collector
from InnerEyeDataQuality.deep_learning.dataloader import (get_number_of_samples_per_epoch, get_train_dataloader,
get_val_dataloader)
from InnerEyeDataQuality.deep_learning.metrics.tracker import MetricTracker
from InnerEyeDataQuality.utils.dataset_utils import get_datasets
from PyTorchImageClassification.optim import create_optimizer
from PyTorchImageClassification.scheduler import create_scheduler
@dataclass(frozen=True)
class IndexContainer:
keep: torch.Tensor
exclude: torch.Tensor
@dataclass
class Loss:
per_sample_loss: torch.Tensor # the loss value computed for each sample without any modifications
loss: torch.Tensor # the loss that is fed to the optimizer; can be derived from per_sample loss given some rule
class ModelTrainer(object):
"""
ModelTrainer class handles the training of models, logging, saving checkpoints and validation. This class trains one
or more similar models at a time each having its own optimizer but which can interact in the loss function.
This is an abstract class for which some methods are left not implemented;
child classes must implement these methods
"""
def __init__(self, config: ConfigNode) -> None:
self.config = config
self.checkpoint_dir = Path(self.config.train.output_dir) / 'checkpoints'
self.log_dir = Path(self.config.train.output_dir) / 'logs'
self.seed = config.train.seed
self.device = torch.device(config.device)
train_dataset, val_dataset = get_datasets(config)
self.weight = torch.tensor([train_dataset.weight, (1-train_dataset.weight)], # type: ignore
device=self.device, dtype=torch.float) if hasattr(train_dataset, "weight") else None
self.train_loader = get_train_dataloader(train_dataset, config, seed=self.seed,
drop_last=config.train.dataloader.drop_last, shuffle=True)
self.val_loader = get_val_dataloader(val_dataset, config, seed=self.seed)
self.models = self.get_models(config)
self.ema_models: Optional[List[EMA]] = None
self.schedulers = [create_scheduler(config, create_optimizer(config, model), len(self.train_loader))
for model in self.models]
self.train_trackers, self.val_trackers = self._create_metric_trackers(config)
self.all_trackers = self.train_trackers + self.val_trackers
self.all_model_cnn_embeddings = register_embeddings_collector(self.models, use_only_in_train=True)
def _create_metric_trackers(self, config: ConfigNode) -> Tuple[List[MetricTracker], List[MetricTracker]]:
"""
Creates metric trackers used at model training and validation.
"""
train_loader = self.train_loader
val_loader = self.val_loader
num_models = len(self.models)
if hasattr(train_loader.dataset, "ambiguity_metric_args"):
ambiguity_metric_args = train_loader.dataset.ambiguity_metric_args # type: ignore
else:
ambiguity_metric_args = dict()
save_tf_events = config.tensorboard.save_events if hasattr(config.tensorboard, 'save_events') else True
metric_kwargs = {"num_epochs": config.scheduler.epochs,
"num_classes": config.dataset.n_classes,
"save_tf_events": save_tf_events}
# A dataset without augmentations and not normalized
dataset_train, dataset_val = get_datasets(config)
train_trackers = [MetricTracker(dataset=dataset_train,
output_dir=str(self.log_dir / f'model_{i:d}_train'),
num_samples_total=len(train_loader.dataset), # type: ignore
num_samples_per_epoch=get_number_of_samples_per_epoch(train_loader),
name=f"model_{i}_train",
**{**metric_kwargs, **ambiguity_metric_args})
for i in range(num_models)]
val_trackers = [MetricTracker(dataset=dataset_val,
output_dir=str(self.log_dir / f'model_{i:d}_val'),
num_samples_total=len(val_loader.dataset), # type: ignore
num_samples_per_epoch=get_number_of_samples_per_epoch(val_loader),
name=f"model_{i}_valid",
**metric_kwargs) for i in range(num_models)]
return train_trackers, val_trackers
def get_models(self, config: ConfigNode) -> List[torch.nn.Module]:
"""
:param config: The job config
:return: A list of models to be trained
"""
raise NotImplementedError
def forward(self, images: torch.Tensor, requires_grad: bool = True) -> List[torch.Tensor]:
"""
Performs the forward pass for all models in the ModelTrainer class
:param images: The input images
:param requires_grad: Flag to indicate if forward pass required .grad attributes
:return: A list of the logits for each model
"""
def _forward(inputs: Union[List[torch.Tensor], torch.Tensor]) -> List[torch.Tensor]:
if isinstance(inputs, list):
return [_model(_input) for _model, _input in zip(self.models, inputs)]
else:
return [_model(inputs) for _model in self.models]
@torch.no_grad()
def _forward_inference_only(inputs: Union[List[torch.Tensor], torch.Tensor]) -> List[torch.Tensor]:
return _forward(inputs)
if requires_grad:
return _forward(images)
else:
return _forward_inference_only(images)
def compute_loss(self, outputs: List[torch.Tensor], labels: torch.Tensor,
indices: Optional[Tuple[IndexContainer, IndexContainer]] = None) -> Union[List[Loss], Loss]:
"""
Compute the losses that will be optimized
:param outputs: A list of logits outputed by each model
:param labels: The target labels
:return: A list of Loss object, each element contains the loss that is fed to the optimizer and a
tensor of per sample losses
"""
raise NotImplementedError
def step_optimizers(self, losses: List[Loss]) -> None:
"""
Take an optimizer step for every model's optimizer
:param losses: A list of Loss objects
:return:
"""
for loss, scheduler in zip(losses, self.schedulers):
scheduler.optimizer.zero_grad()
loss.loss.backward()
scheduler.optimizer.step()
scheduler.step()
def run_epoch(self, dataloader: DataLoader, epoch: int, is_train: bool = False) -> None:
"""
Run a training or validation epoch
:param dataloader: A dataloader object
:param epoch: Current training epoch id
:param is_train: Whether this is a training epoch or not
:return:
"""
raise NotImplementedError
def save_checkpoint(self, epoch: int) -> bool:
"""
Save checkpoints for the models for the current epoch
:param epoch: The current epoch
:return: Save success
"""
is_last_epoch = epoch == self.config.scheduler.epochs - 1
is_save = is_last_epoch or (epoch > 0 and epoch % self.config.train.checkpoint_period == 0)
if is_save:
logging.info(f"Saving model checkpoints, epoch {epoch}")
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
for ii in range(len(self.models)):
path = str(self.checkpoint_dir / f'checkpoint_model_{ii:d}')
full_save_name = path + '_last_epoch.pt' if is_last_epoch else path + f'_epoch_{epoch:d}.pt'
state = {'epoch': epoch,
'model': self.models[ii].state_dict(),
'scheduler': self.schedulers[ii].state_dict(),
'optimizer': self.schedulers[ii].optimizer.state_dict()}
torch.save(state, full_save_name)
return is_save
def load_checkpoints(self, restore_scheduler: bool, epoch: Optional[int] = None) -> None:
"""
If epoch is not specified, latest checkpoint files are loaded to restore the state
of model weights and optimisers
:param restore_scheduler (bool): Restores the state of optimiser and scheduler from checkpoint
:param epoch (int): Training epoch id.
"""
suffix = f'_epoch_{epoch:d}.pt' if epoch else '_last_epoch.pt'
for ii in range(len(self.models)):
path = str(self.checkpoint_dir / f'checkpoint_model_{ii:d}') + suffix
logging.info(f"Loading model-{ii} from checkpoint:\n {path}")
state = torch.load(str(path))
self.models[ii].load_state_dict(state['model'])
if restore_scheduler:
self.schedulers[ii].load_state_dict(state['scheduler'])
self.schedulers[ii].optimizer.load_state_dict(state['optimizer'])
if epoch is not None:
logging.info(f"Model is loaded from epoch: {epoch}")
assert state['epoch'] == epoch
if self.ema_models:
for ii in range(len(self.ema_models)):
path = str(self.checkpoint_dir / f'checkpoint_ema_model_{ii:d}') + suffix
logging.info(f"Loading ema teacher model-{ii} from checkpoint:\n {path}")
self.ema_models[ii].restore_from_checkpoint(path)
def run_training(self) -> None:
"""
Perform model training.
Model/s specified in config are trained for `num_epoch` epochs and results are stored in tf events.
"""
num_epochs = self.config.scheduler.epochs
epoch_range = range(num_epochs)
if self.config.train.resume_epoch > 0:
resume_epoch = self.config.train.resume_epoch
epoch_range = range(resume_epoch + 1, num_epochs)
self.load_checkpoints(restore_scheduler=self.config.train.restore_scheduler, epoch=resume_epoch)
# Model evaluation - startup
logging.info("Running evaluation on the validation set before training ...")
self.run_epoch(self.val_loader, is_train=False, epoch=0)
self.val_trackers[0].log_epoch_and_reset(epoch=0)
# Model training loop
for epoch in epoch_range:
epoch_start = time.time()
logging.info('\n' + f'Epoch {epoch:d}')
self.run_epoch(self.train_loader, is_train=True, epoch=epoch)
self.run_epoch(self.val_loader, is_train=False, epoch=epoch)
self.save_checkpoint(epoch)
for _t in self.all_trackers:
_t.log_epoch_and_reset(epoch)
logging.info('Epoch time: {0:.2f} secs'.format(time.time() - epoch_start))
# Store loss values for post-training analysis
for _t in self.all_trackers:
_t.save_loss()
def run_inference(self, dataloader: Any, use_mc_sampling: bool = False) -> List[MetricTracker]:
"""
Deployment of pre-trained model.
"""
dataset = dataloader.dataset
trackers = [MetricTracker(os.path.join(self.config.train.output_dir, f'model_{i:d}_inference'),
num_samples_total=len(dataset),
num_samples_per_epoch=len(dataset),
name=f"model_{i}_inference",
num_epochs=1,
num_classes=self.config.dataset.n_classes,
save_tf_events=False) for i in range(len(self.models))]
for model in self.models:
model.eval()
if use_mc_sampling:
logging.info("Applying MC sampling at inference time.")
dropout_layers = [layer for layer in model.modules() if isinstance(layer, torch.nn.Dropout)]
for _layer in dropout_layers:
_layer.training = True
with torch.no_grad():
for indices, images, labels in dataloader:
images, labels = images.to(self.device), labels.to(self.device)
outputs = self.forward(images, requires_grad=False)
losses = self.compute_loss(outputs, labels, indices=None)
if not isinstance(losses, List):
losses = [losses]
# Log training and validation stats in metric tracker
for i, (logits, loss) in enumerate(zip(outputs, losses)):
trackers[i].sample_metrics.append_batch(epoch=0,
logits=logits.detach(),
labels=labels.detach(),
loss=loss.loss.item(),
indices=indices.cpu().tolist(),
per_sample_loss=loss.per_sample_loss.detach())
return trackers

Просмотреть файл

@ -0,0 +1,80 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import List, Optional, Tuple
import torch
from InnerEyeDataQuality.configs.config_node import ConfigNode
from InnerEyeDataQuality.deep_learning.collect_embeddings import get_all_embeddings
from InnerEyeDataQuality.deep_learning.trainers.model_trainer_base import IndexContainer, Loss, ModelTrainer
from InnerEyeDataQuality.deep_learning.utils import create_model
from InnerEyeDataQuality.deep_learning.loss import tanh_loss
from torch.utils.data.dataloader import DataLoader
from InnerEyeDataQuality.deep_learning.loss import CrossEntropyLoss
class VanillaTrainer(ModelTrainer):
"""
Implements vanilla cross entropy training with one model
"""
def __init__(self, config: ConfigNode):
super().__init__(config)
self.tanh_regularisation = config.train.tanh_regularisation
self.loss_fn = CrossEntropyLoss(config)
def get_models(self, config: ConfigNode) -> List[torch.nn.Module]:
"""
:param config: The job config
:return: A list with one model to be trained
"""
return [create_model(config, model_id=0)]
def compute_loss(self, outputs: List[torch.Tensor], labels: torch.Tensor,
indices: Optional[Tuple[IndexContainer, IndexContainer]] = None) -> Loss:
"""
Implements the standard cross-entropy loss using one model
:param outputs: A list of logits outputed by each model
:param labels: The target labels
:return: A list of Loss object, each element contains the loss that is fed to the optimizer and a
tensor of per sample losses
"""
logits = outputs[0]
per_sample_loss = self.loss_fn(predictions=logits, targets=labels, reduction='none')
if self.weight is not None:
per_sample_loss *= self.weight[labels]
loss = torch.mean(per_sample_loss)
if self.tanh_regularisation != 0.0:
loss += self.tanh_regularisation * tanh_loss(logits)
return Loss(per_sample_loss, loss)
def run_epoch(self, dataloader: DataLoader, epoch: int, is_train: bool = False) -> None:
"""
Run a training or validation epoch of the base model trainer but also step the forget rate scheduler
:param dataloader: A dataloader object
:param epoch: Current epoch id.
:param is_train: Whether this is a training epoch or not
:param run_inference_on_training_set: If True, record all metrics using the train_trackers
(even if is_train = False)
:return:
"""
for model in self.models:
model.train() if is_train else model.eval()
for indices, images, labels in dataloader:
images, labels = images.to(self.device), labels.to(self.device)
outputs = self.forward(images, requires_grad=is_train)
embeddings = get_all_embeddings(self.all_model_cnn_embeddings)[0]
losses = self.compute_loss(outputs, labels)
if is_train:
self.step_optimizers([losses])
# Log training and validation stats in metric tracker
tracker = self.train_trackers[0] if is_train else self.val_trackers[0]
tracker.sample_metrics.append_batch(epoch, outputs[0].detach(), labels.detach(), losses.loss.item(),
indices.cpu().tolist(), losses.per_sample_loss.detach(), embeddings)

Просмотреть файл

@ -0,0 +1,185 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import random
from typing import Any, Callable, Tuple
import PIL
import PIL.Image
import numpy as np
import torch
import torchvision
from scipy.ndimage import gaussian_filter, map_coordinates
from skimage.filters import rank
from skimage.morphology import disk
from InnerEyeDataQuality.configs.config_node import ConfigNode
class BaseTransform:
def __init__(self, config: ConfigNode):
self.transform = lambda x: x
def __call__(self, data: PIL.Image.Image) -> PIL.Image.Image:
return self.transform(data)
class Standardize:
def __init__(self, mean: np.ndarray, std: np.ndarray):
self.mean = np.array(mean)
self.std = np.array(std)
def __call__(self, image: PIL.Image.Image) -> np.ndarray:
image = np.asarray(image).astype(np.float32) / 255.
image = (image - self.mean) / self.std
return image
class CenterCrop(BaseTransform):
def __init__(self, config: ConfigNode):
self.transform = torchvision.transforms.CenterCrop(config.preprocess.center_crop_size)
class RandomCrop(BaseTransform):
def __init__(self, config: ConfigNode):
self.transform = torchvision.transforms.RandomCrop(
config.dataset.image_size,
padding=config.augmentation.random_crop.padding,
fill=config.augmentation.random_crop.fill,
padding_mode=config.augmentation.random_crop.padding_mode)
class RandomResizeCrop(BaseTransform):
def __init__(self, config: ConfigNode):
self.transform = torchvision.transforms.RandomResizedCrop(
size=config.preprocess.resize,
scale=config.augmentation.random_crop.scale)
class RandomHorizontalFlip(BaseTransform):
def __init__(self, config: ConfigNode):
self.transform = torchvision.transforms.RandomHorizontalFlip(
config.augmentation.random_horizontal_flip.prob)
class RandomAffine(BaseTransform):
def __init__(self, config: ConfigNode):
self.transform = torchvision.transforms.RandomAffine(degrees=config.augmentation.random_affine.max_angle, # 15
translate=(
config.augmentation.random_affine.max_horizontal_shift,
config.augmentation.random_affine.max_vertical_shift),
shear=config.augmentation.random_affine.max_shear)
class Resize(BaseTransform):
def __init__(self, config: ConfigNode):
self.transform = torchvision.transforms.Resize(config.preprocess.resize)
class RandomColorJitter(BaseTransform):
def __init__(self, config: ConfigNode) -> None:
self.transform = torchvision.transforms.ColorJitter(brightness=config.augmentation.random_color.brightness,
contrast=config.augmentation.random_color.contrast,
saturation=config.augmentation.random_color.saturation)
class RandomErasing(BaseTransform):
def __init__(self, config: ConfigNode) -> None:
self.transform = torchvision.transforms.RandomErasing(p=0.5,
scale=config.augmentation.random_erasing.scale,
ratio=config.augmentation.random_erasing.ratio)
class RandomGamma(BaseTransform):
def __init__(self, config: ConfigNode) -> None:
self.min, self.max = config.augmentation.gamma.scale
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
gamma = random.uniform(self.min, self.max)
return torchvision.transforms.functional.adjust_gamma(image, gamma=gamma)
class HistogramNormalization:
def __init__(self, config: ConfigNode) -> None:
self.disk_size = config.preprocess.histogram_normalization.disk_size
def __call__(self, image: PIL.Image.Image) -> np.ndarray:
# Apply local histogram equalization
image = np.array(image)
return PIL.Image.fromarray(rank.equalize(image, selem=disk(self.disk_size)))
class ExpandChannels:
def __call__(self, data: torch.Tensor) -> torch.Tensor:
return torch.repeat_interleave(data, 3, dim=0)
class ToNumpy:
def __call__(self, image: PIL.Image.Image) -> np.ndarray:
return np.array(image)
class AddGaussianNoise:
def __init__(self, config: ConfigNode) -> None:
"""
Transformation to add Gaussian noise N(0, std) to
an image. Where std is set with the config.augmentation.gaussian_noise.std
argument. The transformation will be applied with probability
config.augmentation.gaussian_noise.p_apply
"""
self.std = config.augmentation.gaussian_noise.std
self.p_apply = config.augmentation.gaussian_noise.p_apply
def __call__(self, data: torch.Tensor) -> torch.Tensor:
if np.random.random(1) > self.p_apply:
return data
noise = torch.randn(size=data.shape) * self.std
data = torch.clamp(data + noise, 0, 1)
return data
class ElasticTransform:
"""Elastic deformation of images as described in [Simard2003]_.
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
Convolutional Neural Networks applied to Visual Document Analysis", in
Proc. of the International Conference on Document Analysis and
Recognition, 2003.
https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.160.8494&rep=rep1&type=pdf
:param sigma: elasticity coefficient
:param alpha: intensity of the deformation
:param p_apply: probability of applying the transformation
"""
def __init__(self, config: ConfigNode) -> None:
self.alpha = config.augmentation.elastic_transform.alpha
self.sigma = config.augmentation.elastic_transform.sigma
self.p_apply = config.augmentation.elastic_transform.p_apply
def __call__(self, image: PIL.Image) -> PIL.Image:
if np.random.random(1) > self.p_apply:
return image
image = np.asarray(image).squeeze()
assert len(image.shape) == 2
shape = image.shape
dx = gaussian_filter((np.random.random(shape) * 2 - 1), self.sigma, mode="constant", cval=0) * self.alpha
dy = gaussian_filter((np.random.random(shape) * 2 - 1), self.sigma, mode="constant", cval=0) * self.alpha
x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
indices = np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))
return PIL.Image.fromarray(map_coordinates(image, indices, order=1).reshape(shape))
class DualViewTransformWrapper:
def __init__(self, transforms: Callable):
self.transforms = transforms
def __call__(self, sample: PIL.Image.Image) -> Tuple[Any, Any]:
transform = self.transforms
xi = transform(sample)
xj = transform(sample)
return xi, xj

Просмотреть файл

@ -0,0 +1,154 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import collections
import logging
from importlib import import_module
from pathlib import Path
from typing import Optional, Union
import torch
from yacs.config import CfgNode
from default_paths import PROJECT_ROOT_DIR
from InnerEyeDataQuality.configs.config_node import ConfigNode
from InnerEyeDataQuality.configs import model_config
from InnerEyeDataQuality.configs.selector_config import get_default_selector_config
from InnerEyeDataQuality.deep_learning.self_supervised.configs import ssl_model_config
from InnerEyeDataQuality.deep_learning.self_supervised.utils import create_ssl_image_classifier
from InnerEyeDataQuality.utils.generic import setup_cudnn, get_train_output_dir
def get_run_config(config: ConfigNode, run_seed: int) -> ConfigNode:
config_run = config.clone()
config_run.defrost()
config_run.train.seed = run_seed
config_run.train.output_dir = get_train_output_dir(config)
config_run.freeze()
return config_run
def load_ssl_model_config(config_path: Path) -> ConfigNode:
'''
Loads configs required for self supervised learning. Does not setup cudann as this is being
taken care of by lightining.
'''
config = ssl_model_config.get_default_model_config()
config.merge_from_file(config_path)
update_model_config(config)
# Freeze config entries
config.freeze()
return config
def load_model_config(config_path: Path) -> ConfigNode:
'''
Loads configs required for model training and inference.
'''
config = model_config.get_default_model_config()
config.merge_from_file(config_path)
update_model_config(config)
setup_cudnn(config)
# Freeze config entries
config.freeze()
return config
def update_model_config(config: ConfigNode) -> ConfigNode:
'''
Adds dataset specific parameters in model config
'''
if config.dataset.name in ['CIFAR10', 'CIFAR100']:
dataset_dir = f'~/.torch/datasets/{config.dataset.name}'
config.dataset.dataset_dir = dataset_dir
config.dataset.image_size = 32
config.dataset.n_channels = 3
config.dataset.n_classes = int(config.dataset.name[5:])
elif config.dataset.name in ['MNIST']:
dataset_dir = '~/.torch/datasets'
config.dataset.dataset_dir = dataset_dir
config.dataset.image_size = 28
config.dataset.n_channels = 1
config.dataset.n_classes = 10
if not torch.cuda.is_available():
config.device = 'cpu'
return config
def override_config(source: ConfigNode, overrides: ConfigNode) -> ConfigNode:
'''
Overrides the keys and values present in node `overrides` into source object recursively.
'''
for key, value in overrides.items():
if isinstance(value, collections.Mapping) and value:
returned = override_config(source.get(key, {}), value) # type: ignore
returned = CfgNode(returned)
source[key] = returned
else:
source[key] = overrides[key]
return source
def load_selector_config(config_path: str) -> ConfigNode:
"""
Loads a selector config and merges with its model config
"""
selector_config = get_default_selector_config()
selector_config.merge_from_file(config_path)
model_config = load_model_config(PROJECT_ROOT_DIR / selector_config.selector.model_config_path)
merged_config = override_config(source=model_config, overrides=selector_config)
merged_config.freeze()
return merged_config
def create_logger(output_dir: Union[str, Path]) -> None:
if isinstance(output_dir, str):
output_dir = Path(output_dir)
log_path = output_dir.absolute() / 'training.log'
logging.basicConfig(filename=log_path,
filemode='w',
format='%(asctime)s %(name)-4s %(levelname)-6s %(message)s',
datefmt='%m-%d %H:%M',
level=logging.DEBUG)
# define a Handler which writes INFO messages or higher to the sys.stderr
console = logging.StreamHandler()
console.setLevel(logging.INFO)
# set a format which is simpler for console use
formatter = logging.Formatter('%(name)-4s: %(levelname)-6s %(message)s')
# tell the handler to use this format
console.setFormatter(formatter)
# add the handler to the root logger
logging.getLogger().addHandler(console)
def create_model(config: ConfigNode, model_id: Optional[int]) -> torch.nn.Module:
device = torch.device(config.device)
if config.train.use_self_supervision:
assert isinstance(model_id, int) and model_id < 2
model_checkpoint_name = config.train.self_supervision.checkpoints[model_id]
model_checkpoint_path = PROJECT_ROOT_DIR / model_checkpoint_name
model = create_ssl_image_classifier(num_classes=config.dataset.n_classes,
pl_checkpoint_path=str(model_checkpoint_path),
freeze_encoder=config.train.self_supervision.freeze_encoder)
else:
try:
module = import_module('PyTorchImageClassification.models'f'.{config.model.type}.{config.model.name}')
except ModuleNotFoundError:
module = import_module(
f'InnerEyeDataQuality.deep_learning.architectures.{config.model.type}.{config.model.name}')
model = getattr(module, 'Network')(config)
return model.to(device)

Просмотреть файл

@ -0,0 +1,77 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import numpy as np
def compute_label_entropy(label_counts: np.ndarray) -> np.ndarray:
"""
:param label_counts: Input label histogram (n_samples x n_classes)
"""
label_distribution = label_counts / np.sum(label_counts, axis=-1, keepdims=True)
return cross_entropy(label_distribution, label_distribution)
def compute_accuracy(source: np.ndarray, target: np.ndarray) -> float:
"""
Computes the agreement rate between two tensors
:param source: source array (n_samples, n_classes)
:param target: target array (n_samples, n_classes)
"""
# Compare single label case against all available labels
source = np.argmax(source, axis=1)
target = np.argmax(target, axis=1)
# Accuracy
acc = 100.0 * np.sum(source == target) / source.size
return acc
def compute_model_disagreement_score(posteriors: np.ndarray) -> np.ndarray:
"""
Measure model disagreement score (Ref BALD)
:param posteriors: numpy array (shape: (model_candidates, batch_size, num_classes))
:return: Disagreement score (BALD) for each sample (shape: (batch))
"""
def _entropy(x: np.ndarray, log_base: int, epsilon: float = 1e-12) -> np.ndarray:
return -np.sum(x * (np.log(x + epsilon) / np.log(log_base)), axis=-1)
num_classes = int(posteriors.shape[-1])
avg_posteriors = np.mean(posteriors, axis=0)
avg_entropy = _entropy(avg_posteriors, num_classes)
exp_conditional_entropy = np.mean(_entropy(posteriors, num_classes), axis=0)
bald_score = avg_entropy - exp_conditional_entropy
return bald_score
def cross_entropy(predicted_distribution: np.ndarray, target_distribution: np.ndarray) -> np.ndarray:
"""
Compute the normalised cross-entropy between the predicted and target distributions
:param predicted_distribution: Predicted distribution shape = (num_samples, num_classes)
:param target_distribution: Target distribution shape = (num_samples, num_classes)
:return: The cross-entropy for each sample
"""
num_classes = predicted_distribution.shape[1]
return -np.sum(target_distribution * np.log(predicted_distribution + 1e-12) / np.log(num_classes), axis=-1)
def max_prediction_error(predicted_distribution: np.ndarray, target_distribution: np.ndarray) -> np.ndarray:
"""
Compute the max (class-wise) prediction error between the predicted and target distributions
:param predicted_distribution: Predicted distribution shape = (num_samples, num_classes)
:param target_distribution: Target distribution shape = (num_samples, num_classes)
:return: The max (class-wise) prediction error for each sample
"""
current_target_class = np.argmax(target_distribution, axis=1)
current_target_pred_prob = predicted_distribution[range(len(current_target_class)), current_target_class]
prediction_errors = 1.0 - current_target_pred_prob
return prediction_errors
def total_variation(predicted_distribution: np.ndarray, target_distribution: np.ndarray) -> np.ndarray:
"""
Compute the total variation error between the predicted and target distributions
:param predicted_distribution: Predicted distribution shape = (num_samples, num_classes)
:param target_distribution: Target distribution shape = (num_samples, num_classes)
:return: The total variation for each sample
"""
return np.sum(np.abs(predicted_distribution - target_distribution), axis=-1) / 2.

Просмотреть файл

@ -0,0 +1,408 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
import re
import numpy as np
import pandas as pd
import seaborn as sns
from itertools import cycle
from matplotlib import pyplot as plt
from pathlib import Path
from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_auc_score, roc_curve
from typing import Any, Dict, List, Optional, Tuple, Union
from default_paths import MAIN_SIMULATION_DIR
from InnerEyeDataQuality.selection.simulation_statistics import SimulationStatsDistribution
from InnerEyeDataQuality.utils.generic import create_folder, save_obj
def _plot_mean_and_confidence_intervals(ax: plt.Axes, y: np.ndarray, color: str, label: str, linestyle: str = '-',
clip: Tuple[float, float] = (-np.inf, np.inf),
use_derivative: bool = False,
include_auc: bool = True,
linewidth: float = 3.0) -> None:
if use_derivative:
y = (y - y[:, 0][:, np.newaxis]) / (np.arange(y.shape[1]) + 1)
x = np.arange(0, y.shape[1])
mean = np.mean(y, axis=0)
total_auc = auc([0, np.max(x)], [100, 100])
auc_mean = auc(x, mean) / total_auc
label += f" - AUC: {auc_mean:.4f}" if include_auc else ""
std_error = np.std(y, axis=0) / np.sqrt(y.shape[0]) * 1.96
ax.plot(x, mean, color=color, label=label, linestyle=linestyle, linewidth=linewidth)
ax.fill_between(x,
np.clip((mean - std_error), a_min=clip[0], a_max=clip[1]),
np.clip((mean + std_error), a_min=clip[0], a_max=clip[1]),
color=color, alpha=.2)
def plot_stats_for_all_selectors(stats: Dict[str, SimulationStatsDistribution],
y_attr_names: List[str],
y_attr_labels: List[str],
title: str = '',
x_label: str = '',
y_label: str = '',
legend_loc: Union[int, str] = 2,
fontsize: int = 12,
figsize: Tuple[int, int] = (14, 10),
plot_n_not_ambiguous_noise_cases: bool = False,
**plot_kwargs: Any) -> Tuple[plt.Figure, plt.Axes]:
"""
Given the dictionary with SimulationStatsDistribution plot a curve for each attribute in the y_attr_names list for
each selector. Total number of curves will be num_selectors * num_y_attr_names.
:param stats: A dictionary where each entry corresponds to the SimulationStatsDistribution of each selector.
:param y_attr_names: The names of the attributes to plot on the y axis.
:param y_attr_labels: The labels for the legend of the attributes to plot on the y axis.
:param title: The title of the figure.
:param x_label: The title of the x-axis.
:param y_label: The title of the y-axis.
:param legend_loc: The location of the legend.
:param plot_n_not_ambiguous_noise_cases: If True, indicate the number of noise easy cases left at the end of the
simulation
for each selector. If all cases are selected indicate the corresponding iteration.
:return: The figure and axis with all the curves plotted.
"""
colors = ['blue', 'red', 'orange', 'brown', 'black', 'purple', 'green', 'gray', 'olive', 'cyan', 'yellow', 'pink']
linestyles = ['-', '--']
fig, ax = plt.subplots(figsize=figsize)
# Sort alphabetically
ordered_keys = sorted(stats.keys())
# Put always oracle and random first (this crashes if oracle / random not there)
ordered_keys.remove("Oracle")
ordered_keys.remove("Random")
ordered_keys.insert(0, "Oracle")
ordered_keys.insert(1, "Random")
for _stat_key, color in zip(ordered_keys, cycle(colors)):
_stat = stats[_stat_key]
if '(Posterior)' in _stat.name:
_stat.name = _stat.name.split('(Posterior)')[0]
for y_attr_name, y_attr_label, linestyle in zip(y_attr_names, y_attr_labels, cycle(linestyles)):
color = assign_preset_color(name=_stat.name, color=color)
_plot_mean_and_confidence_intervals(ax=ax, y=_stat.__getattribute__(y_attr_name), color=color,
label=y_attr_label + _stat.name, linestyle=linestyle, **plot_kwargs)
if plot_n_not_ambiguous_noise_cases:
average_value_attribute = np.mean(_stat.__getattribute__(y_attr_name), axis=0)
std_value_attribute = np.std(_stat.__getattribute__(y_attr_name), axis=0)
logging.debug(f"Method {_stat.name} - {y_attr_name} std: {std_value_attribute[-1]}")
n_average_mislabelled_not_ambiguous = np.mean(_stat.num_remaining_mislabelled_not_ambiguous, axis=0)
ax.text(average_value_attribute.size + 1, average_value_attribute[-1],
f"{average_value_attribute[-1]:.2f}", fontsize=fontsize - 4)
no_mislabelled_not_ambiguous_remaining = np.where(n_average_mislabelled_not_ambiguous == 0)[0]
if no_mislabelled_not_ambiguous_remaining.size > 0:
idx = np.min(no_mislabelled_not_ambiguous_remaining)
ax.scatter(idx + 1, average_value_attribute[idx], color=color, marker="s")
ax.grid()
ax.set_title(title, fontsize=fontsize)
ax.set_xlabel(x_label, fontsize=fontsize)
ax.set_ylabel(y_label, fontsize=fontsize)
ax.legend(loc=legend_loc, fontsize=fontsize - 2)
ax.xaxis.set_tick_params(labelsize=fontsize - 2)
ax.yaxis.set_tick_params(labelsize=fontsize - 2)
return fig, ax
def assign_preset_color(name: str, color: str) -> str:
if re.search('oracle', name, re.IGNORECASE):
return 'blue'
elif re.search('random', name, re.IGNORECASE):
return 'red'
elif re.search('bald', name, re.IGNORECASE):
if re.search('active', name, re.IGNORECASE):
return 'green'
return 'purple'
elif re.search('clean', name, re.IGNORECASE):
return 'pink'
elif re.search('self-supervision', name, re.IGNORECASE) \
or re.search('SSL', name, re.IGNORECASE) or re.search('pretrained', name, re.IGNORECASE):
return 'orange'
elif re.search('imagenet', name, re.IGNORECASE):
return 'green'
elif re.search('self-supervision', name, re.IGNORECASE):
if re.search('graph', name, re.IGNORECASE):
return 'green'
elif re.search('With entropy', name, re.IGNORECASE):
return 'olive'
elif re.search('active', name, re.IGNORECASE):
return 'cyan'
else:
return 'orange'
elif re.search('coteaching', name, re.IGNORECASE):
return 'brown'
elif re.search('less', name, re.IGNORECASE):
return 'grey'
elif re.search('vanilla', name, re.IGNORECASE):
if re.search('active', name, re.IGNORECASE):
return 'grey'
return 'black'
else:
return color
def plot_minimal_sampler(stats: Dict[str, SimulationStatsDistribution],
n_samples: int, ax: plt.Axes, fontsize: int = 12, legend_loc: int = 2) -> None:
"""
Starting with one initial label, a noisy sample needs to be relabeled at least twice (best case scenario
the sampled class during relabeling is equal to the correct label
"""
n_mislabelled_ambiguous = list(stats.values())[0].num_initial_mislabelled_ambiguous
n_mislabelled_not_ambiguous = list(stats.values())[0].num_initial_mislabelled_not_ambiguous
accuracy_beginning = 100 * float(n_samples - n_mislabelled_ambiguous - n_mislabelled_not_ambiguous) / n_samples
ax.plot([0, 2 * (n_mislabelled_not_ambiguous + n_mislabelled_ambiguous)], [accuracy_beginning, 100], linestyle="--",
label="Minimal sampler")
# accuracy_easy_cases = 100 * float(n_samples - n_mislabelled_ambiguous) / n_samples
# max_accuracy = max([np.max(_stat.accuracy) for _stat in stats.values()])
# plt.scatter(2 * n_mislabelled_not_ambiguous, accuracy_easy_cases,
# marker='s', label="No non-ambiguous noise cases left")
# ax.set_ylim(accuracy_beginning - 0.25, max_accuracy + 0.25)
ax.legend(loc=legend_loc, fontsize=fontsize - 2)
def plot_stats(stats: Dict[str, SimulationStatsDistribution],
dataset_name: str,
n_samples: int,
filename_suffix: str,
save_path: Optional[Path] = None,
sample_indices: Optional[List[int]] = None) -> None:
"""
:param stats:
:param dataset_name:
:param n_samples:
:param filename_suffix:
:param save_path:
:param sample_indices: Image indices used in the dataset to visualise the selected cases
"""
if save_path:
input_args = locals()
save_path = save_path / dataset_name
save_path.mkdir(exist_ok=True)
save_obj(input_args, save_path / "inputs_to_plot_stats.pkl")
noise_rate = 100 * float(
list(stats.values())[0].num_initial_mislabelled_ambiguous + list(stats.values())[
0].num_initial_mislabelled_not_ambiguous) / n_samples
# Label accuracy vs num relabels
fontsize, legend_loc = 20, 2
fig, ax = plot_stats_for_all_selectors(stats, ['accuracy'], [''],
title=f'Dataset Curation - {dataset_name} (N={n_samples}, '
f'{noise_rate:.1f}% noise)',
x_label='Number of collected relabels on the dataset',
y_label='Percentage of correct labels',
legend_loc=legend_loc,
fontsize=fontsize,
figsize=(11, 10),
plot_n_not_ambiguous_noise_cases=True)
if dataset_name != "NoisyChestXray":
plot_minimal_sampler(stats, n_samples, ax, fontsize=fontsize, legend_loc=legend_loc)
if save_path:
fig.savefig(save_path / f"simulation_label_accuracy_{filename_suffix}.pdf", bbox_inches='tight')
fig.savefig(save_path / f"simulation_label_accuracy_{filename_suffix}.png", bbox_inches='tight')
# Label accuracy vs num relabels - To Origin
fig, ax = plot_stats_for_all_selectors(stats, ['accuracy'], [''],
title=f'Dataset Curation - {dataset_name} (N={n_samples})',
x_label='Number of collected relabels on the dataset',
y_label='Percentage of correct labels',
legend_loc=1,
plot_n_not_ambiguous_noise_cases=False,
use_derivative=True)
if save_path:
fig.savefig(save_path / f"simulation_label_accuracy_to_origin_{filename_suffix}.png", bbox_inches='tight')
# Average total variation vs num relabels
fig, ax = plot_stats_for_all_selectors(stats, ['avg_total_variation'], [''],
title=f'Dataset Curation - {dataset_name} (N={n_samples})',
x_label='Number of collected relabels on the dataset',
y_label='Total Variation (Full vs Sampled Distributions)',
fontsize=fontsize,
include_auc=False,
figsize=(10, 11),
legend_loc=1)
if save_path:
fig.savefig(save_path / f"simulation_avg_total_variation_{filename_suffix}.png", bbox_inches='tight')
fig.savefig(save_path / f"simulation_avg_total_variation_{filename_suffix}.pdf", bbox_inches='tight')
# Remaining number of noisy cases vs num relabels
fig, ax = plot_stats_for_all_selectors(stats,
['num_remaining_mislabelled_not_ambiguous'], [''],
x_label='Number of collected relabels on the dataset',
y_label='# of remaining clear noisy samples',
fontsize=fontsize + 1,
figsize=(10, 10),
include_auc=False,
linewidth=5.0,
legend_loc="lower left")
if save_path:
fig.savefig(save_path / f"simulation_remaining_clear_mislabelled_{filename_suffix}.png", bbox_inches='tight')
fig.savefig(save_path / f"simulation_remaining_clear_mislabelled_{filename_suffix}.pdf", bbox_inches='tight')
# Remaining number ambiguous cases vs num relabels
fig, ax = plot_stats_for_all_selectors(stats,
['num_remaining_mislabelled_ambiguous'], [''],
x_label='Number of collected relabels on the dataset',
y_label='# of remaining difficult noisy samples',
fontsize=fontsize + 1,
include_auc=False,
figsize=(10, 10),
linewidth=5.0,
legend_loc="lower left")
if save_path:
fig.savefig(save_path / f"simulation_remaining_ambiguous_mislabelled_{filename_suffix}.png",
bbox_inches='tight')
fig.savefig(save_path / f"simulation_remaining_ambiguous_mislabelled_{filename_suffix}.pdf",
bbox_inches='tight')
def plot_roc_curve(scores: np.ndarray, labels: np.ndarray,
type_of_cases: str = "mislabelled",
ax: Optional[plt.Axes] = None,
color: str = "b",
legend: str = "AUC",
linestyle: str = "-") -> None:
fpr, tpr, threshold = roc_curve(labels, scores)
roc_auc = roc_auc_score(labels, scores)
if ax is None:
fig, ax = plt.subplots()
ax.plot(fpr, tpr, color=color, label=f"{legend}: {roc_auc:.2f}", linestyle=linestyle)
ax.set_ylabel("True positive rate")
ax.set_xlabel("False positive rate")
ax.set_title(f"ROC curve - {type_of_cases} detection")
ax.legend(loc="lower right")
ax.grid(b=True)
def plot_pr_curve(scores: np.ndarray, labels: np.ndarray,
type_of_cases: str = "mislabelled",
ax: Optional[plt.Axes] = None,
color: str = "b",
legend: str = "AUC",
linestyle: str = "-") -> None:
precision, recall, _ = precision_recall_curve(y_true=labels, probas_pred=scores)
pr_auc = auc(recall, precision)
if ax is None:
fig, ax = plt.subplots()
ax.plot(recall, precision, color=color, label=f"{legend}: {pr_auc:.2f}", linestyle=linestyle)
ax.set_xlim([0.0, 1.05])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.set_title(f"Precision-Recall curve - {type_of_cases} detection")
ax.legend(loc="lower right")
ax.grid(b=True)
def plot_binary_confusion(scores: np.ndarray, labels: np.ndarray, type_of_cases: str = "mislabelled",
ax: Optional[plt.Axes] = None) -> None:
if ax is None:
fig, ax = plt.subplots()
fpr, tpr, threshold = roc_curve(labels, scores)
optimal_threshold = threshold[np.argmax(tpr - fpr)]
prediction = scores > optimal_threshold
tn, fp, fn, tp = confusion_matrix(labels, prediction).ravel()
rates = tn / (tn + tp), fp / (tn + fp), fn / (tp + fn), tp / (tp + fn)
cf_matrix = confusion_matrix(labels, prediction)
group_names = ["True Neg", "False Pos", "False Neg", "True Pos"]
group_counts = [f"{value:.0f}" for value in cf_matrix.flatten()]
group_percentages = [f"{value:.3f}" for value in rates]
annotations = [f"{v1}\n{v2}\n{v3}" for v1, v2, v3 in zip(group_names, group_counts, group_percentages)]
annotations = np.asarray(annotations).reshape(2, 2)
sns.heatmap(cf_matrix, annot=annotations, fmt="", cmap=sns.light_palette("navy", reverse=True), ax=ax,
cbar=False)
ax.set_xlabel("Prediction")
ax.set_ylabel("Actual")
ax.set_title(f"Confusion matrix - {type_of_cases} cases\nThreshold = {optimal_threshold:.2f}")
def plot_stats_scores(selector_name: str, scores_mislabelled: np.ndarray, labels_mislabelled: np.ndarray,
scores_ambiguous: Optional[np.ndarray] = None, labels_ambiguous: Optional[np.ndarray] = None,
save_path: Optional[Path] = None) -> None:
if scores_ambiguous is None:
fig, ax = plt.subplots(3, 1, figsize=(5, 10))
else:
fig, ax = plt.subplots(3, 2, figsize=(10, 15))
fig.suptitle(selector_name)
ax = ax.ravel(order="F")
plot_roc_curve(scores_mislabelled, labels_mislabelled, type_of_cases="mislabelled", ax=ax[0])
plot_pr_curve(scores_mislabelled, labels_mislabelled, type_of_cases="mislabelled", ax=ax[1])
plot_binary_confusion(scores_mislabelled, labels_mislabelled, type_of_cases="mislabelled", ax=ax[2])
if scores_ambiguous is not None and labels_ambiguous is not None and np.sum(labels_ambiguous) != 0:
plot_roc_curve(scores_ambiguous, labels_ambiguous, type_of_cases="ambiguous", ax=ax[3])
plot_pr_curve(scores_ambiguous, labels_ambiguous, type_of_cases="ambiguous", ax=ax[4])
plot_binary_confusion(scores_ambiguous, labels_ambiguous, "ambiguous", ax[5])
# ambiguous detection given that is mislabelled
scores_ambiguous_given_mislabelled = scores_ambiguous[labels_mislabelled == 1]
labels_ambiguous_given_mislabelled = labels_ambiguous[labels_mislabelled == 1]
if roc_auc_score(labels_ambiguous_given_mislabelled, scores_ambiguous_given_mislabelled) < .5:
scores_ambiguous_given_mislabelled *= -1
plot_roc_curve(scores_ambiguous_given_mislabelled,
labels_ambiguous_given_mislabelled,
type_of_cases="ambiguous",
ax=ax[3],
color="red", legend="AUC given mislabelled=True", linestyle="--")
plot_pr_curve(scores_ambiguous_given_mislabelled,
labels_ambiguous_given_mislabelled,
type_of_cases="ambiguous",
ax=ax[4],
color="red", legend="AUC given mislabelled=True", linestyle="--")
if save_path:
plt.savefig(save_path / f"{selector_name}_stats_scoring.png", bbox_inches="tight")
plt.show()
def plot_relabeling_score(true_majority: np.ndarray,
starting_scores: np.ndarray,
current_majority: np.ndarray,
selector_name: str) -> None:
"""
Plots for ranking of samples score-wise
"""
create_folder(MAIN_SIMULATION_DIR / "scores_histogram")
create_folder(MAIN_SIMULATION_DIR / "noise_detection_heatmaps")
is_noisy = true_majority != current_majority
total_noise_rate = np.mean(is_noisy)
# Compute how many noisy sampled where present in the
# highest n_noisy samples.
target_perc = int((1 - total_noise_rate) * 100)
q = np.percentile(starting_scores, q=target_perc)
noisy_cases_detected = is_noisy[starting_scores > q]
percentage_noisy_detected = 100 * float(noisy_cases_detected.sum()) / is_noisy.sum()
# Plot histogram of scores differentiated by noisy or not.
df = pd.DataFrame({"scores": starting_scores,
"is_noisy": is_noisy})
plt.close()
sns.histplot(data=df, x="scores", hue="is_noisy", multiple="dodge", bins=10)
plt.title(f"Histogram of relabeling scores {selector_name}\n"
f"{percentage_noisy_detected:.1f}% noise cases > {target_perc}th percentile of scores")
plt.savefig(MAIN_SIMULATION_DIR / "scores_histogram" / selector_name)
plt.close()
# Plot heatmap showing where the noisy cases are located in the score ranking.
idx = np.argsort(starting_scores)
sorted_is_noisy = is_noisy[idx]
plt.title(f"{selector_name}\n"
f"Location of noisy cases by increasing scores (from left to right)\n"
f"{percentage_noisy_detected:.1f}% noise cases > {target_perc}th percentile of scores")
sns.heatmap(sorted_is_noisy.reshape(1, -1), yticklabels=False, vmax=1.3, cbar=False)
plt.savefig(MAIN_SIMULATION_DIR / "noise_detection_heatmaps" / selector_name)
plt.close()

Просмотреть файл

@ -0,0 +1,115 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import argparse
import copy
import datetime
import logging
import multiprocessing
import numpy as np
from itertools import product
from typing import Any, Dict, Tuple
from joblib import Parallel, delayed
from default_paths import MAIN_SIMULATION_DIR
from InnerEyeDataQuality.evaluation.plot_stats import plot_stats
from InnerEyeDataQuality.selection.data_curation_utils import get_user_specified_selectors, \
update_trainer_for_simulation
from InnerEyeDataQuality.selection.selectors.base import SampleSelector
from InnerEyeDataQuality.selection.selectors.label_based import (LabelBasedDecisionRule, LabelDistributionBasedSampler,
PosteriorBasedSelector)
from InnerEyeDataQuality.selection.selectors.bald import BaldSelector
from InnerEyeDataQuality.selection.selectors.random_selector import RandomSelector
from InnerEyeDataQuality.selection.simulation import DataCurationSimulator
from InnerEyeDataQuality.selection.simulation_statistics import SimulationStats, SimulationStatsDistribution
from InnerEyeDataQuality.utils.dataset_utils import load_dataset_and_initial_labels_for_simulation
from InnerEyeDataQuality.utils.generic import create_folder, get_data_selection_parser, get_logger, set_seed
EXP_OUTPUT_DIR = MAIN_SIMULATION_DIR / datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
def main(args: argparse.Namespace) -> None:
dataset, initial_labels = load_dataset_and_initial_labels_for_simulation(args.config[0], args.on_val_set)
n_classes = dataset.num_classes
n_samples = dataset.num_samples
true_distribution = dataset.label_distribution
user_specified_selectors = get_user_specified_selectors(list_configs=args.config,
dataset=dataset,
output_path=MAIN_SIMULATION_DIR,
plot_embeddings=args.plot_embeddings)
# Data selection simulations for annotation
default_selector = {
'Random': RandomSelector(n_samples, n_classes, name='Random'),
'Oracle': PosteriorBasedSelector(true_distribution.distribution, n_samples,
num_classes=n_classes,
name='Oracle',
allow_repeat_samples=True,
decision_rule=LabelBasedDecisionRule.INV)}
sample_selectors = {**user_specified_selectors, **default_selector}
# Benchmark 2
# Determine the number of simulation iterations based on the noise rate
expected_noise_rate = np.mean(np.argmax(true_distribution.distribution, -1) != dataset.targets[:n_samples])
relabel_budget = int(min(n_samples * expected_noise_rate, n_samples) * 0.35)
if dataset.name == "NoisyChestXray":
relabel_budget = min(int(n_samples * expected_noise_rate * 2.5), n_samples)
else:
relabel_budget = min(int(n_samples * expected_noise_rate * 3.0), n_samples)
logging.info(f"Expected noise rate {expected_noise_rate} - Allocated relabelling budget {relabel_budget}")
# Setup the simulation function.
def _run_simulation_for_selector(name: str,
seed: int,
sample_selector: SampleSelector) -> Tuple[str, SimulationStats]:
if isinstance(sample_selector, (LabelDistributionBasedSampler, BaldSelector)):
update_trainer_for_simulation(sample_selector, seed=seed)
simulator = DataCurationSimulator(initial_labels=copy.deepcopy(initial_labels),
label_distribution=copy.deepcopy(true_distribution),
relabel_budget=relabel_budget,
name=name,
seed=seed,
sample_selector=copy.deepcopy(sample_selector))
simulator.run_simulation()
if sample_selector.output_directory is not None:
simulator.save_simulator_results(MAIN_SIMULATION_DIR / sample_selector.output_directory / f"seed_{seed}")
return name, simulator.global_stats
# Run the simulation over multiple seeds and selectors
simulation_iter = product(sample_selectors.items(), args.seeds)
if args.debug:
parallel_output = [_run_simulation_for_selector(_name, _seed, _sel) for (_name, _sel), _seed in simulation_iter]
else:
num_jobs = min(len(sample_selectors) * len(args.seeds), multiprocessing.cpu_count())
parallel = Parallel(n_jobs=num_jobs)
parallel_output = parallel(delayed(_run_simulation_for_selector)(_name, _seed, _selector)
for (_name, _selector), _seed in simulation_iter)
# Aggregate parallel output arrays
global_stats: Dict[str, Any] = {name: list() for name in set([name for name, _ in parallel_output])}
[global_stats[name].append(stats) for name, stats in parallel_output]
# Analyse simulation stats
stats_dist = {name: SimulationStatsDistribution(stats) for name, stats in global_stats.items()}
plot_filename_suffix = "val" if args.on_val_set else "train"
plot_stats(stats_dist,
dataset_name=dataset.name,
n_samples=n_samples,
save_path=EXP_OUTPUT_DIR,
filename_suffix=plot_filename_suffix)
if __name__ == '__main__':
create_folder(EXP_OUTPUT_DIR)
get_logger(EXP_OUTPUT_DIR / 'dataread.log')
parser = get_data_selection_parser()
args, unknown_args = parser.parse_known_args()
set_seed(seed=12345)
main(args)

Просмотреть файл

@ -0,0 +1,146 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import argparse
import logging
from pathlib import Path
from typing import List, Union
import numpy as np
import pandas as pd
from default_paths import MAIN_SIMULATION_DIR, MODEL_SELECTION_BENCHMARK_DIR, FIGURE_DIR
from InnerEyeDataQuality.datasets.cifar10_utils import get_cifar10_label_names
from InnerEyeDataQuality.deep_learning.model_inference import inference_ensemble
from InnerEyeDataQuality.deep_learning.utils import create_logger, load_model_config, load_selector_config
from InnerEyeDataQuality.selection.selectors.label_based import cross_entropy
from InnerEyeDataQuality.selection.simulation import DataCurationSimulator
from InnerEyeDataQuality.selection.simulation_statistics import get_ambiguous_sample_ids
from InnerEyeDataQuality.utils.dataset_utils import get_datasets
from InnerEyeDataQuality.utils.generic import create_folder
from InnerEyeDataQuality.utils.plot import plot_confusion_matrix
def get_rank(array: Union[np.ndarray, List]) -> int:
"""
Returns the ranking of an array where the highest value has a rank of 1 and
the lowest value has the highest rank.
"""
return len(array) - np.argsort(array).argsort()
def main(args: argparse.Namespace) -> None:
# Parameters
number_of_runs = 1
evaluate_on_ambiguous_samples = False
# Create the evaluation dataset - Make sure that it's the same dataset for all configs
assert isinstance(args.config, list)
_, dataset = get_datasets(load_model_config(args.config[0]),
use_augmentation=False,
use_noisy_labels_for_validation=True,
use_fixed_labels=False if number_of_runs > 1 else True)
# Choose a subset of the dataset
if evaluate_on_ambiguous_samples:
ind = get_ambiguous_sample_ids(dataset.label_counts, threshold=0.10) # type: ignore
else:
ind = range(len(dataset)) # type: ignore
# If specified load curated dataset labels
curated_target_labels = dict()
for cfg_path in args.curated_label_config if args.curated_label_config else list():
cfg = load_selector_config(cfg_path)
search_dir = MAIN_SIMULATION_DIR / cfg.selector.output_directory
targets_list = list()
for _f in search_dir.glob('**/*.hdf'):
_label_counts = DataCurationSimulator.load_simulator_results(_f)
_targets = np.argmax(_label_counts, axis=1)
targets_list.append(_targets)
curated_target_labels[str(Path(cfg_path).stem)] = targets_list
# Define class labels for noisy, clean and curated datasets
df_rows_list = []
metric_names = ["accuracy", "top_n_accuracy", "cross_entropy", "accuracy_per_class"]
# Run the same experiment multiple time
for _run_id in range(number_of_runs):
target_labels = {"clean": dataset.clean_targets[ind], # type: ignore
"noisy": np.array([dataset.__getitem__(_i)[2] for _i in ind]),
**{_n: _l[_run_id] for _n, _l in curated_target_labels.items()}}
# Loops over different models
for config_id, config in enumerate([load_model_config(cfg) for cfg in args.config]):
posteriors = inference_ensemble(dataset, config)[1][ind]
# Collect metrics
for _label_name, _label in target_labels.items():
df_row = {"model": Path(args.config[config_id]).stem, "run_id": _run_id,
"dataset": _label_name, "count": _label.size}
for _metric_name in metric_names:
_val = benchmark_metrics(posteriors, observed_labels=_label, metric_name=_metric_name,
true_labels=target_labels["clean"])
df_row.update({_metric_name: _val}) # type: ignore
df_rows_list.append(df_row)
df = pd.DataFrame(df_rows_list)
df = df.sort_values(by=["dataset", "model"], axis=0)
logging.info(f"\n{df.to_string()}")
# Aggregate multiple runs
group_cols = ['model', 'dataset']
df_grouped = df.groupby(group_cols, as_index=False)['accuracy', 'count', 'cross_entropy'].agg([np.mean, np.std])
logging.info(f"\n{df_grouped.to_string()}")
# Plot the observed confusion matrix
plot_confusion_matrix(target_labels["clean"], target_labels["noisy"],
get_cifar10_label_names(), save_path=FIGURE_DIR)
def benchmark_metrics(posteriors: np.ndarray,
observed_labels: np.ndarray,
metric_name: str,
true_labels: np.ndarray) -> Union[float, List[float]]:
"""
Defines metrics to be used in model comparison.
"""
predictions = np.argmax(posteriors, axis=1)
# Accuracy averaged across all classes
if metric_name == "accuracy":
return np.mean(predictions == observed_labels) * 100.0
# Cross-entropy loss across all samples
elif metric_name == "top_n_accuracy":
N = 2
sorted_class_predictions = np.argsort(posteriors, axis=1)[:, ::-1]
correct = int(0)
for _i in range(observed_labels.size):
correct += np.any(sorted_class_predictions[_i, :N] == observed_labels[_i])
return correct * 100.0 / observed_labels.size
elif metric_name == "cross_entropy":
return np.mean(cross_entropy(posteriors, np.eye(10)[observed_labels]))
# Average accuracy per class - samples are groupped based on their true class label
elif metric_name == "accuracy_per_class":
vals = list()
for _class in np.unique(true_labels, return_counts=False):
mask = true_labels == _class
val = np.mean(predictions[mask] == observed_labels[mask]) * 100.0
vals.append(np.around(val, decimals=3))
return vals
else:
raise ValueError("Unknown metric")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Execute benchmark 3')
parser.add_argument('--config', dest='config', type=str, required=True, nargs='+',
help='Path to config file(s) characterising trained CNN model/s')
parser.add_argument('--curated-label-config', dest='curated_label_config', type=str, required=False, nargs='+',
help='Path to config file(s) corresponding to curated labels in adjudication simulation')
args, unknown_args = parser.parse_known_args()
create_folder(MODEL_SELECTION_BENCHMARK_DIR)
create_logger(MODEL_SELECTION_BENCHMARK_DIR)
main(args)

Просмотреть файл

@ -0,0 +1,5 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше